summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-12-17 13:10:07 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-17 13:17:10 -0800
commit6fa3bd5b3e57806ffa34946bd36528f72bf06b58 (patch)
tree03ba267aa627593d767406a029effc9432708a06
parent319cca3bbe69b20334caee2f93aaf6fe0318ca0d (diff)
downloadbcm5719-llvm-6fa3bd5b3e57806ffa34946bd36528f72bf06b58.tar.gz
bcm5719-llvm-6fa3bd5b3e57806ffa34946bd36528f72bf06b58.zip
Add pattern rewrite which splits a vector TransferWriteOp into slices according to the unrolling/slicing scheme of its InsertSlicesOp operand.
PiperOrigin-RevId: 286042578
-rw-r--r--mlir/include/mlir/Dialect/VectorOps/VectorOps.td4
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp3
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorTransforms.cpp125
-rw-r--r--mlir/test/Dialect/VectorOps/vector-transforms.mlir31
4 files changed, 121 insertions, 42 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 50bf581f5a2..e031d7cfb8c 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -980,8 +980,8 @@ def Vector_TupleGetOp :
VectorType getResultVectorType() {
return getResult()->getType().cast<VectorType>();
}
- unsigned getIndex() {
- return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
+ int64_t getIndex() {
+ return getAttrOfType<IntegerAttr>("index").getValue().getSExtValue();
}
static StringRef getIndexAttrName() { return "index"; }
}];
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 1f6a4bce49e..ff4ff2cb540 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -1505,7 +1505,8 @@ static void print(OpAsmPrinter &p, TupleGetOp op) {
static LogicalResult verify(TupleGetOp op) {
auto tupleType = op.getOperand()->getType().cast<TupleType>();
- if (op.getIndex() < 0 || op.getIndex() >= tupleType.size())
+ if (op.getIndex() < 0 ||
+ op.getIndex() >= static_cast<int64_t>(tupleType.size()))
return op.emitOpError("tuple get index out of range");
return success();
}
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 569ad443960..c4d3e9d993d 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -511,6 +511,42 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
resultIndex, targetShape, builder);
}
+// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
+// calls 'fn' with linear index and indices for each slice.
+static void generateTransferOpSlices(
+ VectorType vectorType, TupleType tupleType, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides, ArrayRef<Value *> indices,
+ PatternRewriter &rewriter,
+ llvm::function_ref<void(unsigned, ArrayRef<Value *>)> fn) {
+ // Compute strides w.r.t. to slice counts in each dimension.
+ auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
+ assert(maybeDimSliceCounts.hasValue());
+ auto sliceDimCounts = *maybeDimSliceCounts;
+ auto basis = computeStrides(sliceDimCounts);
+
+ int64_t numSlices = tupleType.size();
+ unsigned numSliceIndices = indices.size();
+ auto *ctx = rewriter.getContext();
+ for (unsigned i = 0; i < numSlices; ++i) {
+ // De-linearize w.r.t. 'basis'.
+ auto vectorOffsets = delinearize(i, basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+ vectorOffsets, sizes);
+ // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
+ SmallVector<Value *, 4> sliceIndices(numSliceIndices);
+ for (auto it : llvm::enumerate(indices)) {
+ auto expr = getAffineDimExpr(0, ctx) +
+ getAffineConstantExpr(offsets[it.index()], ctx);
+ auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
+ sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(
+ it.value()->getLoc(), map, ArrayRef<Value *>(it.value()));
+ }
+ // Call 'fn' to generate slice 'i' at 'sliceIndices'.
+ fn(i, sliceIndices);
+ }
+}
+
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
// scheme of its unique ExtractSlicesOp user.
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
@@ -538,40 +574,22 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
extractSlicesOp.getStrides(strides);
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
- // Compute strides w.r.t. to slice counts in each dimension.
- auto maybeDimSliceCounts = shapeRatio(sourceVectorType.getShape(), sizes);
- assert(maybeDimSliceCounts.hasValue());
- auto sliceDimCounts = *maybeDimSliceCounts;
- auto basis = computeStrides(sliceDimCounts);
-
Location loc = xferReadOp.getLoc();
- auto *ctx = rewriter.getContext();
int64_t numSlices = resultTupleType.size();
- unsigned numSliceIndices = llvm::size(xferReadOp.indices());
SmallVector<Value *, 4> vectorTupleValues(numSlices);
- for (unsigned i = 0; i < numSlices; ++i) {
- // De-linearize w.r.t. 'basis'.
- auto vectorOffsets = delinearize(i, basis);
- // Convert from unrolled vector-space offsets to element-space offsets.
- auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
- vectorOffsets, sizes);
- // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
- SmallVector<Value *, 4> sliceIndices(numSliceIndices);
- for (auto it : llvm::enumerate(xferReadOp.indices())) {
- auto expr = getAffineDimExpr(0, ctx) +
- getAffineConstantExpr(offsets[it.index()], ctx);
- auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
- SmallVector<Value *, 1> mapOperands = {it.value()};
- sliceIndices[it.index()] =
- rewriter.create<AffineApplyOp>(loc, map, mapOperands);
- }
+ SmallVector<Value *, 4> indices(xferReadOp.indices().begin(),
+ xferReadOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) {
// Get VectorType for slice 'i'.
- auto sliceVectorType = resultTupleType.getType(i);
+ auto sliceVectorType = resultTupleType.getType(index);
// Create split TransferReadOp for 'sliceUser'.
- vectorTupleValues[i] = rewriter.create<vector::TransferReadOp>(
+ vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
xferReadOp.permutation_map(), xferReadOp.padding());
- }
+ };
+ generateTransferOpSlices(sourceVectorType, resultTupleType, sizes, strides,
+ indices, rewriter, createSlice);
+
// Create tuple of splice xfer read operations.
Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
vectorTupleValues);
@@ -583,6 +601,54 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
}
};
+// Splits vector TransferWriteOp into smaller TransferWriteOps for each source.
+struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
+ PatternRewriter &rewriter) const override {
+ // TODO(andydavis, ntv) Support spliting TransferWriteOp with non-identity
+ // permutation maps. Repurpose code from MaterializeVectors transformation.
+ if (!xferWriteOp.permutation_map().isIdentity())
+ return matchFailure();
+ // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
+ auto *vectorDefOp = xferWriteOp.vector()->getDefiningOp();
+ auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
+ if (!insertSlicesOp)
+ return matchFailure();
+
+ // Get TupleOp operand of 'insertSlicesOp'.
+ auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
+ insertSlicesOp.vectors()->getDefiningOp());
+ if (!tupleOp)
+ return matchFailure();
+
+ // Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
+ auto sourceTupleType = insertSlicesOp.getSourceTupleType();
+ auto resultVectorType = insertSlicesOp.getResultVectorType();
+ SmallVector<int64_t, 4> sizes;
+ insertSlicesOp.getSizes(sizes);
+ SmallVector<int64_t, 4> strides;
+ insertSlicesOp.getStrides(strides);
+
+ Location loc = xferWriteOp.getLoc();
+ SmallVector<Value *, 4> indices(xferWriteOp.indices().begin(),
+ xferWriteOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) {
+ // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
+ rewriter.create<vector::TransferWriteOp>(
+ loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
+ xferWriteOp.permutation_map());
+ };
+ generateTransferOpSlices(resultVectorType, sourceTupleType, sizes, strides,
+ indices, rewriter, createSlice);
+
+ // Erase old 'xferWriteOp'.
+ rewriter.eraseOp(xferWriteOp);
+ return matchSuccess();
+ }
+};
+
// Patter rewrite which forward tuple elements to their users.
// User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
// -> User(Producer)
@@ -609,7 +675,7 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
if (!tupleOp)
return matchFailure();
- // Forward Value at tupleOp.getOperand(tupleGetOp.getIndex());
+ // Forward Value from 'tupleOp' at 'tupleGetOp.index'.
Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
rewriter.replaceOp(tupleGetOp, tupleValue);
return matchSuccess();
@@ -620,5 +686,6 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
// TODO(andydavis) Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<SplitTransferReadOp, TupleGetFolderOp>(context);
+ patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
+ context);
}
diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir
index 978b0c2f855..b5fcbaba91c 100644
--- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir
@@ -229,23 +229,32 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
// CHECK: %[[C2:.*]] = constant 2 : index
// Check LHS vector.transfer read is split for each user.
-// TODO(andydavis) Connect VTR results with users in subsequent CL.
// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
-// CHECK: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
-// CHECK: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
-// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
-// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+
+// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {permutation_map = #[[MAP0]]} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {permutation_map = #[[MAP0]]} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {permutation_map = #[[MAP0]]} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {permutation_map = #[[MAP0]]} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: return
func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
%arg1 : memref<2x4xf32>,
- %arg2 : memref<4x4xf32>)
- -> (vector<4x4xf32>) {
+ %arg2 : memref<4x4xf32>) {
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
@@ -264,15 +273,17 @@ func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
%3 = vector.contract #contraction_trait1 %0, %1, %2
: vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
- return %3 : vector<4x4xf32>
+ vector.transfer_write %3, %arg2[%c0, %c0]
+ {permutation_map = (d0, d1) -> (d0, d1)}
+ : vector<4x4xf32>, memref<4x4xf32>
+ return
}
// TODO(andydavis) Update test with VTR split transform.
// CHECK-LABEL: func @vector_transfers
// CHECK-COUNT-8: vector.transfer_read
// CHECK-COUNT-4: addf
-// CHECK-COUNT-1: vector.insert_slices
-// CHECK: vector.transfer_write
+// CHECK-COUNT-4: vector.transfer_write
func @vector_transfers(%arg0: index, %arg1: index) {
%cst = constant 0.000000e+00 : f32
OpenPOWER on IntegriCloud