diff options
-rw-r--r-- | mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h | 7 | ||||
-rw-r--r-- | mlir/include/mlir/VectorOps/VectorOps.td | 22 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 90 | ||||
-rw-r--r-- | mlir/lib/VectorOps/VectorOps.cpp | 57 | ||||
-rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 68 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 63 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 6 |
7 files changed, 198 insertions, 115 deletions
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h index 39b7ee2d03f..7334c67e0d3 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h @@ -18,8 +18,15 @@ #define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_ namespace mlir { +class LLVMTypeConverter; class ModulePassBase; +class OwningRewritePatternList; +/// Collect a set of patterns to convert from the Vector dialect to LLVM. +void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert vector operations to the LLVMIR dialect. ModulePassBase *createLowerVectorToLLVMPass(); } // namespace mlir diff --git a/mlir/include/mlir/VectorOps/VectorOps.td b/mlir/include/mlir/VectorOps/VectorOps.td index 962e53b94c3..e6f543ff26e 100644 --- a/mlir/include/mlir/VectorOps/VectorOps.td +++ b/mlir/include/mlir/VectorOps/VectorOps.td @@ -72,17 +72,25 @@ def ExtractElementOp : } def OuterProductOp : Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs)>, + Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>, Results<(outs AnyVector)> { - let summary = "outerproduct operation"; + let summary = "vector outerproduct with optional fused add"; let description = [{ Takes 2 1-D vectors and returns the 2-D vector containing the outer product. - Example: - ``` + An optional extra 2-D vector argument may be specified in which case the + operation returns the sum of the outer product and the extra vector. When + lowered to the LLVMIR dialect, this form emits `llvm.fmuladd`, which can + lower to actual `fma` instructions in LLVM. + + Examples + %2 = vector.extractelement %0, %1: vector<4xf32>, vector<8xf32> return %2: vector<4x8xf32> - ``` + + %3 = vector.extractelement %0, %1, %2: + vector<4xf32>, vector<8xf32>, vector<4x8xf32> + return %3: vector<4x8xf32> }]; let extraClassDeclaration = [{ VectorType getOperandVectorTypeLHS() { @@ -91,6 +99,10 @@ def OuterProductOp : VectorType getOperandVectorTypeRHS() { return rhs()->getType().cast<VectorType>(); } + VectorType getOperandVectorTypeACC() { + return (llvm::size(acc()) == 0) ? VectorType() : + (*acc().begin())->getType().cast<VectorType>(); + } VectorType getVectorType() { return getResult()->getType().cast<VectorType>(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index bf90edba401..1e4b8ca6419 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -79,11 +79,8 @@ public: auto positionArrayAttr = extractOp.position(); // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa<VectorType>()) { - Value *extracted = - rewriter - .create<LLVM::ExtractValueOp>(loc, llvmResultType, - adaptor.vector(), positionArrayAttr) - .getResult(); + Value *extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return matchSuccess(); } @@ -92,29 +89,24 @@ public: auto *context = op->getContext(); Value *extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); - auto indexType = rewriter.getIndexType(); + auto i32Type = rewriter.getIntegerType(32); if (positionAttrs.size() > 1) { auto nDVectorType = vectorType; auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), nDVectorType.getElementType()); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); - extracted = rewriter - .create<LLVM::ExtractValueOp>( - loc, lowering.convertType(oneDVectorType), extracted, - nMinusOnePositionAttrs) - .getResult(); + extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, lowering.convertType(oneDVectorType), extracted, + nMinusOnePositionAttrs); } // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast<IntegerAttr>(); - auto constant = rewriter - .create<LLVM::ConstantOp>( - loc, lowering.convertType(indexType), position) - .getResult(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(i32Type), position); extracted = - rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant) - .getResult(); + rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); rewriter.replaceOp(op, extracted); return matchSuccess(); @@ -134,32 +126,38 @@ public: auto loc = op->getLoc(); auto adaptor = vector::OuterProductOpOperandAdaptor(operands); auto *ctx = op->getContext(); - auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); - auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); - auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements(); - auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements(); + auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); + auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); + auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); + auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( cast<vector::OuterProductOp>(op).getResult()->getType()); - Value *desc = - rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult(); - for (unsigned i = 0, e = rankV1; i < e; ++i) { - // Emit the following pattern: - // vec(a[i]) * b -> llvmStructOfVectType[i] - Value *a = adaptor.lhs(), *b = adaptor.rhs(); - // shufflevector explicitly requires i32 / - auto attr = rewriter.getI32IntegerAttr(i); - SmallVector<Attribute, 4> broadcastAttr(rankV2, attr); - auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx); - auto *broadcasted = - rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr) - .getResult(); - auto *multiplied = - rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult(); - desc = rewriter - .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc, - multiplied, - positionAttr(rewriter, i)) - .getResult(); + Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); + Value *a = adaptor.lhs(), *b = adaptor.rhs(); + Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); + SmallVector<Value *, 8> lhs, accs; + lhs.reserve(rankLHS); + accs.reserve(rankLHS); + for (unsigned d = 0, e = rankLHS; d < e; ++d) { + // shufflevector explicitly requires i32. + auto attr = rewriter.getI32IntegerAttr(d); + SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); + auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); + Value *aD = nullptr, *accD = nullptr; + // 1. Broadcast the element a[d] into vector aD. + aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); + // 2. If acc is present, extract 1-d vector acc[d] into accD. + if (acc) + accD = rewriter.create<LLVM::ExtractValueOp>(loc, vRHS, acc, + positionAttr(rewriter, d)); + // 3. Compute aD outer b (plus accD, if relevant). + Value *aOuterbD = + accD ? rewriter.create<LLVM::fmuladd>(loc, vRHS, aD, b, accD) + .getResult() + : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); + // 4. Insert as value `d` in the descriptor. + desc = rewriter.create<LLVM::InsertValueOp>( + loc, llvmArrayOfVectType, desc, aOuterbD, positionAttr(rewriter, d)); } rewriter.replaceOp(op, desc); return matchSuccess(); @@ -167,12 +165,10 @@ public: }; /// Populate the given list with patterns that convert from Vector to LLVM. -static void -populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns, - MLIRContext *ctx) { +void mlir::populateVectorToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>( - ctx, converter); + converter.getDialect()->getContext(), converter); } namespace { @@ -185,7 +181,7 @@ void LowerVectorToLLVMPass::runOnModule() { // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); - populateVectorToLLVMConversionPatterns(converter, patterns, &getContext()); + populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); diff --git a/mlir/lib/VectorOps/VectorOps.cpp b/mlir/lib/VectorOps/VectorOps.cpp index 38267af32cf..0bd552ed6a9 100644 --- a/mlir/lib/VectorOps/VectorOps.cpp +++ b/mlir/lib/VectorOps/VectorOps.cpp @@ -116,45 +116,54 @@ static LogicalResult verify(ExtractElementOp op) { static void print(OpAsmPrinter *p, OuterProductOp op) { *p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs(); + if (llvm::size(op.acc()) > 0) + *p << ", " << **op.acc().begin(); *p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType(); } static ParseResult parseOuterProductOp(OpAsmParser *parser, OperationState *result) { - SmallVector<OpAsmParser::OperandType, 2> operandsInfo; - Type t0, t1; - if (parser->parseOperandList(operandsInfo) || parser->parseColonType(t0) || - parser->parseComma() || parser->parseType(t1)) + SmallVector<OpAsmParser::OperandType, 3> operandsInfo; + Type tLHS, tRHS; + if (parser->parseOperandList(operandsInfo) || parser->parseColonType(tLHS) || + parser->parseComma() || parser->parseType(tRHS)) return failure(); - VectorType v0 = t0.dyn_cast<VectorType>(); - VectorType v1 = t1.dyn_cast<VectorType>(); - if (!v0 || !v1) + if (operandsInfo.size() < 2) + return parser->emitError(parser->getNameLoc(), + "expected at least 2 operands"); + VectorType vLHS = tLHS.dyn_cast<VectorType>(); + VectorType vRHS = tRHS.dyn_cast<VectorType>(); + if (!vLHS || !vRHS) return parser->emitError(parser->getNameLoc(), "expected 2 vector types"); - VectorType resType = VectorType::get({v0.getDimSize(0), v1.getDimSize(0)}, - v0.getElementType()); - return failure(parser->resolveOperands(operandsInfo, {t0, t1}, - parser->getCurrentLocation(), - result->operands) || - parser->addTypeToList(resType, result->types)); + VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, + vLHS.getElementType()); + return failure( + parser->resolveOperand(operandsInfo[0], tLHS, result->operands) || + parser->resolveOperand(operandsInfo[1], tRHS, result->operands) || + (operandsInfo.size() > 2 && + parser->resolveOperand(operandsInfo[2], resType, result->operands)) || + parser->addTypeToList(resType, result->types)); } static LogicalResult verify(OuterProductOp op) { - VectorType v1 = op.getOperandVectorTypeLHS(), - v2 = op.getOperandVectorTypeRHS(), res = op.getVectorType(); - if (v1.getRank() != 1) + VectorType vLHS = op.getOperandVectorTypeLHS(), + vRHS = op.getOperandVectorTypeRHS(), + vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType(); + if (vLHS.getRank() != 1) return op.emitOpError("expected 1-d vector for operand #1"); - if (v2.getRank() != 1) + if (vRHS.getRank() != 1) return op.emitOpError("expected 1-d vector for operand #2"); - if (res.getRank() != 2) + if (vRES.getRank() != 2) return op.emitOpError("expected 2-d vector result"); - if (v1.getDimSize(0) != res.getDimSize(0)) - return op.emitOpError( - "expected first operand dim to match first result dim"); - if (v2.getDimSize(0) != res.getDimSize(1)) - return op.emitOpError( - "expected second operand dim to match second result dim"); + if (vLHS.getDimSize(0) != vRES.getDimSize(0)) + return op.emitOpError("expected #1 operand dim to match result dim #1"); + if (vRHS.getDimSize(0) != vRES.getDimSize(1)) + return op.emitOpError("expected #2 operand dim to match result dim #2"); + if (vACC && vACC != vRES) + return op.emitOpError("expected operand #3 of same type as result type"); return success(); } + //===----------------------------------------------------------------------===// // VectorTransferReadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index f582de146ba..532a4c2e369 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,33 +1,49 @@ // RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s -func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { - %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> - %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32> - return %3 : vector<8xf32> +func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> { + %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> + return %2 : vector<2x3xf32> } -// CHECK-LABEL: vec_1d -// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]"> -// CHECK-5: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"<8 x float>"> +// CHECK-LABEL: outerproduct +// CHECK: llvm.undef : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> -func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> { - %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> - return %2 : vector<4x8xf32> +func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> + return %2 : vector<2x3xf32> } -// CHECK-LABEL: vec_2d -// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]"> -// CHECK-4: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x <8 x float>]"> +// CHECK-LABEL: outerproduct_add +// CHECK: llvm.undef : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> +// CHECK: "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> +// CHECK: "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> -func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> { - %0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32> - return %0 : vector<8x16xf32> +func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> { + %0 = vector.extractelement %arg0[0 : i32]: vector<4x3x16xf32> + return %0 : vector<3x16xf32> } -// CHECK-LABEL: vec_3d -// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> -// CHECK: llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]">
\ No newline at end of file +// CHECK-LABEL: extract_vec_2d_from_vec_3d +// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> +// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]"> + +func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { + %0 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32> + return %0 : f32 +} +// CHECK-LABEL: extract_element_from_vec_3d +// CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> +// CHECK: llvm.constant(0 : i32) : !llvm.i32 +// CHECK: llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>"> +// CHECK: llvm.return %{{.*}} : !llvm.float
\ No newline at end of file diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 7917f14e881..ca339e7362a 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -2,39 +2,54 @@ // ----- -// CHECK-LABEL: position_empty -func @position_empty(%arg0: vector<4x8x16xf32>) { +func @extract_element_vector_type(%arg0: index) { + // expected-error@+1 {{expected vector type}} + %1 = vector.extractelement %arg0[] : index +} + +// ----- + +func @extractelement_position_empty(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected non-empty position attribute}} %1 = vector.extractelement %arg0[] : vector<4x8x16xf32> } // ----- -// CHECK-LABEL: position_rank_overflow -func @position_rank_overflow(%arg0: vector<4x8x16xf32>) { +func @extractelement_position_rank_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than vector}} %1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32> } // ----- -// CHECK-LABEL: position_overflow -func @position_overflow(%arg0: vector<4x8x16xf32>) { +func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) { + // expected-error@+1 {{expected position attribute of rank smaller than vector}} + %1 = "vector.extractelement" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>) +} + +// ----- + +func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}} %1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32> } // ----- -// CHECK-LABEL: position_underflow -func @position_overflow(%arg0: vector<4x8x16xf32>) { +func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}} %1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32> } // ----- -// CHECK-LABEL: outerproduct_non_vector_operand +func @outerproduct_num_operands(%arg0: f32) { + // expected-error@+1 {{expected at least 2 operands}} + %1 = vector.outerproduct %arg0 : f32, f32 +} +// ----- + func @outerproduct_non_vector_operand(%arg0: f32) { // expected-error@+1 {{expected 2 vector types}} %1 = vector.outerproduct %arg0, %arg0 : f32, f32 @@ -42,7 +57,6 @@ func @outerproduct_non_vector_operand(%arg0: f32) { // ----- -// CHECK-LABEL: outerproduct_operand_1 func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) { // expected-error@+1 {{expected 1-d vector for operand #1}} %1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32> @@ -50,8 +64,35 @@ func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) { // ----- -// CHECK-LABEL: outerproduct_operand_2 func @outerproduct_operand_2(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) { // expected-error@+1 {{expected 1-d vector for operand #2}} %1 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<4x8xf32> } + +// ----- + +func @outerproduct_result_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) { + // expected-error@+1 {{expected 2-d vector result}} + %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8xf32>) +} + +// ----- + +func @outerproduct_operand_1_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) { + // expected-error@+1 {{expected #1 operand dim to match result dim #1}} + %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8x16xf32>) +} + +// ----- + +func @outerproduct_operand_2_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) { + // expected-error@+1 {{expected #2 operand dim to match result dim #2}} + %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<4x16xf32>) +} + +// ----- + +func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) { + // expected-error@+1 {{expected operand #3 of same type as result type}} + %1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>) +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index a072b5c0689..067345af0d9 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -12,8 +12,10 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x } // CHECK-LABEL: outerproduct -func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> { +func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> { // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32> %0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> - return %0 : vector<4x8xf32> + // CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32> + %1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32> + return %1 : vector<4x8xf32> } |