summaryrefslogtreecommitdiffstats
path: root/mlir/test/Dialect/VectorOps
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-08-16 03:52:56 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-16 03:53:26 -0700
commitf826ceef3ce8bfea1b78ab7bb2c60c53eb13729a (patch)
tree574e65c0dbb0f8f89c1219a7f4ccfcf0547d20ba /mlir/test/Dialect/VectorOps
parentcc980aa41651c2cbfcbd9048fb0788f4aa9ae475 (diff)
downloadbcm5719-llvm-f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a.tar.gz
bcm5719-llvm-f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a.zip
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction. When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it. In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...). This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage. This has been independently verified to result in proper fma instructions for haswell as follows. Input: ``` func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> { %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32> return %2 : vector<17x8xf32> } } ``` Command: ``` mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2 ``` Output: ``` outerproduct_add: # @outerproduct_add # %bb.0: ... vmovaps 112(%rbp), %ymm8 vbroadcastss %xmm0, %ymm0 ... vbroadcastss 64(%rbp), %ymm15 vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem ... vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem ... ``` PiperOrigin-RevId: 263743359
Diffstat (limited to 'mlir/test/Dialect/VectorOps')
-rw-r--r--mlir/test/Dialect/VectorOps/invalid.mlir63
-rw-r--r--mlir/test/Dialect/VectorOps/ops.mlir6
2 files changed, 56 insertions, 13 deletions
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index 7917f14e881..ca339e7362a 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -2,39 +2,54 @@
// -----
-// CHECK-LABEL: position_empty
-func @position_empty(%arg0: vector<4x8x16xf32>) {
+func @extract_element_vector_type(%arg0: index) {
+ // expected-error@+1 {{expected vector type}}
+ %1 = vector.extractelement %arg0[] : index
+}
+
+// -----
+
+func @extractelement_position_empty(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected non-empty position attribute}}
%1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_rank_overflow
-func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank smaller than vector}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_overflow
-func @position_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
+ // expected-error@+1 {{expected position attribute of rank smaller than vector}}
+ %1 = "vector.extractelement" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
+}
+
+// -----
+
+func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_underflow
-func @position_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: outerproduct_non_vector_operand
+func @outerproduct_num_operands(%arg0: f32) {
+ // expected-error@+1 {{expected at least 2 operands}}
+ %1 = vector.outerproduct %arg0 : f32, f32
+}
+// -----
+
func @outerproduct_non_vector_operand(%arg0: f32) {
// expected-error@+1 {{expected 2 vector types}}
%1 = vector.outerproduct %arg0, %arg0 : f32, f32
@@ -42,7 +57,6 @@ func @outerproduct_non_vector_operand(%arg0: f32) {
// -----
-// CHECK-LABEL: outerproduct_operand_1
func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// expected-error@+1 {{expected 1-d vector for operand #1}}
%1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
@@ -50,8 +64,35 @@ func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// -----
-// CHECK-LABEL: outerproduct_operand_2
func @outerproduct_operand_2(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// expected-error@+1 {{expected 1-d vector for operand #2}}
%1 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<4x8xf32>
}
+
+// -----
+
+func @outerproduct_result_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected 2-d vector result}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_1_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected #1 operand dim to match result dim #1}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8x16xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_2_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected #2 operand dim to match result dim #2}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<4x16xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) {
+ // expected-error@+1 {{expected operand #3 of same type as result type}}
+ %1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>)
+}
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index a072b5c0689..067345af0d9 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -12,8 +12,10 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x
}
// CHECK-LABEL: outerproduct
-func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
+func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
%0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
- return %0 : vector<4x8xf32>
+ // CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
+ %1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
+ return %1 : vector<4x8xf32>
}
OpenPOWER on IntegriCloud