diff options
| author | Aart Bik <ajcbik@google.com> | 2019-12-02 09:56:58 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-02 09:57:27 -0800 |
| commit | 3126004a5a8bef0ac079869626b322c2fdbbd655 (patch) | |
| tree | 1cb04680e284de453ce1777c69cb1acd53d96787 /mlir/test/Dialect/VectorOps | |
| parent | b41162b3af62668b36076baebf765044e21c04ba (diff) | |
| download | bcm5719-llvm-3126004a5a8bef0ac079869626b322c2fdbbd655.tar.gz bcm5719-llvm-3126004a5a8bef0ac079869626b322c2fdbbd655.zip | |
[VectorOps] Add legality rules to broadcast
PiperOrigin-RevId: 283360101
Diffstat (limited to 'mlir/test/Dialect/VectorOps')
| -rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 14 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 8 |
2 files changed, 20 insertions, 2 deletions
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index d672b1bf140..0fbcb56f388 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -9,6 +9,20 @@ func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) { // ----- +func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) { + // expected-error@+1 {{vector.broadcast' op dimension mismatch (7 vs. 3)}} + %1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32> +} + +// ----- + +func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) { + // expected-error@+1 {{vector.broadcast' op dimension mismatch (4 vs. 1)}} + %1 = vector.broadcast %arg0 : vector<4x8xf32> to vector<1x8xf32> +} + +// ----- + func @extract_element_vector_type(%arg0: index) { // expected-error@+1 {{expected vector type}} %1 = vector.extractelement %arg0[] : index diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index d167559ac0c..3824dfe20e4 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -23,12 +23,16 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) { } // CHECK-LABEL: @vector_broadcast -func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> { +func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> { // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32> %0 = vector.broadcast %a : f32 to vector<16xf32> // CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32> %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32> - return %1 : vector<8x16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32> + %2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32> + %3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32> + return %3 : vector<8x16xf32> } // CHECK-LABEL: @extractelement |

