summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2020-01-09 02:58:21 -0500
committerNicolas Vasilache <ntv@google.com>2020-01-09 03:03:51 -0500
commit65678d938431c90408afa8d255cbed3d8ed8273f (patch)
treef5efc92f66b1e1954236faa26c9c1fdf4dead892
parent24b326cc610dfdccdd50bc78505ec228d96c8e7a (diff)
downloadbcm5719-llvm-65678d938431c90408afa8d255cbed3d8ed8273f.tar.gz
bcm5719-llvm-65678d938431c90408afa8d255cbed3d8ed8273f.zip
[mlir][VectorOps] Implement strided_slice conversion
Summary: This diff implements the progressive lowering of strided_slice to either: 1. extractelement + insertelement for the 1-D case 2. extract + optional strided_slice + insert for the n-D case. This combines properly with the other conversion patterns to lower all the way to LLVM. Appropriate tests are added. Reviewers: ftynse, rriddle, AlexEichenberger, andydavis1, tetuante Reviewed By: andydavis1 Subscribers: merge_guards_bot, mehdi_amini, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72310
-rw-r--r--mlir/include/mlir/IR/Attributes.h19
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp101
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir61
3 files changed, 178 insertions, 3 deletions
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index b8398580f61..64b8063bdcb 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -215,6 +215,25 @@ public:
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Array;
}
+
+private:
+ /// Class for underlying value iterator support.
+ template <typename AttrTy>
+ class attr_value_iterator final
+ : public llvm::mapped_iterator<iterator, AttrTy (*)(Attribute)> {
+ public:
+ explicit attr_value_iterator(iterator it)
+ : llvm::mapped_iterator<iterator, AttrTy (*)(Attribute)>(
+ it, [](Attribute attr) { return attr.cast<AttrTy>(); }) {}
+ AttrTy operator*() { return (*this->I).template cast<AttrTy>(); }
+ };
+
+public:
+ template <typename AttrTy>
+ llvm::iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
+ return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
+ attr_value_iterator<AttrTy>(end()));
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b48930c4dda..7035c2e55bc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -6,10 +6,11 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -31,6 +32,7 @@
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
+using namespace mlir::vector;
template <typename T>
static LLVM::LLVMType getPtrToElementType(T containerType,
@@ -723,15 +725,108 @@ private:
}
};
+// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
+static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront = 0,
+ unsigned dropBack = 0) {
+ assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+ auto range = arrayAttr.getAsRange<IntegerAttr>();
+ SmallVector<int64_t, 4> res;
+ res.reserve(arrayAttr.size() - dropFront - dropBack);
+ for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+ it != eit; ++it)
+ res.push_back((*it).getValue().getSExtValue());
+ return res;
+}
+
+/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank
+/// of `vector`.
+static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
+ int64_t offset) {
+ auto vectorType = vector.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<ExtractOp>(loc, vector, offset);
+ return rewriter.create<vector::ExtractElementOp>(
+ loc, vectorType.getElementType(), vector,
+ rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank
+/// of `vector`.
+static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
+ Value into, int64_t offset) {
+ auto vectorType = into.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<InsertOp>(loc, from, into, offset);
+ return rewriter.create<vector::InsertElementOp>(
+ loc, vectorType, from, into,
+ rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+/// Progressive lowering of StridedSliceOp to either:
+/// 1. extractelement + insertelement for the 1-D case
+/// 2. extract + optional strided_slice + insert for the n-D case.
+class VectorStridedSliceOpRewritePattern
+ : public OpRewritePattern<StridedSliceOp> {
+public:
+ using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(StridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getResult().getType().cast<VectorType>();
+
+ assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
+
+ int64_t offset =
+ op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t stride =
+ op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+ auto loc = op.getLoc();
+ auto elemType = dstType.getElementType();
+ assert(elemType.isIntOrIndexOrFloat());
+ Value zero = rewriter.create<ConstantOp>(loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value res = rewriter.create<SplatOp>(loc, dstType, zero);
+ for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+ off += stride, ++idx) {
+ Value extracted = extractOne(rewriter, loc, op.vector(), off);
+ if (op.offsets().getValue().size() > 1) {
+ StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
+ loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
+ getI64SubArray(op.sizes(), /* dropFront=*/1),
+ getI64SubArray(op.strides(), /* dropFront=*/1));
+ // Call matchAndRewrite recursively from within the pattern. This
+ // circumvents the current limitation that a given pattern cannot
+ // be called multiple times by the PatternRewrite infrastructure (to
+ // avoid infinite recursion, but in this case, infinite recursion
+ // cannot happen because the rank is strictly decreasing).
+ // TODO(rriddle, nicolasvasilache) Implement something like a hook for
+ // a potential function that must decrease and allow the same pattern
+ // multiple times.
+ auto success = matchAndRewrite(stridedSliceOp, rewriter);
+ (void)success;
+ assert(success && "Unexpected failure");
+ extracted = stridedSliceOp;
+ }
+ res = insertOne(rewriter, loc, extracted, res, idx);
+ }
+ rewriter.replaceOp(op, {res});
+ return matchSuccess();
+ }
+};
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ MLIRContext *ctx = converter.getDialect()->getContext();
+ patterns.insert<VectorStridedSliceOpRewritePattern>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion,
- VectorPrintOpConversion>(converter.getDialect()->getContext(),
- converter);
+ VectorPrintOpConversion>(ctx, converter);
}
namespace {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1725a0b7c75..3a001211430 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -423,3 +423,64 @@ func @vector_print_vector(%arg0: vector<2x2xf32>) {
// CHECK: llvm.call @print_close() : () -> ()
// CHECK: llvm.call @print_close() : () -> ()
// CHECK: llvm.call @print_newline() : () -> ()
+
+
+func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) {
+// CHECK-LABEL: llvm.func @strided_slice(
+
+ %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+
+ %1 = vector.strided_slice %arg1 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]">
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"[2 x <8 x float>]">
+// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]">
+
+ %2 = vector.strided_slice %arg1 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]">
+//
+// Subvector vector<8xf32> @2
+// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <2 x float>]">
+//
+// Subvector vector<8xf32> @3
+// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]">
+
+ return
+}
+
+
OpenPOWER on IntegriCloud