diff options
| -rw-r--r-- | mlir/include/mlir/Transforms/MLPatternLoweringPass.h | 64 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LowerVectorTransfers.cpp | 13 |
2 files changed, 28 insertions, 49 deletions
diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 15e4f215c61..c9ed3a38a65 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -81,41 +81,7 @@ namespace detail { /// Owning list of ML lowering patterns. using OwningMLLoweringPatternList = std::vector<std::unique_ptr<mlir::MLLoweringPattern>>; -} // namespace detail - -/// Generic lowering pass for ML functions. The lowering details are defined as -/// a sequence of pattern matchers. The following constraints on matchers -/// apply: -/// - only one (match root) operation can be removed; -/// - the code produced by rewriters is final, it is not pattern-matched; -/// - the matchers are applied in their order of appearance in the list; -/// - if the match is found, the operation is rewritten immediately and the -/// next _original_ operation is considered. -/// In other words, for each operation, the pass applies the first matching -/// rewriter in the list and advances to the (lexically) next operation. -/// Non-operation instructions (ForInst) are ignored. -/// This is similar to greedy worklist-based pattern rewriter, except that this -/// operates on ML functions using an ML builder and does not maintain the work -/// list. Note that, as of the time of writing, worklist-based rewriter did not -/// support removing multiple operations either. -template <typename... Patterns> -class MLPatternLoweringPass : public FunctionPass { -public: - explicit MLPatternLoweringPass(const PassID *ID) : FunctionPass(ID) {} - - virtual std::unique_ptr<MLFuncGlobalLoweringState> - makeFuncWiseState(Function *f) const { - return nullptr; - } - PassResult runOnFunction(Function *f) override; -}; - -///////////////////////////////////////////////////////////////////// -// MLPatternLoweringPass template implementations -///////////////////////////////////////////////////////////////////// - -namespace detail { template <typename Pattern, typename... Patterns> struct ListAdder { static void addPatternsToList(OwningMLLoweringPatternList *list, MLIRContext *context) { @@ -134,11 +100,25 @@ template <typename Pattern> struct ListAdder<Pattern> { }; } // namespace detail +/// Generic lowering for ML patterns. The lowering details are defined as +/// a sequence of pattern matchers. The following constraints on matchers +/// apply: +/// - only one (match root) operation can be removed; +/// - the code produced by rewriters is final, it is not pattern-matched; +/// - the matchers are applied in their order of appearance in the list; +/// - if the match is found, the operation is rewritten immediately and the +/// next _original_ operation is considered. +/// In other words, for each operation, apply the first matching rewriter in the +/// list and advance to the (lexically) next operation. This is similar to +/// greedy worklist-based pattern rewriter, except that this operates on ML +/// functions using an ML builder and does not maintain the work list. Note +/// that, as of the time of writing, worklist-based rewriter did not support +/// removing multiple operations either. template <typename... Patterns> -PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) { +void applyMLPatternsGreedily( + Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) { detail::OwningMLLoweringPatternList patterns; detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext()); - auto funcWiseState = makeFuncWiseState(f); FuncBuilder builder(f); MLFuncLoweringRewriter rewriter(&builder); @@ -148,19 +128,15 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) { for (Instruction *inst : ops) { for (const auto &pattern : patterns) { - rewriter.getBuilder()->setInsertionPoint(inst); - auto matchResult = pattern->match(inst); - if (matchResult) { - pattern->rewriteOpInst(inst, funcWiseState.get(), - std::move(*matchResult), &rewriter); + builder.setInsertionPoint(inst); + if (auto matchResult = pattern->match(inst)) { + pattern->rewriteOpInst(inst, funcWiseState, std::move(*matchResult), + &rewriter); break; } } } - - return PassResult::Success; } - } // end namespace mlir #endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index ac8f7e064f5..61f75ae76e6 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -424,12 +424,15 @@ public: } }; -struct LowerVectorTransfersPass - : public MLPatternLoweringPass< - VectorTransferExpander<VectorTransferReadOp>, - VectorTransferExpander<VectorTransferWriteOp>> { +struct LowerVectorTransfersPass : public FunctionPass { LowerVectorTransfersPass() - : MLPatternLoweringPass(&LowerVectorTransfersPass::passID) {} + : FunctionPass(&LowerVectorTransfersPass::passID) {} + + PassResult runOnFunction(Function *fn) override { + applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>, + VectorTransferExpander<VectorTransferWriteOp>>(fn); + return success(); + } // Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit. edsc::ScopedEDSCContext raiiContext; |

