diff options
| author | Andy Davis <andydavis@google.com> | 2019-12-19 16:04:59 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-19 16:05:32 -0800 |
| commit | 8020ad3e396bcca8dba94cea397cece81b76b119 (patch) | |
| tree | c5836064e80ece1a666f56a6441fe51bc8631431 /mlir/test/Dialect/VectorOps | |
| parent | 6685282253c33fa2c5dc7487b04fc92d47082e78 (diff) | |
| download | bcm5719-llvm-8020ad3e396bcca8dba94cea397cece81b76b119.tar.gz bcm5719-llvm-8020ad3e396bcca8dba94cea397cece81b76b119.zip | |
[VectorOps] Update vector transfer_read/write ops to operatate on memrefs with vector element type.
Update vector transfer_read/write ops to operatate on memrefs with vector element type.
This handle cases where the memref vector element type represents the minimal memory transfer unit (or multiple of the minimal memory transfer unit).
PiperOrigin-RevId: 286482115
Diffstat (limited to 'mlir/test/Dialect/VectorOps')
| -rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 30 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 19 |
2 files changed, 45 insertions, 4 deletions
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index c208c92fc23..9ef39e25144 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -308,6 +308,36 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) { // ----- +func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{requires memref and vector types of the same elemental type}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xi32> +} + +// ----- + +func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{requires memref vector element and vector result ranks to match}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<3xf32> +} + +// ----- + +func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x2x3xf32> +} + +// ----- + func @test_vector.transfer_write(%arg0: memref<?x?xf32>) { %c3 = constant 3 : index %cst = constant dense<3.0> : vector<128 x f32> diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index e1607996cc2..d99a7df0d2b 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -1,24 +1,35 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s +// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1) + // CHECK-LABEL: func @vector_transfer_ops( -func @vector_transfer_ops(%arg0: memref<?x?xf32>) { +func @vector_transfer_ops(%arg0: memref<?x?xf32>, + %arg1 : memref<?x?xvector<4x3xf32>>) { + // CHECK: %[[C3:.*]] = constant 3 : index %c3 = constant 3 : index %cst = constant 3.0 : f32 %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // - // CHECK: %0 = vector.transfer_read + // CHECK: vector.transfer_read %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d0)} : memref<?x?xf32>, vector<128xf32> - // CHECK: %1 = vector.transfer_read + // CHECK: vector.transfer_read %1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d1, d0)} : memref<?x?xf32>, vector<3x7xf32> // CHECK: vector.transfer_read %2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d0)} : memref<?x?xf32>, vector<128xf32> // CHECK: vector.transfer_read %3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d1)} : memref<?x?xf32>, vector<128xf32> - // + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32> + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32> // CHECK: vector.transfer_write vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d1, d0)} : vector<3x7xf32>, memref<?x?xf32> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] {permutation_map = #[[MAP0]]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>> + vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>> + return } |

