diff options
author | Nicolas Vasilache <ntv@google.com> | 2020-01-09 02:58:21 -0500 |
---|---|---|
committer | Nicolas Vasilache <ntv@google.com> | 2020-01-09 03:03:51 -0500 |
commit | 65678d938431c90408afa8d255cbed3d8ed8273f (patch) | |
tree | f5efc92f66b1e1954236faa26c9c1fdf4dead892 /mlir/lib/Conversion | |
parent | 24b326cc610dfdccdd50bc78505ec228d96c8e7a (diff) | |
download | bcm5719-llvm-65678d938431c90408afa8d255cbed3d8ed8273f.tar.gz bcm5719-llvm-65678d938431c90408afa8d255cbed3d8ed8273f.zip |
[mlir][VectorOps] Implement strided_slice conversion
Summary:
This diff implements the progressive lowering of strided_slice to either:
1. extractelement + insertelement for the 1-D case
2. extract + optional strided_slice + insert for the n-D case.
This combines properly with the other conversion patterns to lower all the way to LLVM.
Appropriate tests are added.
Reviewers: ftynse, rriddle, AlexEichenberger, andydavis1, tetuante
Reviewed By: andydavis1
Subscribers: merge_guards_bot, mehdi_amini, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72310
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 101 |
1 files changed, 98 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index b48930c4dda..7035c2e55bc 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -6,10 +6,11 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -31,6 +32,7 @@ #include "llvm/Support/ErrorHandling.h" using namespace mlir; +using namespace mlir::vector; template <typename T> static LLVM::LLVMType getPtrToElementType(T containerType, @@ -723,15 +725,108 @@ private: } }; +// TODO(rriddle): Better support for attribute subtype forwarding + slicing. +static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront = 0, + unsigned dropBack = 0) { + assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); + auto range = arrayAttr.getAsRange<IntegerAttr>(); + SmallVector<int64_t, 4> res; + res.reserve(arrayAttr.size() - dropFront - dropBack); + for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; + it != eit; ++it) + res.push_back((*it).getValue().getSExtValue()); + return res; +} + +/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank +/// of `vector`. +static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, + int64_t offset) { + auto vectorType = vector.getType().cast<VectorType>(); + if (vectorType.getRank() > 1) + return rewriter.create<ExtractOp>(loc, vector, offset); + return rewriter.create<vector::ExtractElementOp>( + loc, vectorType.getElementType(), vector, + rewriter.create<ConstantIndexOp>(loc, offset)); +} + +/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank +/// of `vector`. +static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, + Value into, int64_t offset) { + auto vectorType = into.getType().cast<VectorType>(); + if (vectorType.getRank() > 1) + return rewriter.create<InsertOp>(loc, from, into, offset); + return rewriter.create<vector::InsertElementOp>( + loc, vectorType, from, into, + rewriter.create<ConstantIndexOp>(loc, offset)); +} + +/// Progressive lowering of StridedSliceOp to either: +/// 1. extractelement + insertelement for the 1-D case +/// 2. extract + optional strided_slice + insert for the n-D case. +class VectorStridedSliceOpRewritePattern + : public OpRewritePattern<StridedSliceOp> { +public: + using OpRewritePattern<StridedSliceOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(StridedSliceOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getResult().getType().cast<VectorType>(); + + assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); + + int64_t offset = + op.offsets().getValue().front().cast<IntegerAttr>().getInt(); + int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); + int64_t stride = + op.strides().getValue().front().cast<IntegerAttr>().getInt(); + + auto loc = op.getLoc(); + auto elemType = dstType.getElementType(); + assert(elemType.isIntOrIndexOrFloat()); + Value zero = rewriter.create<ConstantOp>(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value res = rewriter.create<SplatOp>(loc, dstType, zero); + for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; + off += stride, ++idx) { + Value extracted = extractOne(rewriter, loc, op.vector(), off); + if (op.offsets().getValue().size() > 1) { + StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>( + loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), + getI64SubArray(op.sizes(), /* dropFront=*/1), + getI64SubArray(op.strides(), /* dropFront=*/1)); + // Call matchAndRewrite recursively from within the pattern. This + // circumvents the current limitation that a given pattern cannot + // be called multiple times by the PatternRewrite infrastructure (to + // avoid infinite recursion, but in this case, infinite recursion + // cannot happen because the rank is strictly decreasing). + // TODO(rriddle, nicolasvasilache) Implement something like a hook for + // a potential function that must decrease and allow the same pattern + // multiple times. + auto success = matchAndRewrite(stridedSliceOp, rewriter); + (void)success; + assert(success && "Unexpected failure"); + extracted = stridedSliceOp; + } + res = insertOne(rewriter, loc, extracted, res, idx); + } + rewriter.replaceOp(op, {res}); + return matchSuccess(); + } +}; + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + patterns.insert<VectorStridedSliceOpRewritePattern>(ctx); patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, VectorExtractElementOpConversion, VectorExtractOpConversion, VectorInsertElementOpConversion, VectorInsertOpConversion, VectorOuterProductOpConversion, VectorTypeCastOpConversion, - VectorPrintOpConversion>(converter.getDialect()->getContext(), - converter); + VectorPrintOpConversion>(ctx, converter); } namespace { |