diff options
| author | Aart Bik <ajcbik@google.com> | 2019-12-06 11:01:54 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-06 11:02:29 -0800 |
| commit | b36aaeafb1b026213432b5a8110467e16ed3f306 (patch) | |
| tree | 00d62d7455c2fd1b9b47f19b5681492ab10e1a51 /mlir/test/Conversion/VectorToLLVM | |
| parent | 398f04aa49109fd5d1eff2c1946a2956dc6b29c6 (diff) | |
| download | bcm5719-llvm-b36aaeafb1b026213432b5a8110467e16ed3f306.tar.gz bcm5719-llvm-b36aaeafb1b026213432b5a8110467e16ed3f306.zip | |
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
Diffstat (limited to 'mlir/test/Conversion/VectorToLLVM')
| -rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 200 |
1 files changed, 200 insertions, 0 deletions
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 025027dcddc..b07a8634da4 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,5 +1,205 @@ // RUN: mlir-opt %s -convert-vector-to-llvm | FileCheck %s +func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2xf32> + return %0 : vector<2xf32> +} +// CHECK-LABEL: broadcast_vec1d_from_scalar +// CHECK: llvm.mlir.undef : !llvm<"<2 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>"> + +func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> + return %0 : vector<2x3xf32> +} +// CHECK-LABEL: broadcast_vec2d_from_scalar +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> + +func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> + return %0 : vector<2x3x4xf32> +} +// CHECK-LABEL: broadcast_vec3d_from_scalar +// CHECK: llvm.mlir.undef : !llvm<"<4 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[2 x [3 x <4 x float>]]"> + +func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> + return %0 : vector<2xf32> +} +// CHECK-LABEL: broadcast_vec1d_from_vec1d +// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>"> + +func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> + return %0 : vector<3x2xf32> +} +// CHECK-LABEL: broadcast_vec2d_from_vec1d +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[3 x <2 x float>]"> + +func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} +// CHECK-LABEL: broadcast_vec3d_from_vec1d +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> + +func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} +// CHECK-LABEL: broadcast_vec3d_from_vec2d +// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> + +func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: broadcast_stretch +// CHECK: llvm.mlir.undef : !llvm<"<4 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> +// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>"> + +func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} +// CHECK-LABEL: broadcast_stretch_at_start +// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[3 x <4 x float>]"> + +func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { + %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> + return %0 : vector<4x3xf32> +} +// CHECK-LABEL: broadcast_stretch_at_end +// CHECK: llvm.mlir.undef : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x <1 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x <1 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <1 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <1 x float>]"> +// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x <3 x float>]"> + +func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} +// CHECK-LABEL: broadcast_stretch_in_middle +// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> + func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> { %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> return %2 : vector<2x3xf32> |

