summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Support
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2019-09-26 05:41:26 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-09-26 05:42:01 -0700
commit99be3351b874444498c03a87e2aeec6f2f8c208d (patch)
tree67e537ff443b8535a422942a24bee5674b082ea7 /mlir/lib/Support
parent116dac00baa6870aec2a2b469b2d6f95c2fbb316 (diff)
downloadbcm5719-llvm-99be3351b874444498c03a87e2aeec6f2f8c208d.tar.gz
bcm5719-llvm-99be3351b874444498c03a87e2aeec6f2f8c208d.zip
Drop support for memrefs from JitRunner
The support for functions taking and returning memrefs of floats was introduced in the first version of the runner, created before MLIR had reliable lowering of allocation/deallocation to library calls. It forcibly runs MLIR transformation convering affine, loop and standard dialects into the LLVM dialect, unlike the other runner flows that accept the LLVM dialect directly. Memref support leads to more complex layering and is generally fragile. Drop it in favor of functions returning a scalar, or library-based function calls to print memrefs and other data structures. PiperOrigin-RevId: 271330839
Diffstat (limited to 'mlir/lib/Support')
-rw-r--r--mlir/lib/Support/JitRunner.cpp113
1 files changed, 9 insertions, 104 deletions
diff --git a/mlir/lib/Support/JitRunner.cpp b/mlir/lib/Support/JitRunner.cpp
index f87664d621a..3324aa9da31 100644
--- a/mlir/lib/Support/JitRunner.cpp
+++ b/mlir/lib/Support/JitRunner.cpp
@@ -25,8 +25,6 @@
#include "mlir/Support/JitRunner.h"
-#include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/MemRefUtils.h"
@@ -35,10 +33,7 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
-#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
@@ -62,16 +57,13 @@ static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string>
- initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
- llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
-static llvm::cl::opt<std::string>
mainFuncName("e", llvm::cl::desc("The function to be called"),
llvm::cl::value_desc("<function name>"),
llvm::cl::init("main"));
static llvm::cl::opt<std::string> mainFuncType(
"entry-point-result",
llvm::cl::desc("Textual description of the function type to be called"),
- llvm::cl::value_desc("f32 | memrefs | void"), llvm::cl::init("memrefs"));
+ llvm::cl::value_desc("f32 | void"), llvm::cl::init("f32"));
static llvm::cl::OptionCategory optFlags("opt-like flags");
@@ -136,52 +128,6 @@ static inline Error make_string_error(const llvm::Twine &message) {
llvm::inconvertibleErrorCode());
}
-static void printOneMemRef(Type t, void *val) {
- auto memRefType = t.cast<MemRefType>();
- auto shape = memRefType.getShape();
- int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
- for (int64_t i = 0; i < size; ++i) {
- llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
- }
- llvm::outs() << '\n';
-}
-
-static void printMemRefArguments(ArrayRef<Type> argTypes,
- ArrayRef<Type> resTypes,
- ArrayRef<void *> args) {
- auto properArgs = args.take_front(argTypes.size());
- for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
- auto type = std::get<0>(kvp);
- auto val = std::get<1>(kvp);
- printOneMemRef(type, val);
- }
-
- auto results = args.drop_front(argTypes.size());
- for (const auto &kvp : llvm::zip(resTypes, results)) {
- auto type = std::get<0>(kvp);
- auto val = std::get<1>(kvp);
- printOneMemRef(type, val);
- }
-}
-
-// Calls the passes necessary to convert affine and standard dialects to the
-// LLVM IR dialect.
-// Currently, these passes are:
-// - CSE
-// - canonicalization
-// - affine to standard lowering
-// - standard to llvm lowering
-static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
- PassManager manager(module.getContext());
- manager.addPass(mlir::createCanonicalizerPass());
- manager.addPass(mlir::createCSEPass());
- manager.addPass(mlir::createLowerAffinePass());
- manager.addPass(mlir::createLowerToCFGPass());
- manager.addPass(mlir::createLowerToLLVMPass());
- return manager.run(module);
-}
-
static llvm::Optional<unsigned> getCommandLineOptLevel() {
llvm::Optional<unsigned> optLevel;
llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
@@ -238,40 +184,6 @@ static Error compileAndExecuteVoidFunction(
return compileAndExecute(module, entryPoint, transformer, &empty);
}
-static Error compileAndExecuteFunctionWithMemRefs(
- ModuleOp module, StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer) {
- FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
- if (!mainFunction || mainFunction.getBlocks().empty()) {
- return make_string_error("entry point not found");
- }
-
- // Store argument and result types of the original function necessary to
- // pretty print the results, because the function itself will be rewritten
- // to use the LLVM dialect.
- SmallVector<Type, 8> argTypes =
- llvm::to_vector<8>(mainFunction.getType().getInputs());
- SmallVector<Type, 8> resTypes =
- llvm::to_vector<8>(mainFunction.getType().getResults());
-
- float init = std::stof(initValue.getValue());
-
- auto expectedArguments = allocateMemRefArguments(mainFunction, init);
- if (!expectedArguments)
- return expectedArguments.takeError();
-
- if (failed(convertAffineStandardToLLVMIR(module)))
- return make_string_error("conversion to the LLVM IR dialect failed");
-
- if (auto error = compileAndExecute(module, entryPoint, transformer,
- expectedArguments->data()))
- return error;
-
- printMemRefArguments(argTypes, resTypes, *expectedArguments);
- freeMemRefArguments(*expectedArguments);
- return Error::success();
-}
-
static Error compileAndExecuteSingleFloatReturnFunction(
ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
@@ -303,7 +215,7 @@ static Error compileAndExecuteSingleFloatReturnFunction(
return error;
// Intentional printing of the output so we can test.
- llvm::outs() << res;
+ llvm::outs() << res << '\n';
return Error::success();
}
@@ -372,20 +284,13 @@ int mlir::JitRunnerMain(
auto transformer = mlir::makeLLVMPassesTransformer(
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
- // Get the function used to compile and execute the module.
- using CompileAndExecuteFnT = Error (*)(
- ModuleOp, StringRef, std::function<llvm::Error(llvm::Module *)>);
- auto compileAndExecuteFn =
- llvm::StringSwitch<CompileAndExecuteFnT>(mainFuncType.getValue())
- .Case("f32", compileAndExecuteSingleFloatReturnFunction)
- .Case("memrefs", compileAndExecuteFunctionWithMemRefs)
- .Case("void", compileAndExecuteVoidFunction)
- .Default(nullptr);
-
- Error error =
- compileAndExecuteFn
- ? compileAndExecuteFn(m.get(), mainFuncName.getValue(), transformer)
- : make_string_error("unsupported function type");
+ Error error = make_string_error("unsupported function type");
+ if (mainFuncType.getValue() == "f32")
+ error = compileAndExecuteSingleFloatReturnFunction(
+ m.get(), mainFuncName.getValue(), transformer);
+ else if (mainFuncType.getValue() == "void")
+ error = compileAndExecuteVoidFunction(m.get(), mainFuncName.getValue(),
+ transformer);
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
OpenPOWER on IntegriCloud