summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp101
1 files changed, 98 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b48930c4dda..7035c2e55bc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -6,10 +6,11 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -31,6 +32,7 @@
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
+using namespace mlir::vector;
template <typename T>
static LLVM::LLVMType getPtrToElementType(T containerType,
@@ -723,15 +725,108 @@ private:
}
};
+// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
+static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront = 0,
+ unsigned dropBack = 0) {
+ assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+ auto range = arrayAttr.getAsRange<IntegerAttr>();
+ SmallVector<int64_t, 4> res;
+ res.reserve(arrayAttr.size() - dropFront - dropBack);
+ for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+ it != eit; ++it)
+ res.push_back((*it).getValue().getSExtValue());
+ return res;
+}
+
+/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank
+/// of `vector`.
+static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
+ int64_t offset) {
+ auto vectorType = vector.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<ExtractOp>(loc, vector, offset);
+ return rewriter.create<vector::ExtractElementOp>(
+ loc, vectorType.getElementType(), vector,
+ rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank
+/// of `vector`.
+static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
+ Value into, int64_t offset) {
+ auto vectorType = into.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<InsertOp>(loc, from, into, offset);
+ return rewriter.create<vector::InsertElementOp>(
+ loc, vectorType, from, into,
+ rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+/// Progressive lowering of StridedSliceOp to either:
+/// 1. extractelement + insertelement for the 1-D case
+/// 2. extract + optional strided_slice + insert for the n-D case.
+class VectorStridedSliceOpRewritePattern
+ : public OpRewritePattern<StridedSliceOp> {
+public:
+ using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(StridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getResult().getType().cast<VectorType>();
+
+ assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
+
+ int64_t offset =
+ op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t stride =
+ op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+ auto loc = op.getLoc();
+ auto elemType = dstType.getElementType();
+ assert(elemType.isIntOrIndexOrFloat());
+ Value zero = rewriter.create<ConstantOp>(loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value res = rewriter.create<SplatOp>(loc, dstType, zero);
+ for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+ off += stride, ++idx) {
+ Value extracted = extractOne(rewriter, loc, op.vector(), off);
+ if (op.offsets().getValue().size() > 1) {
+ StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
+ loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
+ getI64SubArray(op.sizes(), /* dropFront=*/1),
+ getI64SubArray(op.strides(), /* dropFront=*/1));
+ // Call matchAndRewrite recursively from within the pattern. This
+ // circumvents the current limitation that a given pattern cannot
+ // be called multiple times by the PatternRewrite infrastructure (to
+ // avoid infinite recursion, but in this case, infinite recursion
+ // cannot happen because the rank is strictly decreasing).
+ // TODO(rriddle, nicolasvasilache) Implement something like a hook for
+ // a potential function that must decrease and allow the same pattern
+ // multiple times.
+ auto success = matchAndRewrite(stridedSliceOp, rewriter);
+ (void)success;
+ assert(success && "Unexpected failure");
+ extracted = stridedSliceOp;
+ }
+ res = insertOne(rewriter, loc, extracted, res, idx);
+ }
+ rewriter.replaceOp(op, {res});
+ return matchSuccess();
+ }
+};
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ MLIRContext *ctx = converter.getDialect()->getContext();
+ patterns.insert<VectorStridedSliceOpRewritePattern>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion,
- VectorPrintOpConversion>(converter.getDialect()->getContext(),
- converter);
+ VectorPrintOpConversion>(ctx, converter);
}
namespace {
OpenPOWER on IntegriCloud