diff options
author | Nicolas Vasilache <ntv@google.com> | 2020-01-11 02:22:00 -0500 |
---|---|---|
committer | Nicolas Vasilache <ntv@google.com> | 2020-01-14 17:25:28 -0500 |
commit | f52d71736b10e87b1aa1880b777dc9462a0085ce (patch) | |
tree | 3eaa824f59037e0b987abd0c39094ec999e04c3c /mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp | |
parent | 8d07f8d98c48ee0a9dca450aaf4e1cabc621ff68 (diff) | |
download | bcm5719-llvm-f52d71736b10e87b1aa1880b777dc9462a0085ce.tar.gz bcm5719-llvm-f52d71736b10e87b1aa1880b777dc9462a0085ce.zip |
[mlir][Linalg] Update the semantics, verifier and test for Linalg with tensors.
Summary:
This diff fixes issues with the semantics of linalg.generic on tensors that appeared when converting directly from HLO to linalg.generic.
The changes are self-contained within MLIR and can be captured and tested independently of XLA.
The linalg.generic and indexed_generic are updated to:
To allow progressive lowering from the value world (a.k.a tensor values) to
the buffer world (a.k.a memref values), a linalg.generic op accepts
mixing input and output ranked tensor values with input and output memrefs.
```
%1 = linalg.generic #trait_attribute %A, %B {other-attributes} :
tensor<?x?xf32>,
memref<?x?xf32, stride_specification>
-> (tensor<?x?xf32>)
```
In this case, the number of outputs (args_out) must match the sum of (1) the
number of output buffer operands and (2) the number of tensor return values.
The semantics is that the linalg.indexed_generic op produces (i.e.
allocates and fills) its return values.
Tensor values must be legalized by a buffer allocation pass before most
transformations can be applied. Such legalization moves tensor return values
into output buffer operands and updates the region argument accordingly.
Transformations that create control-flow around linalg.indexed_generic
operations are not expected to mix with tensors because SSA values do not
escape naturally. Still, transformations and rewrites that take advantage of
tensor SSA values are expected to be useful and will be added in the near
future.
Subscribers: bmahjour, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72555
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index f5dac8aced1..2f97b62280a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -90,6 +90,8 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, CopyOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) { + assert(copyOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = @@ -98,7 +100,7 @@ public: permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end()); SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. // clang-format off @@ -112,11 +114,13 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, FillOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) { + assert(fillOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutput(0)); + IndexedValueType O(fillOp.getOutputBuffer(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. nPar > 0 ? O(ivs) = ValueHandle(fillOp.value()) @@ -128,10 +132,12 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, DotOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) { + assert(dotOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutput(0)); + C(dotOp.getOutputBuffer(0)); // Emit scalar form. C() = C() + A(r_i) * B(r_i); } @@ -142,10 +148,12 @@ class LinalgScopedEmitter<IndexedValueType, MatvecOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, MatvecOp matvecOp) { + assert(matvecOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutput(0)); + C(matvecOp.getOutputBuffer(0)); // Emit scalar form. C(i) = C(i) + A(i, r_j) * B(r_j); } @@ -156,10 +164,12 @@ class LinalgScopedEmitter<IndexedValueType, MatmulOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, MatmulOp matmulOp) { + assert(matmulOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutput(0)); + C(matmulOp.getOutputBuffer(0)); // Emit scalar form. C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); } @@ -169,6 +179,8 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, ConvOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); @@ -219,6 +231,8 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, GenericOp genericOp) { + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -237,7 +251,8 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); + indexedValues[nInputs + i] = + std_load(genericOp.getOutputBuffer(i), indexing); } auto funcOp = genericOp.getFunction(); @@ -250,7 +265,7 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing); } return; } @@ -273,8 +288,8 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), - indexing); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); } } }; @@ -314,6 +329,8 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, IndexedGenericOp indexedGenericOp) { + assert(indexedGenericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -339,7 +356,7 @@ public: ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = - std_load(indexedGenericOp.getOutput(i), indexing); + std_load(indexedGenericOp.getOutputBuffer(i), indexing); } if (auto funcOp = indexedGenericOp.getFunction()) { @@ -351,7 +368,7 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), + std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i), indexing); } return; @@ -376,7 +393,7 @@ public: ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), - indexedGenericOp.getOutput(i), indexing); + indexedGenericOp.getOutputBuffer(i), indexing); } } }; @@ -404,6 +421,8 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit( // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). auto linalgOp = cast<ConcreteOpTy>(op); + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto invertedMap = inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); if (!invertedMap) { |