diff options
author | Andy Davis <andydavis@google.com> | 2019-12-17 13:10:07 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-17 13:17:10 -0800 |
commit | 6fa3bd5b3e57806ffa34946bd36528f72bf06b58 (patch) | |
tree | 03ba267aa627593d767406a029effc9432708a06 /mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | |
parent | 319cca3bbe69b20334caee2f93aaf6fe0318ca0d (diff) | |
download | bcm5719-llvm-6fa3bd5b3e57806ffa34946bd36528f72bf06b58.tar.gz bcm5719-llvm-6fa3bd5b3e57806ffa34946bd36528f72bf06b58.zip |
Add pattern rewrite which splits a vector TransferWriteOp into slices according to the unrolling/slicing scheme of its InsertSlicesOp operand.
PiperOrigin-RevId: 286042578
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 125 |
1 files changed, 96 insertions, 29 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 569ad443960..c4d3e9d993d 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -511,6 +511,42 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( resultIndex, targetShape, builder); } +// Generates slices of 'vectorType' according to 'sizes' and 'strides, and +// calls 'fn' with linear index and indices for each slice. +static void generateTransferOpSlices( + VectorType vectorType, TupleType tupleType, ArrayRef<int64_t> sizes, + ArrayRef<int64_t> strides, ArrayRef<Value *> indices, + PatternRewriter &rewriter, + llvm::function_ref<void(unsigned, ArrayRef<Value *>)> fn) { + // Compute strides w.r.t. to slice counts in each dimension. + auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); + assert(maybeDimSliceCounts.hasValue()); + auto sliceDimCounts = *maybeDimSliceCounts; + auto basis = computeStrides(sliceDimCounts); + + int64_t numSlices = tupleType.size(); + unsigned numSliceIndices = indices.size(); + auto *ctx = rewriter.getContext(); + for (unsigned i = 0; i < numSlices; ++i) { + // De-linearize w.r.t. 'basis'. + auto vectorOffsets = delinearize(i, basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, + vectorOffsets, sizes); + // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. + SmallVector<Value *, 4> sliceIndices(numSliceIndices); + for (auto it : llvm::enumerate(indices)) { + auto expr = getAffineDimExpr(0, ctx) + + getAffineConstantExpr(offsets[it.index()], ctx); + auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + sliceIndices[it.index()] = rewriter.create<AffineApplyOp>( + it.value()->getLoc(), map, ArrayRef<Value *>(it.value())); + } + // Call 'fn' to generate slice 'i' at 'sliceIndices'. + fn(i, sliceIndices); + } +} + // Splits vector TransferReadOp into smaller TransferReadOps based on slicing // scheme of its unique ExtractSlicesOp user. struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { @@ -538,40 +574,22 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { extractSlicesOp.getStrides(strides); assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); - // Compute strides w.r.t. to slice counts in each dimension. - auto maybeDimSliceCounts = shapeRatio(sourceVectorType.getShape(), sizes); - assert(maybeDimSliceCounts.hasValue()); - auto sliceDimCounts = *maybeDimSliceCounts; - auto basis = computeStrides(sliceDimCounts); - Location loc = xferReadOp.getLoc(); - auto *ctx = rewriter.getContext(); int64_t numSlices = resultTupleType.size(); - unsigned numSliceIndices = llvm::size(xferReadOp.indices()); SmallVector<Value *, 4> vectorTupleValues(numSlices); - for (unsigned i = 0; i < numSlices; ++i) { - // De-linearize w.r.t. 'basis'. - auto vectorOffsets = delinearize(i, basis); - // Convert from unrolled vector-space offsets to element-space offsets. - auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, - vectorOffsets, sizes); - // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. - SmallVector<Value *, 4> sliceIndices(numSliceIndices); - for (auto it : llvm::enumerate(xferReadOp.indices())) { - auto expr = getAffineDimExpr(0, ctx) + - getAffineConstantExpr(offsets[it.index()], ctx); - auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - SmallVector<Value *, 1> mapOperands = {it.value()}; - sliceIndices[it.index()] = - rewriter.create<AffineApplyOp>(loc, map, mapOperands); - } + SmallVector<Value *, 4> indices(xferReadOp.indices().begin(), + xferReadOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) { // Get VectorType for slice 'i'. - auto sliceVectorType = resultTupleType.getType(i); + auto sliceVectorType = resultTupleType.getType(index); // Create split TransferReadOp for 'sliceUser'. - vectorTupleValues[i] = rewriter.create<vector::TransferReadOp>( + vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>( loc, sliceVectorType, xferReadOp.memref(), sliceIndices, xferReadOp.permutation_map(), xferReadOp.padding()); - } + }; + generateTransferOpSlices(sourceVectorType, resultTupleType, sizes, strides, + indices, rewriter, createSlice); + // Create tuple of splice xfer read operations. Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType, vectorTupleValues); @@ -583,6 +601,54 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { } }; +// Splits vector TransferWriteOp into smaller TransferWriteOps for each source. +struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> { + using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::TransferWriteOp xferWriteOp, + PatternRewriter &rewriter) const override { + // TODO(andydavis, ntv) Support spliting TransferWriteOp with non-identity + // permutation maps. Repurpose code from MaterializeVectors transformation. + if (!xferWriteOp.permutation_map().isIdentity()) + return matchFailure(); + // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'. + auto *vectorDefOp = xferWriteOp.vector()->getDefiningOp(); + auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp); + if (!insertSlicesOp) + return matchFailure(); + + // Get TupleOp operand of 'insertSlicesOp'. + auto tupleOp = dyn_cast_or_null<vector::TupleOp>( + insertSlicesOp.vectors()->getDefiningOp()); + if (!tupleOp) + return matchFailure(); + + // Get 'sizes' and 'strides' parameters from InsertSlicesOp user. + auto sourceTupleType = insertSlicesOp.getSourceTupleType(); + auto resultVectorType = insertSlicesOp.getResultVectorType(); + SmallVector<int64_t, 4> sizes; + insertSlicesOp.getSizes(sizes); + SmallVector<int64_t, 4> strides; + insertSlicesOp.getStrides(strides); + + Location loc = xferWriteOp.getLoc(); + SmallVector<Value *, 4> indices(xferWriteOp.indices().begin(), + xferWriteOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) { + // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. + rewriter.create<vector::TransferWriteOp>( + loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, + xferWriteOp.permutation_map()); + }; + generateTransferOpSlices(resultVectorType, sourceTupleType, sizes, strides, + indices, rewriter, createSlice); + + // Erase old 'xferWriteOp'. + rewriter.eraseOp(xferWriteOp); + return matchSuccess(); + } +}; + // Patter rewrite which forward tuple elements to their users. // User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer))))) // -> User(Producer) @@ -609,7 +675,7 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> { if (!tupleOp) return matchFailure(); - // Forward Value at tupleOp.getOperand(tupleGetOp.getIndex()); + // Forward Value from 'tupleOp' at 'tupleGetOp.index'. Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); rewriter.replaceOp(tupleGetOp, tupleValue); return matchSuccess(); @@ -620,5 +686,6 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> { // TODO(andydavis) Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert<SplitTransferReadOp, TupleGetFolderOp>(context); + patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>( + context); } |