diff options
author | Alex Zinenko <zinenko@google.com> | 2020-01-14 11:30:25 +0100 |
---|---|---|
committer | Alex Zinenko <zinenko@google.com> | 2020-01-14 12:37:47 +0100 |
commit | d6ea8ff0d74bfe5cd181ccfe91c2c300c5f7a35d (patch) | |
tree | 71b416c166b40ccea467300cf6075fcada5c0ccd /mlir | |
parent | 3d6c492d7a9830a1a39b85dfa215743581d52715 (diff) | |
download | bcm5719-llvm-d6ea8ff0d74bfe5cd181ccfe91c2c300c5f7a35d.tar.gz bcm5719-llvm-d6ea8ff0d74bfe5cd181ccfe91c2c300c5f7a35d.zip |
[mlir] Fix translation of splat constants to LLVM IR
Summary:
When converting splat constants for nested sequential LLVM IR types wrapped in
MLIR, the constant conversion was erroneously assuming it was always possible
to recursively construct a constant of a sequential type given only one value.
Instead, wait until all sequential types are unpacked recursively before
constructing a scalar constant and wrapping it into the surrounding sequential
type.
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72688
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 9 | ||||
-rw-r--r-- | mlir/test/Target/llvmir.mlir | 28 |
2 files changed, 36 insertions, 1 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index e7c1862232b..c716cf33b57 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -49,7 +49,14 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, auto *sequentialType = cast<llvm::SequentialType>(llvmType); auto elementType = sequentialType->getElementType(); uint64_t numElements = sequentialType->getNumElements(); - auto *child = getLLVMConstant(elementType, splatAttr.getSplatValue(), loc); + // Splat value is a scalar. Extract it only if the element type is not + // another sequence type. The recursion terminates because each step removes + // one outer sequential type. + llvm::Constant *child = getLLVMConstant( + elementType, + isa<llvm::SequentialType>(elementType) ? splatAttr + : splatAttr.getSplatValue(), + loc); if (llvmType->isVectorTy()) return llvm::ConstantVector::getSplat(numElements, child); if (llvmType->isArrayTy()) { diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir index 9d0ee383046..3ce3eb20ea6 100644 --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -804,6 +804,34 @@ llvm.func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %ar llvm.return %1 : !llvm<"<4 x float>"> } +// CHECK-LABEL: @vector_splat_1d +llvm.func @vector_splat_1d() -> !llvm<"<4 x float>"> { + // CHECK: ret <4 x float> zeroinitializer + %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>"> + llvm.return %0 : !llvm<"<4 x float>"> +} + +// CHECK-LABEL: @vector_splat_2d +llvm.func @vector_splat_2d() -> !llvm<"[4 x <16 x float>]"> { + // CHECK: ret [4 x <16 x float>] zeroinitializer + %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<4x16xf32>) : !llvm<"[4 x <16 x float>]"> + llvm.return %0 : !llvm<"[4 x <16 x float>]"> +} + +// CHECK-LABEL: @vector_splat_3d +llvm.func @vector_splat_3d() -> !llvm<"[4 x [16 x <4 x float>]]"> { + // CHECK: ret [4 x [16 x <4 x float>]] zeroinitializer + %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<4x16x4xf32>) : !llvm<"[4 x [16 x <4 x float>]]"> + llvm.return %0 : !llvm<"[4 x [16 x <4 x float>]]"> +} + +// CHECK-LABEL: @vector_splat_nonzero +llvm.func @vector_splat_nonzero() -> !llvm<"<4 x float>"> { + // CHECK: ret <4 x float> <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00> + %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>"> + llvm.return %0 : !llvm<"<4 x float>"> +} + // CHECK-LABEL: @ops llvm.func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm.i32) -> !llvm<"{ float, i32 }"> { // CHECK-NEXT: fsub float %0, %1 |