diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index f5dac8aced1..2f97b62280a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -90,6 +90,8 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, CopyOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) { + assert(copyOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = @@ -98,7 +100,7 @@ public: permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end()); SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. // clang-format off @@ -112,11 +114,13 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, FillOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) { + assert(fillOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutput(0)); + IndexedValueType O(fillOp.getOutputBuffer(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. nPar > 0 ? O(ivs) = ValueHandle(fillOp.value()) @@ -128,10 +132,12 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, DotOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) { + assert(dotOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutput(0)); + C(dotOp.getOutputBuffer(0)); // Emit scalar form. C() = C() + A(r_i) * B(r_i); } @@ -142,10 +148,12 @@ class LinalgScopedEmitter<IndexedValueType, MatvecOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, MatvecOp matvecOp) { + assert(matvecOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutput(0)); + C(matvecOp.getOutputBuffer(0)); // Emit scalar form. C(i) = C(i) + A(i, r_j) * B(r_j); } @@ -156,10 +164,12 @@ class LinalgScopedEmitter<IndexedValueType, MatmulOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, MatmulOp matmulOp) { + assert(matmulOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutput(0)); + C(matmulOp.getOutputBuffer(0)); // Emit scalar form. C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); } @@ -169,6 +179,8 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, ConvOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); @@ -219,6 +231,8 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, GenericOp genericOp) { + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -237,7 +251,8 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); + indexedValues[nInputs + i] = + std_load(genericOp.getOutputBuffer(i), indexing); } auto funcOp = genericOp.getFunction(); @@ -250,7 +265,7 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing); } return; } @@ -273,8 +288,8 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), - indexing); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); } } }; @@ -314,6 +329,8 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, IndexedGenericOp indexedGenericOp) { + assert(indexedGenericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -339,7 +356,7 @@ public: ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = - std_load(indexedGenericOp.getOutput(i), indexing); + std_load(indexedGenericOp.getOutputBuffer(i), indexing); } if (auto funcOp = indexedGenericOp.getFunction()) { @@ -351,7 +368,7 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), + std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i), indexing); } return; @@ -376,7 +393,7 @@ public: ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), - indexedGenericOp.getOutput(i), indexing); + indexedGenericOp.getOutputBuffer(i), indexing); } } }; @@ -404,6 +421,8 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit( // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). auto linalgOp = cast<ConcreteOpTy>(op); + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto invertedMap = inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); if (!invertedMap) { |