summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2020-01-09 03:12:41 -0500
committerNicolas Vasilache <ntv@google.com>2020-01-09 03:13:01 -0500
commit2d515e49d89c0738ccef8f1733d5f9afe00ee979 (patch)
tree46c20bc47996d20398c819dcd9cdde7e6eed63cd /mlir/lib/Conversion
parent65678d938431c90408afa8d255cbed3d8ed8273f (diff)
downloadbcm5719-llvm-2d515e49d89c0738ccef8f1733d5f9afe00ee979.tar.gz
bcm5719-llvm-2d515e49d89c0738ccef8f1733d5f9afe00ee979.zip
[mlir][VectorOps] Implement insert_strided_slice conversion
Summary: This diff implements the progressive lowering of insert_strided_slice. Two cases appear: 1. when the source and dest vectors have different ranks, extract the dest subvector at the proper offset and reduce to case 2. 2. when they have the same rank N: a. if the source and dest type are the same, the insertion is trivial: just forward the source b. otherwise, iterate over all N-1 D subvectors and create an extract/insert_strided_slice/insert replacement, reducing the problem to vecotrs of the same N-1 rank. This combines properly with the other conversion patterns to lower all the way to LLVM. Reviewers: ftynse, rriddle, AlexEichenberger, andydavis1, tetuante, nicolasvasilache Reviewed By: andydavis1 Subscribers: merge_guards_bot, mehdi_amini, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72317
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp215
1 files changed, 174 insertions, 41 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 7035c2e55bc..1fbee9742e0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -70,6 +70,17 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
rewriter.getI64ArrayAttr(pos));
}
+// Helper that picks the proper sequence for inserting.
+static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
+ Value into, int64_t offset) {
+ auto vectorType = into.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<InsertOp>(loc, from, into, offset);
+ return rewriter.create<vector::InsertElementOp>(
+ loc, vectorType, from, into,
+ rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
// Helper that picks the proper sequence for extracting.
static Value extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &lowering, Location loc, Value val,
@@ -86,6 +97,32 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
rewriter.getI64ArrayAttr(pos));
}
+// Helper that picks the proper sequence for extracting.
+static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
+ int64_t offset) {
+ auto vectorType = vector.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<ExtractOp>(loc, vector, offset);
+ return rewriter.create<vector::ExtractElementOp>(
+ loc, vectorType.getElementType(), vector,
+ rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
+// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
+static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront = 0,
+ unsigned dropBack = 0) {
+ assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+ auto range = arrayAttr.getAsRange<IntegerAttr>();
+ SmallVector<int64_t, 4> res;
+ res.reserve(arrayAttr.size() - dropFront - dropBack);
+ for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+ it != eit; ++it)
+ res.push_back((*it).getValue().getSExtValue());
+ return res;
+}
+
class VectorBroadcastOpConversion : public LLVMOpLowering {
public:
explicit VectorBroadcastOpConversion(MLIRContext *context,
@@ -464,6 +501,139 @@ public:
}
};
+// When ranks are different, InsertStridedSlice needs to extract a properly
+// ranked vector from the destination vector into which to insert. This pattern
+// only takes care of this part and forwards the rest of the conversion to
+// another pattern that converts InsertStridedSlice for operands of the same
+// rank.
+//
+// RewritePattern for InsertStridedSliceOp where source and destination vectors
+// have different ranks. In this case:
+// 1. the proper subvector is extracted from the destination vector
+// 2. a new InsertStridedSlice op is created to insert the source in the
+// destination subvector
+// 3. the destination subvector is inserted back in the proper place
+// 4. the op is replaced by the result of step 3.
+// The new InsertStridedSlice from step 2. will be picked up by a
+// `VectorInsertStridedSliceOpSameRankRewritePattern`.
+class VectorInsertStridedSliceOpDifferentRankRewritePattern
+ : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+ using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto srcType = op.getSourceVectorType();
+ auto dstType = op.getDestVectorType();
+
+ if (op.offsets().getValue().empty())
+ return matchFailure();
+
+ auto loc = op.getLoc();
+ int64_t rankDiff = dstType.getRank() - srcType.getRank();
+ assert(rankDiff >= 0);
+ if (rankDiff == 0)
+ return matchFailure();
+
+ int64_t rankRest = dstType.getRank() - rankDiff;
+ // Extract / insert the subvector of matching rank and InsertStridedSlice
+ // on it.
+ Value extracted =
+ rewriter.create<ExtractOp>(loc, op.dest(),
+ getI64SubArray(op.offsets(), /*dropFront=*/0,
+ /*dropFront=*/rankRest));
+ // A different pattern will kick in for InsertStridedSlice with matching
+ // ranks.
+ auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
+ loc, op.source(), extracted,
+ getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
+ getI64SubArray(op.strides(), /*dropFront=*/rankDiff));
+ rewriter.replaceOpWithNewOp<InsertOp>(
+ op, stridedSliceInnerOp.getResult(), op.dest(),
+ getI64SubArray(op.offsets(), /*dropFront=*/0,
+ /*dropFront=*/rankRest));
+ return matchSuccess();
+ }
+};
+
+// RewritePattern for InsertStridedSliceOp where source and destination vectors
+// have the same rank. In this case, we reduce
+// 1. the proper subvector is extracted from the destination vector
+// 2. a new InsertStridedSlice op is created to insert the source in the
+// destination subvector
+// 3. the destination subvector is inserted back in the proper place
+// 4. the op is replaced by the result of step 3.
+// The new InsertStridedSlice from step 2. will be picked up by a
+// `VectorInsertStridedSliceOpSameRankRewritePattern`.
+class VectorInsertStridedSliceOpSameRankRewritePattern
+ : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+ using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto srcType = op.getSourceVectorType();
+ auto dstType = op.getDestVectorType();
+
+ if (op.offsets().getValue().empty())
+ return matchFailure();
+
+ int64_t rankDiff = dstType.getRank() - srcType.getRank();
+ assert(rankDiff >= 0);
+ if (rankDiff != 0)
+ return matchFailure();
+
+ if (srcType == dstType) {
+ rewriter.replaceOp(op, op.source());
+ return matchSuccess();
+ }
+
+ int64_t offset =
+ op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size = srcType.getShape().front();
+ int64_t stride =
+ op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+ auto loc = op.getLoc();
+ Value res = op.dest();
+ // For each slice of the source vector along the most major dimension.
+ for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+ off += stride, ++idx) {
+ // 1. extract the proper subvector (or element) from source
+ Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
+ if (extractedSource.getType().isa<VectorType>()) {
+ // 2. If we have a vector, extract the proper subvector from destination
+ // Otherwise we are at the element level and no need to recurse.
+ Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
+ // 3. Reduce the problem to lowering a new InsertStridedSlice op with
+ // smaller rank.
+ InsertStridedSliceOp insertStridedSliceOp =
+ rewriter.create<InsertStridedSliceOp>(
+ loc, extractedSource, extractedDest,
+ getI64SubArray(op.offsets(), /* dropFront=*/1),
+ getI64SubArray(op.strides(), /* dropFront=*/1));
+ // Call matchAndRewrite recursively from within the pattern. This
+ // circumvents the current limitation that a given pattern cannot
+ // be called multiple times by the PatternRewrite infrastructure (to
+ // avoid infinite recursion, but in this case, infinite recursion
+ // cannot happen because the rank is strictly decreasing).
+ // TODO(rriddle, nicolasvasilache) Implement something like a hook for
+ // a potential function that must decrease and allow the same pattern
+ // multiple times.
+ auto success = matchAndRewrite(insertStridedSliceOp, rewriter);
+ (void)success;
+ assert(success && "Unexpected failure");
+ extractedSource = insertStridedSliceOp;
+ }
+ // 4. Insert the extractedSource into the res vector.
+ res = insertOne(rewriter, loc, extractedSource, res, off);
+ }
+
+ rewriter.replaceOp(op, res);
+ return matchSuccess();
+ }
+};
+
class VectorOuterProductOpConversion : public LLVMOpLowering {
public:
explicit VectorOuterProductOpConversion(MLIRContext *context,
@@ -725,49 +895,10 @@ private:
}
};
-// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
-static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
- unsigned dropFront = 0,
- unsigned dropBack = 0) {
- assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
- auto range = arrayAttr.getAsRange<IntegerAttr>();
- SmallVector<int64_t, 4> res;
- res.reserve(arrayAttr.size() - dropFront - dropBack);
- for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
- it != eit; ++it)
- res.push_back((*it).getValue().getSExtValue());
- return res;
-}
-
-/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank
-/// of `vector`.
-static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
- int64_t offset) {
- auto vectorType = vector.getType().cast<VectorType>();
- if (vectorType.getRank() > 1)
- return rewriter.create<ExtractOp>(loc, vector, offset);
- return rewriter.create<vector::ExtractElementOp>(
- loc, vectorType.getElementType(), vector,
- rewriter.create<ConstantIndexOp>(loc, offset));
-}
-
-/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank
-/// of `vector`.
-static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
- Value into, int64_t offset) {
- auto vectorType = into.getType().cast<VectorType>();
- if (vectorType.getRank() > 1)
- return rewriter.create<InsertOp>(loc, from, into, offset);
- return rewriter.create<vector::InsertElementOp>(
- loc, vectorType, from, into,
- rewriter.create<ConstantIndexOp>(loc, offset));
-}
-
/// Progressive lowering of StridedSliceOp to either:
/// 1. extractelement + insertelement for the 1-D case
/// 2. extract + optional strided_slice + insert for the n-D case.
-class VectorStridedSliceOpRewritePattern
- : public OpRewritePattern<StridedSliceOp> {
+class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
public:
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
@@ -821,7 +952,9 @@ public:
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
- patterns.insert<VectorStridedSliceOpRewritePattern>(ctx);
+ patterns.insert<VectorInsertStridedSliceOpDifferentRankRewritePattern,
+ VectorInsertStridedSliceOpSameRankRewritePattern,
+ VectorStridedSliceOpConversion>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
OpenPOWER on IntegriCloud