diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-08-09 06:55:10 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-08-09 06:55:36 -0700 |
| commit | d2aba89f2e88d529ec203a7d9121c8451893e66b (patch) | |
| tree | f003593b5425c75d0cf902d92a34f7acbef4b963 /mlir/test/Dialect/VectorOps | |
| parent | 39f1b9a053a38c8acafbc0244028c0e9d665f63b (diff) | |
| download | bcm5719-llvm-d2aba89f2e88d529ec203a7d9121c8451893e66b.tar.gz bcm5719-llvm-d2aba89f2e88d529ec203a7d9121c8451893e66b.zip | |
Add a higher-order vector.outerproduct operation in MLIR
This CL is step 2/n towards building a simple, programmable and portable vector abstraction in MLIR that can go all the way down to generating assembly vector code via LLVM's opt and llc tools.
This CL adds the vector.outerproduct operation to the MLIR vector dialect as well as the appropriate roundtrip test. Lowering to LLVM will occur in the following CL.
PiperOrigin-RevId: 262552027
Diffstat (limited to 'mlir/test/Dialect/VectorOps')
| -rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 28 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 9 |
2 files changed, 32 insertions, 5 deletions
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 49fcefc475c..7917f14e881 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -6,7 +6,6 @@ func @position_empty(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected non-empty position attribute}} %1 = vector.extractelement %arg0[] : vector<4x8x16xf32> - return } // ----- @@ -15,7 +14,6 @@ func @position_empty(%arg0: vector<4x8x16xf32>) { func @position_rank_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than vector}} %1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32> - return } // ----- @@ -24,7 +22,6 @@ func @position_rank_overflow(%arg0: vector<4x8x16xf32>) { func @position_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}} %1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32> - return } // ----- @@ -33,5 +30,28 @@ func @position_overflow(%arg0: vector<4x8x16xf32>) { func @position_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}} %1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32> - return +} + +// ----- + +// CHECK-LABEL: outerproduct_non_vector_operand +func @outerproduct_non_vector_operand(%arg0: f32) { + // expected-error@+1 {{expected 2 vector types}} + %1 = vector.outerproduct %arg0, %arg0 : f32, f32 +} + +// ----- + +// CHECK-LABEL: outerproduct_operand_1 +func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) { + // expected-error@+1 {{expected 1-d vector for operand #1}} + %1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32> +} + +// ----- + +// CHECK-LABEL: outerproduct_operand_2 +func @outerproduct_operand_2(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) { + // expected-error@+1 {{expected 1-d vector for operand #2}} + %1 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<4x8xf32> } diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 11928adda8f..a072b5c0689 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -9,4 +9,11 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x // CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32> %3 = vector.extractelement %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32> return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32 -}
\ No newline at end of file +} + +// CHECK-LABEL: outerproduct +func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> { + // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32> + %0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> + return %0 : vector<4x8xf32> +} |

