diff options
| author | Alex Zinenko <zinenko@google.com> | 2020-01-14 11:30:25 +0100 |
|---|---|---|
| committer | Alex Zinenko <zinenko@google.com> | 2020-01-14 12:37:47 +0100 |
| commit | d6ea8ff0d74bfe5cd181ccfe91c2c300c5f7a35d (patch) | |
| tree | 71b416c166b40ccea467300cf6075fcada5c0ccd /mlir/lib | |
| parent | 3d6c492d7a9830a1a39b85dfa215743581d52715 (diff) | |
| download | bcm5719-llvm-d6ea8ff0d74bfe5cd181ccfe91c2c300c5f7a35d.tar.gz bcm5719-llvm-d6ea8ff0d74bfe5cd181ccfe91c2c300c5f7a35d.zip | |
[mlir] Fix translation of splat constants to LLVM IR
Summary:
When converting splat constants for nested sequential LLVM IR types wrapped in
MLIR, the constant conversion was erroneously assuming it was always possible
to recursively construct a constant of a sequential type given only one value.
Instead, wait until all sequential types are unpacked recursively before
constructing a scalar constant and wrapping it into the surrounding sequential
type.
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72688
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index e7c1862232b..c716cf33b57 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -49,7 +49,14 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, auto *sequentialType = cast<llvm::SequentialType>(llvmType); auto elementType = sequentialType->getElementType(); uint64_t numElements = sequentialType->getNumElements(); - auto *child = getLLVMConstant(elementType, splatAttr.getSplatValue(), loc); + // Splat value is a scalar. Extract it only if the element type is not + // another sequence type. The recursion terminates because each step removes + // one outer sequential type. + llvm::Constant *child = getLLVMConstant( + elementType, + isa<llvm::SequentialType>(elementType) ? splatAttr + : splatAttr.getSplatValue(), + loc); if (llvmType->isVectorTy()) return llvm::ConstantVector::getSplat(numElements, child); if (llvmType->isArrayTy()) { |

