summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Passes.h11
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td11
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h41
-rw-r--r--mlir/lib/Dialect/Linalg/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp (renamed from mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp)314
-rw-r--r--mlir/test/Dialect/Linalg/llvm.mlir2
-rw-r--r--mlir/test/Dialect/Linalg/loops.mlir2
-rw-r--r--mlir/test/Dialect/Linalg/transform-patterns.mlir8
-rw-r--r--mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td7
-rw-r--r--mlir/test/mlir-cpu-runner/linalg_integration_test.mlir6
-rw-r--r--mlir/test/mlir-cpu-runner/utils.mlir8
11 files changed, 275 insertions, 137 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 5ecd50070da..7ae3877f01e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -39,9 +39,16 @@ createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
std::unique_ptr<OpPassBase<FuncOp>>
createLinalgPromotionPass(bool dynamicBuffers);
-std::unique_ptr<OpPassBase<FuncOp>> createLowerLinalgToLoopsPass();
+/// Create a pass to convert Linalg operations to loop.for loops and
+/// std.load/std.store accesses.
+std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToLoopsPass();
-/// Create a pass to convert vector operations to the LLVMIR dialect.
+/// Create a pass to convert Linalg operations to affine.for loops and
+/// affine_load/affine_store accesses.
+/// Placeholder for now, this is NYI.
+std::unique_ptr<OpPassBase<FuncOp>> createConvertLinalgToAffineLoopsPass();
+
+/// Create a pass to convert Linalg operations to the LLVMIR dialect.
std::unique_ptr<OpPassBase<ModuleOp>> createConvertLinalgToLLVMPass();
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
index d243bb23f2c..8bc0eaf2097 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
@@ -62,4 +62,15 @@ class TileLinalgOp<list<int> sizes, string value> : NativeCodeCall<
StrJoinInt<sizes>.result # "}, \"" # value # "\")))" #
" return matchFailure();">;
+//===----------------------------------------------------------------------===//
+// Linalg to loop patterns.
+//===----------------------------------------------------------------------===//
+class LinalgOpToLoops<string OpType> : NativeCodeCall<
+ "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " #
+ " return matchFailure();">;
+
+class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
+ "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
+ " return matchFailure();">;
+
#endif // LINALG_TRANSFORMS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
index 56ae94f32c6..966b8f93135 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
@@ -35,20 +35,6 @@ struct LinalgTransforms {
static const StringLiteral kLinalgTransformMarker;
};
-// Declarative transformation used in tablegen patterns.
-// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to
-// `linalgMarker`.
-LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
- ArrayRef<int64_t> sizes,
- StringRef linalgMarker);
-
-// Declarative transformation used in tablegen patterns.
-// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets
-// the attribute `kLinalgTransformMarker` to `linalgMarker`.
-LogicalResult tileAndFuseLinalgOpAndSetMarker(
- PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
-
namespace detail {
// Implementation detail of isProducedByOpOfType avoids the need for explicit
// template instantiations.
@@ -65,6 +51,33 @@ bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) {
consumerOp, consumedView, [](Operation *op) { return isa<OpTy>(op); });
}
+////////////////////////////////////////////////////////////////////////////////
+// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite
+// patterns. As such, they must not call into `rewriter.erase/replace` APIs and
+// it is the responsibility of the enclosing PatternRewriter to erase on
+// success.
+////////////////////////////////////////////////////////////////////////////////
+
+// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to
+// `linalgMarker`.
+LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
+ ArrayRef<int64_t> sizes,
+ StringRef linalgMarker);
+
+// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets
+// the attribute `kLinalgTransformMarker` to `linalgMarker`.
+LogicalResult tileAndFuseLinalgOpAndSetMarker(
+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
+
+// Emits a loop nest of `loop.for` with the proper body for `op`.
+template <typename ConcreteOp>
+LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
+
+// Emits a loop nest of `affine.for` with the proper body for `op`.
+template <typename ConcreteOp>
+LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/CMakeLists.txt b/mlir/lib/Dialect/Linalg/CMakeLists.txt
index 4b7cd81be94..a4ce5038891 100644
--- a/mlir/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/CMakeLists.txt
@@ -5,7 +5,7 @@ add_llvm_library(MLIRLinalg
IR/LinalgTypes.cpp
Transforms/Fusion.cpp
Transforms/LinalgTransforms.cpp
- Transforms/LowerToLoops.cpp
+ Transforms/LinalgToLoops.cpp
Transforms/Promotion.cpp
Transforms/Tiling.cpp
Utils/Utils.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 0bf4ceaa33b..cf0b235f57f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
@@ -41,12 +42,14 @@ using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
using IndexedStdValue = TemplatedIndexedValue<std_load, std_store>;
+using IndexedAffineValue = TemplatedIndexedValue<affine_load, affine_store>;
+
using edsc::op::operator+;
using edsc::op::operator==;
static SmallVector<ValueHandle, 8>
-foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
- ArrayRef<Value *> vals, OperationFolder *folder) {
+makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
+ ArrayRef<Value *> vals) {
assert(map.getNumSymbols() == 0);
assert(map.getNumInputs() == vals.size());
SmallVector<ValueHandle, 8> res;
@@ -56,17 +59,16 @@ foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
auto exprMap = AffineMap::get(dims, 0, e);
SmallVector<Value *, 4> operands(vals.begin(), vals.end());
canonicalizeMapAndOperands(&exprMap, &operands);
- res.push_back(affine_apply(folder, exprMap, operands));
+ res.push_back(affine_apply(exprMap, operands));
}
return res;
}
static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
- Optional<AffineMap> permutation,
- OperationFolder *folder) {
+ Optional<AffineMap> permutation) {
return permutation ? applyMapToValues(ScopedContext::getBuilder(),
ScopedContext::getLocation(),
- permutation.getValue(), ivs, folder)
+ permutation.getValue(), ivs)
: SmallVector<Value *, 4>(ivs.begin(), ivs.end());
}
@@ -75,20 +77,17 @@ static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
// which new loops will be created.
static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
- ArrayRef<Value *> allViewSizes,
- OperationFolder *folder);
+ ArrayRef<Value *> allViewSizes);
SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
- ArrayRef<Value *> allViewSizes,
- OperationFolder *folder) {
+ ArrayRef<Value *> allViewSizes) {
// Apply `map` to get view sizes in loop order.
- auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder);
+ auto sizes = applyMapToValues(b, loc, map, allViewSizes);
// Create a new range with the applied tile sizes.
ScopedContext scope(b, loc);
SmallVector<Value *, 4> res;
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
- res.push_back(range(constant_index(folder, 0), sizes[idx],
- constant_index(folder, 1)));
+ res.push_back(range(constant_index(0), sizes[idx], constant_index(1)));
}
return res;
}
@@ -99,14 +98,14 @@ class LinalgScopedEmitter {};
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, CopyOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp,
- OperationFolder *folder) {
+ static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+ CopyOp copyOp) {
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
- permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder);
+ permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
auto outputIvs =
- permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder);
+ permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0));
@@ -122,8 +121,8 @@ public:
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, FillOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp,
- OperationFolder *folder) {
+ static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+ FillOp fillOp) {
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs =
@@ -139,8 +138,7 @@ public:
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, DotOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp,
- OperationFolder *folder) {
+ static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp) {
assert(allIvs.size() == 1);
IndexHandle r_i(allIvs[0]);
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
@@ -154,8 +152,7 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
- MatvecOp matvecOp,
- OperationFolder *folder) {
+ MatvecOp matvecOp) {
assert(allIvs.size() == 2);
IndexHandle i(allIvs[0]), r_j(allIvs[1]);
IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
@@ -169,8 +166,7 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
- MatmulOp matmulOp,
- OperationFolder *folder) {
+ MatmulOp matmulOp) {
assert(allIvs.size() == 3);
IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
@@ -183,17 +179,17 @@ public:
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp,
- OperationFolder *folder) {
+ static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+ ConvOp convOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
auto maps = loopToOperandRangesMaps(convOp);
SmallVector<ValueHandle, 8> fIdx(
- foldedAffineApplies(b, loc, maps[0], allIvs, folder));
+ makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
SmallVector<ValueHandle, 8> imIdx(
- foldedAffineApplies(b, loc, maps[1], allIvs, folder));
+ makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
SmallVector<ValueHandle, 8> oIdx(
- foldedAffineApplies(b, loc, maps[2], allIvs, folder));
+ makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
// Emit scalar form.
O(oIdx) += F(fIdx) * I(imIdx);
@@ -234,8 +230,7 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, GenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
- GenericOp genericOp,
- OperationFolder *folder) {
+ GenericOp genericOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@@ -245,15 +240,15 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getInputIndexingMap(i), allIvs));
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
}
@@ -265,8 +260,8 @@ public:
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
}
return;
@@ -288,8 +283,8 @@ public:
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
indexing);
}
@@ -330,8 +325,7 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
- IndexedGenericOp indexedGenericOp,
- OperationFolder *folder) {
+ IndexedGenericOp indexedGenericOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@@ -346,16 +340,16 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs, folder));
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
indexedValues[nLoops + i] =
std_load(indexedGenericOp.getInput(i), indexing);
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] =
std_load(indexedGenericOp.getOutput(i), indexing);
}
@@ -367,8 +361,8 @@ public:
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
indexing);
}
@@ -391,96 +385,110 @@ public:
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
indexedGenericOp.getOutput(i), indexing);
}
}
};
+namespace {
+// This struct is for factoring out the implementation and support template
+// instantiations in the following 2 cases:
+// 1. Appending to a list of patterns via RewritePatternList.
+// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
+// The implementation must work both in DRR and inside a RewritePattern. As a
+// consequence, (1) it is only allowed to emit new ops if the match is
+// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
+// encompassing pattern must take care of the erasure logic.
+template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
+class LinalgOpToLoopsImpl {
+public:
+ static LogicalResult doit(Operation *op, PatternRewriter &rewriter);
+};
+} // namespace
+
+template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
+LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
+ Operation *op, PatternRewriter &rewriter) {
+ OpBuilder b(op);
+ ScopedContext scope(b, op->getLoc());
+
+ // The flattened loopToOperandRangesMaps is expected to be an invertible
+ // permutation map (which is asserted in the inverse calculation).
+ auto linalgOp = cast<ConcreteOpTy>(op);
+ auto invertedMap =
+ inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
+ if (!invertedMap) {
+ LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
+ {}, linalgOp);
+ return success();
+ }
+
+ auto nPar = linalgOp.getNumParallelLoops();
+ auto nRed = linalgOp.getNumReductionLoops();
+ auto nWin = linalgOp.getNumWindowLoops();
+ SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
+ SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
+ auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(),
+ invertedMap, getViewSizes(linalgOp));
+ assert(loopRanges.size() == allIvs.size());
+
+ LoopNestRangeBuilder(allPIvs, loopRanges)([&] {
+ auto allIvValues = extractValues(allIvs);
+ LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
+ allIvValues, linalgOp);
+ });
+ return success();
+}
+
template <typename LoopType, typename IndexedValueType, typename ConcreteOp>
class LinalgRewritePattern : public RewritePattern {
public:
explicit LinalgRewritePattern(MLIRContext *context)
- : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context),
- folder(context) {}
+ : RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- OpBuilder b(op);
- ScopedContext scope(b, op->getLoc());
-
- // The flattened loopToOperandRangesMaps is expected to be an invertible
- // permutation map (which is asserted in the inverse calculation).
- auto linalgOp = cast<ConcreteOp>(op);
- auto invertedMap =
- inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
- if (!invertedMap) {
- LinalgScopedEmitter<IndexedValueType,
- ConcreteOp>::emitScalarImplementation({}, linalgOp,
- &folder);
- rewriter.eraseOp(op);
- return matchSuccess();
- }
-
- auto nPar = linalgOp.getNumParallelLoops();
- auto nRed = linalgOp.getNumReductionLoops();
- auto nWin = linalgOp.getNumWindowLoops();
- SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
- SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
- auto loopRanges =
- emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
- getViewSizes(linalgOp), &folder);
- assert(loopRanges.size() == allIvs.size());
-
- // clang-format off;
- LoopNestRangeBuilder(allPIvs, loopRanges)([&] {
- auto allIvValues = extractValues(allIvs);
- LinalgScopedEmitter<IndexedValueType,
- ConcreteOp>::emitScalarImplementation(allIvValues,
- linalgOp,
- &folder);
- });
- // clang-format on
+ using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
+ if (failed(Impl::doit(op, rewriter)))
+ return matchFailure();
rewriter.eraseOp(op);
return matchSuccess();
}
-
- mutable OperationFolder folder;
};
// Helper classes for type list expansion.
template <typename LoopType, typename IndexedValueType, typename... LinalgOps>
-class ConversionList;
+class RewritePatternList;
template <typename LoopType, typename IndexedValueType>
-class ConversionList<LoopType, IndexedValueType> {
+class RewritePatternList<LoopType, IndexedValueType> {
public:
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
};
template <typename LoopType, typename IndexedValueType, typename ConcreteOp,
typename... LinalgOps>
-class ConversionList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
+class RewritePatternList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
public:
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns
.insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>(
ctx);
- ConversionList<LoopType, IndexedValueType, LinalgOps...>::build(patterns,
- ctx);
+ RewritePatternList<LoopType, IndexedValueType, LinalgOps...>::build(
+ patterns, ctx);
}
};
/// Populate the given list with patterns that convert from Linalg to LLVM.
template <typename LoopType, typename IndexedValueType>
-void ForOpRewritePatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- ConversionList<LoopType, IndexedValueType,
+void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ RewritePatternList<LoopType, IndexedValueType,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
- >::build(patterns, ctx);
+ >::build(patterns, ctx);
}
namespace {
@@ -491,28 +499,114 @@ struct LowerLinalgToLoopsPass
};
} // namespace
+// Local folding pattern for AffineApplyOp that we can apply greedily.
+// This replaces AffineApplyOp by the proper value in cases where the associated
+// map is trivial. A trivial map here is defined as a map with a single result
+// and either:
+// 1. Zero operand + returns a single AffineConstantExpr
+// 2. One operand + returns a single AffineDimExpr
+// 3. One operands + returns a single AffineSymbolExpr
+//
+// In the first case, the AffineApplyOp is replaced by a new constant. In the
+// other cases, it is replaced by its unique operand.
+struct FoldAffineOp : public RewritePattern {
+ FoldAffineOp(MLIRContext *context)
+ : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
+ auto map = affineApplyOp.getAffineMap();
+ if (map.getNumResults() != 1 || map.getNumInputs() > 1)
+ return matchFailure();
+
+ AffineExpr expr = map.getResult(0);
+ if (map.getNumInputs() == 0) {
+ if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue());
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+ if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
+ rewriter.replaceOp(op, op->getOperand(0));
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+
template <typename LoopType, typename IndexedValueType>
void LowerLinalgToLoopsPass<LoopType, IndexedValueType>::runOnFunction() {
+ auto *context = &this->getContext();
OwningRewritePatternList patterns;
- ForOpRewritePatterns<LoopType, IndexedValueType>(patterns,
- &this->getContext());
-
- ConversionTarget target(this->getContext());
- target.addLegalDialect<AffineOpsDialect>();
- target.addLegalDialect<loop::LoopOpsDialect>();
- target.addLegalDialect<StandardOpsDialect>();
- if (failed(applyPartialConversion(this->getFunction(), target, patterns))) {
- this->signalPassFailure();
- }
+ // Canonicalization and folding patterns applied greedily allow cleaning up
+ // the emitted IR on the fly.
+ // TODO(ntv) fold view and subview ops?
+ FillRewritePatterns<LoopType, IndexedValueType>(patterns, context);
+ DimOp::getCanonicalizationPatterns(patterns, context);
+ AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+ patterns.insert<FoldAffineOp>(context);
+ // Just apply the patterns greedily.
+ applyPatternsGreedily(this->getFunction(), patterns);
}
+/// Create a pass to convert Linalg operations to loop.for loops and
+/// std.load/std.store accesses.
std::unique_ptr<OpPassBase<FuncOp>>
-mlir::linalg::createLowerLinalgToLoopsPass() {
+mlir::linalg::createConvertLinalgToLoopsPass() {
return std::make_unique<
LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>();
}
+/// Create a pass to convert Linalg operations to affine.for loops and
+/// affine_load/affine_store accesses.
+/// Placeholder for now, this is NYI.
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::linalg::createConvertLinalgToAffineLoopsPass() {
+ return std::make_unique<
+ LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>>();
+}
+
+// Emits a loop nest of `loop.for` with the proper body for `op`.
+template <typename ConcreteOp>
+LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
+ Operation *op) {
+ return LinalgOpToLoopsImpl<loop::ForOp, IndexedStdValue, ConcreteOp>::doit(
+ op, rewriter);
+}
+
+// Emits a loop nest of `affine.for` with the proper body for `op`.
+template <typename ConcreteOp>
+LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
+ Operation *op) {
+ return LinalgOpToLoopsImpl<AffineForOp, IndexedAffineValue, ConcreteOp>::doit(
+ op, rewriter);
+}
+
+// TODO(ntv) Need to make these instantiations more future-proof to avoid the
+// need to update as soon as we add new ops.
+#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \
+ template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \
+ PatternRewriter & rewriter, Operation * op); \
+ template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \
+ PatternRewriter & rewriter, Operation * op);
+
+INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)
+
static PassRegistration<LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>
structuredLoopsPass(
- "linalg-lower-to-loops",
+ "convert-linalg-to-loops",
"Lower the operations from the linalg dialect into loops");
+
+static PassRegistration<LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>>
+ affineLoopsPass(
+ "convert-linalg-to-affine-loops",
+ "Lower the operations from the linalg dialect into affine loops");
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 769141008bf..dd19d5d82cd 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
-// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS
func @range(%arg0: index) {
%c0 = constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 93cf69fe3ca..9a1c91d09e0 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-lower-to-loops | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s
// Test that we can lower all the way to LLVM without crashing, don't check results here.
// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 2561cea04d3..d94342f52ca 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -17,7 +17,13 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
// CHECK-DAG : %[[c8000:.*]] = constant 8000 : index
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] {
-// CHECK : linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<f32>
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] {
+// CHECK : load
+// CHECK : load
+// CHECK : mulf
+// CHECK : load
+// CHECK : addf
+// CHECK : store
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
index 839671c866a..97e0cb21704 100644
--- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -73,4 +73,11 @@ def : Pattern<(DotOp:$op $a, $b, $c),
[(TileLinalgOp<[8], "REG"> $op)],
[(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+//===----------------------------------------------------------------------===//
+// Linalg to loops patterns.
+//===----------------------------------------------------------------------===//
+def : Pattern<(DotOp:$op $a, $b, $c),
+ [(LinalgOpToLoops<"DotOp"> $op)],
+ [(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;
+
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
index b153faea5e6..d1ee472850a 100644
--- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
+++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt %s -convert-linalg-to-llvm | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
-// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
// RUN: mlir-opt %s -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
-// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
-// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
+// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -convert-linalg-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
#strided1D = (d0) -> (d0)
diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir
index 798c53959e5..ed54b902683 100644
--- a/mlir/test/mlir-cpu-runner/utils.mlir
+++ b/mlir/test/mlir-cpu-runner/utils.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D
-// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D
-// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D
-// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e vector_splat_2d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-VECTOR-SPLAT-2D
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e vector_splat_2d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-VECTOR-SPLAT-2D
func @print_0d() {
%f = constant 2.00000e+00 : f32
OpenPOWER on IntegriCloud