summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Transforms/MLPatternLoweringPass.h64
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp13
2 files changed, 28 insertions, 49 deletions
diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h
index 15e4f215c61..c9ed3a38a65 100644
--- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h
+++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h
@@ -81,41 +81,7 @@ namespace detail {
/// Owning list of ML lowering patterns.
using OwningMLLoweringPatternList =
std::vector<std::unique_ptr<mlir::MLLoweringPattern>>;
-} // namespace detail
-
-/// Generic lowering pass for ML functions. The lowering details are defined as
-/// a sequence of pattern matchers. The following constraints on matchers
-/// apply:
-/// - only one (match root) operation can be removed;
-/// - the code produced by rewriters is final, it is not pattern-matched;
-/// - the matchers are applied in their order of appearance in the list;
-/// - if the match is found, the operation is rewritten immediately and the
-/// next _original_ operation is considered.
-/// In other words, for each operation, the pass applies the first matching
-/// rewriter in the list and advances to the (lexically) next operation.
-/// Non-operation instructions (ForInst) are ignored.
-/// This is similar to greedy worklist-based pattern rewriter, except that this
-/// operates on ML functions using an ML builder and does not maintain the work
-/// list. Note that, as of the time of writing, worklist-based rewriter did not
-/// support removing multiple operations either.
-template <typename... Patterns>
-class MLPatternLoweringPass : public FunctionPass {
-public:
- explicit MLPatternLoweringPass(const PassID *ID) : FunctionPass(ID) {}
-
- virtual std::unique_ptr<MLFuncGlobalLoweringState>
- makeFuncWiseState(Function *f) const {
- return nullptr;
- }
- PassResult runOnFunction(Function *f) override;
-};
-
-/////////////////////////////////////////////////////////////////////
-// MLPatternLoweringPass template implementations
-/////////////////////////////////////////////////////////////////////
-
-namespace detail {
template <typename Pattern, typename... Patterns> struct ListAdder {
static void addPatternsToList(OwningMLLoweringPatternList *list,
MLIRContext *context) {
@@ -134,11 +100,25 @@ template <typename Pattern> struct ListAdder<Pattern> {
};
} // namespace detail
+/// Generic lowering for ML patterns. The lowering details are defined as
+/// a sequence of pattern matchers. The following constraints on matchers
+/// apply:
+/// - only one (match root) operation can be removed;
+/// - the code produced by rewriters is final, it is not pattern-matched;
+/// - the matchers are applied in their order of appearance in the list;
+/// - if the match is found, the operation is rewritten immediately and the
+/// next _original_ operation is considered.
+/// In other words, for each operation, apply the first matching rewriter in the
+/// list and advance to the (lexically) next operation. This is similar to
+/// greedy worklist-based pattern rewriter, except that this operates on ML
+/// functions using an ML builder and does not maintain the work list. Note
+/// that, as of the time of writing, worklist-based rewriter did not support
+/// removing multiple operations either.
template <typename... Patterns>
-PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) {
+void applyMLPatternsGreedily(
+ Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) {
detail::OwningMLLoweringPatternList patterns;
detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
- auto funcWiseState = makeFuncWiseState(f);
FuncBuilder builder(f);
MLFuncLoweringRewriter rewriter(&builder);
@@ -148,19 +128,15 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) {
for (Instruction *inst : ops) {
for (const auto &pattern : patterns) {
- rewriter.getBuilder()->setInsertionPoint(inst);
- auto matchResult = pattern->match(inst);
- if (matchResult) {
- pattern->rewriteOpInst(inst, funcWiseState.get(),
- std::move(*matchResult), &rewriter);
+ builder.setInsertionPoint(inst);
+ if (auto matchResult = pattern->match(inst)) {
+ pattern->rewriteOpInst(inst, funcWiseState, std::move(*matchResult),
+ &rewriter);
break;
}
}
}
-
- return PassResult::Success;
}
-
} // end namespace mlir
#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index ac8f7e064f5..61f75ae76e6 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -424,12 +424,15 @@ public:
}
};
-struct LowerVectorTransfersPass
- : public MLPatternLoweringPass<
- VectorTransferExpander<VectorTransferReadOp>,
- VectorTransferExpander<VectorTransferWriteOp>> {
+struct LowerVectorTransfersPass : public FunctionPass {
LowerVectorTransfersPass()
- : MLPatternLoweringPass(&LowerVectorTransfersPass::passID) {}
+ : FunctionPass(&LowerVectorTransfersPass::passID) {}
+
+ PassResult runOnFunction(Function *fn) override {
+ applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
+ VectorTransferExpander<VectorTransferWriteOp>>(fn);
+ return success();
+ }
// Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit.
edsc::ScopedEDSCContext raiiContext;
OpenPOWER on IntegriCloud