diff options
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 187 | ||||
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 12 | ||||
-rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 200 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 14 |
4 files changed, 403 insertions, 10 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 7221998ce25..c40c7c5242a 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -49,6 +49,191 @@ static LLVM::LLVMType getPtrToElementType(T containerType, .getPointerTo(); } +class VectorBroadcastOpConversion : public LLVMOpLowering { +public: + explicit VectorBroadcastOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto broadcastOp = cast<vector::BroadcastOp>(op); + VectorType dstVectorType = broadcastOp.getVectorType(); + if (lowering.convertType(dstVectorType) == nullptr) + return matchFailure(); + // Rewrite when the full vector type can be lowered (which + // implies all 'reduced' types can be lowered too). + VectorType srcVectorType = + broadcastOp.getSourceType().dyn_cast<VectorType>(); + rewriter.replaceOp( + op, expandRanks(operands[0], // source value to be expanded + op->getLoc(), // location of original broadcast + srcVectorType, dstVectorType, rewriter)); + return matchSuccess(); + } + +private: + // Expands the given source value over all the ranks, as defined + // by the source and destination type (a null source type denotes + // expansion from a scalar value into a vector). + // + // TODO(ajcbik): consider replacing this one-pattern lowering + // with a two-pattern lowering using other vector + // ops once all insert/extract/shuffle operations + // are available with lowering implemention. + // + Value *expandRanks(Value *value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, + ConversionPatternRewriter &rewriter) const { + assert((dstVectorType != nullptr) && "invalid result type in broadcast"); + // Determine rank of source and destination. + int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; + int64_t dstRank = dstVectorType.getRank(); + int64_t curDim = dstVectorType.getDimSize(0); + if (srcRank < dstRank) + // Duplicate this rank. + return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, + curDim, rewriter); + // If all trailing dimensions are the same, the broadcast consists of + // simply passing through the source value and we are done. Otherwise, + // any non-matching dimension forces a stretch along this rank. + assert((srcVectorType != nullptr) && (srcRank > 0) && + (srcRank == dstRank) && "invalid rank in broadcast"); + for (int64_t r = 0; r < dstRank; r++) { + if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { + return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, + curDim, rewriter); + } + } + return value; + } + + // Picks the best way to duplicate a single rank. For the 1-D case, a + // single insert-elt/shuffle is the most efficient expansion. For higher + // dimensions, however, we need dim x insert-values on a new broadcast + // with one less leading dimension, which will be lowered "recursively" + // to matching LLVM IR. + // For example: + // v = broadcast s : f32 to vector<4x2xf32> + // becomes: + // x = broadcast s : f32 to vector<2xf32> + // v = [x,x,x,x] + // becomes: + // x = [s,s] + // v = [x,x,x,x] + Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { + Type llvmType = lowering.convertType(dstVectorType); + assert((llvmType != nullptr) && "unlowerable vector type"); + if (rank == 1) { + Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); + Value *expand = insertOne(undef, value, loc, llvmType, rank, 0, rewriter); + SmallVector<int32_t, 4> zeroValues(dim, 0); + return rewriter.create<LLVM::ShuffleVectorOp>( + loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); + } + Value *expand = expandRanks(value, loc, srcVectorType, + reducedVectorType(dstVectorType), rewriter); + Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); + for (int64_t d = 0; d < dim; ++d) { + result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); + } + return result; + } + + // Picks the best way to stretch a single rank. For the 1-D case, a + // single insert-elt/shuffle is the most efficient expansion when at + // a stretch. Otherwise, every dimension needs to be expanded + // individually and individually inserted in the resulting vector. + // For example: + // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> + // becomes: + // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> + // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> + // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> + // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> + // v = [a,b,c,d] + // becomes: + // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> + // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> + // a = [x, y] + // etc. + Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { + Type llvmType = lowering.convertType(dstVectorType); + assert((llvmType != nullptr) && "unlowerable vector type"); + Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); + bool atStretch = dim != srcVectorType.getDimSize(0); + if (rank == 1) { + Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); + if (atStretch) { + Value *one = extractOne(value, loc, redLlvmType, rank, 0, rewriter); + Value *expand = + insertOne(result, one, loc, llvmType, rank, 0, rewriter); + SmallVector<int32_t, 4> zeroValues(dim, 0); + return rewriter.create<LLVM::ShuffleVectorOp>( + loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); + } + for (int64_t d = 0; d < dim; ++d) { + Value *one = extractOne(value, loc, redLlvmType, rank, d, rewriter); + result = insertOne(result, one, loc, llvmType, rank, d, rewriter); + } + } else { + VectorType redSrcType = reducedVectorType(srcVectorType); + VectorType redDstType = reducedVectorType(dstVectorType); + Type redLlvmType = lowering.convertType(redSrcType); + for (int64_t d = 0; d < dim; ++d) { + int64_t pos = atStretch ? 0 : d; + Value *one = extractOne(value, loc, redLlvmType, rank, pos, rewriter); + Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); + result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); + } + } + return result; + } + + // Picks the proper sequence for inserting. + Value *insertOne(Value *val1, Value *val2, Location loc, Type llvmType, + int64_t rank, int64_t pos, + ConversionPatternRewriter &rewriter) const { + if (rank == 1) { + auto idxType = rewriter.getIndexType(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(idxType), + rewriter.getIntegerAttr(idxType, pos)); + return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, + constant); + } + return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, + rewriter.getI64ArrayAttr(pos)); + } + + // Picks the proper sequence for extracting. + Value *extractOne(Value *value, Location loc, Type llvmType, int64_t rank, + int64_t pos, ConversionPatternRewriter &rewriter) const { + if (rank == 1) { + auto idxType = rewriter.getIndexType(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(idxType), + rewriter.getIntegerAttr(idxType, pos)); + return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, value, + constant); + } + return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value, + rewriter.getI64ArrayAttr(pos)); + } + + // Helper to reduce vector type by one rank. + static VectorType reducedVectorType(VectorType tp) { + assert((tp.getRank() > 1) && "unlowerable vector type"); + return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); + } +}; + class VectorExtractElementOpConversion : public LLVMOpLowering { public: explicit VectorExtractElementOpConversion(MLIRContext *context, @@ -246,7 +431,7 @@ public: /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert<VectorExtractElementOpConversion, + patterns.insert<VectorBroadcastOpConversion, VectorExtractElementOpConversion, VectorOuterProductOpConversion, VectorTypeCastOpConversion>( converter.getDialect()->getContext(), converter); } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 5d596f388ed..65441674165 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -416,16 +416,16 @@ static LogicalResult verify(BroadcastOp op) { // Scalar to vector broadcast is always valid. A vector // to vector broadcast needs some additional checking. if (srcVectorType) { - const int64_t srcRank = srcVectorType.getRank(); - const int64_t dstRank = dstVectorType.getRank(); + int64_t srcRank = srcVectorType.getRank(); + int64_t dstRank = dstVectorType.getRank(); if (srcRank > dstRank) return op.emitOpError("source rank higher than destination rank"); // Source has an exact match or singleton value for all trailing dimensions // (all leading dimensions are simply duplicated). - const int64_t lead = dstRank - srcRank; - for (int64_t i = 0; i < srcRank; i++) { - const int64_t srcDim = srcVectorType.getDimSize(i); - const int64_t dstDim = dstVectorType.getDimSize(lead + i); + int64_t lead = dstRank - srcRank; + for (int64_t r = 0; r < srcRank; ++r) { + int64_t srcDim = srcVectorType.getDimSize(r); + int64_t dstDim = dstVectorType.getDimSize(lead + r); if (srcDim != 1 && srcDim != dstDim) return op.emitOpError("dimension mismatch (") << srcDim << " vs. " << dstDim << ")"; diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 025027dcddc..b07a8634da4 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,5 +1,205 @@ // RUN: mlir-opt %s -convert-vector-to-llvm | FileCheck %s +func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2xf32> + return %0 : vector<2xf32> +} +// CHECK-LABEL: broadcast_vec1d_from_scalar +// CHECK: llvm.mlir.undef : !llvm<"<2 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>"> + +func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> + return %0 : vector<2x3xf32> +} +// CHECK-LABEL: broadcast_vec2d_from_scalar +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> + +func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> + return %0 : vector<2x3x4xf32> +} +// CHECK-LABEL: broadcast_vec3d_from_scalar +// CHECK: llvm.mlir.undef : !llvm<"<4 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[2 x [3 x <4 x float>]]"> + +func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> + return %0 : vector<2xf32> +} +// CHECK-LABEL: broadcast_vec1d_from_vec1d +// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>"> + +func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> + return %0 : vector<3x2xf32> +} +// CHECK-LABEL: broadcast_vec2d_from_vec1d +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[3 x <2 x float>]"> + +func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} +// CHECK-LABEL: broadcast_vec3d_from_vec1d +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> + +func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} +// CHECK-LABEL: broadcast_vec3d_from_vec2d +// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> + +func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: broadcast_stretch +// CHECK: llvm.mlir.undef : !llvm<"<4 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> +// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>"> + +func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} +// CHECK-LABEL: broadcast_stretch_at_start +// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[3 x <4 x float>]"> + +func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { + %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> + return %0 : vector<4x3xf32> +} +// CHECK-LABEL: broadcast_stretch_at_end +// CHECK: llvm.mlir.undef : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x <1 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x <1 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <1 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <1 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x <3 x float>]"> + +func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} +// CHECK-LABEL: broadcast_stretch_in_middle +// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> + func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> { %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> return %2 : vector<2x3xf32> diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 3b521f6e9ba..b70fc23ef1d 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -2,22 +2,30 @@ // ----- +func @broadcast_to_scalar(%arg0: f32) -> f32 { + // expected-error@+1 {{'vector.broadcast' op result #0 must be vector of any type values, but got 'f32'}} + %0 = vector.broadcast %arg0 : f32 to f32 + return %0 : f32 +} + +// ----- + func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) { - // expected-error@+1 {{source rank higher than destination rank}} + // expected-error@+1 {{'vector.broadcast' op source rank higher than destination rank}} %1 = vector.broadcast %arg0 : vector<4x4xf32> to vector<4xf32> } // ----- func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) { - // expected-error@+1 {{vector.broadcast' op dimension mismatch (7 vs. 3)}} + // expected-error@+1 {{'vector.broadcast' op dimension mismatch (7 vs. 3)}} %1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32> } // ----- func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) { - // expected-error@+1 {{vector.broadcast' op dimension mismatch (4 vs. 1)}} + // expected-error@+1 {{'vector.broadcast' op dimension mismatch (4 vs. 1)}} %1 = vector.broadcast %arg0 : vector<4x8xf32> to vector<1x8xf32> } |