diff options
| author | MLIR Team <no-reply@google.com> | 2019-09-04 03:45:38 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-09-04 03:46:06 -0700 |
| commit | 2f13df13b0bb481702fc83eb50c273deadb55f20 (patch) | |
| tree | 5f26cb0df35f9df08eb5e9583c99d6e14659e239 /mlir/lib/Target | |
| parent | 71d27dfc3b242158a46993e3e061e18940d81cf2 (diff) | |
| download | bcm5719-llvm-2f13df13b0bb481702fc83eb50c273deadb55f20.tar.gz bcm5719-llvm-2f13df13b0bb481702fc83eb50c273deadb55f20.zip | |
Add support for array-typed constants.
PiperOrigin-RevId: 267121729
Diffstat (limited to 'mlir/lib/Target')
| -rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 198aa8fcf3c..cbdc6c27ea1 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -85,23 +85,35 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, if (auto funcAttr = attr.dyn_cast<SymbolRefAttr>()) return functionMapping.lookup(funcAttr.getValue()); if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { - auto *vectorType = cast<llvm::VectorType>(llvmType); - auto *child = getLLVMConstant(vectorType->getElementType(), - splatAttr.getSplatValue(), loc); - return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child); + auto *sequentialType = cast<llvm::SequentialType>(llvmType); + auto elementType = sequentialType->getElementType(); + uint64_t numElements = sequentialType->getNumElements(); + auto *child = getLLVMConstant(elementType, splatAttr.getSplatValue(), loc); + if (llvmType->isVectorTy()) + return llvm::ConstantVector::getSplat(numElements, child); + if (llvmType->isArrayTy()) { + auto arrayType = llvm::ArrayType::get(elementType, numElements); + SmallVector<llvm::Constant *, 8> constants(numElements, child); + return llvm::ConstantArray::get(arrayType, constants); + } } if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) { - auto *vectorType = cast<llvm::VectorType>(llvmType); + auto *sequentialType = cast<llvm::SequentialType>(llvmType); + auto elementType = sequentialType->getElementType(); + uint64_t numElements = sequentialType->getNumElements(); SmallVector<llvm::Constant *, 8> constants; - uint64_t numElements = vectorType->getNumElements(); constants.reserve(numElements); for (auto n : elementsAttr.getValues<Attribute>()) { - constants.push_back( - getLLVMConstant(vectorType->getElementType(), n, loc)); + constants.push_back(getLLVMConstant(elementType, n, loc)); if (!constants.back()) return nullptr; } - return llvm::ConstantVector::get(constants); + if (llvmType->isVectorTy()) + return llvm::ConstantVector::get(constants); + if (llvmType->isArrayTy()) { + auto arrayType = llvm::ArrayType::get(elementType, numElements); + return llvm::ConstantArray::get(arrayType, constants); + } } if (auto stringAttr = attr.dyn_cast<StringAttr>()) { return llvm::ConstantDataArray::get( |

