summaryrefslogtreecommitdiffstats
path: root/mlir/test/Dialect/VectorOps
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-12-19 12:22:35 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-19 12:27:59 -0800
commit1d798b1d27fb150de47266b009a414db46344f5a (patch)
tree2fa5f7e42ef4f895b25bc9789b7d3a5554761f95 /mlir/test/Dialect/VectorOps
parent1bcd8ef32f8104cc4bbe9e7003cf8a23c51ae24f (diff)
downloadbcm5719-llvm-1d798b1d27fb150de47266b009a414db46344f5a.tar.gz
bcm5719-llvm-1d798b1d27fb150de47266b009a414db46344f5a.zip
[VectorOps] Add vector ReshapeOp to the VectorOps dialect.
Adds vector ReshapeOp to the VectorOps dialect. An aggregate vector reshape operation, which aggregates multiple hardware vectors, can enable optimizations during decomposition (e.g. loading one input hardware vector and performing multiple rotate and scatter store operations to the vector output). PiperOrigin-RevId: 286440658
Diffstat (limited to 'mlir/test/Dialect/VectorOps')
-rw-r--r--mlir/test/Dialect/VectorOps/invalid.mlir60
-rw-r--r--mlir/test/Dialect/VectorOps/ops.mlir17
2 files changed, 77 insertions, 0 deletions
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index d79c0350910..c208c92fc23 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -826,3 +826,63 @@ func @print_no_result(%arg0 : f32) -> i32 {
%0 = vector.print %arg0 : f32
return %0
}
+
+// -----
+
+func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{invalid input shape for vector type}}
+ %1 = vector.reshape %arg0, [%c3, %c6, %c3], [%c2, %c9], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+}
+
+// -----
+
+func @reshape_bad_output_shape(%arg0 : vector<3x2x4xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{invalid output shape for vector type}}
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9, %c3], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+}
+
+// -----
+
+func @reshape_bad_input_output_shape_product(%arg0 : vector<3x2x4xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{product of input and output shape sizes must match}}
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c6], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+}
+
+// -----
+
+func @reshape_bad_input_fixed_size(%arg0 : vector<3x2x5xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{fixed vector size must match input vector for dim 0}}
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
+ : vector<3x2x5xf32> to vector<2x3x4xf32>
+}
+
+// -----
+
+func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{fixed vector size must match output vector for dim 0}}
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
+ : vector<3x2x4xf32> to vector<2x3x5xf32>
+}
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index 06d57289363..e1607996cc2 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -205,3 +205,20 @@ func @vector_print(%arg0: vector<8x4xf32>) {
vector.print %arg0 : vector<8x4xf32>
return
}
+
+// CHECK-LABEL: reshape
+func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
+ // CHECK: %[[C2:.*]] = constant 2 : index
+ %c2 = constant 2 : index
+ // CHECK: %[[C3:.*]] = constant 3 : index
+ %c3 = constant 3 : index
+ // CHECK: %[[C6:.*]] = constant 6 : index
+ %c6 = constant 6 : index
+ // CHECK: %[[C9:.*]] = constant 9 : index
+ %c9 = constant 9 : index
+ // CHECK: vector.reshape %{{.*}}, [%[[C3]], %[[C6]]], [%[[C2]], %[[C9]]], [4] : vector<3x2x4xf32> to vector<2x3x4xf32>
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+
+ return %1 : vector<2x3x4xf32>
+}
OpenPOWER on IntegriCloud