summaryrefslogtreecommitdiffstats
path: root/mlir/lib/ExecutionEngine
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-08-21 18:15:39 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-21 18:16:02 -0700
commitfe3594f745f70244c0c32b8b3287799ff2cdcbc7 (patch)
tree398a40490dd17ce6e76fd3d474fdb3051b5d15e1 /mlir/lib/ExecutionEngine
parent748edce6b831a453831bf8d8688fdbae68d44e14 (diff)
downloadbcm5719-llvm-fe3594f745f70244c0c32b8b3287799ff2cdcbc7.tar.gz
bcm5719-llvm-fe3594f745f70244c0c32b8b3287799ff2cdcbc7.zip
Reduce reliance on custom grown Jit implementation - NFC
This CL makes use of the standard LLVM LLJIT and removes the need for a custom JIT implementation within MLIR. To achieve this, one needs to clone (i.e. serde) the produced llvm::Module into a new LLVMContext. This is currently necessary because the llvm::LLVMContext is owned by the LLVMDialect, somewhat deep in the call hierarchy. In the future we should remove the reliance of serding the llvm::Module by allowing the injection of an LLVMContext from the top-level. Unfortunately this will require deeper API changes and impact multiple places. It is therefore left for future work. PiperOrigin-RevId: 264737459
Diffstat (limited to 'mlir/lib/ExecutionEngine')
-rw-r--r--mlir/lib/ExecutionEngine/CMakeLists.txt13
-rw-r--r--mlir/lib/ExecutionEngine/ExecutionEngine.cpp325
2 files changed, 130 insertions, 208 deletions
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fd856a77f62..07061b1db11 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -7,4 +7,15 @@ add_llvm_library(MLIRExecutionEngine
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/ExecutionEngine
)
-target_link_libraries(MLIRExecutionEngine MLIRLLVMIR MLIRTargetLLVMIR LLVMExecutionEngine LLVMOrcJIT LLVMSupport ${outlibs})
+target_link_libraries(MLIRExecutionEngine
+
+ MLIRLLVMIR
+ MLIRTargetLLVMIR
+ LLVMBitReader
+ LLVMBitWriter
+ LLVMExecutionEngine
+ LLVMOrcJIT
+ LLVMSupport
+ LLVMTransformUtils
+
+ ${outlibs})
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index 4450bf4d403..dbc59d0383a 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -24,6 +24,9 @@
#include "mlir/IR/Module.h"
#include "mlir/Target/LLVMIR.h"
+#include "llvm/Bitcode/BitcodeReader.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/ExecutionEngine/ObjectCache.h"
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
@@ -36,215 +39,59 @@
#include "llvm/Support/TargetRegistry.h"
using namespace mlir;
+using llvm::dbgs;
using llvm::Error;
+using llvm::errs;
using llvm::Expected;
+using llvm::LLVMContext;
+using llvm::MemoryBuffer;
+using llvm::MemoryBufferRef;
+using llvm::Module;
+using llvm::SectionMemoryManager;
+using llvm::StringError;
+using llvm::Triple;
+using llvm::orc::DynamicLibrarySearchGenerator;
+using llvm::orc::ExecutionSession;
+using llvm::orc::IRCompileLayer;
+using llvm::orc::JITTargetMachineBuilder;
+using llvm::orc::RTDyldObjectLinkingLayer;
+using llvm::orc::ThreadSafeModule;
+using llvm::orc::TMOwningSimpleCompiler;
-namespace {
-// Memory manager for the JIT's objectLayer. Its main goal is to fallback to
-// resolving functions in the current process if they cannot be resolved in the
-// JIT-compiled modules.
-class MemoryManager : public llvm::SectionMemoryManager {
-public:
- MemoryManager(llvm::orc::ExecutionSession &execSession)
- : session(execSession) {}
-
- // Resolve the named symbol. First, try looking it up in the main library of
- // the execution session. If there is no such symbol, try looking it up in
- // the current process (for example, if it is a standard library function).
- // Return `nullptr` if lookup fails.
- llvm::JITSymbol findSymbol(const std::string &name) override {
- auto mainLibSymbol = session.lookup({&session.getMainJITDylib()}, name);
- if (mainLibSymbol)
- return mainLibSymbol.get();
- auto address = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
- if (!address) {
- llvm::errs() << "Could not look up: " << name << '\n';
- return nullptr;
- }
- return llvm::JITSymbol(address, llvm::JITSymbolFlags::Exported);
- }
-
-private:
- llvm::orc::ExecutionSession &session;
-};
-} // end anonymous namespace
+// Wrap a string into an llvm::StringError.
+static inline Error make_string_error(const llvm::Twine &message) {
+ return llvm::make_error<StringError>(message.str(),
+ llvm::inconvertibleErrorCode());
+}
namespace mlir {
-namespace impl {
-
-/// Wrapper class around DynamicLibrarySearchGenerator to allow searching
-/// in-process symbols that have not been explicitly exported.
-/// This first tries to resolve a symbol by using DynamicLibrarySearchGenerator.
-/// For symbols that are not found this way, it then uses
-/// `llvm::sys::DynamicLibrary::SearchForAddressOfSymbol` to extract symbols
-/// that have been explicitly added with `llvm::sys::DynamicLibrary::AddSymbol`,
-/// previously.
-class SearchGenerator {
-public:
- SearchGenerator(char GlobalPrefix)
- : defaultGenerator(cantFail(
- llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
- GlobalPrefix))) {}
-
- // This function forwards to DynamicLibrarySearchGenerator::operator() and
- // adds an extra resolution for names explicitly registered via
- // `llvm::sys::DynamicLibrary::AddSymbol`.
- Expected<llvm::orc::SymbolNameSet>
- operator()(llvm::orc::JITDylib &JD, const llvm::orc::SymbolNameSet &Names) {
- auto res = defaultGenerator->tryToGenerate(JD, Names);
- if (!res)
- return res;
- llvm::orc::SymbolMap newSymbols;
- for (auto &Name : Names) {
- if (res.get().count(Name) > 0)
- continue;
- res.get().insert(Name);
- auto addedSymbolAddress =
- llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(*Name);
- if (!addedSymbolAddress)
- continue;
- llvm::JITEvaluatedSymbol Sym(
- reinterpret_cast<uintptr_t>(addedSymbolAddress),
- llvm::JITSymbolFlags::Exported);
- newSymbols[Name] = Sym;
- }
- if (!newSymbols.empty())
- cantFail(JD.define(absoluteSymbols(std::move(newSymbols))));
- return res;
- }
-
-private:
- std::unique_ptr<llvm::orc::DynamicLibrarySearchGenerator> defaultGenerator;
-};
-
-// Simple layered Orc JIT compilation engine.
-class OrcJIT {
-public:
- using IRTransformer = std::function<Error(llvm::Module *)>;
-
- // Construct a JIT engine for the target host defined by `machineBuilder`,
- // using the data layout provided as `dataLayout`.
- // Setup the object layer to use our custom memory manager in order to
- // resolve calls to library functions present in the process.
- OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder,
- llvm::DataLayout layout, IRTransformer transform,
- ArrayRef<StringRef> sharedLibPaths)
- : irTransformer(transform),
- objectLayer(
- session,
- [this]() { return std::make_unique<MemoryManager>(session); }),
- compileLayer(
- session, objectLayer,
- llvm::orc::ConcurrentIRCompiler(std::move(machineBuilder))),
- transformLayer(session, compileLayer, makeIRTransformFunction()),
- dataLayout(layout), mangler(session, this->dataLayout),
- threadSafeCtx(std::make_unique<llvm::LLVMContext>()) {
- session.getMainJITDylib().addGenerator(
- cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
- layout.getGlobalPrefix())));
- loadLibraries(sharedLibPaths);
- }
-
- // Create a JIT engine for the current host.
- static Expected<std::unique_ptr<OrcJIT>>
- createDefault(IRTransformer transformer, ArrayRef<StringRef> sharedLibPaths) {
- auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost();
- if (!machineBuilder)
- return machineBuilder.takeError();
-
- auto dataLayout = machineBuilder->getDefaultDataLayoutForTarget();
- if (!dataLayout)
- return dataLayout.takeError();
-
- return std::make_unique<OrcJIT>(std::move(*machineBuilder),
- std::move(*dataLayout), transformer,
- sharedLibPaths);
- }
-
- // Add an LLVM module to the main library managed by the JIT engine.
- Error addModule(std::unique_ptr<llvm::Module> M) {
- return transformLayer.add(
- session.getMainJITDylib(),
- llvm::orc::ThreadSafeModule(std::move(M), threadSafeCtx));
- }
-
- // Lookup a symbol in the main library managed by the JIT engine.
- Expected<llvm::JITEvaluatedSymbol> lookup(StringRef Name) {
- return session.lookup({&session.getMainJITDylib()}, mangler(Name.str()));
- }
-
-private:
- // Wrap the `irTransformer` into a function that can be called by the
- // IRTranformLayer. If `irTransformer` is not set up, return the module as
- // is without errors.
- llvm::orc::IRTransformLayer::TransformFunction makeIRTransformFunction() {
- return [this](llvm::orc::ThreadSafeModule module,
- const llvm::orc::MaterializationResponsibility &resp)
- -> Expected<llvm::orc::ThreadSafeModule> {
- (void)resp;
- if (!irTransformer)
- return std::move(module);
- Error err = module.withModuleDo(
- [this](llvm::Module &module) { return irTransformer(&module); });
- if (err)
- return std::move(err);
- return std::move(module);
- };
- }
-
- // Iterate over shareLibPaths and load the corresponding libraries for symbol
- // resolution.
- void loadLibraries(ArrayRef<StringRef> sharedLibPaths);
- IRTransformer irTransformer;
- llvm::orc::ExecutionSession session;
- llvm::orc::RTDyldObjectLinkingLayer objectLayer;
- llvm::orc::IRCompileLayer compileLayer;
- llvm::orc::IRTransformLayer transformLayer;
- llvm::DataLayout dataLayout;
- llvm::orc::MangleAndInterner mangler;
- llvm::orc::ThreadSafeContext threadSafeCtx;
-};
-} // end namespace impl
-} // namespace mlir
-
-void mlir::impl::OrcJIT::loadLibraries(ArrayRef<StringRef> sharedLibPaths) {
- for (auto libPath : sharedLibPaths) {
- auto mb = llvm::MemoryBuffer::getFile(libPath);
- if (!mb) {
- llvm::errs() << "Could not create MemoryBuffer for: " << libPath << " "
- << mb.getError().message() << "\n";
- continue;
- }
- auto &JD = session.createJITDylib(libPath);
- auto loaded = llvm::orc::DynamicLibrarySearchGenerator::Load(
- libPath.data(), dataLayout.getGlobalPrefix());
- if (!loaded) {
- llvm::errs() << "Could not load: " << libPath << " " << loaded.takeError()
- << "\n";
- continue;
- }
- JD.addGenerator(std::move(*loaded));
- auto res = objectLayer.add(JD, std::move(mb.get()));
- if (res)
- llvm::errs() << "Could not add: " << libPath << " " << res << "\n";
- }
+void SimpleObjectCache::notifyObjectCompiled(const Module *M,
+ MemoryBufferRef ObjBuffer) {
+ CachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
+ ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier());
}
-// Wrap a string into an llvm::StringError.
-static inline Error make_string_error(const llvm::Twine &message) {
- return llvm::make_error<llvm::StringError>(message.str(),
- llvm::inconvertibleErrorCode());
+std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) {
+ auto I = CachedObjects.find(M->getModuleIdentifier());
+ if (I == CachedObjects.end()) {
+ dbgs() << "No object for " << M->getModuleIdentifier()
+ << " in cache. Compiling.\n";
+ return nullptr;
+ }
+ dbgs() << "Object for " << M->getModuleIdentifier()
+ << " loaded from cache.\n";
+ return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef());
}
// Setup LLVM target triple from the current machine.
-bool ExecutionEngine::setupTargetTriple(llvm::Module *llvmModule) {
+bool ExecutionEngine::setupTargetTriple(Module *llvmModule) {
// Setup the machine properties from the current architecture.
auto targetTriple = llvm::sys::getDefaultTargetTriple();
std::string errorMessage;
auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
if (!target) {
- llvm::errs() << "NO target: " << errorMessage << "\n";
+ errs() << "NO target: " << errorMessage << "\n";
return true;
}
auto machine =
@@ -261,7 +108,7 @@ static std::string makePackedFunctionName(StringRef name) {
// For each function in the LLVM module, define an interface function that wraps
// all the arguments of the original function and all its results into an i8**
// pointer to provide a unified invocation interface.
-void packFunctionArguments(llvm::Module *module) {
+void packFunctionArguments(Module *module) {
auto &ctx = module->getContext();
llvm::IRBuilder<> builder(ctx);
llvm::DenseSet<llvm::Function *> interfaceFunctions;
@@ -321,18 +168,13 @@ void packFunctionArguments(llvm::Module *module) {
}
}
-// Out of line for PIMPL unique_ptr.
-ExecutionEngine::~ExecutionEngine() = default;
-
Expected<std::unique_ptr<ExecutionEngine>>
ExecutionEngine::create(ModuleOp m,
- std::function<llvm::Error(llvm::Module *)> transformer,
+ std::function<Error(llvm::Module *)> transformer,
ArrayRef<StringRef> sharedLibPaths) {
auto engine = std::make_unique<ExecutionEngine>();
- auto expectedJIT = impl::OrcJIT::createDefault(transformer, sharedLibPaths);
- if (!expectedJIT)
- return expectedJIT.takeError();
+ std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
auto llvmModule = translateModuleToLLVMIR(m);
if (!llvmModule)
return make_string_error("could not convert to LLVM IR");
@@ -342,9 +184,77 @@ ExecutionEngine::create(ModuleOp m,
setupTargetTriple(llvmModule.get());
packFunctionArguments(llvmModule.get());
- if (auto err = (*expectedJIT)->addModule(std::move(llvmModule)))
- return std::move(err);
- engine->jit = std::move(*expectedJIT);
+ // Clone module in a new LLVMContext since translateModuleToLLVMIR buries
+ // ownership too deeply.
+ // TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect.
+ SmallVector<char, 1> buffer;
+ {
+ llvm::raw_svector_ostream os(buffer);
+ WriteBitcodeToFile(*llvmModule, os);
+ }
+ llvm::MemoryBufferRef bufferRef(llvm::StringRef(buffer.data(), buffer.size()),
+ "cloned module buffer");
+ auto expectedModule = parseBitcodeFile(bufferRef, *ctx);
+ if (!expectedModule)
+ return expectedModule.takeError();
+ std::unique_ptr<Module> deserModule = std::move(*expectedModule);
+
+ // Callback to create the object layer with symbol resolution to current
+ // process and dynamically linked libraries.
+ auto objectLinkingLayerCreator = [&](ExecutionSession &session,
+ const Triple &TT) {
+ auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
+ session, []() { return std::make_unique<SectionMemoryManager>(); });
+ auto dataLayout = deserModule->getDataLayout();
+
+ // Resolve symbols that are statically linked in the current process.
+ session.getMainJITDylib().addGenerator(
+ cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
+ dataLayout.getGlobalPrefix())));
+
+ // Resolve symbols from shared libraries.
+ for (auto libPath : sharedLibPaths) {
+ auto mb = llvm::MemoryBuffer::getFile(libPath);
+ if (!mb) {
+ errs() << "Fail to create MemoryBuffer for: " << libPath << "\n";
+ continue;
+ }
+ auto &JD = session.createJITDylib(libPath);
+ auto loaded = DynamicLibrarySearchGenerator::Load(
+ libPath.data(), dataLayout.getGlobalPrefix());
+ if (!loaded) {
+ errs() << "Could not load: " << libPath << "\n";
+ continue;
+ }
+ JD.addGenerator(std::move(*loaded));
+ cantFail(objectLayer->add(JD, std::move(mb.get())));
+ }
+
+ return objectLayer;
+ };
+
+ // Callback to inspect the cache and recompile on demand. This follows Lang's
+ // LLJITWithObjectCache example.
+ auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB)
+ -> Expected<IRCompileLayer::CompileFunction> {
+ auto TM = JTMB.createTargetMachine();
+ if (!TM)
+ return TM.takeError();
+ return IRCompileLayer::CompileFunction(
+ TMOwningSimpleCompiler(std::move(*TM), engine->cache.get()));
+ };
+
+ // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
+ auto jit =
+ cantFail(llvm::orc::LLJITBuilder()
+ .setCompileFunctionCreator(compileFunctionCreator)
+ .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
+ .create());
+
+ // Add a ThreadSafemodule to the engine and return.
+ ThreadSafeModule tsm(std::move(deserModule), std::move(ctx));
+ cantFail(jit->addIRModule(std::move(tsm)));
+ engine->jit = std::move(jit);
return std::move(engine);
}
@@ -360,8 +270,7 @@ Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
return fptr;
}
-llvm::Error ExecutionEngine::invoke(StringRef name,
- MutableArrayRef<void *> args) {
+Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) {
auto expectedFPtr = lookup(name);
if (!expectedFPtr)
return expectedFPtr.takeError();
@@ -369,5 +278,7 @@ llvm::Error ExecutionEngine::invoke(StringRef name,
(*fptr)(args.data());
- return llvm::Error::success();
+ return Error::success();
}
+
+} // end namespace mlir
OpenPOWER on IntegriCloud