diff options
| author | Andy Davis <andydavis@google.com> | 2019-12-19 12:22:35 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-19 12:27:59 -0800 |
| commit | 1d798b1d27fb150de47266b009a414db46344f5a (patch) | |
| tree | 2fa5f7e42ef4f895b25bc9789b7d3a5554761f95 /mlir/test/Dialect/VectorOps | |
| parent | 1bcd8ef32f8104cc4bbe9e7003cf8a23c51ae24f (diff) | |
| download | bcm5719-llvm-1d798b1d27fb150de47266b009a414db46344f5a.tar.gz bcm5719-llvm-1d798b1d27fb150de47266b009a414db46344f5a.zip | |
[VectorOps] Add vector ReshapeOp to the VectorOps dialect.
Adds vector ReshapeOp to the VectorOps dialect. An aggregate vector reshape operation, which aggregates multiple hardware vectors, can enable optimizations during decomposition (e.g. loading one input hardware vector and performing multiple rotate and scatter store operations to the vector output).
PiperOrigin-RevId: 286440658
Diffstat (limited to 'mlir/test/Dialect/VectorOps')
| -rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 60 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 17 |
2 files changed, 77 insertions, 0 deletions
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index d79c0350910..c208c92fc23 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -826,3 +826,63 @@ func @print_no_result(%arg0 : f32) -> i32 { %0 = vector.print %arg0 : f32 return %0 } + +// ----- + +func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{invalid input shape for vector type}} + %1 = vector.reshape %arg0, [%c3, %c6, %c3], [%c2, %c9], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> +} + +// ----- + +func @reshape_bad_output_shape(%arg0 : vector<3x2x4xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{invalid output shape for vector type}} + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9, %c3], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> +} + +// ----- + +func @reshape_bad_input_output_shape_product(%arg0 : vector<3x2x4xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{product of input and output shape sizes must match}} + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c6], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> +} + +// ----- + +func @reshape_bad_input_fixed_size(%arg0 : vector<3x2x5xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{fixed vector size must match input vector for dim 0}} + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] + : vector<3x2x5xf32> to vector<2x3x4xf32> +} + +// ----- + +func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{fixed vector size must match output vector for dim 0}} + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] + : vector<3x2x4xf32> to vector<2x3x5xf32> +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 06d57289363..e1607996cc2 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -205,3 +205,20 @@ func @vector_print(%arg0: vector<8x4xf32>) { vector.print %arg0 : vector<8x4xf32> return } + +// CHECK-LABEL: reshape +func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) { + // CHECK: %[[C2:.*]] = constant 2 : index + %c2 = constant 2 : index + // CHECK: %[[C3:.*]] = constant 3 : index + %c3 = constant 3 : index + // CHECK: %[[C6:.*]] = constant 6 : index + %c6 = constant 6 : index + // CHECK: %[[C9:.*]] = constant 9 : index + %c9 = constant 9 : index + // CHECK: vector.reshape %{{.*}}, [%[[C3]], %[[C6]]], [%[[C2]], %[[C9]]], [4] : vector<3x2x4xf32> to vector<2x3x4xf32> + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> + + return %1 : vector<2x3x4xf32> +} |

