summaryrefslogtreecommitdiffstats
path: root/mlir/test/Dialect/VectorOps
diff options
context:
space:
mode:
authorAart Bik <ajcbik@google.com>2019-12-02 09:56:58 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-02 09:57:27 -0800
commit3126004a5a8bef0ac079869626b322c2fdbbd655 (patch)
tree1cb04680e284de453ce1777c69cb1acd53d96787 /mlir/test/Dialect/VectorOps
parentb41162b3af62668b36076baebf765044e21c04ba (diff)
downloadbcm5719-llvm-3126004a5a8bef0ac079869626b322c2fdbbd655.tar.gz
bcm5719-llvm-3126004a5a8bef0ac079869626b322c2fdbbd655.zip
[VectorOps] Add legality rules to broadcast
PiperOrigin-RevId: 283360101
Diffstat (limited to 'mlir/test/Dialect/VectorOps')
-rw-r--r--mlir/test/Dialect/VectorOps/invalid.mlir14
-rw-r--r--mlir/test/Dialect/VectorOps/ops.mlir8
2 files changed, 20 insertions, 2 deletions
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index d672b1bf140..0fbcb56f388 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -9,6 +9,20 @@ func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
// -----
+func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) {
+ // expected-error@+1 {{vector.broadcast' op dimension mismatch (7 vs. 3)}}
+ %1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32>
+}
+
+// -----
+
+func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
+ // expected-error@+1 {{vector.broadcast' op dimension mismatch (4 vs. 1)}}
+ %1 = vector.broadcast %arg0 : vector<4x8xf32> to vector<1x8xf32>
+}
+
+// -----
+
func @extract_element_vector_type(%arg0: index) {
// expected-error@+1 {{expected vector type}}
%1 = vector.extractelement %arg0[] : index
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index d167559ac0c..3824dfe20e4 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -23,12 +23,16 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
}
// CHECK-LABEL: @vector_broadcast
-func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> {
+func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
%0 = vector.broadcast %a : f32 to vector<16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
%1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
- return %1 : vector<8x16xf32>
+ // CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
+ %2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
+ // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
+ %3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
+ return %3 : vector<8x16xf32>
}
// CHECK-LABEL: @extractelement
OpenPOWER on IntegriCloud