diff options
author | Andy Davis <andydavis@google.com> | 2019-12-17 07:28:37 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-17 07:29:06 -0800 |
commit | 038ad1d8567ae2f46294e7e7fe68e09c20a309d6 (patch) | |
tree | c4724e174fa1b10ed5cfd7fe5e0401c672354186 /mlir/lib/Dialect | |
parent | 8d68fe684e6f8a4012e4d0047df45b6e200244e2 (diff) | |
download | bcm5719-llvm-038ad1d8567ae2f46294e7e7fe68e09c20a309d6.tar.gz bcm5719-llvm-038ad1d8567ae2f46294e7e7fe68e09c20a309d6.zip |
Add pattern rewrite which splits a vector TransferReadOp into slices according to the unrolling/slicing scheme of its ExtractSlicesOp user.
PiperOrigin-RevId: 285975613
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 78 |
1 files changed, 44 insertions, 34 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 8d70f4ac83f..85f306e7834 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -511,7 +511,8 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( resultIndex, targetShape, builder); } -// Splits vector TransferReadOp into smaller TransferReadOps for each user. +// Splits vector TransferReadOp into smaller TransferReadOps based on slicing +// scheme of its unique ExtractSlicesOp user. struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; @@ -521,54 +522,63 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { // permutation maps. Repurpose code from MaterializeVectors transformation. if (!xferReadOp.permutation_map().isIdentity()) return matchFailure(); - // Gather 'xferReadOp' users. - SmallVector<vector::StridedSliceOp, 2> sliceUsers; - sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(), - xferReadOp.getResult()->use_end())); - - for (auto *user : xferReadOp.getResult()->getUsers()) { - auto sliceOp = dyn_cast<vector::StridedSliceOp>(user); - // Return if any user is not a vector::StridedSliceOp. - if (!sliceOp) - return matchFailure(); - sliceUsers.push_back(sliceOp); - } - // Make zero splat into which we will insert split xferReadOp results. - Location loc = xferReadOp.getLoc(); - auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType()); + // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. + Value *xferReadResult = xferReadOp.getResult(); + auto extractSlicesOp = + dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin()); + if (!xferReadResult->hasOneUse() || !extractSlicesOp) + return matchFailure(); - // Update each user in 'sliceUser' to use 'res'. + // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user. + auto sourceVectorType = extractSlicesOp.getSourceVectorType(); + auto resultTupleType = extractSlicesOp.getResultTupleType(); + SmallVector<int64_t, 4> sizes; + extractSlicesOp.getSizes(sizes); + SmallVector<int64_t, 4> strides; + extractSlicesOp.getStrides(strides); + assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); + + // Compute strides w.r.t. to slice counts in each dimension. + auto maybeDimSliceCounts = shapeRatio(sourceVectorType.getShape(), sizes); + assert(maybeDimSliceCounts.hasValue()); + auto sliceDimCounts = *maybeDimSliceCounts; + auto basis = computeStrides(sliceDimCounts); + + Location loc = xferReadOp.getLoc(); + auto *ctx = rewriter.getContext(); + int64_t numSlices = resultTupleType.size(); unsigned numSliceIndices = llvm::size(xferReadOp.indices()); - for (auto sliceUser : sliceUsers) { - // Gather static offsets from 'sliceUser'. - SmallVector<int64_t, 4> sliceOffsets; - sliceUser.getOffsets(sliceOffsets); - assert(sliceOffsets.size() == numSliceIndices); - auto *ctx = rewriter.getContext(); + SmallVector<Value *, 4> vectorTupleValues(numSlices); + for (unsigned i = 0; i < numSlices; ++i) { + // De-linearize w.r.t. 'basis'. + auto vectorOffsets = delinearize(i, basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, + vectorOffsets, sizes); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. SmallVector<Value *, 4> sliceIndices(numSliceIndices); for (auto it : llvm::enumerate(xferReadOp.indices())) { auto expr = getAffineDimExpr(0, ctx) + - getAffineConstantExpr(sliceOffsets[it.index()], ctx); + getAffineConstantExpr(offsets[it.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); SmallVector<Value *, 1> mapOperands = {it.value()}; sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(loc, map, mapOperands); } + // Get VectorType for slice 'i'. + auto sliceVectorType = resultTupleType.getType(i); // Create split TransferReadOp for 'sliceUser'. - auto sliceVectorType = - sliceUser.getResult()->getType().cast<VectorType>(); - auto splitXferReadOp = rewriter.create<vector::TransferReadOp>( + vectorTupleValues[i] = rewriter.create<vector::TransferReadOp>( loc, sliceVectorType, xferReadOp.memref(), sliceIndices, xferReadOp.permutation_map(), xferReadOp.padding()); - // Create InsertStridedSlice into splat at same offsets as slice. - res = rewriter.create<vector::InsertStridedSliceOp>( - loc, xferReadOp.getVectorType(), splitXferReadOp, res, - sliceUser.offsets(), sliceUser.strides()); } - - // Replace 'xferReadOp' with result 'res'. - rewriter.replaceOp(xferReadOp, res); + // Create tuple of splice xfer read operations. + Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType, + vectorTupleValues); + // Replace 'xferReadOp' with result 'insertSlicesResult'. + rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>( + xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(), + extractSlicesOp.strides()); return matchSuccess(); } }; |