diff options
author | River Riddle <riverriddle@google.com> | 2019-05-25 17:22:27 -0700 |
---|---|---|
committer | Mehdi Amini <joker.eph@gmail.com> | 2019-06-01 20:03:22 -0700 |
commit | 9e21ab8f522265d37159372dbce96f66488c4e34 (patch) | |
tree | 0ec933521a5894a21543e7ef1684e437b4978380 /mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | |
parent | 2f50b6c401fd4d6ff63718ef3b889a79ba32a640 (diff) | |
download | bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.tar.gz bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.zip |
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
--
PiperOrigin-RevId: 250003405
Diffstat (limited to 'mlir/examples/Linalg/Linalg3/lib/Transforms.cpp')
-rw-r--r-- | mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 621bc267205..0fe70e27f1b 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -248,12 +248,12 @@ namespace { /// mlir::StoreOp requires finding the proper indexing in the supporting MemRef. /// This is most easily achieved by calling emitAndReturnFullyComposedView to /// fold away all the SliceOp. -template <typename LoadOrStoreOpTy> struct Rewriter : public RewritePattern { - explicit Rewriter(MLIRContext *context) - : RewritePattern(LoadOrStoreOpTy::getOperationName(), 1, context) {} +template <typename LoadOrStoreOpTy> +struct Rewriter : public OpRewritePattern<LoadOrStoreOpTy> { + using OpRewritePattern<LoadOrStoreOpTy>::OpRewritePattern; /// Performs the rewrite. - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op, PatternRewriter &rewriter) const override; }; @@ -270,9 +270,8 @@ struct LowerLinalgLoadStorePass template <> PatternMatchResult -Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op, +Rewriter<linalg::LoadOp>::matchAndRewrite(linalg::LoadOp load, PatternRewriter &rewriter) const { - auto load = cast<linalg::LoadOp>(op); SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : cast<ViewOp>(load.getView()->getDefiningOp()); @@ -280,15 +279,14 @@ Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op, ScopedContext scope(builder, load.getLoc()); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(load, view); - rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, memRef, operands); + rewriter.replaceOpWithNewOp<mlir::LoadOp>(load, memRef, operands); return matchSuccess(); } template <> PatternMatchResult -Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op, +Rewriter<linalg::StoreOp>::matchAndRewrite(linalg::StoreOp store, PatternRewriter &rewriter) const { - auto store = cast<linalg::StoreOp>(op); SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : cast<ViewOp>(store.getView()->getDefiningOp()); @@ -297,7 +295,7 @@ Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op, auto *valueToStore = store.getValueToStore(); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(store, view); - rewriter.replaceOpWithNewOp<mlir::StoreOp>(op, valueToStore, memRef, + rewriter.replaceOpWithNewOp<mlir::StoreOp>(store, valueToStore, memRef, operands); return matchSuccess(); } |