diff options
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 6b13bcf75ca..6825709334b 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -446,3 +446,71 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( return unrollSingleResultStructuredOp(op, iterationBounds, vectors, resultIndex, targetShape, builder); } + +// Splits vector TransferReadOp into smaller TransferReadOps for each user. +struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { + using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp, + PatternRewriter &rewriter) const override { + // TODO(andydavis, ntv) Support spliting TransferReadOp with non-identity + // permutation maps. Repurpose code from MaterializeVectors transformation. + if (!xferReadOp.permutation_map().isIdentity()) + return matchFailure(); + // Gather 'xferReadOp' users. + SmallVector<vector::StridedSliceOp, 2> sliceUsers; + sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(), + xferReadOp.getResult()->use_end())); + + for (auto *user : xferReadOp.getResult()->getUsers()) { + auto sliceOp = dyn_cast<vector::StridedSliceOp>(user); + // Return if any user is not a vector::StridedSliceOp. + if (!sliceOp) + return matchFailure(); + sliceUsers.push_back(sliceOp); + } + // Make zero splat into which we will insert split xferReadOp results. + Location loc = xferReadOp.getLoc(); + auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType()); + + // Update each user in 'sliceUser' to use 'res'. + unsigned numSliceIndices = llvm::size(xferReadOp.indices()); + for (auto sliceUser : sliceUsers) { + // Gather static offsets from 'sliceUser'. + SmallVector<int64_t, 4> sliceOffsets; + sliceUser.getOffsets(sliceOffsets); + assert(sliceOffsets.size() == numSliceIndices); + auto *ctx = rewriter.getContext(); + // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. + SmallVector<Value *, 4> sliceIndices(numSliceIndices); + for (auto it : llvm::enumerate(xferReadOp.indices())) { + auto expr = getAffineDimExpr(0, ctx) + + getAffineConstantExpr(sliceOffsets[it.index()], ctx); + auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + SmallVector<Value *, 1> mapOperands = {it.value()}; + sliceIndices[it.index()] = + rewriter.create<AffineApplyOp>(loc, map, mapOperands); + } + // Create split TransferReadOp for 'sliceUser'. + auto sliceVectorType = + sliceUser.getResult()->getType().cast<VectorType>(); + auto splitXferReadOp = rewriter.create<vector::TransferReadOp>( + loc, sliceVectorType, xferReadOp.memref(), sliceIndices, + xferReadOp.permutation_map(), xferReadOp.padding()); + // Create InsertStridedSlice into splat at same offsets as slice. + res = rewriter.create<vector::InsertStridedSliceOp>( + loc, xferReadOp.getVectorType(), splitXferReadOp, res, + sliceUser.offsets(), sliceUser.strides()); + } + + // Replace 'xferReadOp' with result 'res'. + rewriter.replaceOp(xferReadOp, res); + return matchSuccess(); + } +}; + +// TODO(andydavis) Add this as DRR pattern. +void mlir::vector::populateVectorToVectorTransformationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert<SplitTransferReadOp>(context); +} |