diff options
| author | Tres Popp <tpopp@google.com> | 2019-12-16 01:35:03 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-16 01:45:44 -0800 |
| commit | 44fc7d72b3cb44147394e22f1f21ad36cca7bca8 (patch) | |
| tree | 056115a82b26879275e2de6b10a759b71e0df335 /mlir/lib/Target/LLVMIR | |
| parent | 97af93227283e9252d7e497bd08ea2b78ece8f92 (diff) | |
| download | bcm5719-llvm-44fc7d72b3cb44147394e22f1f21ad36cca7bca8.tar.gz bcm5719-llvm-44fc7d72b3cb44147394e22f1f21ad36cca7bca8.zip | |
Remove LLVM dependency on mlir::Module and instead check Traits.
PiperOrigin-RevId: 285724678
Diffstat (limited to 'mlir/lib/Target/LLVMIR')
| -rw-r--r-- | mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 17 |
3 files changed, 17 insertions, 14 deletions
diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 606e91b955f..728dc864ae5 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -58,7 +58,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, class ModuleTranslation : public LLVM::ModuleTranslation { public: - explicit ModuleTranslation(ModuleOp module) + explicit ModuleTranslation(Operation *module) : LLVM::ModuleTranslation(module) {} ~ModuleTranslation() override {} @@ -73,7 +73,7 @@ protected: }; } // namespace -std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) { +std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Operation *m) { ModuleTranslation translation(m); auto llvmModule = LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m); @@ -82,7 +82,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) { // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the // function as a kernel. - for (auto func : m.getOps<LLVM::LLVMFuncOp>()) { + for (auto func : + ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) { if (!gpu::GPUDialect::isKernel(func)) continue; diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp index dcd4d6c221f..7b7c3681371 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -69,7 +69,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder, class ModuleTranslation : public LLVM::ModuleTranslation { public: - explicit ModuleTranslation(ModuleOp module) + explicit ModuleTranslation(Operation *module) : LLVM::ModuleTranslation(module) {} ~ModuleTranslation() override {} @@ -84,7 +84,7 @@ protected: }; } // namespace -std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) { +std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(Operation *m) { ModuleTranslation translation(m); // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics) @@ -94,7 +94,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) { // foreach GPU kernel // 1. Insert AMDGPU_KERNEL calling convention. // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute. - for (auto func : m.getOps<LLVM::LLVMFuncOp>()) { + for (auto func : + ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) { if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName())) continue; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index f985fed3991..f5f9ccabd76 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -311,7 +311,7 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { // Create named global variables that correspond to llvm.mlir.global // definitions. void ModuleTranslation::convertGlobals() { - for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) { + for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { llvm::Type *type = op.getType().getUnderlyingType(); llvm::Constant *cst = llvm::UndefValue::get(type); if (op.getValueOrNull()) { @@ -470,10 +470,10 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { return success(); } -LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) { - for (Operation &o : m.getBody()->getOperations()) +LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) { + for (Operation &o : getModuleBody(m).getOperations()) if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) && - !isa<ModuleTerminatorOp>(&o)) + !o.isKnownTerminator()) return o.emitOpError("unsupported module-level operation"); return success(); } @@ -481,7 +481,7 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) { LogicalResult ModuleTranslation::convertFunctions() { // Declare all functions first because there may be function calls that form a // call graph with cycles. - for (auto function : mlirModule.getOps<LLVMFuncOp>()) { + for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( function.getName(), llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType())); @@ -491,7 +491,7 @@ LogicalResult ModuleTranslation::convertFunctions() { } // Convert functions. - for (auto function : mlirModule.getOps<LLVMFuncOp>()) { + for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { // Ignore external functions. if (function.isExternal()) continue; @@ -503,8 +503,9 @@ LogicalResult ModuleTranslation::convertFunctions() { return success(); } -std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(ModuleOp m) { - auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); +std::unique_ptr<llvm::Module> +ModuleTranslation::prepareLLVMModule(Operation *m) { + auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); assert(dialect && "LLVM dialect must be registered"); auto llvmModule = llvm::CloneModule(dialect->getLLVMModule()); |

