summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorTransforms.cpp78
1 files changed, 44 insertions, 34 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 8d70f4ac83f..85f306e7834 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -511,7 +511,8 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
resultIndex, targetShape, builder);
}
-// Splits vector TransferReadOp into smaller TransferReadOps for each user.
+// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
+// scheme of its unique ExtractSlicesOp user.
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -521,54 +522,63 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
// permutation maps. Repurpose code from MaterializeVectors transformation.
if (!xferReadOp.permutation_map().isIdentity())
return matchFailure();
- // Gather 'xferReadOp' users.
- SmallVector<vector::StridedSliceOp, 2> sliceUsers;
- sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(),
- xferReadOp.getResult()->use_end()));
-
- for (auto *user : xferReadOp.getResult()->getUsers()) {
- auto sliceOp = dyn_cast<vector::StridedSliceOp>(user);
- // Return if any user is not a vector::StridedSliceOp.
- if (!sliceOp)
- return matchFailure();
- sliceUsers.push_back(sliceOp);
- }
- // Make zero splat into which we will insert split xferReadOp results.
- Location loc = xferReadOp.getLoc();
- auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType());
+ // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
+ Value *xferReadResult = xferReadOp.getResult();
+ auto extractSlicesOp =
+ dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin());
+ if (!xferReadResult->hasOneUse() || !extractSlicesOp)
+ return matchFailure();
- // Update each user in 'sliceUser' to use 'res'.
+ // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
+ auto sourceVectorType = extractSlicesOp.getSourceVectorType();
+ auto resultTupleType = extractSlicesOp.getResultTupleType();
+ SmallVector<int64_t, 4> sizes;
+ extractSlicesOp.getSizes(sizes);
+ SmallVector<int64_t, 4> strides;
+ extractSlicesOp.getStrides(strides);
+ assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
+
+ // Compute strides w.r.t. to slice counts in each dimension.
+ auto maybeDimSliceCounts = shapeRatio(sourceVectorType.getShape(), sizes);
+ assert(maybeDimSliceCounts.hasValue());
+ auto sliceDimCounts = *maybeDimSliceCounts;
+ auto basis = computeStrides(sliceDimCounts);
+
+ Location loc = xferReadOp.getLoc();
+ auto *ctx = rewriter.getContext();
+ int64_t numSlices = resultTupleType.size();
unsigned numSliceIndices = llvm::size(xferReadOp.indices());
- for (auto sliceUser : sliceUsers) {
- // Gather static offsets from 'sliceUser'.
- SmallVector<int64_t, 4> sliceOffsets;
- sliceUser.getOffsets(sliceOffsets);
- assert(sliceOffsets.size() == numSliceIndices);
- auto *ctx = rewriter.getContext();
+ SmallVector<Value *, 4> vectorTupleValues(numSlices);
+ for (unsigned i = 0; i < numSlices; ++i) {
+ // De-linearize w.r.t. 'basis'.
+ auto vectorOffsets = delinearize(i, basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+ vectorOffsets, sizes);
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
SmallVector<Value *, 4> sliceIndices(numSliceIndices);
for (auto it : llvm::enumerate(xferReadOp.indices())) {
auto expr = getAffineDimExpr(0, ctx) +
- getAffineConstantExpr(sliceOffsets[it.index()], ctx);
+ getAffineConstantExpr(offsets[it.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
SmallVector<Value *, 1> mapOperands = {it.value()};
sliceIndices[it.index()] =
rewriter.create<AffineApplyOp>(loc, map, mapOperands);
}
+ // Get VectorType for slice 'i'.
+ auto sliceVectorType = resultTupleType.getType(i);
// Create split TransferReadOp for 'sliceUser'.
- auto sliceVectorType =
- sliceUser.getResult()->getType().cast<VectorType>();
- auto splitXferReadOp = rewriter.create<vector::TransferReadOp>(
+ vectorTupleValues[i] = rewriter.create<vector::TransferReadOp>(
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
xferReadOp.permutation_map(), xferReadOp.padding());
- // Create InsertStridedSlice into splat at same offsets as slice.
- res = rewriter.create<vector::InsertStridedSliceOp>(
- loc, xferReadOp.getVectorType(), splitXferReadOp, res,
- sliceUser.offsets(), sliceUser.strides());
}
-
- // Replace 'xferReadOp' with result 'res'.
- rewriter.replaceOp(xferReadOp, res);
+ // Create tuple of splice xfer read operations.
+ Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
+ vectorTupleValues);
+ // Replace 'xferReadOp' with result 'insertSlicesResult'.
+ rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
+ xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
+ extractSlicesOp.strides());
return matchSuccess();
}
};
OpenPOWER on IntegriCloud