summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-12-10 17:02:17 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-10 17:02:51 -0800
commit4d8ba886103b0022b019671bf27547d55a902b54 (patch)
treece4aebe086bad357f892feeaa1492bd7e507f2a7 /mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
parent36a415bcc543553891af6809c5256e6e2469357d (diff)
downloadbcm5719-llvm-4d8ba886103b0022b019671bf27547d55a902b54.tar.gz
bcm5719-llvm-4d8ba886103b0022b019671bf27547d55a902b54.zip
Add VectorOp transform pattern which splits vector TransferReadOps to target vector unroll size.
PiperOrigin-RevId: 284880592
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorTransforms.cpp68
1 files changed, 68 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 6b13bcf75ca..6825709334b 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -446,3 +446,71 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
resultIndex, targetShape, builder);
}
+
+// Splits vector TransferReadOp into smaller TransferReadOps for each user.
+struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp,
+ PatternRewriter &rewriter) const override {
+ // TODO(andydavis, ntv) Support spliting TransferReadOp with non-identity
+ // permutation maps. Repurpose code from MaterializeVectors transformation.
+ if (!xferReadOp.permutation_map().isIdentity())
+ return matchFailure();
+ // Gather 'xferReadOp' users.
+ SmallVector<vector::StridedSliceOp, 2> sliceUsers;
+ sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(),
+ xferReadOp.getResult()->use_end()));
+
+ for (auto *user : xferReadOp.getResult()->getUsers()) {
+ auto sliceOp = dyn_cast<vector::StridedSliceOp>(user);
+ // Return if any user is not a vector::StridedSliceOp.
+ if (!sliceOp)
+ return matchFailure();
+ sliceUsers.push_back(sliceOp);
+ }
+ // Make zero splat into which we will insert split xferReadOp results.
+ Location loc = xferReadOp.getLoc();
+ auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType());
+
+ // Update each user in 'sliceUser' to use 'res'.
+ unsigned numSliceIndices = llvm::size(xferReadOp.indices());
+ for (auto sliceUser : sliceUsers) {
+ // Gather static offsets from 'sliceUser'.
+ SmallVector<int64_t, 4> sliceOffsets;
+ sliceUser.getOffsets(sliceOffsets);
+ assert(sliceOffsets.size() == numSliceIndices);
+ auto *ctx = rewriter.getContext();
+ // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
+ SmallVector<Value *, 4> sliceIndices(numSliceIndices);
+ for (auto it : llvm::enumerate(xferReadOp.indices())) {
+ auto expr = getAffineDimExpr(0, ctx) +
+ getAffineConstantExpr(sliceOffsets[it.index()], ctx);
+ auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
+ SmallVector<Value *, 1> mapOperands = {it.value()};
+ sliceIndices[it.index()] =
+ rewriter.create<AffineApplyOp>(loc, map, mapOperands);
+ }
+ // Create split TransferReadOp for 'sliceUser'.
+ auto sliceVectorType =
+ sliceUser.getResult()->getType().cast<VectorType>();
+ auto splitXferReadOp = rewriter.create<vector::TransferReadOp>(
+ loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
+ xferReadOp.permutation_map(), xferReadOp.padding());
+ // Create InsertStridedSlice into splat at same offsets as slice.
+ res = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, xferReadOp.getVectorType(), splitXferReadOp, res,
+ sliceUser.offsets(), sliceUser.strides());
+ }
+
+ // Replace 'xferReadOp' with result 'res'.
+ rewriter.replaceOp(xferReadOp, res);
+ return matchSuccess();
+ }
+};
+
+// TODO(andydavis) Add this as DRR pattern.
+void mlir::vector::populateVectorToVectorTransformationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<SplitTransferReadOp>(context);
+}
OpenPOWER on IntegriCloud