diff options
Diffstat (limited to 'mlir/examples/toy/Ch7/toyc.cpp')
| -rw-r--r-- | mlir/examples/toy/Ch7/toyc.cpp | 282 |
1 files changed, 282 insertions, 0 deletions
diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp new file mode 100644 index 00000000000..26b684cbe2a --- /dev/null +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -0,0 +1,282 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt<std::string> inputFilename(cl::Positional, + cl::desc("<input toy file>"), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt<enum InputType> inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { + None, + DumpAST, + DumpMLIR, + DumpMLIRAffine, + DumpMLIRLLVM, + DumpLLVMIR, + RunJIT +}; +} +static cl::opt<enum Action> emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering")), + cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", + "output the MLIR dump after llvm lowering")), + cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), + cl::values( + clEnumValN(RunJIT, "jit", + "JIT the code and run it by invoking the main function"))); + +static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + if (int error = loadMLIR(context, module)) + return error; + + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; + + if (enableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect with a few cleanups afterwards. + pm.addPass(mlir::toy::createLowerToAffinePass()); + + mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (enableOpt) { + optPM.addPass(mlir::createLoopFusionPass()); + optPM.addPass(mlir::createMemRefDataFlowOptPass()); + } + } + + if (isLoweringToLLVM) { + // Finish lowering the toy IR to the LLVM dialect. + pm.addPass(mlir::toy::createLowerToLLVMPass()); + } + + if (mlir::failed(pm.run(*module))) + return 4; + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int dumpLLVMIR(mlir::ModuleOp module) { + auto llvmModule = mlir::translateModuleToLLVMIR(module); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} + +int runJit(mlir::ModuleOp module) { + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invoke("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + if (emitAction == Action::DumpAST) + return dumpAST(); + + // If we aren't dumping the AST, then we are compiling with/to MLIR. + + // Register our Dialect with MLIR. + mlir::registerDialect<mlir::toy::ToyDialect>(); + + mlir::MLIRContext context; + mlir::OwningModuleRef module; + if (int error = loadAndProcessMLIR(context, module)) + return error; + + // If we aren't exporting to non-mlir, then we are done. + bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; + if (isOutputingMLIR) { + module->dump(); + return 0; + } + + // Check to see if we are compiling to LLVM IR. + if (emitAction == Action::DumpLLVMIR) + return dumpLLVMIR(*module); + + // Otherwise, we must be running the jit. + if (emitAction == Action::RunJIT) + return runJit(*module); + + llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n"; + return -1; +} |

