diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-08-21 18:15:39 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-08-21 18:16:02 -0700 |
| commit | fe3594f745f70244c0c32b8b3287799ff2cdcbc7 (patch) | |
| tree | 398a40490dd17ce6e76fd3d474fdb3051b5d15e1 /mlir/lib/ExecutionEngine | |
| parent | 748edce6b831a453831bf8d8688fdbae68d44e14 (diff) | |
| download | bcm5719-llvm-fe3594f745f70244c0c32b8b3287799ff2cdcbc7.tar.gz bcm5719-llvm-fe3594f745f70244c0c32b8b3287799ff2cdcbc7.zip | |
Reduce reliance on custom grown Jit implementation - NFC
This CL makes use of the standard LLVM LLJIT and removes the need for a custom JIT implementation within MLIR.
To achieve this, one needs to clone (i.e. serde) the produced llvm::Module into a new LLVMContext. This is currently necessary because the llvm::LLVMContext is owned by the LLVMDialect, somewhat deep in the call hierarchy.
In the future we should remove the reliance of serding the llvm::Module by allowing the injection of an LLVMContext from the top-level. Unfortunately this will require deeper API changes and impact multiple places. It is therefore left for future work.
PiperOrigin-RevId: 264737459
Diffstat (limited to 'mlir/lib/ExecutionEngine')
| -rw-r--r-- | mlir/lib/ExecutionEngine/CMakeLists.txt | 13 | ||||
| -rw-r--r-- | mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 325 |
2 files changed, 130 insertions, 208 deletions
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt index fd856a77f62..07061b1db11 100644 --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -7,4 +7,15 @@ add_llvm_library(MLIRExecutionEngine ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/ExecutionEngine ) -target_link_libraries(MLIRExecutionEngine MLIRLLVMIR MLIRTargetLLVMIR LLVMExecutionEngine LLVMOrcJIT LLVMSupport ${outlibs}) +target_link_libraries(MLIRExecutionEngine + + MLIRLLVMIR + MLIRTargetLLVMIR + LLVMBitReader + LLVMBitWriter + LLVMExecutionEngine + LLVMOrcJIT + LLVMSupport + LLVMTransformUtils + + ${outlibs}) diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 4450bf4d403..dbc59d0383a 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -24,6 +24,9 @@ #include "mlir/IR/Module.h" #include "mlir/Target/LLVMIR.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/ExecutionEngine/ObjectCache.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" @@ -36,215 +39,59 @@ #include "llvm/Support/TargetRegistry.h" using namespace mlir; +using llvm::dbgs; using llvm::Error; +using llvm::errs; using llvm::Expected; +using llvm::LLVMContext; +using llvm::MemoryBuffer; +using llvm::MemoryBufferRef; +using llvm::Module; +using llvm::SectionMemoryManager; +using llvm::StringError; +using llvm::Triple; +using llvm::orc::DynamicLibrarySearchGenerator; +using llvm::orc::ExecutionSession; +using llvm::orc::IRCompileLayer; +using llvm::orc::JITTargetMachineBuilder; +using llvm::orc::RTDyldObjectLinkingLayer; +using llvm::orc::ThreadSafeModule; +using llvm::orc::TMOwningSimpleCompiler; -namespace { -// Memory manager for the JIT's objectLayer. Its main goal is to fallback to -// resolving functions in the current process if they cannot be resolved in the -// JIT-compiled modules. -class MemoryManager : public llvm::SectionMemoryManager { -public: - MemoryManager(llvm::orc::ExecutionSession &execSession) - : session(execSession) {} - - // Resolve the named symbol. First, try looking it up in the main library of - // the execution session. If there is no such symbol, try looking it up in - // the current process (for example, if it is a standard library function). - // Return `nullptr` if lookup fails. - llvm::JITSymbol findSymbol(const std::string &name) override { - auto mainLibSymbol = session.lookup({&session.getMainJITDylib()}, name); - if (mainLibSymbol) - return mainLibSymbol.get(); - auto address = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name); - if (!address) { - llvm::errs() << "Could not look up: " << name << '\n'; - return nullptr; - } - return llvm::JITSymbol(address, llvm::JITSymbolFlags::Exported); - } - -private: - llvm::orc::ExecutionSession &session; -}; -} // end anonymous namespace +// Wrap a string into an llvm::StringError. +static inline Error make_string_error(const llvm::Twine &message) { + return llvm::make_error<StringError>(message.str(), + llvm::inconvertibleErrorCode()); +} namespace mlir { -namespace impl { - -/// Wrapper class around DynamicLibrarySearchGenerator to allow searching -/// in-process symbols that have not been explicitly exported. -/// This first tries to resolve a symbol by using DynamicLibrarySearchGenerator. -/// For symbols that are not found this way, it then uses -/// `llvm::sys::DynamicLibrary::SearchForAddressOfSymbol` to extract symbols -/// that have been explicitly added with `llvm::sys::DynamicLibrary::AddSymbol`, -/// previously. -class SearchGenerator { -public: - SearchGenerator(char GlobalPrefix) - : defaultGenerator(cantFail( - llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( - GlobalPrefix))) {} - - // This function forwards to DynamicLibrarySearchGenerator::operator() and - // adds an extra resolution for names explicitly registered via - // `llvm::sys::DynamicLibrary::AddSymbol`. - Expected<llvm::orc::SymbolNameSet> - operator()(llvm::orc::JITDylib &JD, const llvm::orc::SymbolNameSet &Names) { - auto res = defaultGenerator->tryToGenerate(JD, Names); - if (!res) - return res; - llvm::orc::SymbolMap newSymbols; - for (auto &Name : Names) { - if (res.get().count(Name) > 0) - continue; - res.get().insert(Name); - auto addedSymbolAddress = - llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(*Name); - if (!addedSymbolAddress) - continue; - llvm::JITEvaluatedSymbol Sym( - reinterpret_cast<uintptr_t>(addedSymbolAddress), - llvm::JITSymbolFlags::Exported); - newSymbols[Name] = Sym; - } - if (!newSymbols.empty()) - cantFail(JD.define(absoluteSymbols(std::move(newSymbols)))); - return res; - } - -private: - std::unique_ptr<llvm::orc::DynamicLibrarySearchGenerator> defaultGenerator; -}; - -// Simple layered Orc JIT compilation engine. -class OrcJIT { -public: - using IRTransformer = std::function<Error(llvm::Module *)>; - - // Construct a JIT engine for the target host defined by `machineBuilder`, - // using the data layout provided as `dataLayout`. - // Setup the object layer to use our custom memory manager in order to - // resolve calls to library functions present in the process. - OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder, - llvm::DataLayout layout, IRTransformer transform, - ArrayRef<StringRef> sharedLibPaths) - : irTransformer(transform), - objectLayer( - session, - [this]() { return std::make_unique<MemoryManager>(session); }), - compileLayer( - session, objectLayer, - llvm::orc::ConcurrentIRCompiler(std::move(machineBuilder))), - transformLayer(session, compileLayer, makeIRTransformFunction()), - dataLayout(layout), mangler(session, this->dataLayout), - threadSafeCtx(std::make_unique<llvm::LLVMContext>()) { - session.getMainJITDylib().addGenerator( - cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( - layout.getGlobalPrefix()))); - loadLibraries(sharedLibPaths); - } - - // Create a JIT engine for the current host. - static Expected<std::unique_ptr<OrcJIT>> - createDefault(IRTransformer transformer, ArrayRef<StringRef> sharedLibPaths) { - auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!machineBuilder) - return machineBuilder.takeError(); - - auto dataLayout = machineBuilder->getDefaultDataLayoutForTarget(); - if (!dataLayout) - return dataLayout.takeError(); - - return std::make_unique<OrcJIT>(std::move(*machineBuilder), - std::move(*dataLayout), transformer, - sharedLibPaths); - } - - // Add an LLVM module to the main library managed by the JIT engine. - Error addModule(std::unique_ptr<llvm::Module> M) { - return transformLayer.add( - session.getMainJITDylib(), - llvm::orc::ThreadSafeModule(std::move(M), threadSafeCtx)); - } - - // Lookup a symbol in the main library managed by the JIT engine. - Expected<llvm::JITEvaluatedSymbol> lookup(StringRef Name) { - return session.lookup({&session.getMainJITDylib()}, mangler(Name.str())); - } - -private: - // Wrap the `irTransformer` into a function that can be called by the - // IRTranformLayer. If `irTransformer` is not set up, return the module as - // is without errors. - llvm::orc::IRTransformLayer::TransformFunction makeIRTransformFunction() { - return [this](llvm::orc::ThreadSafeModule module, - const llvm::orc::MaterializationResponsibility &resp) - -> Expected<llvm::orc::ThreadSafeModule> { - (void)resp; - if (!irTransformer) - return std::move(module); - Error err = module.withModuleDo( - [this](llvm::Module &module) { return irTransformer(&module); }); - if (err) - return std::move(err); - return std::move(module); - }; - } - - // Iterate over shareLibPaths and load the corresponding libraries for symbol - // resolution. - void loadLibraries(ArrayRef<StringRef> sharedLibPaths); - IRTransformer irTransformer; - llvm::orc::ExecutionSession session; - llvm::orc::RTDyldObjectLinkingLayer objectLayer; - llvm::orc::IRCompileLayer compileLayer; - llvm::orc::IRTransformLayer transformLayer; - llvm::DataLayout dataLayout; - llvm::orc::MangleAndInterner mangler; - llvm::orc::ThreadSafeContext threadSafeCtx; -}; -} // end namespace impl -} // namespace mlir - -void mlir::impl::OrcJIT::loadLibraries(ArrayRef<StringRef> sharedLibPaths) { - for (auto libPath : sharedLibPaths) { - auto mb = llvm::MemoryBuffer::getFile(libPath); - if (!mb) { - llvm::errs() << "Could not create MemoryBuffer for: " << libPath << " " - << mb.getError().message() << "\n"; - continue; - } - auto &JD = session.createJITDylib(libPath); - auto loaded = llvm::orc::DynamicLibrarySearchGenerator::Load( - libPath.data(), dataLayout.getGlobalPrefix()); - if (!loaded) { - llvm::errs() << "Could not load: " << libPath << " " << loaded.takeError() - << "\n"; - continue; - } - JD.addGenerator(std::move(*loaded)); - auto res = objectLayer.add(JD, std::move(mb.get())); - if (res) - llvm::errs() << "Could not add: " << libPath << " " << res << "\n"; - } +void SimpleObjectCache::notifyObjectCompiled(const Module *M, + MemoryBufferRef ObjBuffer) { + CachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy( + ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier()); } -// Wrap a string into an llvm::StringError. -static inline Error make_string_error(const llvm::Twine &message) { - return llvm::make_error<llvm::StringError>(message.str(), - llvm::inconvertibleErrorCode()); +std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) { + auto I = CachedObjects.find(M->getModuleIdentifier()); + if (I == CachedObjects.end()) { + dbgs() << "No object for " << M->getModuleIdentifier() + << " in cache. Compiling.\n"; + return nullptr; + } + dbgs() << "Object for " << M->getModuleIdentifier() + << " loaded from cache.\n"; + return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef()); } // Setup LLVM target triple from the current machine. -bool ExecutionEngine::setupTargetTriple(llvm::Module *llvmModule) { +bool ExecutionEngine::setupTargetTriple(Module *llvmModule) { // Setup the machine properties from the current architecture. auto targetTriple = llvm::sys::getDefaultTargetTriple(); std::string errorMessage; auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage); if (!target) { - llvm::errs() << "NO target: " << errorMessage << "\n"; + errs() << "NO target: " << errorMessage << "\n"; return true; } auto machine = @@ -261,7 +108,7 @@ static std::string makePackedFunctionName(StringRef name) { // For each function in the LLVM module, define an interface function that wraps // all the arguments of the original function and all its results into an i8** // pointer to provide a unified invocation interface. -void packFunctionArguments(llvm::Module *module) { +void packFunctionArguments(Module *module) { auto &ctx = module->getContext(); llvm::IRBuilder<> builder(ctx); llvm::DenseSet<llvm::Function *> interfaceFunctions; @@ -321,18 +168,13 @@ void packFunctionArguments(llvm::Module *module) { } } -// Out of line for PIMPL unique_ptr. -ExecutionEngine::~ExecutionEngine() = default; - Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(ModuleOp m, - std::function<llvm::Error(llvm::Module *)> transformer, + std::function<Error(llvm::Module *)> transformer, ArrayRef<StringRef> sharedLibPaths) { auto engine = std::make_unique<ExecutionEngine>(); - auto expectedJIT = impl::OrcJIT::createDefault(transformer, sharedLibPaths); - if (!expectedJIT) - return expectedJIT.takeError(); + std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext); auto llvmModule = translateModuleToLLVMIR(m); if (!llvmModule) return make_string_error("could not convert to LLVM IR"); @@ -342,9 +184,77 @@ ExecutionEngine::create(ModuleOp m, setupTargetTriple(llvmModule.get()); packFunctionArguments(llvmModule.get()); - if (auto err = (*expectedJIT)->addModule(std::move(llvmModule))) - return std::move(err); - engine->jit = std::move(*expectedJIT); + // Clone module in a new LLVMContext since translateModuleToLLVMIR buries + // ownership too deeply. + // TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect. + SmallVector<char, 1> buffer; + { + llvm::raw_svector_ostream os(buffer); + WriteBitcodeToFile(*llvmModule, os); + } + llvm::MemoryBufferRef bufferRef(llvm::StringRef(buffer.data(), buffer.size()), + "cloned module buffer"); + auto expectedModule = parseBitcodeFile(bufferRef, *ctx); + if (!expectedModule) + return expectedModule.takeError(); + std::unique_ptr<Module> deserModule = std::move(*expectedModule); + + // Callback to create the object layer with symbol resolution to current + // process and dynamically linked libraries. + auto objectLinkingLayerCreator = [&](ExecutionSession &session, + const Triple &TT) { + auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>( + session, []() { return std::make_unique<SectionMemoryManager>(); }); + auto dataLayout = deserModule->getDataLayout(); + + // Resolve symbols that are statically linked in the current process. + session.getMainJITDylib().addGenerator( + cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( + dataLayout.getGlobalPrefix()))); + + // Resolve symbols from shared libraries. + for (auto libPath : sharedLibPaths) { + auto mb = llvm::MemoryBuffer::getFile(libPath); + if (!mb) { + errs() << "Fail to create MemoryBuffer for: " << libPath << "\n"; + continue; + } + auto &JD = session.createJITDylib(libPath); + auto loaded = DynamicLibrarySearchGenerator::Load( + libPath.data(), dataLayout.getGlobalPrefix()); + if (!loaded) { + errs() << "Could not load: " << libPath << "\n"; + continue; + } + JD.addGenerator(std::move(*loaded)); + cantFail(objectLayer->add(JD, std::move(mb.get()))); + } + + return objectLayer; + }; + + // Callback to inspect the cache and recompile on demand. This follows Lang's + // LLJITWithObjectCache example. + auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB) + -> Expected<IRCompileLayer::CompileFunction> { + auto TM = JTMB.createTargetMachine(); + if (!TM) + return TM.takeError(); + return IRCompileLayer::CompileFunction( + TMOwningSimpleCompiler(std::move(*TM), engine->cache.get())); + }; + + // Create the LLJIT by calling the LLJITBuilder with 2 callbacks. + auto jit = + cantFail(llvm::orc::LLJITBuilder() + .setCompileFunctionCreator(compileFunctionCreator) + .setObjectLinkingLayerCreator(objectLinkingLayerCreator) + .create()); + + // Add a ThreadSafemodule to the engine and return. + ThreadSafeModule tsm(std::move(deserModule), std::move(ctx)); + cantFail(jit->addIRModule(std::move(tsm))); + engine->jit = std::move(jit); return std::move(engine); } @@ -360,8 +270,7 @@ Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const { return fptr; } -llvm::Error ExecutionEngine::invoke(StringRef name, - MutableArrayRef<void *> args) { +Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) { auto expectedFPtr = lookup(name); if (!expectedFPtr) return expectedFPtr.takeError(); @@ -369,5 +278,7 @@ llvm::Error ExecutionEngine::invoke(StringRef name, (*fptr)(args.data()); - return llvm::Error::success(); + return Error::success(); } + +} // end namespace mlir |

