diff options
author | Andy Davis <andydavis@google.com> | 2019-12-17 11:21:12 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-17 11:21:45 -0800 |
commit | d1fb285b32d107e21a022f354a09f54f38421529 (patch) | |
tree | 5197752d0a9b8705d8d688a3e4344f3914990a5e /mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | |
parent | 9f45a224412c8381de418d78935075b318352864 (diff) | |
download | bcm5719-llvm-d1fb285b32d107e21a022f354a09f54f38421529.tar.gz bcm5719-llvm-d1fb285b32d107e21a022f354a09f54f38421529.zip |
Add pattern rewrite to forward vector tuple elements to their users.
User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer))) -> User(Producer)
PiperOrigin-RevId: 286020249
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 35 |
1 files changed, 34 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 85f306e7834..569ad443960 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -583,9 +583,42 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { } }; +// Patter rewrite which forward tuple elements to their users. +// User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer))))) +// -> User(Producer) +struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> { + using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::TupleGetOp tupleGetOp, + PatternRewriter &rewriter) const override { + // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp. + auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>( + tupleGetOp.vectors()->getDefiningOp()); + if (!extractSlicesOp) + return matchFailure(); + + // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp. + auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>( + extractSlicesOp.vector()->getDefiningOp()); + if (!insertSlicesOp) + return matchFailure(); + + // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp. + auto tupleOp = dyn_cast_or_null<vector::TupleOp>( + insertSlicesOp.vectors()->getDefiningOp()); + if (!tupleOp) + return matchFailure(); + + // Forward Value at tupleOp.getOperand(tupleGetOp.getIndex()); + Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); + rewriter.replaceOp(tupleGetOp, tupleValue); + return matchSuccess(); + } +}; + // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO(andydavis) Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert<SplitTransferReadOp>(context); + patterns.insert<SplitTransferReadOp, TupleGetFolderOp>(context); } |