diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp | 10 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 187 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 31 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp | 43 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp | 14 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 11 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 9 |
7 files changed, 182 insertions, 123 deletions
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 144afa4c5e1..109a35b7611 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -139,7 +139,11 @@ LinalgDependenceGraph::getDependencesInto( } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - for (auto srcView : src.getOutputs()) { // W + assert(src.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(dst.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + for (auto srcView : src.getOutputBuffers()) { // W // RAW graph for (auto dstView : dst.getInputs()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAW @@ -149,7 +153,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } } // WAW graph - for (auto dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputBuffers()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAW addDependenceElem(DependenceType::WAW, LinalgOpView{src.getOperation(), srcView}, @@ -167,7 +171,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } } // WAR graph - for (auto dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputBuffers()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAR addDependenceElem(DependenceType::WAR, LinalgOpView{src.getOperation(), srcView}, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 384d24957f4..7850dd60be1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -112,19 +112,20 @@ template <typename GenericOpType> static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (block.getNumArguments() != nViews) - return op.emitOpError( - "expected number of block arguments to match number of views"); + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands) + return op.emitOpError("expected number of block arguments to match number " + "of operands"); - for (unsigned i = 0; i < nViews; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + auto nInputViews = op.getNumInputs(); + for (unsigned i = 0; i < nOperands; ++i) { auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i).getType()) return op.emitOpError("expected block argument ") - << i << " of the same type as elemental type of " + << (i + 1) << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -132,27 +133,28 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (block.getNumArguments() != nViews + nLoops) + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands + nLoops) return op.emitOpError( - "expected number of block arguments to match number of views + " + "expected number of block arguments to match number of operands + " "number of loops"); - for (unsigned i = 0; i < nLoops; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + for (unsigned i = 0; i < nLoops; ++i) if (!block.getArgument(i).getType().isIndex()) return op.emitOpError("expected block argument ") - << i << " to be of IndexType"; - } + << (i + 1) << " to be an index"; - for (unsigned i = 0; i < nViews; ++i) { + for (unsigned i = 0; i < nOperands; ++i) { unsigned memrefArgIndex = i + nLoops; auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(memrefArgIndex).getType()) return op.emitOpError("expected block argument ") - << memrefArgIndex << " of the same type as elemental type of " + << (memrefArgIndex + 1) + << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -160,70 +162,74 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { template <typename GenericOpType> static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType); +template <typename GenericOpType> +LogicalResult verifyFuncArgsGeneric(GenericOpType op, FunctionType funType) { + auto res = verifyFuncArgs(op, funType); + if (failed(res)) + return res; + + auto nInputs = op.getNumInputs(); + auto nOutputs = op.getNumOutputs(); + // linalg.generic output element types are exactly the function results. + for (unsigned idx = 0; idx < nOutputs; ++idx) { + ShapedType shapedType = op.getShapedType(nInputs + idx); + if (funType.getResult(idx) != shapedType.getElementType()) + return op.emitOpError("expected function result ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of output " << (idx + 1); + } + return success(); +} + template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (funType.getNumInputs() != nViews) - return op.emitOpError("expected fun arguments to match number of views"); - if (funType.getNumResults() != op.getNumOutputs()) + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands) return op.emitOpError( - "expected fun results to match number of output views"); - - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(idx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << idx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + "expected function arguments to match number of operands"); + if (funType.getNumResults() != op.getNumOutputs()) + return op.emitOpError("expected function results(") + << funType.getNumResults() << ") to match number of outputs(" + << op.getNumOutputs() << ")"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of operand " << (idx + 1); } + return success(); } template <> LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) { auto nLoops = op.getNumLoops(); - auto nInputViews = op.getNumInputs(); auto nOutputs = op.getNumOutputs(); - auto nViews = op.getNumInputsAndOutputs(); - if (funType.getNumInputs() != nViews + nLoops) - return op.emitOpError( - "expected fun arguments to match number of views + number of loops"); + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands + nLoops) + return op.emitOpError("expected function arguments to match number of " + "loops + number of operands"); if (funType.getNumResults() != nOutputs) return op.emitOpError( - "expected fun results to match number of output views"); - for (unsigned i = 0; i < nLoops; ++i) { + "expected function results to match number of outputs"); + for (unsigned i = 0; i < nLoops; ++i) if (!funType.getInput(i).isIndex()) - return op.emitOpError("expected fun argument ") - << i << " to be of IndexType"; - } - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto funIdx = nLoops + idx; - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(funIdx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << funIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + return op.emitOpError("expected function argument ") + << (i + 1) << " to be an index"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx + nLoops) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + nLoops + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of input " << (idx + 1); } + return success(); } @@ -231,9 +237,11 @@ template <typename GenericOpType> static LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (nViews != llvm::size(op.views())) - return op.emitOpError("expected exactly ") << nViews << " view operands"; + auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); + if (nInputsAndOutputBuffers != llvm::size(op.views())) + return op.emitOpError("expected exactly ") + << nInputsAndOutputBuffers + << " inputs (tensor or buffer) and output buffer operands"; auto ®ion = op.region(); auto funOp = op.getFunction(); @@ -246,8 +254,8 @@ static LogicalResult verifyGenericOp(GenericOpType op) { } else { if (!funOp || !funOp.getType()) return op.emitOpError( - "expected fun attribute to refer to a defined symbol"); - if (failed(verifyFuncArgs(op, funType))) + "expected function attribute to refer to a defined symbol"); + if (failed(verifyFuncArgsGeneric(op, funType))) return failure(); } @@ -287,22 +295,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) { return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); - auto outputTensorTypes = op.getOutputTensorTypes(); - if (outputTensorTypes.size() != op.getNumResults()) - return op.emitOpError("expected #output tensor operands (") - << outputTensorTypes.size() << ") to match #results (" - << op.getNumResults() << ")"; - - unsigned index = 0; - for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) { - auto resTy = std::get<0>(it); - auto outOpTy = std::get<1>(it); - if (resTy != outOpTy) - return op.emitOpError("result #") - << index << " must be " << outOpTy << ", but got " << resTy; - ++index; - } - return success(); } @@ -731,17 +723,20 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { template <typename GenericOpType> static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { // The operand number and types must match the view element types. - auto nOutputViews = genericOp.getNumOutputs(); - if (op.getNumOperands() != nOutputViews) - return op.emitOpError("expected ") - << nOutputViews << " operand to match enclosing linalg.generic op"; + auto nOutputs = genericOp.getNumOutputs(); + if (op.getNumOperands() != nOutputs) + return op.emitOpError("expected number of yield values (") + << nOutputs << ") to match the number of operands of the enclosing " + << "linalg.generic op (" << op.getNumOperands() << ")"; - for (unsigned i = 0; i != nOutputViews; ++i) { + for (unsigned i = 0; i != nOutputs; ++i) { auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) - return op.emitOpError("type of return operand ") - << i << " (" << op.getOperand(i).getType() - << ") doesn't match view element type (" << elementType << ")"; + return op.emitOpError("type of yield operand ") + << (i + 1) << " (" << op.getOperand(i).getType() + << ") doesn't match " + << "the element type of the enclosing linalg.generic op (" + << elementType << ")"; } return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 043d9c0e7cd..6ad73ee759e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -67,12 +67,13 @@ static llvm::cl::list<unsigned> clTileSizes( // to the `loopRanges` in order to obtain view ranges. static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, ArrayRef<SubViewOp::Range> loopRanges) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto maps = loopToOperandRangesMaps(op); SmallVector<Value, 8> clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector<Value, 8> ios(op.getInputsAndOutputs()); + SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; @@ -118,10 +119,11 @@ struct ViewDimension { // they must agree by construction (i.e. have the same size) and we just return // the first one. static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto maps = loopToOperandRangesMaps(op); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector<Value, 8> ios(op.getInputsAndOutputs()); + SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; @@ -144,6 +146,10 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, unsigned producerIdx, OperationFolder *folder) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto subView = dyn_cast_or_null<SubViewOp>( consumer.getInput(consumerIdx).getDefiningOp()); auto slice = @@ -197,6 +203,10 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, // Some of these will be lifted in the future with better analysis. static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, LinalgOp consumer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); return false; @@ -217,6 +227,10 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); // Make some simple structural checks that alleviate the need for more // complex analyses. if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { @@ -236,6 +250,10 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) return false; // Check for any fusion-preventing dependence to any view read/written that @@ -252,6 +270,8 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, Optional<FusionInfo> mlir::linalg::fuseProducerOf( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, const LinalgDependenceGraph &graph, OperationFolder *folder) { + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); LLVM_DEBUG(dbgs() << "\nStart examining consumer: " << *consumer.getOperation()); for (auto dependence : graph.getDependencesInto( @@ -268,7 +288,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf( // Consumer consumes this view, `isStructurallyFusableProducer` also checks // whether it is a strict subview of the producer view. auto producedView = dependence.dependentOpView.view; - auto producerIdx = producer.getIndexOfOutput(producedView).getValue(); + auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); // `consumerIdx` and `producerIdx` exist by construction. LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() << " view: " << producedView @@ -309,7 +329,10 @@ static void fuseLinalgOpsGreedily(FuncOp f) { // Save original Linalg ops, we only want to make a pass over those. SmallVector<Operation *, 8> linalgOps; - f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); + f.walk([&](LinalgOp op) { + if (op.hasBufferSemantics()) + linalgOps.push_back(op); + }); Aliases aliases; LinalgDependenceGraph G(aliases, linalgOps); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index f5dac8aced1..2f97b62280a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -90,6 +90,8 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, CopyOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) { + assert(copyOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = @@ -98,7 +100,7 @@ public: permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end()); SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. // clang-format off @@ -112,11 +114,13 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, FillOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) { + assert(fillOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutput(0)); + IndexedValueType O(fillOp.getOutputBuffer(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. nPar > 0 ? O(ivs) = ValueHandle(fillOp.value()) @@ -128,10 +132,12 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, DotOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) { + assert(dotOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutput(0)); + C(dotOp.getOutputBuffer(0)); // Emit scalar form. C() = C() + A(r_i) * B(r_i); } @@ -142,10 +148,12 @@ class LinalgScopedEmitter<IndexedValueType, MatvecOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, MatvecOp matvecOp) { + assert(matvecOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutput(0)); + C(matvecOp.getOutputBuffer(0)); // Emit scalar form. C(i) = C(i) + A(i, r_j) * B(r_j); } @@ -156,10 +164,12 @@ class LinalgScopedEmitter<IndexedValueType, MatmulOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, MatmulOp matmulOp) { + assert(matmulOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutput(0)); + C(matmulOp.getOutputBuffer(0)); // Emit scalar form. C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); } @@ -169,6 +179,8 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, ConvOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); @@ -219,6 +231,8 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, GenericOp genericOp) { + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -237,7 +251,8 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); + indexedValues[nInputs + i] = + std_load(genericOp.getOutputBuffer(i), indexing); } auto funcOp = genericOp.getFunction(); @@ -250,7 +265,7 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing); } return; } @@ -273,8 +288,8 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), - indexing); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); } } }; @@ -314,6 +329,8 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> { public: static void emitScalarImplementation(ArrayRef<Value> allIvs, IndexedGenericOp indexedGenericOp) { + assert(indexedGenericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -339,7 +356,7 @@ public: ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = - std_load(indexedGenericOp.getOutput(i), indexing); + std_load(indexedGenericOp.getOutputBuffer(i), indexing); } if (auto funcOp = indexedGenericOp.getFunction()) { @@ -351,7 +368,7 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), + std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i), indexing); } return; @@ -376,7 +393,7 @@ public: ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), - indexedGenericOp.getOutput(i), indexing); + indexedGenericOp.getOutputBuffer(i), indexing); } } }; @@ -404,6 +421,8 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit( // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). auto linalgOp = cast<ConcreteOpTy>(op); + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto invertedMap = inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); if (!invertedMap) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 9657daf9137..10c537ebd29 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -93,6 +93,8 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl( Operation *consumerOp, Value consumedView, function_ref<bool(Operation *)> isaOpType) { LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (!consumer) return false; @@ -171,7 +173,7 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) { return false; return true; }; - if (!llvm::all_of(genericOp.getInputsAndOutputs(), + if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), isStaticMemRefWithIdentityLayout)) return failure(); return success(); @@ -188,6 +190,8 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, "DRR failure case must be a precondition"); auto genericOp = cast<linalg::GenericOp>(op); + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); edsc::ScopedContext scope(rewriter, op->getLoc()); using edsc::intrinsics::std_load; using edsc::intrinsics::std_store; @@ -195,7 +199,7 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>; auto vA = std_load(vector_type_cast(genericOp.getInput(0))); auto vB = std_load(vector_type_cast(genericOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0)); + auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0)); auto vC = std_load(vectorMemRefC); auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), genericOp.iterator_types()); @@ -262,7 +266,7 @@ LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { // Transformation applies to buffers only. if (!linOp || !linOp.hasBufferSemantics()) return failure(); - if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) { + if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { return isa_and_nonnull<SubViewOp>(v.getDefiningOp()); })) return failure(); @@ -279,8 +283,10 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, "DRR failure case must be a precondition"); LinalgOp linOp = cast<LinalgOp>(op); + assert(linOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); SetVector<Value> subViews; - for (auto it : linOp.getInputsAndOutputs()) + for (auto it : linOp.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index eb605699890..3caa8c8d1f0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -155,6 +155,8 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, SetVector<Value> subViews, bool dynamicBuffers, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + // 1. Promote the specified views and use them in the new op. ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( @@ -164,7 +166,7 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, SmallVector<std::pair<Value, Value>, 8> writebackViews; writebackViews.reserve(subViews.size()); unsigned promotedIdx = 0; - for (auto view : op.getInputsAndOutputs()) { + for (auto view : op.getInputsAndOutputBuffers()) { if (subViews.count(view) != 0) { opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); writebackViews.emplace_back(std::make_pair( @@ -187,7 +189,7 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, // WARNING: MUST use the old op to determine whether the operand view is an // output. bool isOutput = - op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); + op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue(); if (isOutput) copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first); } @@ -203,11 +205,14 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { SmallVector<LinalgOp, 8> toErase; OperationFolder folder(f.getContext()); f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { + if (!op.hasBufferSemantics()) + return; + // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. SetVector<Value> subViews; OpBuilder b(op); - for (auto it : op.getInputsAndOutputs()) + for (auto it : op.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index bcf6576abbb..ed9553d710d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -173,6 +173,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> { static void transformIndexedGenericOpIndices( OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation()); if (!indexedGenericOp) return; @@ -232,6 +233,8 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, ArrayRef<Value> ivs, ArrayRef<Value> tileSizes, ArrayRef<Value> viewSizes, OperationFolder *folder) { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(ivs.size() == static_cast<size_t>(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), [](Value v) { return !isZero(v); })) && @@ -254,7 +257,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, SmallVector<Value, 4> res; res.reserve(op->getNumOperands()); - auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin(); + auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin(); for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); ++viewIndex) { Value view = *(viewIteratorBegin + viewIndex); @@ -309,6 +312,7 @@ Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes, ArrayRef<unsigned> permutation, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of // adjusting affine maps to account for missing dimensions. @@ -383,6 +387,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes, Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes, ArrayRef<unsigned> permutation, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (tileSizes.empty()) return llvm::None; @@ -419,6 +424,8 @@ static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) { OpBuilder b(f); OperationFolder folder(f.getContext()); f.walk([tileSizes, &b, &folder](LinalgOp op) { + if (!op.hasBufferSemantics()) + return; auto opLoopsPair = tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder); // If tiling occurred successfully, erase old op. |