diff options
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Passes.h | 11 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td | 11 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h | 41 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/CMakeLists.txt | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp (renamed from mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp) | 314 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/llvm.mlir | 2 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/loops.mlir | 2 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/transform-patterns.mlir | 8 | ||||
| -rw-r--r-- | mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td | 7 | ||||
| -rw-r--r-- | mlir/test/mlir-cpu-runner/linalg_integration_test.mlir | 6 | ||||
| -rw-r--r-- | mlir/test/mlir-cpu-runner/utils.mlir | 8 |
11 files changed, 275 insertions, 137 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 5ecd50070da..7ae3877f01e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -39,9 +39,16 @@ createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {}); std::unique_ptr<OpPassBase<FuncOp>> createLinalgPromotionPass(bool dynamicBuffers); -std::unique_ptr<OpPassBase<FuncOp>> createLowerLinalgToLoopsPass(); +/// Create a pass to convert Linalg operations to loop.for loops and +/// std.load/std.store accesses. +std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToLoopsPass(); -/// Create a pass to convert vector operations to the LLVMIR dialect. +/// Create a pass to convert Linalg operations to affine.for loops and +/// affine_load/affine_store accesses. +/// Placeholder for now, this is NYI. +std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToAffineLoopsPass(); + +/// Create a pass to convert Linalg operations to the LLVMIR dialect. std::unique_ptr<OpPassBase<ModuleOp>> createConvertLinalgToLLVMPass(); } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index d243bb23f2c..8bc0eaf2097 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -62,4 +62,15 @@ class TileLinalgOp<list<int> sizes, string value> : NativeCodeCall< StrJoinInt<sizes>.result # "}, \"" # value # "\")))" # " return matchFailure();">; +//===----------------------------------------------------------------------===// +// Linalg to loop patterns. +//===----------------------------------------------------------------------===// +class LinalgOpToLoops<string OpType> : NativeCodeCall< + "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " # + " return matchFailure();">; + +class LinalgOpToAffineLoops<string OpType> : NativeCodeCall< + "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " # + " return matchFailure();">; + #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index 56ae94f32c6..966b8f93135 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -35,20 +35,6 @@ struct LinalgTransforms { static const StringLiteral kLinalgTransformMarker; }; -// Declarative transformation used in tablegen patterns. -// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to -// `linalgMarker`. -LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, - ArrayRef<int64_t> sizes, - StringRef linalgMarker); - -// Declarative transformation used in tablegen patterns. -// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets -// the attribute `kLinalgTransformMarker` to `linalgMarker`. -LogicalResult tileAndFuseLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes, - ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker); - namespace detail { // Implementation detail of isProducedByOpOfType avoids the need for explicit // template instantiations. @@ -65,6 +51,33 @@ bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) { consumerOp, consumedView, [](Operation *op) { return isa<OpTy>(op); }); } +//////////////////////////////////////////////////////////////////////////////// +// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite +// patterns. As such, they must not call into `rewriter.erase/replace` APIs and +// it is the responsibility of the enclosing PatternRewriter to erase on +// success. +//////////////////////////////////////////////////////////////////////////////// + +// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to +// `linalgMarker`. +LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, + ArrayRef<int64_t> sizes, + StringRef linalgMarker); + +// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets +// the attribute `kLinalgTransformMarker` to `linalgMarker`. +LogicalResult tileAndFuseLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes, + ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker); + +// Emits a loop nest of `loop.for` with the proper body for `op`. +template <typename ConcreteOp> +LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op); + +// Emits a loop nest of `affine.for` with the proper body for `op`. +template <typename ConcreteOp> +LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/CMakeLists.txt b/mlir/lib/Dialect/Linalg/CMakeLists.txt index 4b7cd81be94..a4ce5038891 100644 --- a/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -5,7 +5,7 @@ add_llvm_library(MLIRLinalg IR/LinalgTypes.cpp Transforms/Fusion.cpp Transforms/LinalgTransforms.cpp - Transforms/LowerToLoops.cpp + Transforms/LinalgToLoops.cpp Transforms/Promotion.cpp Transforms/Tiling.cpp Utils/Utils.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index 0bf4ceaa33b..cf0b235f57f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" @@ -41,12 +42,14 @@ using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using IndexedStdValue = TemplatedIndexedValue<std_load, std_store>; +using IndexedAffineValue = TemplatedIndexedValue<affine_load, affine_store>; + using edsc::op::operator+; using edsc::op::operator==; static SmallVector<ValueHandle, 8> -foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, - ArrayRef<Value *> vals, OperationFolder *folder) { +makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, + ArrayRef<Value *> vals) { assert(map.getNumSymbols() == 0); assert(map.getNumInputs() == vals.size()); SmallVector<ValueHandle, 8> res; @@ -56,17 +59,16 @@ foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, auto exprMap = AffineMap::get(dims, 0, e); SmallVector<Value *, 4> operands(vals.begin(), vals.end()); canonicalizeMapAndOperands(&exprMap, &operands); - res.push_back(affine_apply(folder, exprMap, operands)); + res.push_back(affine_apply(exprMap, operands)); } return res; } static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs, - Optional<AffineMap> permutation, - OperationFolder *folder) { + Optional<AffineMap> permutation) { return permutation ? applyMapToValues(ScopedContext::getBuilder(), ScopedContext::getLocation(), - permutation.getValue(), ivs, folder) + permutation.getValue(), ivs) : SmallVector<Value *, 4>(ivs.begin(), ivs.end()); } @@ -75,20 +77,17 @@ static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs, // which new loops will be created. static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef<Value *> allViewSizes, - OperationFolder *folder); + ArrayRef<Value *> allViewSizes); SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef<Value *> allViewSizes, - OperationFolder *folder) { + ArrayRef<Value *> allViewSizes) { // Apply `map` to get view sizes in loop order. - auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder); + auto sizes = applyMapToValues(b, loc, map, allViewSizes); // Create a new range with the applied tile sizes. ScopedContext scope(b, loc); SmallVector<Value *, 4> res; for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { - res.push_back(range(constant_index(folder, 0), sizes[idx], - constant_index(folder, 1))); + res.push_back(range(constant_index(0), sizes[idx], constant_index(1))); } return res; } @@ -99,14 +98,14 @@ class LinalgScopedEmitter {}; template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, CopyOp> { public: - static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef<Value *> allIvs, + CopyOp copyOp) { auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder); + permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); auto outputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder); + permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end()); SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end()); IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); @@ -122,8 +121,8 @@ public: template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, FillOp> { public: - static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef<Value *> allIvs, + FillOp fillOp) { auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = @@ -139,8 +138,7 @@ public: template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, DotOp> { public: - static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp) { assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), @@ -154,8 +152,7 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, MatvecOp> { public: static void emitScalarImplementation(ArrayRef<Value *> allIvs, - MatvecOp matvecOp, - OperationFolder *folder) { + MatvecOp matvecOp) { assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), @@ -169,8 +166,7 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, MatmulOp> { public: static void emitScalarImplementation(ArrayRef<Value *> allIvs, - MatmulOp matmulOp, - OperationFolder *folder) { + MatmulOp matmulOp) { assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), @@ -183,17 +179,17 @@ public: template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, ConvOp> { public: - static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef<Value *> allIvs, + ConvOp convOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); SmallVector<ValueHandle, 8> fIdx( - foldedAffineApplies(b, loc, maps[0], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); SmallVector<ValueHandle, 8> imIdx( - foldedAffineApplies(b, loc, maps[1], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); SmallVector<ValueHandle, 8> oIdx( - foldedAffineApplies(b, loc, maps[2], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output()); // Emit scalar form. O(oIdx) += F(fIdx) * I(imIdx); @@ -234,8 +230,7 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, GenericOp> { public: static void emitScalarImplementation(ArrayRef<Value *> allIvs, - GenericOp genericOp, - OperationFolder *folder) { + GenericOp genericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -245,15 +240,15 @@ public: // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs)); indexedValues[i] = std_load(genericOp.getInput(i), indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); } @@ -265,8 +260,8 @@ public: // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); } return; @@ -288,8 +283,8 @@ public: auto *yieldOp = cast<YieldOp>(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), indexing); } @@ -330,8 +325,7 @@ template <typename IndexedValueType> class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> { public: static void emitScalarImplementation(ArrayRef<Value *> allIvs, - IndexedGenericOp indexedGenericOp, - OperationFolder *folder) { + IndexedGenericOp indexedGenericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -346,16 +340,16 @@ public: // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); indexedValues[nLoops + i] = std_load(indexedGenericOp.getInput(i), indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = std_load(indexedGenericOp.getOutput(i), indexing); } @@ -367,8 +361,8 @@ public: // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), indexing); } @@ -391,96 +385,110 @@ public: auto *yieldOp = cast<YieldOp>(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), indexedGenericOp.getOutput(i), indexing); } } }; +namespace { +// This struct is for factoring out the implementation and support template +// instantiations in the following 2 cases: +// 1. Appending to a list of patterns via RewritePatternList. +// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. +// The implementation must work both in DRR and inside a RewritePattern. As a +// consequence, (1) it is only allowed to emit new ops if the match is +// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an +// encompassing pattern must take care of the erasure logic. +template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy> +class LinalgOpToLoopsImpl { +public: + static LogicalResult doit(Operation *op, PatternRewriter &rewriter); +}; +} // namespace + +template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy> +LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit( + Operation *op, PatternRewriter &rewriter) { + OpBuilder b(op); + ScopedContext scope(b, op->getLoc()); + + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (which is asserted in the inverse calculation). + auto linalgOp = cast<ConcreteOpTy>(op); + auto invertedMap = + inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); + if (!invertedMap) { + LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation( + {}, linalgOp); + return success(); + } + + auto nPar = linalgOp.getNumParallelLoops(); + auto nRed = linalgOp.getNumReductionLoops(); + auto nWin = linalgOp.getNumWindowLoops(); + SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin); + SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs); + auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(), + invertedMap, getViewSizes(linalgOp)); + assert(loopRanges.size() == allIvs.size()); + + LoopNestRangeBuilder(allPIvs, loopRanges)([&] { + auto allIvValues = extractValues(allIvs); + LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation( + allIvValues, linalgOp); + }); + return success(); +} + template <typename LoopType, typename IndexedValueType, typename ConcreteOp> class LinalgRewritePattern : public RewritePattern { public: explicit LinalgRewritePattern(MLIRContext *context) - : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context), - folder(context) {} + : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - OpBuilder b(op); - ScopedContext scope(b, op->getLoc()); - - // The flattened loopToOperandRangesMaps is expected to be an invertible - // permutation map (which is asserted in the inverse calculation). - auto linalgOp = cast<ConcreteOp>(op); - auto invertedMap = - inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); - if (!invertedMap) { - LinalgScopedEmitter<IndexedValueType, - ConcreteOp>::emitScalarImplementation({}, linalgOp, - &folder); - rewriter.eraseOp(op); - return matchSuccess(); - } - - auto nPar = linalgOp.getNumParallelLoops(); - auto nRed = linalgOp.getNumReductionLoops(); - auto nWin = linalgOp.getNumWindowLoops(); - SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin); - SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs); - auto loopRanges = - emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap, - getViewSizes(linalgOp), &folder); - assert(loopRanges.size() == allIvs.size()); - - // clang-format off; - LoopNestRangeBuilder(allPIvs, loopRanges)([&] { - auto allIvValues = extractValues(allIvs); - LinalgScopedEmitter<IndexedValueType, - ConcreteOp>::emitScalarImplementation(allIvValues, - linalgOp, - &folder); - }); - // clang-format on + using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>; + if (failed(Impl::doit(op, rewriter))) + return matchFailure(); rewriter.eraseOp(op); return matchSuccess(); } - - mutable OperationFolder folder; }; // Helper classes for type list expansion. template <typename LoopType, typename IndexedValueType, typename... LinalgOps> -class ConversionList; +class RewritePatternList; template <typename LoopType, typename IndexedValueType> -class ConversionList<LoopType, IndexedValueType> { +class RewritePatternList<LoopType, IndexedValueType> { public: static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} }; template <typename LoopType, typename IndexedValueType, typename ConcreteOp, typename... LinalgOps> -class ConversionList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> { +class RewritePatternList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> { public: static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns .insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>( ctx); - ConversionList<LoopType, IndexedValueType, LinalgOps...>::build(patterns, - ctx); + RewritePatternList<LoopType, IndexedValueType, LinalgOps...>::build( + patterns, ctx); } }; /// Populate the given list with patterns that convert from Linalg to LLVM. template <typename LoopType, typename IndexedValueType> -void ForOpRewritePatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - ConversionList<LoopType, IndexedValueType, +void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { + RewritePatternList<LoopType, IndexedValueType, #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc" - >::build(patterns, ctx); + >::build(patterns, ctx); } namespace { @@ -491,28 +499,114 @@ struct LowerLinalgToLoopsPass }; } // namespace +// Local folding pattern for AffineApplyOp that we can apply greedily. +// This replaces AffineApplyOp by the proper value in cases where the associated +// map is trivial. A trivial map here is defined as a map with a single result +// and either: +// 1. Zero operand + returns a single AffineConstantExpr +// 2. One operand + returns a single AffineDimExpr +// 3. One operands + returns a single AffineSymbolExpr +// +// In the first case, the AffineApplyOp is replaced by a new constant. In the +// other cases, it is replaced by its unique operand. +struct FoldAffineOp : public RewritePattern { + FoldAffineOp(MLIRContext *context) + : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); + auto map = affineApplyOp.getAffineMap(); + if (map.getNumResults() != 1 || map.getNumInputs() > 1) + return matchFailure(); + + AffineExpr expr = map.getResult(0); + if (map.getNumInputs() == 0) { + if (auto val = expr.dyn_cast<AffineConstantExpr>()) { + rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); + return matchSuccess(); + } + return matchFailure(); + } + if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { + rewriter.replaceOp(op, op->getOperand(0)); + return matchSuccess(); + } + return matchFailure(); + } +}; + template <typename LoopType, typename IndexedValueType> void LowerLinalgToLoopsPass<LoopType, IndexedValueType>::runOnFunction() { + auto *context = &this->getContext(); OwningRewritePatternList patterns; - ForOpRewritePatterns<LoopType, IndexedValueType>(patterns, - &this->getContext()); - - ConversionTarget target(this->getContext()); - target.addLegalDialect<AffineOpsDialect>(); - target.addLegalDialect<loop::LoopOpsDialect>(); - target.addLegalDialect<StandardOpsDialect>(); - if (failed(applyPartialConversion(this->getFunction(), target, patterns))) { - this->signalPassFailure(); - } + // Canonicalization and folding patterns applied greedily allow cleaning up + // the emitted IR on the fly. + // TODO(ntv) fold view and subview ops? + FillRewritePatterns<LoopType, IndexedValueType>(patterns, context); + DimOp::getCanonicalizationPatterns(patterns, context); + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + patterns.insert<FoldAffineOp>(context); + // Just apply the patterns greedily. + applyPatternsGreedily(this->getFunction(), patterns); } +/// Create a pass to convert Linalg operations to loop.for loops and +/// std.load/std.store accesses. std::unique_ptr<OpPassBase<FuncOp>> -mlir::linalg::createLowerLinalgToLoopsPass() { +mlir::linalg::createConvertLinalgToLoopsPass() { return std::make_unique< LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>(); } +/// Create a pass to convert Linalg operations to affine.for loops and +/// affine_load/affine_store accesses. +/// Placeholder for now, this is NYI. +std::unique_ptr<OpPassBase<FuncOp>> +mlir::linalg::createConvertLinalgToAffineLoopsPass() { + return std::make_unique< + LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>>(); +} + +// Emits a loop nest of `loop.for` with the proper body for `op`. +template <typename ConcreteOp> +LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl<loop::ForOp, IndexedStdValue, ConcreteOp>::doit( + op, rewriter); +} + +// Emits a loop nest of `affine.for` with the proper body for `op`. +template <typename ConcreteOp> +LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl<AffineForOp, IndexedAffineValue, ConcreteOp>::doit( + op, rewriter); +} + +// TODO(ntv) Need to make these instantiations more future-proof to avoid the +// need to update as soon as we add new ops. +#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ + template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \ + PatternRewriter & rewriter, Operation * op); \ + template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \ + PatternRewriter & rewriter, Operation * op); + +INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) + static PassRegistration<LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>> structuredLoopsPass( - "linalg-lower-to-loops", + "convert-linalg-to-loops", "Lower the operations from the linalg dialect into loops"); + +static PassRegistration<LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>> + affineLoopsPass( + "convert-linalg-to-affine-loops", + "Lower the operations from the linalg dialect into affine loops"); diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir index 769141008bf..dd19d5d82cd 100644 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s -// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS func @range(%arg0: index) { %c0 = constant 0 : index diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index 93cf69fe3ca..9a1c91d09e0 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-lower-to-loops | FileCheck %s +// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s // Test that we can lower all the way to LLVM without crashing, don't check results here. // RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 2561cea04d3..d94342f52ca 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -17,7 +17,13 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>, // CHECK-DAG : %[[c8000:.*]] = constant 8000 : index // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { -// CHECK : linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<f32> +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] { +// CHECK : load +// CHECK : load +// CHECK : mulf +// CHECK : load +// CHECK : addf +// CHECK : store func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %x: memref<?xf32, offset: ?, strides: [1]>, diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index 839671c866a..97e0cb21704 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -73,4 +73,11 @@ def : Pattern<(DotOp:$op $a, $b, $c), [(TileLinalgOp<[8], "REG"> $op)], [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>; +//===----------------------------------------------------------------------===// +// Linalg to loops patterns. +//===----------------------------------------------------------------------===// +def : Pattern<(DotOp:$op $a, $b, $c), + [(LinalgOpToLoops<"DotOp"> $op)], + [(Constraint<HasLinalgTransformMarker<"REG">> $op)]>; + #endif // TEST_LINALG_TRANSFORMS_PATTERNS diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir index b153faea5e6..d1ee472850a 100644 --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -1,8 +1,8 @@ // RUN: mlir-opt %s -convert-linalg-to-llvm | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s // RUN: mlir-opt %s -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -convert-linalg-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s // RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s #strided1D = (d0) -> (d0) diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir index 798c53959e5..ed54b902683 100644 --- a/mlir/test/mlir-cpu-runner/utils.mlir +++ b/mlir/test/mlir-cpu-runner/utils.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D -// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D -// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D -// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e vector_splat_2d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-VECTOR-SPLAT-2D +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e vector_splat_2d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-VECTOR-SPLAT-2D func @print_0d() { %f = constant 2.00000e+00 : f32 |

