summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorTransforms.cpp125
1 files changed, 96 insertions, 29 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 569ad443960..c4d3e9d993d 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -511,6 +511,42 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
resultIndex, targetShape, builder);
}
+// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
+// calls 'fn' with linear index and indices for each slice.
+static void generateTransferOpSlices(
+ VectorType vectorType, TupleType tupleType, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides, ArrayRef<Value *> indices,
+ PatternRewriter &rewriter,
+ llvm::function_ref<void(unsigned, ArrayRef<Value *>)> fn) {
+ // Compute strides w.r.t. to slice counts in each dimension.
+ auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
+ assert(maybeDimSliceCounts.hasValue());
+ auto sliceDimCounts = *maybeDimSliceCounts;
+ auto basis = computeStrides(sliceDimCounts);
+
+ int64_t numSlices = tupleType.size();
+ unsigned numSliceIndices = indices.size();
+ auto *ctx = rewriter.getContext();
+ for (unsigned i = 0; i < numSlices; ++i) {
+ // De-linearize w.r.t. 'basis'.
+ auto vectorOffsets = delinearize(i, basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+ vectorOffsets, sizes);
+ // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
+ SmallVector<Value *, 4> sliceIndices(numSliceIndices);
+ for (auto it : llvm::enumerate(indices)) {
+ auto expr = getAffineDimExpr(0, ctx) +
+ getAffineConstantExpr(offsets[it.index()], ctx);
+ auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
+ sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(
+ it.value()->getLoc(), map, ArrayRef<Value *>(it.value()));
+ }
+ // Call 'fn' to generate slice 'i' at 'sliceIndices'.
+ fn(i, sliceIndices);
+ }
+}
+
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
// scheme of its unique ExtractSlicesOp user.
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
@@ -538,40 +574,22 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
extractSlicesOp.getStrides(strides);
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
- // Compute strides w.r.t. to slice counts in each dimension.
- auto maybeDimSliceCounts = shapeRatio(sourceVectorType.getShape(), sizes);
- assert(maybeDimSliceCounts.hasValue());
- auto sliceDimCounts = *maybeDimSliceCounts;
- auto basis = computeStrides(sliceDimCounts);
-
Location loc = xferReadOp.getLoc();
- auto *ctx = rewriter.getContext();
int64_t numSlices = resultTupleType.size();
- unsigned numSliceIndices = llvm::size(xferReadOp.indices());
SmallVector<Value *, 4> vectorTupleValues(numSlices);
- for (unsigned i = 0; i < numSlices; ++i) {
- // De-linearize w.r.t. 'basis'.
- auto vectorOffsets = delinearize(i, basis);
- // Convert from unrolled vector-space offsets to element-space offsets.
- auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
- vectorOffsets, sizes);
- // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
- SmallVector<Value *, 4> sliceIndices(numSliceIndices);
- for (auto it : llvm::enumerate(xferReadOp.indices())) {
- auto expr = getAffineDimExpr(0, ctx) +
- getAffineConstantExpr(offsets[it.index()], ctx);
- auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
- SmallVector<Value *, 1> mapOperands = {it.value()};
- sliceIndices[it.index()] =
- rewriter.create<AffineApplyOp>(loc, map, mapOperands);
- }
+ SmallVector<Value *, 4> indices(xferReadOp.indices().begin(),
+ xferReadOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) {
// Get VectorType for slice 'i'.
- auto sliceVectorType = resultTupleType.getType(i);
+ auto sliceVectorType = resultTupleType.getType(index);
// Create split TransferReadOp for 'sliceUser'.
- vectorTupleValues[i] = rewriter.create<vector::TransferReadOp>(
+ vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
xferReadOp.permutation_map(), xferReadOp.padding());
- }
+ };
+ generateTransferOpSlices(sourceVectorType, resultTupleType, sizes, strides,
+ indices, rewriter, createSlice);
+
// Create tuple of splice xfer read operations.
Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
vectorTupleValues);
@@ -583,6 +601,54 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
}
};
+// Splits vector TransferWriteOp into smaller TransferWriteOps for each source.
+struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
+ PatternRewriter &rewriter) const override {
+ // TODO(andydavis, ntv) Support spliting TransferWriteOp with non-identity
+ // permutation maps. Repurpose code from MaterializeVectors transformation.
+ if (!xferWriteOp.permutation_map().isIdentity())
+ return matchFailure();
+ // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
+ auto *vectorDefOp = xferWriteOp.vector()->getDefiningOp();
+ auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
+ if (!insertSlicesOp)
+ return matchFailure();
+
+ // Get TupleOp operand of 'insertSlicesOp'.
+ auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
+ insertSlicesOp.vectors()->getDefiningOp());
+ if (!tupleOp)
+ return matchFailure();
+
+ // Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
+ auto sourceTupleType = insertSlicesOp.getSourceTupleType();
+ auto resultVectorType = insertSlicesOp.getResultVectorType();
+ SmallVector<int64_t, 4> sizes;
+ insertSlicesOp.getSizes(sizes);
+ SmallVector<int64_t, 4> strides;
+ insertSlicesOp.getStrides(strides);
+
+ Location loc = xferWriteOp.getLoc();
+ SmallVector<Value *, 4> indices(xferWriteOp.indices().begin(),
+ xferWriteOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) {
+ // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
+ rewriter.create<vector::TransferWriteOp>(
+ loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
+ xferWriteOp.permutation_map());
+ };
+ generateTransferOpSlices(resultVectorType, sourceTupleType, sizes, strides,
+ indices, rewriter, createSlice);
+
+ // Erase old 'xferWriteOp'.
+ rewriter.eraseOp(xferWriteOp);
+ return matchSuccess();
+ }
+};
+
// Patter rewrite which forward tuple elements to their users.
// User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
// -> User(Producer)
@@ -609,7 +675,7 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
if (!tupleOp)
return matchFailure();
- // Forward Value at tupleOp.getOperand(tupleGetOp.getIndex());
+ // Forward Value from 'tupleOp' at 'tupleGetOp.index'.
Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
rewriter.replaceOp(tupleGetOp, tupleValue);
return matchSuccess();
@@ -620,5 +686,6 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
// TODO(andydavis) Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<SplitTransferReadOp, TupleGetFolderOp>(context);
+ patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
+ context);
}
OpenPOWER on IntegriCloud