diff options
Diffstat (limited to 'mlir/examples/toy/Ch7/mlir/MLIRGen.cpp')
-rw-r--r-- | mlir/examples/toy/Ch7/mlir/MLIRGen.cpp | 674 |
1 files changed, 674 insertions, 0 deletions
diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp new file mode 100644 index 00000000000..3d543f69bdc --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -0,0 +1,674 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include <numeric> + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::makeArrayRef; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (auto &record : moduleAST) { + if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) { + auto func = mlirGen(*funcAST); + if (!func) + return nullptr; + + theModule.push_back(func); + functionMap.insert({func.getName(), func}); + } else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) { + if (failed(mlirGen(*str))) + return nullptr; + } else { + llvm_unreachable("unknown record type"); + } + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable<StringRef, std::pair<mlir::Value, VarDeclExprAST *>> + symbolTable; + using SymbolTableScopeT = + llvm::ScopedHashTableScope<StringRef, + std::pair<mlir::Value, VarDeclExprAST *>>; + + /// A mapping for the functions that have been code generated to MLIR. + llvm::StringMap<mlir::FuncOp> functionMap; + + /// A mapping for named struct types to the underlying MLIR type and the + /// original AST node. + llvm::StringMap<std::pair<mlir::Type, StructAST *>> structMap; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(Location loc) { + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { + if (symbolTable.count(var.getName())) + return mlir::failure(); + symbolTable.insert(var.getName(), {value, &var}); + return mlir::success(); + } + + /// Create an MLIR type for the given struct. + mlir::LogicalResult mlirGen(StructAST &str) { + if (structMap.count(str.getName())) + return emitError(loc(str.loc())) << "error: struct type with name `" + << str.getName() << "' already exists"; + + auto variables = str.getVariables(); + std::vector<mlir::Type> elementTypes; + elementTypes.reserve(variables.size()); + for (auto &variable : variables) { + if (variable->getInitVal()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + if (!variable->getType().shape.empty()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + + mlir::Type type = getType(variable->getType(), variable->loc()); + if (!type) + return mlir::failure(); + elementTypes.push_back(type); + } + + structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + llvm::SmallVector<mlir::Type, 4> argTypes; + argTypes.reserve(proto.getArgs().size()); + for (auto &arg : proto.getArgs()) { + mlir::Type type = getType(arg->getType(), arg->loc()); + if (!type) + return nullptr; + argTypes.push_back(type); + } + auto func_type = builder.getFunctionType(argTypes, llvm::None); + return mlir::FuncOp::create(location, proto.getName(), func_type); + } + + /// Emit a new function and add it to the MLIR module. + mlir::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + SymbolTableScopeT var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + mlir::FuncOp function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + auto &entryBlock = *function.addEntryBlock(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(*std::get<0>(name_value), std::get<1>(name_value)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast<ReturnOp>(entryBlock.back()); + if (!returnOp) { + builder.create<ReturnOp>(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + *returnOp.operand_type_begin())); + } + + return function; + } + + /// Return the struct type that is the result of the given expression, or null + /// if it cannot be inferred. + StructAST *getStructFor(ExprAST *expr) { + llvm::StringRef structName; + if (auto *decl = llvm::dyn_cast<VariableExprAST>(expr)) { + auto varIt = symbolTable.lookup(decl->getName()); + if (!varIt.first) + return nullptr; + structName = varIt.second->getType().name; + } else if (auto *access = llvm::dyn_cast<BinaryExprAST>(expr)) { + if (access->getOp() != '.') + return nullptr; + // The name being accessed should be in the RHS. + auto *name = llvm::dyn_cast<VariableExprAST>(access->getRHS()); + if (!name) + return nullptr; + StructAST *parentStruct = getStructFor(access->getLHS()); + if (!parentStruct) + return nullptr; + + // Get the element within the struct corresponding to the name. + VarDeclExprAST *decl = nullptr; + for (auto &var : parentStruct->getVariables()) { + if (var->getName() == name->getName()) { + decl = var.get(); + break; + } + } + if (!decl) + return nullptr; + structName = decl->getType().name; + } + if (structName.empty()) + return nullptr; + + // If the struct name was valid, check for an entry in the struct map. + auto structIt = structMap.find(structName); + if (structIt == structMap.end()) + return nullptr; + return structIt->second.second; + } + + /// Return the numeric member index of the given struct access expression. + llvm::Optional<size_t> getMemberIndex(BinaryExprAST &accessOp) { + assert(accessOp.getOp() == '.' && "expected access operation"); + + // Lookup the struct node for the LHS. + StructAST *structAST = getStructFor(accessOp.getLHS()); + if (!structAST) + return llvm::None; + + // Get the name from the RHS. + VariableExprAST *name = llvm::dyn_cast<VariableExprAST>(accessOp.getRHS()); + if (!name) + return llvm::None; + + auto structVars = structAST->getVariables(); + auto it = llvm::find_if(structVars, [&](auto &var) { + return var->getName() == name->getName(); + }); + if (it == structVars.end()) + return llvm::None; + return it - structVars.begin(); + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + auto location = loc(binop.loc()); + + // If this is an access operation, handle it immediately. + if (binop.getOp() == '.') { + llvm::Optional<size_t> accessIndex = getMemberIndex(binop); + if (!accessIndex) { + emitError(location, "invalid access into struct expression"); + return nullptr; + } + return builder.create<StructAccessOp>(location, lhs, *accessIndex); + } + + // Otherwise, this is a normal binary op. + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create<AddOp>(location, lhs, rhs); + case '*': + return builder.create<MulOp>(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName()).first) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create<ReturnOp>(location, expr ? makeArrayRef(expr) + : ArrayRef<mlir::Value>()); + return mlir::success(); + } + + /// Emit a constant for a literal/constant array. It will be emitted as a + /// flattened array of data in an Attribute attached to a `toy.constant` + /// operation. See documentation on [Attributes](LangRef.md#attributes) for + /// more details. Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) { + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector<double> data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies<int>())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); + } + mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) { + // The type of this attribute is tensor of 64-bit floating-point with no + // shape. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get({}, elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, + llvm::makeArrayRef(lit.getValue())); + } + /// Emit a constant for a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. This function returns the generated constant, along with the + /// corresponding struct type. + std::pair<mlir::ArrayAttr, mlir::Type> + getConstantAttr(StructLiteralExprAST &lit) { + std::vector<mlir::Attribute> attrElements; + std::vector<mlir::Type> typeElements; + + for (auto &var : lit.getValues()) { + if (auto *number = llvm::dyn_cast<NumberExprAST>(var.get())) { + attrElements.push_back(getConstantAttr(*number)); + typeElements.push_back(getType(llvm::None)); + } else if (auto *lit = llvm::dyn_cast<LiteralExprAST>(var.get())) { + attrElements.push_back(getConstantAttr(*lit)); + typeElements.push_back(getType(llvm::None)); + } else { + auto *structLit = llvm::cast<StructLiteralExprAST>(var.get()); + auto attrTypePair = getConstantAttr(*structLit); + attrElements.push_back(attrTypePair.first); + typeElements.push_back(attrTypePair.second); + } + } + mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements); + mlir::Type dataType = StructType::get(typeElements); + return std::make_pair(dataAttr, dataType); + } + + /// Emit an array literal. + mlir::Value mlirGen(LiteralExprAST &lit) { + mlir::Type type = getType(lit.getDims()); + mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute); + } + + /// Emit a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. + mlir::Value mlirGen(StructLiteralExprAST &lit) { + mlir::ArrayAttr dataAttr; + mlir::Type dataType; + std::tie(dataAttr, dataType) = getConstantAttr(lit); + + // Build the MLIR op `toy.struct_constant`. This invokes the + // `StructConstantOp::build` method. + return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector<double> &data) { + if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa<NumberExprAST>(expr) && "expected literal or number expr"); + data.push_back(cast<NumberExprAST>(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector<mlir::Value, 4> operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create<TransposeOp>(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + auto calledFuncIt = functionMap.find(callee); + if (calledFuncIt == functionMap.end()) { + emitError(location) << "no defined function found for '" << callee << "'"; + return nullptr; + } + mlir::FuncOp calledFunc = calledFuncIt->second; + return builder.create<GenericCallOp>( + location, calledFunc.getType().getResult(0), + builder.getSymbolRefAttr(callee), operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create<PrintOp>(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create<ConstantOp>(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast<BinaryExprAST>(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast<VariableExprAST>(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast<LiteralExprAST>(expr)); + case toy::ExprAST::Expr_StructLiteral: + return mlirGen(cast<StructLiteralExprAST>(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast<CallExprAST>(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast<NumberExprAST>(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // Handle the case where we are initializing a struct value. + VarType varType = vardecl.getType(); + if (!varType.name.empty()) { + // Check that the initializer type is the same as the variable + // declaration. + mlir::Type type = getType(varType, vardecl.loc()); + if (!type) + return nullptr; + if (type != value->getType()) { + emitError(loc(vardecl.loc())) + << "struct type of initializer is different than the variable " + "declaration. Got " + << value->getType() << ", but expected " << type; + return nullptr; + } + + // Otherwise, we have the initializer value, but in case the variable was + // declared with specific shape, we emit a "reshape" operation. It will + // get optimized out later as needed. + } else if (!varType.shape.empty()) { + value = builder.create<ReshapeOp>(loc(vardecl.loc()), + getType(varType.shape), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl, value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + SymbolTableScopeT var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast<PrintExprAST>(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef<int64_t> shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above for non-struct types). + mlir::Type getType(const VarType &type, const Location &location) { + if (!type.name.empty()) { + auto it = structMap.find(type.name); + if (it == structMap.end()) { + emitError(loc(location)) + << "error: unknown struct type '" << type.name << "'"; + return nullptr; + } + return it->second.first; + } + + return getType(type.shape); + } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy |