summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-08-16 03:52:56 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-16 03:53:26 -0700
commitf826ceef3ce8bfea1b78ab7bb2c60c53eb13729a (patch)
tree574e65c0dbb0f8f89c1219a7f4ccfcf0547d20ba
parentcc980aa41651c2cbfcbd9048fb0788f4aa9ae475 (diff)
downloadbcm5719-llvm-f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a.tar.gz
bcm5719-llvm-f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a.zip
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction. When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it. In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...). This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage. This has been independently verified to result in proper fma instructions for haswell as follows. Input: ``` func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> { %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32> return %2 : vector<17x8xf32> } } ``` Command: ``` mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2 ``` Output: ``` outerproduct_add: # @outerproduct_add # %bb.0: ... vmovaps 112(%rbp), %ymm8 vbroadcastss %xmm0, %ymm0 ... vbroadcastss 64(%rbp), %ymm15 vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem ... vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem ... ``` PiperOrigin-RevId: 263743359
-rw-r--r--mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h7
-rw-r--r--mlir/include/mlir/VectorOps/VectorOps.td22
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp90
-rw-r--r--mlir/lib/VectorOps/VectorOps.cpp57
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir68
-rw-r--r--mlir/test/Dialect/VectorOps/invalid.mlir63
-rw-r--r--mlir/test/Dialect/VectorOps/ops.mlir6
7 files changed, 198 insertions, 115 deletions
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
index 39b7ee2d03f..7334c67e0d3 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
@@ -18,8 +18,15 @@
#define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
namespace mlir {
+class LLVMTypeConverter;
class ModulePassBase;
+class OwningRewritePatternList;
+/// Collect a set of patterns to convert from the Vector dialect to LLVM.
+void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns);
+
+/// Create a pass to convert vector operations to the LLVMIR dialect.
ModulePassBase *createLowerVectorToLLVMPass();
} // namespace mlir
diff --git a/mlir/include/mlir/VectorOps/VectorOps.td b/mlir/include/mlir/VectorOps/VectorOps.td
index 962e53b94c3..e6f543ff26e 100644
--- a/mlir/include/mlir/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/VectorOps/VectorOps.td
@@ -72,17 +72,25 @@ def ExtractElementOp :
}
def OuterProductOp :
Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
- Arguments<(ins AnyVector:$lhs, AnyVector:$rhs)>,
+ Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
Results<(outs AnyVector)> {
- let summary = "outerproduct operation";
+ let summary = "vector outerproduct with optional fused add";
let description = [{
Takes 2 1-D vectors and returns the 2-D vector containing the outer product.
- Example:
- ```
+ An optional extra 2-D vector argument may be specified in which case the
+ operation returns the sum of the outer product and the extra vector. When
+ lowered to the LLVMIR dialect, this form emits `llvm.fmuladd`, which can
+ lower to actual `fma` instructions in LLVM.
+
+ Examples
+
%2 = vector.extractelement %0, %1: vector<4xf32>, vector<8xf32>
return %2: vector<4x8xf32>
- ```
+
+ %3 = vector.extractelement %0, %1, %2:
+ vector<4xf32>, vector<8xf32>, vector<4x8xf32>
+ return %3: vector<4x8xf32>
}];
let extraClassDeclaration = [{
VectorType getOperandVectorTypeLHS() {
@@ -91,6 +99,10 @@ def OuterProductOp :
VectorType getOperandVectorTypeRHS() {
return rhs()->getType().cast<VectorType>();
}
+ VectorType getOperandVectorTypeACC() {
+ return (llvm::size(acc()) == 0) ? VectorType() :
+ (*acc().begin())->getType().cast<VectorType>();
+ }
VectorType getVectorType() {
return getResult()->getType().cast<VectorType>();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index bf90edba401..1e4b8ca6419 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -79,11 +79,8 @@ public:
auto positionArrayAttr = extractOp.position();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
- Value *extracted =
- rewriter
- .create<LLVM::ExtractValueOp>(loc, llvmResultType,
- adaptor.vector(), positionArrayAttr)
- .getResult();
+ Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
}
@@ -92,29 +89,24 @@ public:
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
- auto indexType = rewriter.getIndexType();
+ auto i32Type = rewriter.getIntegerType(32);
if (positionAttrs.size() > 1) {
auto nDVectorType = vectorType;
auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
nDVectorType.getElementType());
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
- extracted = rewriter
- .create<LLVM::ExtractValueOp>(
- loc, lowering.convertType(oneDVectorType), extracted,
- nMinusOnePositionAttrs)
- .getResult();
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs);
}
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto constant = rewriter
- .create<LLVM::ConstantOp>(
- loc, lowering.convertType(indexType), position)
- .getResult();
+ auto constant = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(i32Type), position);
extracted =
- rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant)
- .getResult();
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
return matchSuccess();
@@ -134,32 +126,38 @@ public:
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
auto *ctx = op->getContext();
- auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
- auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
- auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements();
- auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements();
+ auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
+ auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
+ auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
+ auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
- Value *desc =
- rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult();
- for (unsigned i = 0, e = rankV1; i < e; ++i) {
- // Emit the following pattern:
- // vec(a[i]) * b -> llvmStructOfVectType[i]
- Value *a = adaptor.lhs(), *b = adaptor.rhs();
- // shufflevector explicitly requires i32 /
- auto attr = rewriter.getI32IntegerAttr(i);
- SmallVector<Attribute, 4> broadcastAttr(rankV2, attr);
- auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx);
- auto *broadcasted =
- rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr)
- .getResult();
- auto *multiplied =
- rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult();
- desc = rewriter
- .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc,
- multiplied,
- positionAttr(rewriter, i))
- .getResult();
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+ Value *a = adaptor.lhs(), *b = adaptor.rhs();
+ Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+ SmallVector<Value *, 8> lhs, accs;
+ lhs.reserve(rankLHS);
+ accs.reserve(rankLHS);
+ for (unsigned d = 0, e = rankLHS; d < e; ++d) {
+ // shufflevector explicitly requires i32.
+ auto attr = rewriter.getI32IntegerAttr(d);
+ SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
+ auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
+ Value *aD = nullptr, *accD = nullptr;
+ // 1. Broadcast the element a[d] into vector aD.
+ aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
+ // 2. If acc is present, extract 1-d vector acc[d] into accD.
+ if (acc)
+ accD = rewriter.create<LLVM::ExtractValueOp>(loc, vRHS, acc,
+ positionAttr(rewriter, d));
+ // 3. Compute aD outer b (plus accD, if relevant).
+ Value *aOuterbD =
+ accD ? rewriter.create<LLVM::fmuladd>(loc, vRHS, aD, b, accD)
+ .getResult()
+ : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
+ // 4. Insert as value `d` in the descriptor.
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmArrayOfVectType, desc, aOuterbD, positionAttr(rewriter, d));
}
rewriter.replaceOp(op, desc);
return matchSuccess();
@@ -167,12 +165,10 @@ public:
};
/// Populate the given list with patterns that convert from Vector to LLVM.
-static void
-populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
+void mlir::populateVectorToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
- ctx, converter);
+ converter.getDialect()->getContext(), converter);
}
namespace {
@@ -185,7 +181,7 @@ void LowerVectorToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
- populateVectorToLLVMConversionPatterns(converter, patterns, &getContext());
+ populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
diff --git a/mlir/lib/VectorOps/VectorOps.cpp b/mlir/lib/VectorOps/VectorOps.cpp
index 38267af32cf..0bd552ed6a9 100644
--- a/mlir/lib/VectorOps/VectorOps.cpp
+++ b/mlir/lib/VectorOps/VectorOps.cpp
@@ -116,45 +116,54 @@ static LogicalResult verify(ExtractElementOp op) {
static void print(OpAsmPrinter *p, OuterProductOp op) {
*p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
+ if (llvm::size(op.acc()) > 0)
+ *p << ", " << **op.acc().begin();
*p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
}
static ParseResult parseOuterProductOp(OpAsmParser *parser,
OperationState *result) {
- SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
- Type t0, t1;
- if (parser->parseOperandList(operandsInfo) || parser->parseColonType(t0) ||
- parser->parseComma() || parser->parseType(t1))
+ SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
+ Type tLHS, tRHS;
+ if (parser->parseOperandList(operandsInfo) || parser->parseColonType(tLHS) ||
+ parser->parseComma() || parser->parseType(tRHS))
return failure();
- VectorType v0 = t0.dyn_cast<VectorType>();
- VectorType v1 = t1.dyn_cast<VectorType>();
- if (!v0 || !v1)
+ if (operandsInfo.size() < 2)
+ return parser->emitError(parser->getNameLoc(),
+ "expected at least 2 operands");
+ VectorType vLHS = tLHS.dyn_cast<VectorType>();
+ VectorType vRHS = tRHS.dyn_cast<VectorType>();
+ if (!vLHS || !vRHS)
return parser->emitError(parser->getNameLoc(), "expected 2 vector types");
- VectorType resType = VectorType::get({v0.getDimSize(0), v1.getDimSize(0)},
- v0.getElementType());
- return failure(parser->resolveOperands(operandsInfo, {t0, t1},
- parser->getCurrentLocation(),
- result->operands) ||
- parser->addTypeToList(resType, result->types));
+ VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+ vLHS.getElementType());
+ return failure(
+ parser->resolveOperand(operandsInfo[0], tLHS, result->operands) ||
+ parser->resolveOperand(operandsInfo[1], tRHS, result->operands) ||
+ (operandsInfo.size() > 2 &&
+ parser->resolveOperand(operandsInfo[2], resType, result->operands)) ||
+ parser->addTypeToList(resType, result->types));
}
static LogicalResult verify(OuterProductOp op) {
- VectorType v1 = op.getOperandVectorTypeLHS(),
- v2 = op.getOperandVectorTypeRHS(), res = op.getVectorType();
- if (v1.getRank() != 1)
+ VectorType vLHS = op.getOperandVectorTypeLHS(),
+ vRHS = op.getOperandVectorTypeRHS(),
+ vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
+ if (vLHS.getRank() != 1)
return op.emitOpError("expected 1-d vector for operand #1");
- if (v2.getRank() != 1)
+ if (vRHS.getRank() != 1)
return op.emitOpError("expected 1-d vector for operand #2");
- if (res.getRank() != 2)
+ if (vRES.getRank() != 2)
return op.emitOpError("expected 2-d vector result");
- if (v1.getDimSize(0) != res.getDimSize(0))
- return op.emitOpError(
- "expected first operand dim to match first result dim");
- if (v2.getDimSize(0) != res.getDimSize(1))
- return op.emitOpError(
- "expected second operand dim to match second result dim");
+ if (vLHS.getDimSize(0) != vRES.getDimSize(0))
+ return op.emitOpError("expected #1 operand dim to match result dim #1");
+ if (vRHS.getDimSize(0) != vRES.getDimSize(1))
+ return op.emitOpError("expected #2 operand dim to match result dim #2");
+ if (vACC && vACC != vRES)
+ return op.emitOpError("expected operand #3 of same type as result type");
return success();
}
+
//===----------------------------------------------------------------------===//
// VectorTransferReadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f582de146ba..532a4c2e369 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1,33 +1,49 @@
// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s
-func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
- %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
- %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32>
- return %3 : vector<8xf32>
+func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
+ %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
+ return %2 : vector<2x3xf32>
}
-// CHECK-LABEL: vec_1d
-// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
-// CHECK-5: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]">
-// CHECK: llvm.return {{.*}} : !llvm<"<8 x float>">
+// CHECK-LABEL: outerproduct
+// CHECK: llvm.undef : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
-func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
- %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
- return %2 : vector<4x8xf32>
+func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+ %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
+ return %2 : vector<2x3xf32>
}
-// CHECK-LABEL: vec_2d
-// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
-// CHECK-4: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
-// CHECK: llvm.return {{.*}} : !llvm<"[4 x <8 x float>]">
+// CHECK-LABEL: outerproduct_add
+// CHECK: llvm.undef : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+// CHECK: "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+// CHECK: "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
-func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> {
- %0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32>
- return %0 : vector<8x16xf32>
+func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
+ %0 = vector.extractelement %arg0[0 : i32]: vector<4x3x16xf32>
+ return %0 : vector<3x16xf32>
}
-// CHECK-LABEL: vec_3d
-// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
-// CHECK: llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]"> \ No newline at end of file
+// CHECK-LABEL: extract_vec_2d_from_vec_3d
+// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
+// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]">
+
+func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
+ %0 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
+ return %0 : f32
+}
+// CHECK-LABEL: extract_element_from_vec_3d
+// CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
+// CHECK: llvm.constant(0 : i32) : !llvm.i32
+// CHECK: llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>">
+// CHECK: llvm.return %{{.*}} : !llvm.float \ No newline at end of file
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index 7917f14e881..ca339e7362a 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -2,39 +2,54 @@
// -----
-// CHECK-LABEL: position_empty
-func @position_empty(%arg0: vector<4x8x16xf32>) {
+func @extract_element_vector_type(%arg0: index) {
+ // expected-error@+1 {{expected vector type}}
+ %1 = vector.extractelement %arg0[] : index
+}
+
+// -----
+
+func @extractelement_position_empty(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected non-empty position attribute}}
%1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_rank_overflow
-func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank smaller than vector}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_overflow
-func @position_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
+ // expected-error@+1 {{expected position attribute of rank smaller than vector}}
+ %1 = "vector.extractelement" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
+}
+
+// -----
+
+func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_underflow
-func @position_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: outerproduct_non_vector_operand
+func @outerproduct_num_operands(%arg0: f32) {
+ // expected-error@+1 {{expected at least 2 operands}}
+ %1 = vector.outerproduct %arg0 : f32, f32
+}
+// -----
+
func @outerproduct_non_vector_operand(%arg0: f32) {
// expected-error@+1 {{expected 2 vector types}}
%1 = vector.outerproduct %arg0, %arg0 : f32, f32
@@ -42,7 +57,6 @@ func @outerproduct_non_vector_operand(%arg0: f32) {
// -----
-// CHECK-LABEL: outerproduct_operand_1
func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// expected-error@+1 {{expected 1-d vector for operand #1}}
%1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
@@ -50,8 +64,35 @@ func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// -----
-// CHECK-LABEL: outerproduct_operand_2
func @outerproduct_operand_2(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// expected-error@+1 {{expected 1-d vector for operand #2}}
%1 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<4x8xf32>
}
+
+// -----
+
+func @outerproduct_result_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected 2-d vector result}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_1_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected #1 operand dim to match result dim #1}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8x16xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_2_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected #2 operand dim to match result dim #2}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<4x16xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) {
+ // expected-error@+1 {{expected operand #3 of same type as result type}}
+ %1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>)
+}
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index a072b5c0689..067345af0d9 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -12,8 +12,10 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x
}
// CHECK-LABEL: outerproduct
-func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
+func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
%0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
- return %0 : vector<4x8xf32>
+ // CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
+ %1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
+ return %1 : vector<4x8xf32>
}
OpenPOWER on IntegriCloud