summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp187
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp12
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir200
-rw-r--r--mlir/test/Dialect/VectorOps/invalid.mlir14
4 files changed, 403 insertions, 10 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 7221998ce25..c40c7c5242a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -49,6 +49,191 @@ static LLVM::LLVMType getPtrToElementType(T containerType,
.getPointerTo();
}
+class VectorBroadcastOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorBroadcastOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto broadcastOp = cast<vector::BroadcastOp>(op);
+ VectorType dstVectorType = broadcastOp.getVectorType();
+ if (lowering.convertType(dstVectorType) == nullptr)
+ return matchFailure();
+ // Rewrite when the full vector type can be lowered (which
+ // implies all 'reduced' types can be lowered too).
+ VectorType srcVectorType =
+ broadcastOp.getSourceType().dyn_cast<VectorType>();
+ rewriter.replaceOp(
+ op, expandRanks(operands[0], // source value to be expanded
+ op->getLoc(), // location of original broadcast
+ srcVectorType, dstVectorType, rewriter));
+ return matchSuccess();
+ }
+
+private:
+ // Expands the given source value over all the ranks, as defined
+ // by the source and destination type (a null source type denotes
+ // expansion from a scalar value into a vector).
+ //
+ // TODO(ajcbik): consider replacing this one-pattern lowering
+ // with a two-pattern lowering using other vector
+ // ops once all insert/extract/shuffle operations
+ // are available with lowering implemention.
+ //
+ Value *expandRanks(Value *value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType,
+ ConversionPatternRewriter &rewriter) const {
+ assert((dstVectorType != nullptr) && "invalid result type in broadcast");
+ // Determine rank of source and destination.
+ int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
+ int64_t dstRank = dstVectorType.getRank();
+ int64_t curDim = dstVectorType.getDimSize(0);
+ if (srcRank < dstRank)
+ // Duplicate this rank.
+ return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
+ curDim, rewriter);
+ // If all trailing dimensions are the same, the broadcast consists of
+ // simply passing through the source value and we are done. Otherwise,
+ // any non-matching dimension forces a stretch along this rank.
+ assert((srcVectorType != nullptr) && (srcRank > 0) &&
+ (srcRank == dstRank) && "invalid rank in broadcast");
+ for (int64_t r = 0; r < dstRank; r++) {
+ if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
+ return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
+ curDim, rewriter);
+ }
+ }
+ return value;
+ }
+
+ // Picks the best way to duplicate a single rank. For the 1-D case, a
+ // single insert-elt/shuffle is the most efficient expansion. For higher
+ // dimensions, however, we need dim x insert-values on a new broadcast
+ // with one less leading dimension, which will be lowered "recursively"
+ // to matching LLVM IR.
+ // For example:
+ // v = broadcast s : f32 to vector<4x2xf32>
+ // becomes:
+ // x = broadcast s : f32 to vector<2xf32>
+ // v = [x,x,x,x]
+ // becomes:
+ // x = [s,s]
+ // v = [x,x,x,x]
+ Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType, int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
+ Type llvmType = lowering.convertType(dstVectorType);
+ assert((llvmType != nullptr) && "unlowerable vector type");
+ if (rank == 1) {
+ Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value *expand = insertOne(undef, value, loc, llvmType, rank, 0, rewriter);
+ SmallVector<int32_t, 4> zeroValues(dim, 0);
+ return rewriter.create<LLVM::ShuffleVectorOp>(
+ loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
+ }
+ Value *expand = expandRanks(value, loc, srcVectorType,
+ reducedVectorType(dstVectorType), rewriter);
+ Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ for (int64_t d = 0; d < dim; ++d) {
+ result = insertOne(result, expand, loc, llvmType, rank, d, rewriter);
+ }
+ return result;
+ }
+
+ // Picks the best way to stretch a single rank. For the 1-D case, a
+ // single insert-elt/shuffle is the most efficient expansion when at
+ // a stretch. Otherwise, every dimension needs to be expanded
+ // individually and individually inserted in the resulting vector.
+ // For example:
+ // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32>
+ // becomes:
+ // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32>
+ // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32>
+ // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32>
+ // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32>
+ // v = [a,b,c,d]
+ // becomes:
+ // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32>
+ // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
+ // a = [x, y]
+ // etc.
+ Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType, int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
+ Type llvmType = lowering.convertType(dstVectorType);
+ assert((llvmType != nullptr) && "unlowerable vector type");
+ Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ bool atStretch = dim != srcVectorType.getDimSize(0);
+ if (rank == 1) {
+ Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
+ if (atStretch) {
+ Value *one = extractOne(value, loc, redLlvmType, rank, 0, rewriter);
+ Value *expand =
+ insertOne(result, one, loc, llvmType, rank, 0, rewriter);
+ SmallVector<int32_t, 4> zeroValues(dim, 0);
+ return rewriter.create<LLVM::ShuffleVectorOp>(
+ loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
+ }
+ for (int64_t d = 0; d < dim; ++d) {
+ Value *one = extractOne(value, loc, redLlvmType, rank, d, rewriter);
+ result = insertOne(result, one, loc, llvmType, rank, d, rewriter);
+ }
+ } else {
+ VectorType redSrcType = reducedVectorType(srcVectorType);
+ VectorType redDstType = reducedVectorType(dstVectorType);
+ Type redLlvmType = lowering.convertType(redSrcType);
+ for (int64_t d = 0; d < dim; ++d) {
+ int64_t pos = atStretch ? 0 : d;
+ Value *one = extractOne(value, loc, redLlvmType, rank, pos, rewriter);
+ Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
+ result = insertOne(result, expand, loc, llvmType, rank, d, rewriter);
+ }
+ }
+ return result;
+ }
+
+ // Picks the proper sequence for inserting.
+ Value *insertOne(Value *val1, Value *val2, Location loc, Type llvmType,
+ int64_t rank, int64_t pos,
+ ConversionPatternRewriter &rewriter) const {
+ if (rank == 1) {
+ auto idxType = rewriter.getIndexType();
+ auto constant = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(idxType),
+ rewriter.getIntegerAttr(idxType, pos));
+ return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
+ constant);
+ }
+ return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
+ rewriter.getI64ArrayAttr(pos));
+ }
+
+ // Picks the proper sequence for extracting.
+ Value *extractOne(Value *value, Location loc, Type llvmType, int64_t rank,
+ int64_t pos, ConversionPatternRewriter &rewriter) const {
+ if (rank == 1) {
+ auto idxType = rewriter.getIndexType();
+ auto constant = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(idxType),
+ rewriter.getIntegerAttr(idxType, pos));
+ return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, value,
+ constant);
+ }
+ return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value,
+ rewriter.getI64ArrayAttr(pos));
+ }
+
+ // Helper to reduce vector type by one rank.
+ static VectorType reducedVectorType(VectorType tp) {
+ assert((tp.getRank() > 1) && "unlowerable vector type");
+ return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
+ }
+};
+
class VectorExtractElementOpConversion : public LLVMOpLowering {
public:
explicit VectorExtractElementOpConversion(MLIRContext *context,
@@ -246,7 +431,7 @@ public:
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<VectorExtractElementOpConversion,
+ patterns.insert<VectorBroadcastOpConversion, VectorExtractElementOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 5d596f388ed..65441674165 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -416,16 +416,16 @@ static LogicalResult verify(BroadcastOp op) {
// Scalar to vector broadcast is always valid. A vector
// to vector broadcast needs some additional checking.
if (srcVectorType) {
- const int64_t srcRank = srcVectorType.getRank();
- const int64_t dstRank = dstVectorType.getRank();
+ int64_t srcRank = srcVectorType.getRank();
+ int64_t dstRank = dstVectorType.getRank();
if (srcRank > dstRank)
return op.emitOpError("source rank higher than destination rank");
// Source has an exact match or singleton value for all trailing dimensions
// (all leading dimensions are simply duplicated).
- const int64_t lead = dstRank - srcRank;
- for (int64_t i = 0; i < srcRank; i++) {
- const int64_t srcDim = srcVectorType.getDimSize(i);
- const int64_t dstDim = dstVectorType.getDimSize(lead + i);
+ int64_t lead = dstRank - srcRank;
+ for (int64_t r = 0; r < srcRank; ++r) {
+ int64_t srcDim = srcVectorType.getDimSize(r);
+ int64_t dstDim = dstVectorType.getDimSize(lead + r);
if (srcDim != 1 && srcDim != dstDim)
return op.emitOpError("dimension mismatch (")
<< srcDim << " vs. " << dstDim << ")";
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 025027dcddc..b07a8634da4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1,5 +1,205 @@
// RUN: mlir-opt %s -convert-vector-to-llvm | FileCheck %s
+func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+// CHECK-LABEL: broadcast_vec1d_from_scalar
+// CHECK: llvm.mlir.undef : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>">
+
+func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+// CHECK-LABEL: broadcast_vec2d_from_scalar
+// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
+
+func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32>
+ return %0 : vector<2x3x4xf32>
+}
+// CHECK-LABEL: broadcast_vec3d_from_scalar
+// CHECK: llvm.mlir.undef : !llvm<"<4 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
+// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
+// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.mlir.undef : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: llvm.return {{.*}} : !llvm<"[2 x [3 x <4 x float>]]">
+
+func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+// CHECK-LABEL: broadcast_vec1d_from_vec1d
+// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>">
+
+func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
+// CHECK-LABEL: broadcast_vec2d_from_vec1d
+// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[3 x <2 x float>]">
+
+func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32>
+ return %0 : vector<4x3x2xf32>
+}
+// CHECK-LABEL: broadcast_vec3d_from_vec1d
+// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]">
+
+func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32>
+ return %0 : vector<4x3x2xf32>
+}
+// CHECK-LABEL: broadcast_vec3d_from_vec2d
+// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]">
+
+func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
+ %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+// CHECK-LABEL: broadcast_stretch
+// CHECK: llvm.mlir.undef : !llvm<"<4 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
+// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
+// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>">
+
+func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32>
+ return %0 : vector<3x4xf32>
+}
+// CHECK-LABEL: broadcast_stretch_at_start
+// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[3 x <4 x float>]">
+
+func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
+ %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32>
+ return %0 : vector<4x3xf32>
+}
+// CHECK-LABEL: broadcast_stretch_at_end
+// CHECK: llvm.mlir.undef : !llvm<"[4 x <3 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x <1 x float>]">
+// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <3 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x <1 x float>]">
+// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x <3 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <1 x float>]">
+// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <3 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <1 x float>]">
+// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <3 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x <3 x float>]">
+
+func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
+ return %0 : vector<4x3x2xf32>
+}
+// CHECK-LABEL: broadcast_stretch_in_middle
+// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]">
+
func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
%2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
return %2 : vector<2x3xf32>
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index 3b521f6e9ba..b70fc23ef1d 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -2,22 +2,30 @@
// -----
+func @broadcast_to_scalar(%arg0: f32) -> f32 {
+ // expected-error@+1 {{'vector.broadcast' op result #0 must be vector of any type values, but got 'f32'}}
+ %0 = vector.broadcast %arg0 : f32 to f32
+ return %0 : f32
+}
+
+// -----
+
func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
- // expected-error@+1 {{source rank higher than destination rank}}
+ // expected-error@+1 {{'vector.broadcast' op source rank higher than destination rank}}
%1 = vector.broadcast %arg0 : vector<4x4xf32> to vector<4xf32>
}
// -----
func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) {
- // expected-error@+1 {{vector.broadcast' op dimension mismatch (7 vs. 3)}}
+ // expected-error@+1 {{'vector.broadcast' op dimension mismatch (7 vs. 3)}}
%1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32>
}
// -----
func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
- // expected-error@+1 {{vector.broadcast' op dimension mismatch (4 vs. 1)}}
+ // expected-error@+1 {{'vector.broadcast' op dimension mismatch (4 vs. 1)}}
%1 = vector.broadcast %arg0 : vector<4x8xf32> to vector<1x8xf32>
}
OpenPOWER on IntegriCloud