diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-08-16 03:52:56 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-08-16 03:53:26 -0700 |
| commit | f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a (patch) | |
| tree | 574e65c0dbb0f8f89c1219a7f4ccfcf0547d20ba /mlir/test/Conversion/VectorToLLVM | |
| parent | cc980aa41651c2cbfcbd9048fb0788f4aa9ae475 (diff) | |
| download | bcm5719-llvm-f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a.tar.gz bcm5719-llvm-f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a.zip | |
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
Diffstat (limited to 'mlir/test/Conversion/VectorToLLVM')
| -rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 68 |
1 files changed, 42 insertions, 26 deletions
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index f582de146ba..532a4c2e369 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,33 +1,49 @@ // RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s -func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { - %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> - %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32> - return %3 : vector<8xf32> +func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> { + %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> + return %2 : vector<2x3xf32> } -// CHECK-LABEL: vec_1d -// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]"> -// CHECK-5: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"<8 x float>"> +// CHECK-LABEL: outerproduct +// CHECK: llvm.undef : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> -func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> { - %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> - return %2 : vector<4x8xf32> +func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> + return %2 : vector<2x3xf32> } -// CHECK-LABEL: vec_2d -// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]"> -// CHECK-4: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x <8 x float>]"> +// CHECK-LABEL: outerproduct_add +// CHECK: llvm.undef : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> +// CHECK: "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> +// CHECK: "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> -func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> { - %0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32> - return %0 : vector<8x16xf32> +func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> { + %0 = vector.extractelement %arg0[0 : i32]: vector<4x3x16xf32> + return %0 : vector<3x16xf32> } -// CHECK-LABEL: vec_3d -// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> -// CHECK: llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]">
\ No newline at end of file +// CHECK-LABEL: extract_vec_2d_from_vec_3d +// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> +// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]"> + +func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { + %0 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32> + return %0 : f32 +} +// CHECK-LABEL: extract_element_from_vec_3d +// CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> +// CHECK: llvm.constant(0 : i32) : !llvm.i32 +// CHECK: llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>"> +// CHECK: llvm.return %{{.*}} : !llvm.float
\ No newline at end of file |

