summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Support
diff options
context:
space:
mode:
authorStephan Herhut <herhut@google.com>2019-07-16 05:28:56 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-07-16 13:45:16 -0700
commit6760ea53386a9515f464dd4d6e270e5620601d22 (patch)
treef63925f0822f39c4b653f4ae0506016b82d5f581 /mlir/lib/Support
parentd36dd94c752efc3a962481ea615225301582efce (diff)
downloadbcm5719-llvm-6760ea53386a9515f464dd4d6e270e5620601d22.tar.gz
bcm5719-llvm-6760ea53386a9515f464dd4d6e270e5620601d22.zip
Move shared cpu runner library to Support/JitRunner.
PiperOrigin-RevId: 258347825
Diffstat (limited to 'mlir/lib/Support')
-rw-r--r--mlir/lib/Support/CMakeLists.txt17
-rw-r--r--mlir/lib/Support/JitRunner.cpp328
2 files changed, 345 insertions, 0 deletions
diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt
index e725f708f8c..6edfc76d29d 100644
--- a/mlir/lib/Support/CMakeLists.txt
+++ b/mlir/lib/Support/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
FileUtilities.cpp
+ JitRunner.cpp
MlirOptMain.cpp
StorageUniquer.cpp
TranslateClParser.cpp
@@ -38,3 +39,19 @@ add_llvm_library(MLIRTranslateClParser
${MLIR_MAIN_INCLUDE_DIR}/mlir/Support
)
target_link_libraries(MLIRTranslateClParser LLVMSupport)
+
+add_llvm_library(MLIRJitRunner
+ JitRunner.cpp
+)
+target_link_libraries(MLIRJitRunner PRIVATE
+ MLIRExecutionEngine
+ MLIRIR
+ MLIRParser
+ MLIRStandardOps
+ MLIRTargetLLVMIR
+ MLIRTransforms
+ MLIRStandardToLLVM
+ MLIRSupport
+ LLVMCore
+ LLVMSupport
+)
diff --git a/mlir/lib/Support/JitRunner.cpp b/mlir/lib/Support/JitRunner.cpp
new file mode 100644
index 00000000000..1c6df7c5be8
--- /dev/null
+++ b/mlir/lib/Support/JitRunner.cpp
@@ -0,0 +1,328 @@
+//===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a library that provides a shared implementation for command line
+// utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
+// IR before JIT-compiling and executing the latter.
+//
+// The translation can be customized by providing an MLIR to MLIR
+// transformation.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/JitRunner.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/MemRefUtils.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassNameParser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include <numeric>
+
+using namespace mlir;
+using llvm::Error;
+
+static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+ llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+static llvm::cl::opt<std::string>
+ initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
+ llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
+static llvm::cl::opt<std::string>
+ mainFuncName("e", llvm::cl::desc("The function to be called"),
+ llvm::cl::value_desc("<function name>"),
+ llvm::cl::init("main"));
+static llvm::cl::opt<std::string> mainFuncType(
+ "entry-point-result",
+ llvm::cl::desc("Textual description of the function type to be called"),
+ llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
+
+static llvm::cl::OptionCategory optFlags("opt-like flags");
+
+// CLI list of pass information
+static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
+ llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
+ llvm::cl::cat(optFlags));
+
+// CLI variables for -On options.
+static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
+ llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
+ llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
+ llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
+ llvm::cl::cat(optFlags));
+
+static llvm::cl::OptionCategory clOptionsCategory("linking options");
+static llvm::cl::list<std::string>
+ clSharedLibs("shared-libs", llvm::cl::desc("Libraries to link dynamically"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::cat(clOptionsCategory));
+
+static OwningModuleRef parseMLIRInput(StringRef inputFilename,
+ MLIRContext *context) {
+ // Set up the input file.
+ std::string errorMessage;
+ auto file = openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return nullptr;
+ }
+
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+ return OwningModuleRef(parseSourceFile(sourceMgr, context));
+}
+
+// Initialize the relevant subsystems of LLVM.
+static void initializeLLVM() {
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+}
+
+static inline Error make_string_error(const llvm::Twine &message) {
+ return llvm::make_error<llvm::StringError>(message.str(),
+ llvm::inconvertibleErrorCode());
+}
+
+static void printOneMemRef(Type t, void *val) {
+ auto memRefType = t.cast<MemRefType>();
+ auto shape = memRefType.getShape();
+ int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
+ std::multiplies<int64_t>());
+ for (int64_t i = 0; i < size; ++i) {
+ llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
+ }
+ llvm::outs() << '\n';
+}
+
+static void printMemRefArguments(ArrayRef<Type> argTypes,
+ ArrayRef<Type> resTypes,
+ ArrayRef<void *> args) {
+ auto properArgs = args.take_front(argTypes.size());
+ for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
+ auto type = std::get<0>(kvp);
+ auto val = std::get<1>(kvp);
+ printOneMemRef(type, val);
+ }
+
+ auto results = args.drop_front(argTypes.size());
+ for (const auto &kvp : llvm::zip(resTypes, results)) {
+ auto type = std::get<0>(kvp);
+ auto val = std::get<1>(kvp);
+ printOneMemRef(type, val);
+ }
+}
+
+// Calls the passes necessary to convert affine and standard dialects to the
+// LLVM IR dialect.
+// Currently, these passes are:
+// - CSE
+// - canonicalization
+// - affine to standard lowering
+// - standard to llvm lowering
+static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
+ PassManager manager;
+ manager.addPass(mlir::createCanonicalizerPass());
+ manager.addPass(mlir::createCSEPass());
+ manager.addPass(mlir::createLowerAffinePass());
+ manager.addPass(mlir::createConvertToLLVMIRPass());
+ return manager.run(module);
+}
+
+static Error compileAndExecuteFunctionWithMemRefs(
+ ModuleOp module, StringRef entryPoint,
+ std::function<llvm::Error(llvm::Module *)> transformer) {
+ FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
+ if (!mainFunction || mainFunction.getBlocks().empty()) {
+ return make_string_error("entry point not found");
+ }
+
+ // Store argument and result types of the original function necessary to
+ // pretty print the results, because the function itself will be rewritten
+ // to use the LLVM dialect.
+ SmallVector<Type, 8> argTypes =
+ llvm::to_vector<8>(mainFunction.getType().getInputs());
+ SmallVector<Type, 8> resTypes =
+ llvm::to_vector<8>(mainFunction.getType().getResults());
+
+ float init = std::stof(initValue.getValue());
+
+ auto expectedArguments = allocateMemRefArguments(mainFunction, init);
+ if (!expectedArguments)
+ return expectedArguments.takeError();
+
+ if (failed(convertAffineStandardToLLVMIR(module)))
+ return make_string_error("conversion to the LLVM IR dialect failed");
+
+ SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
+ auto expectedEngine =
+ mlir::ExecutionEngine::create(module, transformer, libs);
+ if (!expectedEngine)
+ return expectedEngine.takeError();
+
+ auto engine = std::move(*expectedEngine);
+ auto expectedFPtr = engine->lookup(entryPoint);
+ if (!expectedFPtr)
+ return expectedFPtr.takeError();
+ void (*fptr)(void **) = *expectedFPtr;
+ (*fptr)(expectedArguments->data());
+ printMemRefArguments(argTypes, resTypes, *expectedArguments);
+ freeMemRefArguments(*expectedArguments);
+
+ return Error::success();
+}
+
+static Error compileAndExecuteSingleFloatReturnFunction(
+ ModuleOp module, StringRef entryPoint,
+ std::function<llvm::Error(llvm::Module *)> transformer) {
+ FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
+ if (!mainFunction || mainFunction.isExternal()) {
+ return make_string_error("entry point not found");
+ }
+
+ if (!mainFunction.getType().getInputs().empty())
+ return make_string_error("function inputs not supported");
+
+ if (mainFunction.getType().getResults().size() != 1)
+ return make_string_error("only single f32 function result supported");
+
+ auto t = mainFunction.getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
+ if (!t)
+ return make_string_error("only single llvm.f32 function result supported");
+ auto *llvmTy = t.getUnderlyingType();
+ if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
+ return make_string_error("only single llvm.f32 function result supported");
+
+ SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
+ auto expectedEngine =
+ mlir::ExecutionEngine::create(module, transformer, libs);
+ if (!expectedEngine)
+ return expectedEngine.takeError();
+
+ auto engine = std::move(*expectedEngine);
+ auto expectedFPtr = engine->lookup(entryPoint);
+ if (!expectedFPtr)
+ return expectedFPtr.takeError();
+ void (*fptr)(void **) = *expectedFPtr;
+
+ float res;
+ struct {
+ void *data;
+ } data;
+ data.data = &res;
+ (*fptr)((void **)&data);
+
+ // Intentional printing of the output so we can test.
+ llvm::outs() << res;
+
+ return Error::success();
+}
+
+// Entry point for all CPU runners. Expects the common argc/argv arguments for
+// standard C++ main functions and an mlirTransformer.
+// The latter is applied after parsing the input into MLIR IR and before passing
+// the MLIR module to the ExecutionEngine.
+int mlir::JitRunnerMain(
+ int argc, char **argv,
+ llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) {
+ llvm::PrettyStackTraceProgram x(argc, argv);
+ llvm::InitLLVM y(argc, argv);
+
+ initializeLLVM();
+ mlir::initializeLLVMPasses();
+
+ llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
+ optO0, optO1, optO2, optO3};
+
+ llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
+
+ llvm::SmallVector<const llvm::PassInfo *, 4> passes;
+ llvm::Optional<unsigned> optLevel;
+ unsigned optCLIPosition = 0;
+ // Determine if there is an optimization flag present, and its CLI position
+ // (optCLIPosition).
+ for (unsigned j = 0; j < 4; ++j) {
+ auto &flag = optFlags[j].get();
+ if (flag) {
+ optLevel = j;
+ optCLIPosition = flag.getPosition();
+ break;
+ }
+ }
+ // Generate vector of pass information, plus the index at which we should
+ // insert any optimization passes in that vector (optPosition).
+ unsigned optPosition = 0;
+ for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
+ passes.push_back(llvmPasses[i]);
+ if (optCLIPosition < llvmPasses.getPosition(i)) {
+ optPosition = i;
+ optCLIPosition = UINT_MAX; // To ensure we never insert again
+ }
+ }
+
+ MLIRContext context;
+ auto m = parseMLIRInput(inputFilename, &context);
+ if (!m) {
+ llvm::errs() << "could not parse the input IR\n";
+ return 1;
+ }
+
+ if (mlirTransformer)
+ if (failed(mlirTransformer(m.get())))
+ return EXIT_FAILURE;
+
+ auto transformer =
+ mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
+ auto error = mainFuncType.getValue() == "f32"
+ ? compileAndExecuteSingleFloatReturnFunction(
+ m.get(), mainFuncName.getValue(), transformer)
+ : compileAndExecuteFunctionWithMemRefs(
+ m.get(), mainFuncName.getValue(), transformer);
+ int exitCode = EXIT_SUCCESS;
+ llvm::handleAllErrors(std::move(error),
+ [&exitCode](const llvm::ErrorInfoBase &info) {
+ llvm::errs() << "Error: ";
+ info.log(llvm::errs());
+ llvm::errs() << '\n';
+ exitCode = EXIT_FAILURE;
+ });
+
+ return exitCode;
+}
OpenPOWER on IntegriCloud