diff options
36 files changed, 5428 insertions, 1 deletions
diff --git a/mlir/examples/toy/CMakeLists.txt b/mlir/examples/toy/CMakeLists.txt index 52a22014b4e..56002b1ad2e 100644 --- a/mlir/examples/toy/CMakeLists.txt +++ b/mlir/examples/toy/CMakeLists.txt @@ -12,3 +12,4 @@ add_subdirectory(Ch3) add_subdirectory(Ch4) add_subdirectory(Ch5) add_subdirectory(Ch6) +add_subdirectory(Ch7) diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt new file mode 100644 index 00000000000..fc26425f038 --- /dev/null +++ b/mlir/examples/toy/Ch7/CMakeLists.txt @@ -0,0 +1,51 @@ +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Core + Support + ) + +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh7CombineIncGen) + +add_toy_chapter(toyc-ch7 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/DeadFunctionEliminationPass.cpp + mlir/LowerToAffineLoops.cpp + mlir/LowerToLLVM.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + ) + +add_dependencies(toyc-ch7 ToyCh7ShapeInferenceInterfaceIncGen) +add_dependencies(toyc-ch7 ToyCh7OpsIncGen) +add_dependencies(toyc-ch7 ToyCh7CombineIncGen) +add_dependencies(toyc-ch7 MLIRCallOpInterfacesIncGen) +include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch7 + PRIVATE + MLIRAffineOps + MLIRAnalysis + MLIRExecutionEngine + MLIRIR + MLIRLLVMIR + MLIRLoopToStandard + MLIRParser + MLIRPass + MLIRStandardOps + MLIRStandardToLLVM + MLIRTargetLLVMIR + MLIRTransforms + ) + +whole_archive_link(toyc-ch7 + MLIRAffineOps + MLIRLLVMIR + MLIRStandardOps + ) diff --git a/mlir/examples/toy/Ch7/include/CMakeLists.txt b/mlir/examples/toy/Ch7/include/CMakeLists.txt new file mode 100644 index 00000000000..37c89d0bae9 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch7/include/toy/AST.h b/mlir/examples/toy/Ch7/include/toy/AST.h new file mode 100644 index 00000000000..558d9deab8e --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/AST.h @@ -0,0 +1,317 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include <vector> + +namespace toy { + +/// A variable type with either name or shape information. +struct VarType { + std::string name; + std::vector<int64_t> shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_StructLiteral, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector<std::unique_ptr<ExprAST>>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector<std::unique_ptr<ExprAST>> values; + std::vector<int64_t> dims; + +public: + LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values, + std::vector<int64_t> dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } + llvm::ArrayRef<int64_t> getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for a literal struct value. +class StructLiteralExprAST : public ExprAST { + std::vector<std::unique_ptr<ExprAST>> values; + +public: + StructLiteralExprAST(Location loc, + std::vector<std::unique_ptr<ExprAST>> values) + : ExprAST(Expr_StructLiteral, loc), values(std::move(values)) {} + + llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { + return c->getKind() == Expr_StructLiteral; + } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr<ExprAST> initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr<ExprAST> initVal = nullptr) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + llvm::Optional<std::unique_ptr<ExprAST>> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional<ExprAST *> getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::None; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr<ExprAST> lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs, + std::unique_ptr<ExprAST> rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector<std::unique_ptr<ExprAST>> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector<std::unique_ptr<ExprAST>> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr<ExprAST> arg; + +public: + PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector<std::unique_ptr<VarDeclExprAST>> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector<std::unique_ptr<VarDeclExprAST>> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getArgs() { return args; } +}; + +/// This class represents a top level record in a module. +class RecordAST { +public: + enum RecordASTKind { + Record_Function, + Record_Struct, + }; + + RecordAST(RecordASTKind kind) : kind(kind) {} + virtual ~RecordAST() = default; + + RecordASTKind getKind() const { return kind; } + +private: + const RecordASTKind kind; +}; + +/// This class represents a function definition itself. +class FunctionAST : public RecordAST { + std::unique_ptr<PrototypeAST> proto; + std::unique_ptr<ExprASTList> body; + +public: + FunctionAST(std::unique_ptr<PrototypeAST> proto, + std::unique_ptr<ExprASTList> body) + : RecordAST(Record_Function), proto(std::move(proto)), + body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } + + /// LLVM style RTTI + static bool classof(const RecordAST *R) { + return R->getKind() == Record_Function; + } +}; + +/// This class represents a struct definition. +class StructAST : public RecordAST { + Location location; + std::string name; + std::vector<std::unique_ptr<VarDeclExprAST>> variables; + +public: + StructAST(Location location, const std::string &name, + std::vector<std::unique_ptr<VarDeclExprAST>> variables) + : RecordAST(Record_Struct), location(location), name(name), + variables(std::move(variables)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getVariables() { + return variables; + } + + /// LLVM style RTTI + static bool classof(const RecordAST *R) { + return R->getKind() == Record_Struct; + } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector<std::unique_ptr<RecordAST>> records; + +public: + ModuleAST(std::vector<std::unique_ptr<RecordAST>> records) + : records(std::move(records)) {} + + auto begin() -> decltype(records.begin()) { return records.begin(); } + auto end() -> decltype(records.end()) { return records.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt new file mode 100644 index 00000000000..fa30bd2e8e0 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +add_public_tablegen_target(ToyCh7OpsIncGen) + +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ToyCh7ShapeInferenceInterfaceIncGen) diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h new file mode 100644 index 00000000000..82b677a3c2a --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h @@ -0,0 +1,109 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" +#include "toy/ShapeInferenceInterface.h" + +namespace mlir { +namespace toy { +namespace detail { +class StructTypeStorage; +} // end namespace detail + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// A hook used to materialize constant values with the given type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; + + /// Parse an instance of a type registered to the toy dialect. + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + + /// Print an instance of a type registered to the toy dialect. + void printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const override; + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +//===----------------------------------------------------------------------===// +// Toy Types +//===----------------------------------------------------------------------===// + +/// Create a local enumeration with all of the types that are defined by Toy. +namespace ToyTypes { +enum Types { + Struct = mlir::Type::FIRST_TOY_TYPE, +}; +} // end namespace ToyTypes + +/// This class defines the Toy struct type. It represents a collection of +/// element types. All derived types in MLIR must inherit from the CRTP class +/// 'Type::TypeBase'. It takes as template parameters the concrete type +/// (StructType), the base class to use (Type), and the storage class +/// (StructTypeStorage). +class StructType : public mlir::Type::TypeBase<StructType, mlir::Type, + detail::StructTypeStorage> { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// This static method is used to support type inquiry through isa, cast, + /// and dyn_cast. + static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; } + + /// Create an instance of a `StructType` with the given element types. There + /// *must* be atleast one element type. + static StructType get(llvm::ArrayRef<mlir::Type> elementTypes); + + /// Returns the element types of this struct type. + llvm::ArrayRef<mlir::Type> getElementTypes(); + + /// Returns the number of element type held by this struct. + size_t getNumElementTypes() { return getElementTypes().size(); } +}; +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/Lexer.h b/mlir/examples/toy/Ch7/include/toy/Lexer.h new file mode 100644 index 00000000000..89dc6cba9ff --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Lexer.h @@ -0,0 +1,244 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include <memory> +#include <string> + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr<std::string> file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + tok_struct = -5, + + // primary + tok_identifier = -6, + tok_number = -7, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared<std::string>(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "struct") + return tok_struct; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9] ([0-9.])* + if (isdigit(lastChar)) { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast<size_t>(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/MLIRGen.h b/mlir/examples/toy/Ch7/include/toy/MLIRGen.h new file mode 100644 index 00000000000..287f432c847 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/MLIRGen.h @@ -0,0 +1,41 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include <memory> + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td new file mode 100644 index 00000000000..5e932bb0a7e --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -0,0 +1,314 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +#ifndef MLIR_CALLINTERFACES +include "mlir/Analysis/CallInterfaces.td" +#endif // MLIR_CALLINTERFACES + +#ifndef SHAPE_INFERENCE_INTERFACE +include "toy/ShapeInferenceInterface.td" +#endif // SHAPE_INFERENCE_INTERFACE + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op<string mnemonic, list<OpTrait> traits = []> : + Op<Toy_Dialect, mnemonic, traits>; + +// Provide a definition for the Toy StructType for use in ODS. This allows for +// using StructType in a similar way to Tensor or MemRef. +def Toy_StructType : + Type<CPred<"$_self.isa<StructType>()">, "Toy struct type">; + +// Provide a definition of the types that are used within the Toy dialect. +def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inherting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", + [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documenatation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create<ConstantOp>(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &state, " + "DenseElementsAttr value", [{ + build(builder, state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &state, double value"> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; + + // Set the folder bit so that we can implement constant folders. + let hasFolder = 1; +} + +def AddOp : Toy_Op<"add", + [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + ]; +} + +def CastOp : Toy_Op<"cast", + [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + // Set the folder bit so that we can fold redundant cast operations. + let hasFolder = 1; +} + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods<CallOpInterface>]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins SymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs); + + // The generic call operation returns a single value of TensorType or + // StructType. + let results = (outs Toy_Type); + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "StringRef callee, ArrayRef<Value *> arguments"> + ]; +} + +def MulOp : Toy_Op<"mul", + [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + ]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic<Toy_Type>:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &state", [{ build(b, state, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> { + let summary = "struct access"; + let description = [{ + Access the Nth element of a value returning a struct type. + }]; + + let arguments = (ins Toy_StructType:$input, I64Attr:$index); + let results = (outs Toy_Type); + + // Allow building a StructAccessOp with just a struct value and an index. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value *input, size_t index"> + ]; + + let verifier = [{ return ::verify(*this); }]; + + // Set the folder bit so that we can fold constant accesses. + let hasFolder = 1; +} + +def StructConstantOp : Toy_Op<"struct_constant", [NoSideEffect]> { + let summary = "struct constant"; + let description = [{ + Constant operation turns a literal struct value into an SSA value. The data + is attached to the operation as an attribute. The struct constant is encoded + as an array of other constant values. For example: + + ```mlir + %0 = "toy.struct_constant"() { + value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>] + } : () -> !toy.struct<tensor<*xf64>> + ``` + }]; + + let hasFolder = 1; + let arguments = (ins ArrayAttr:$value); + let results = (outs Toy_StructType); + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", + [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value *input"> + ]; + + // Invoke a static verify method to verify this transpose operation. + let verifier = [{ return ::verify(*this); }]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch7/include/toy/Parser.h b/mlir/examples/toy/Ch7/include/toy/Parser.h new file mode 100644 index 00000000000..e0f36028c59 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Parser.h @@ -0,0 +1,687 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include <map> +#include <utility> +#include <vector> + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr<ModuleAST> parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions and structs one at a time and accumulate in this vector. + std::vector<std::unique_ptr<RecordAST>> records; + while (true) { + std::unique_ptr<RecordAST> record; + switch (lexer.getCurToken()) { + case tok_eof: + break; + case tok_def: + record = parseDefinition(); + break; + case tok_struct: + record = parseStruct(); + break; + default: + return parseError<ModuleAST>("'def' or 'struct'", + "when parsing top level module records"); + } + if (!record) + break; + records.push_back(std::move(record)); + } + + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError<ModuleAST>("nothing", "at end of module"); + + return std::make_unique<ModuleAST>(std::move(records)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr<ReturnExprAST> parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional<std::unique_ptr<ExprAST>> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr<ExprAST> parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr<ExprAST> parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector<std::unique_ptr<ExprAST>> values; + // Hold the dimensions for all the nesting inside this level. + std::vector<int64_t> dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError<ExprAST>("<num> or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError<ExprAST>("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError<ExprAST>("<something>", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { + return llvm::isa<LiteralExprAST>(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get()); + if (!firstLiteral) + return parseError<ExprAST>("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get()); + if (!exprLiteral) + return parseError<ExprAST>("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError<ExprAST>("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values), + std::move(dims)); + } + + /// Parse a literal struct expression. + /// structLiteral ::= { (structLiteral | tensorLiteral)+ } + std::unique_ptr<ExprAST> parseStructLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('{')); + + // Hold the list of values. + std::vector<std::unique_ptr<ExprAST>> values; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; + } else if (lexer.getCurToken() == tok_number) { + values.push_back(parseNumberExpr()); + if (!values.back()) + return nullptr; + } else { + if (lexer.getCurToken() != '{') + return parseError<ExprAST>("{, [, or number", + "in struct literal expression"); + values.push_back(parseStructLiteralExpr()); + } + + // End of this list on '}' + if (lexer.getCurToken() == '}') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError<ExprAST>("} or ,", "in struct literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError<ExprAST>("<something>", + "to fill struct literal expression"); + lexer.getNextToken(); // eat } + + return std::make_unique<StructLiteralExprAST>(std::move(loc), + std::move(values)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr<ExprAST> parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError<ExprAST>(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// Parse a call expression. + std::unique_ptr<ExprAST> parseCallExpr(llvm::StringRef name, + const Location &loc) { + lexer.consume(Token('(')); + std::vector<std::unique_ptr<ExprAST>> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError<ExprAST>(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError<ExprAST>("<single arg>", "as argument to print()"); + + return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args)); + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr<ExprAST> parseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique<VariableExprAST>(std::move(loc), name); + + // This is a function call. + return parseCallExpr(name, loc); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr<ExprAST> parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case '{': + return parseStructLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec, + std::unique_ptr<ExprAST> lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError<ExprAST>("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr<ExprAST> parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr<VarType> parseType() { + if (lexer.getCurToken() != '<') + return parseError<VarType>("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique<VarType>(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError<VarType>(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse either a variable declaration or a call expression. + std::unique_ptr<ExprAST> parseDeclarationOrCallExpr() { + auto loc = lexer.getLastLocation(); + std::string id = lexer.getId(); + lexer.consume(tok_identifier); + + // Check for a call expression. + if (lexer.getCurToken() == '(') + return parseCallExpr(id, loc); + + // Otherwise, this is a variable declaration. + return parseTypedDeclaration(id, /*requiresInitializer=*/true, loc); + } + + /// Parse a typed variable declaration. + std::unique_ptr<VarDeclExprAST> + parseTypedDeclaration(llvm::StringRef typeName, bool requiresInitializer, + const Location &loc) { + // Parse the variable name. + if (lexer.getCurToken() != tok_identifier) + return parseError<VarDeclExprAST>("name", "in variable declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + // Parse the initializer. + std::unique_ptr<ExprAST> expr; + if (requiresInitializer) { + if (lexer.getCurToken() != '=') + return parseError<VarDeclExprAST>("initializer", + "in variable declaration"); + lexer.consume(Token('=')); + expr = parseExpression(); + } + + VarType type; + type.name = typeName; + return std::make_unique<VarDeclExprAST>(loc, std::move(id), std::move(type), + std::move(expr)); + } + + /// Parse a variable declaration, for either a tensor value or a struct value, + /// with an optionally required initializer. + /// decl ::= var identifier [ type ] (= expr)? + /// decl ::= identifier identifier (= expr)? + std::unique_ptr<VarDeclExprAST> parseDeclaration(bool requiresInitializer) { + // Check to see if this is a 'var' declaration. + if (lexer.getCurToken() == tok_var) + return parseVarDeclaration(requiresInitializer); + + // Parse the type name. + if (lexer.getCurToken() != tok_identifier) + return parseError<VarDeclExprAST>("type name", "in variable declaration"); + auto loc = lexer.getLastLocation(); + std::string typeName = lexer.getId(); + lexer.getNextToken(); // eat id + + // Parse the rest of the declaration. + return parseTypedDeclaration(typeName, requiresInitializer, loc); + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// optionally required initializer. + /// decl ::= var identifier [ type ] (= expr)? + std::unique_ptr<VarDeclExprAST> + parseVarDeclaration(bool requiresInitializer) { + if (lexer.getCurToken() != tok_var) + return parseError<VarDeclExprAST>("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError<VarDeclExprAST>("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr<VarType> type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + if (!type) + type = std::make_unique<VarType>(); + + std::unique_ptr<ExprAST> expr; + if (requiresInitializer) { + lexer.consume(Token('=')); + expr = parseExpression(); + } + return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr<ExprASTList> parseBlock() { + if (lexer.getCurToken() != '{') + return parseError<ExprASTList>("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique<ExprASTList>(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_identifier) { + // Variable declaration or call + auto expr = parseDeclarationOrCallExpr(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } else if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(/*requiresInitializer=*/true); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError<ExprASTList>(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError<ExprASTList>("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr<PrototypeAST> parsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError<PrototypeAST>("function name", "in prototype"); + + std::string fnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError<PrototypeAST>("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector<std::unique_ptr<VarDeclExprAST>> args; + if (lexer.getCurToken() != ')') { + do { + VarType type; + std::string name; + + // Parse either the name of the variable, or its type. + std::string nameOrType = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + + // If the next token is an identifier, we just parsed the type. + if (lexer.getCurToken() == tok_identifier) { + type.name = std::move(nameOrType); + + // Parse the name. + name = lexer.getId(); + lexer.consume(tok_identifier); + } else { + // Otherwise, we just parsed the name. + name = std::move(nameOrType); + } + + args.push_back( + std::make_unique<VarDeclExprAST>(std::move(loc), name, type)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError<PrototypeAST>( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError<PrototypeAST>("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique<PrototypeAST>(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr<FunctionAST> parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique<FunctionAST>(std::move(proto), std::move(block)); + return nullptr; + } + + /// Parse a struct definition, we expect a struct initiated with the + /// `struct` keyword, followed by a block containing a list of variable + /// declarations. + /// + /// definition ::= `struct` identifer `{` decl+ `}` + std::unique_ptr<StructAST> parseStruct() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_struct); + if (lexer.getCurToken() != tok_identifier) + return parseError<StructAST>("name", "in struct definition"); + std::string name = lexer.getId(); + lexer.consume(tok_identifier); + + // Parse: '{' + if (lexer.getCurToken() != '{') + return parseError<StructAST>("{", "in struct definition"); + lexer.consume(Token('{')); + + // Parse: decl+ + std::vector<std::unique_ptr<VarDeclExprAST>> decls; + do { + auto decl = parseDeclaration(/*requiresInitializer=*/false); + if (!decl) + return nullptr; + decls.push_back(std::move(decl)); + + if (lexer.getCurToken() != ';') + return parseError<StructAST>(";", + "after variable in struct definition"); + lexer.consume(Token(';')); + } while (lexer.getCurToken() != '}'); + + // Parse: '}' + lexer.consume(Token('}')); + return std::make_unique<StructAST>(loc, name, std::move(decls)); + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast<char>(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + case '.': + return 60; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template <typename R, typename T, typename U = const char *> + std::unique_ptr<R> parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch7/include/toy/Passes.h b/mlir/examples/toy/Ch7/include/toy/Passes.h new file mode 100644 index 00000000000..00fe4ffe49b --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Passes.h @@ -0,0 +1,45 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PASSES_H +#define MLIR_TUTORIAL_TOY_PASSES_H + +#include <memory> + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr<Pass> createDeadFunctionEliminationPass(); +std::unique_ptr<Pass> createShapeInferencePass(); + +/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr<mlir::Pass> createLowerToAffinePass(); + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr<mlir::Pass> createLowerToLLVMPass(); + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h new file mode 100644 index 00000000000..fc36b5b100d --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,37 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td new file mode 100644 index 00000000000..dc345907ac7 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,41 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +#ifndef OP_BASE +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp new file mode 100644 index 00000000000..b58adb5d52f --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp @@ -0,0 +1,68 @@ +//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a Module level pass performing dead function +// elimination. This is required as a post-processing step after function +// inlining. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Passes.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> + +namespace { +/// This is a simple function DCE pass that deletes all non-main functions after +/// inlining. +/// TODO(riverriddle) This is only necessary because MLIR currently does not +/// have generic DCE support for functions. +class DeadFunctionEliminationPass + : public mlir::ModulePass<DeadFunctionEliminationPass> { +public: + void runOnModule() override { + mlir::ModuleOp module = getModule(); + mlir::SymbolTable moduleSymTable(module); + + // Eliminate non-main functions. + auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main"); + for (mlir::FuncOp func : + llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) { + if (func != mainFn) + func.erase(); + } + } +}; +} // end anonymous namespace + +/// Create a pass that eliminates inlined functions in toy. +std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() { + return std::make_unique<DeadFunctionEliminationPass>(); +} diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp new file mode 100644 index 00000000000..2beaa870a89 --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -0,0 +1,483 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, + ArrayRef<Value *> valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast<ReturnOp>(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Type resultType, + Location conversionLoc) const final { + return builder.create<CastOp>(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces<ToyInlinerInterface>(); + addTypes<StructType>(); +} + +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (type.isa<StructType>()) + return builder.create<StructConstantOp>(loc, type, + value.cast<mlir::ArrayAttr>()); + return builder.create<ConstantOp>(loc, type, + value.cast<mlir::DenseElementsAttr>()); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verify that the given attribute value is valid for the given type. +static mlir::LogicalResult verifyConstantForType(mlir::Type type, + mlir::Attribute opaqueValue, + mlir::Operation *op) { + if (type.isa<mlir::TensorType>()) { + // Check that the value is a elements attribute. + auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>(); + if (!attrValue) + return op->emitError("constant of TensorType must be initialized by " + "a DenseFPElementsAttr, got ") + << opaqueValue; + + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = type.dyn_cast<mlir::RankedTensorType>(); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the + // constant result type. + auto attrType = attrValue.getType().cast<mlir::TensorType>(); + if (attrType.getRank() != resultType.getRank()) { + return op->emitOpError("return type must match the one of the attached " + "value attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op->emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); + } + auto resultType = type.cast<StructType>(); + llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes(); + + // Verify that the initializer is an Array. + auto attrValue = opaqueValue.dyn_cast<ArrayAttr>(); + if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) + return op->emitError("constant of StructType must be initialized by an " + "ArrayAttr with the same number of elements, got ") + << opaqueValue; + + // Check that each of the elements are valid. + llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue(); + for (const auto &it : llvm::zip(resultElementTypes, attrElementValues)) + if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) + return mlir::failure(); + return mlir::success(); +} + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +static mlir::LogicalResult verify(ConstantOp op) { + return verifyConstantForType(op.getResult()->getType(), op.value(), op); +} + +static mlir::LogicalResult verify(StructConstantOp op) { + return verifyConstantForType(op.getResult()->getType(), op.value(), op); +} + +/// Infer the output shape of the ConstantOp, this is required by the shape +/// inference interface. +void ConstantOp::inferShapes() { getResult()->setType(value().getType()); } + +//===----------------------------------------------------------------------===// +// AddOp + +void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *lhs, mlir::Value *rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// CastOp + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } + +//===----------------------------------------------------------------------===// +// GenericCallOp + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, + StringRef callee, ArrayRef<mlir::Value *> arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return getAttrOfType<SymbolRefAttr>("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } + +//===----------------------------------------------------------------------===// +// MulOp + +void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *lhs, mlir::Value *rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast<FuncOp>(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() || + resultType.isa<mlir::UnrankedTensorType>()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +//===----------------------------------------------------------------------===// +// StructAccessOp + +void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state, + mlir::Value *input, size_t index) { + // Extract the result type from the input type. + StructType structTy = input->getType().cast<StructType>(); + assert(index < structTy.getNumElementTypes()); + mlir::Type resultType = structTy.getElementTypes()[index]; + + // Call into the auto-generated build method. + build(b, state, resultType, input, b->getI64IntegerAttr(index)); +} + +static mlir::LogicalResult verify(StructAccessOp op) { + StructType structTy = op.input()->getType().cast<StructType>(); + size_t index = op.index().getZExtValue(); + if (index >= structTy.getNumElementTypes()) + return op.emitOpError() + << "index should be within the range of the input struct type"; + mlir::Type resultType = op.getResult()->getType(); + if (resultType != structTy.getElementTypes()[index]) + return op.emitOpError() << "must have the same result type as the struct " + "element referred to by the index"; + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TransposeOp + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *value) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = getOperand()->getType().cast<RankedTensorType>(); + SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +static mlir::LogicalResult verify(TransposeOp op) { + auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>(); + auto resultType = op.getType().dyn_cast<RankedTensorType>(); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return op.emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// Toy Types +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { +namespace detail { +/// This class represents the internal storage of the Toy `StructType`. +struct StructTypeStorage : public mlir::TypeStorage { + /// The `KeyTy` is a required type that provides an interface for the storage + /// instance. This type will be used when uniquing an instance of the type + /// storage. For our struct type, we will unique each instance structurally on + /// the elements that it contains. + using KeyTy = llvm::ArrayRef<mlir::Type>; + + /// A constructor for the type storage instance. + StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes) + : elementTypes(elementTypes) {} + + /// Define the comparison function for the key type with the current storage + /// instance. This is used when constructing a new instance to ensure that we + /// haven't already uniqued an instance of the given key. + bool operator==(const KeyTy &key) const { return key == elementTypes; } + + /// Define a hash function for the key type. This is used when uniquing + /// instances of the storage, see the `StructType::get` method. + /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type + /// have hash functions available, so we could just omit this entirely. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Define a construction function for the key type from a set of parameters. + /// These parameters will be provided when constructing the storage instance + /// itself. + /// Note: This method isn't necessary because KeyTy can be directly + /// constructed with the given parameters. + static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) { + return KeyTy(elementTypes); + } + + /// Define a construction method for creating a new instance of this storage. + /// This method takes an instance of a storage allocator, and an instance of a + /// `KeyTy`. The given allocator must be used for *all* necessary dynamic + /// allocations used to create the type storage and its internal. + static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the elements from the provided `KeyTy` into the allocator. + llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key); + + // Allocate the storage instance and construct it. + return new (allocator.allocate<StructTypeStorage>()) + StructTypeStorage(elementTypes); + } + + /// The following field contains the element types of the struct. + llvm::ArrayRef<mlir::Type> elementTypes; +}; +} // end namespace detail +} // end namespace toy +} // end namespace mlir + +/// Create an instance of a `StructType` with the given element types. There +/// *must* be at least one element type. +StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) { + assert(!elementTypes.empty() && "expected at least 1 element type"); + + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. The first two parameters are the context to unique in and the + // kind of the type. The parameters after the type kind are forwarded to the + // storage instance. + mlir::MLIRContext *ctx = elementTypes.front().getContext(); + return Base::get(ctx, ToyTypes::Struct, elementTypes); +} + +/// Returns the element types of this struct type. +llvm::ArrayRef<mlir::Type> StructType::getElementTypes() { + // 'getImpl' returns a pointer to the internal storage instance. + return getImpl()->elementTypes; +} + +/// Parse an instance of a type registered to the toy dialect. +mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { + // Parse a struct type in the following form: + // struct-type ::= `struct` `<` type (`,` type)* `>` + + // NOTE: All MLIR parser function return a ParseResult. This is a + // specialization of LogicalResult that auto-converts to a `true` boolean + // value on failure to allow for chaining, but may be used with explicit + // `mlir::failed/mlir::succeeded` as desired. + + // Parse: `struct` `<` + if (parser.parseKeyword("struct") || parser.parseLess()) + return Type(); + + // Parse the element types of the struct. + SmallVector<mlir::Type, 1> elementTypes; + do { + // Parse the current element type. + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + mlir::Type elementType; + if (parser.parseType(elementType)) + return nullptr; + + // Check that the type is either a TensorType or another StructType. + if (!elementType.isa<mlir::TensorType>() && + !elementType.isa<StructType>()) { + parser.emitError(typeLoc, "element type for a struct must either " + "be a TensorType or a StructType, got: ") + << elementType; + return Type(); + } + elementTypes.push_back(elementType); + + // Parse the optional: `,` + } while (succeeded(parser.parseOptionalComma())); + + // Parse: `>` + if (parser.parseGreater()) + return Type(); + return StructType::get(elementTypes); +} + +/// Print an instance of a type registered to the toy dialect. +void ToyDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // Currently the only toy type is a struct type. + StructType structType = type.cast<StructType>(); + + // Print the struct type according to the parser format. + printer << "struct<"; + mlir::interleaveComma(structType.getElementTypes(), printer); + printer << '>'; +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp new file mode 100644 index 00000000000..a8e38aef7ad --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -0,0 +1,318 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops and standard operations. This lowering expects that all calls +// have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Convert the given TensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(TensorType type) { + assert(type.hasRank() && "expected only ranked shapes"); + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value *insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { + auto alloc = rewriter.create<AllocOp>(loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc.getOperation()->getBlock(); + alloc.getOperation()->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create<DeallocOp>(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input a rewriter, an array of memRefOperands corresponding +/// to the operands of the input operation, and the set of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = function_ref<Value *(PatternRewriter &rewriter, + ArrayRef<Value *> memRefOperands, + ArrayRef<Value *> loopIvs)>; + +static void lowerOpToLoops(Operation *op, ArrayRef<Value *> operands, + PatternRewriter &rewriter, + LoopIterationFn processIteration) { + auto tensorType = (*op->result_type_begin()).cast<TensorType>(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create an empty affine loop for each of the dimensions within the shape. + SmallVector<Value *, 4> loopIvs; + for (auto dim : tensorType.getShape()) { + auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1); + loop.getBody()->clear(); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body and update the rewriter insertion point to the + // beginning of the loop. + rewriter.setInsertionPointToStart(loop.getBody()); + rewriter.create<AffineTerminatorOp>(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to the processing function with the rewriter, the memref + // operands, and the loop induction variables. This function will return the + // value to store at the current index. + Value *valueToStore = processIteration(rewriter, operands, loopIvs); + rewriter.create<AffineStoreOp>(loc, valueToStore, alloc, + llvm::makeArrayRef(loopIvs)); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace { +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Binary operations +//===----------------------------------------------------------------------===// + +template <typename BinaryOp, typename LoweredBinaryOp> +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext *ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands, + ArrayRef<Value *> loopIvs) { + // Generate an adaptor for the remapped operands of the BinaryOp. This + // allows for using the nice named accessors that are generated by the + // ODS. + typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands); + + // Generate loads for the element of 'lhs' and 'rhs' at the inner + // loop. + auto loadedLhs = + rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs); + auto loadedRhs = + rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs); + + // Create the binary operation performed on the loaded values. + return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs); + }); + return matchSuccess(); + } +}; +using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>; +using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Constant operations +//===----------------------------------------------------------------------===// + +struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { + using OpRewritePattern<toy::ConstantOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { + DenseElementsAttr constantValue = op.value(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = op.getType().cast<TensorType>(); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector<Value *, 8> constantIndices; + for (auto i : llvm::seq<int64_t>( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i)); + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector<Value *, 2> indices; + auto valueIt = constantValue.getValues<FloatAttr>().begin(); + std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create<AffineStoreOp>( + loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc, + llvm::makeArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Return operations +//===----------------------------------------------------------------------===// + +struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { + using OpRewritePattern<toy::ReturnOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { + // During this lowering, we expect that all function calls have been + // inlined. + if (op.hasOperand()) + return matchFailure(); + + // We lower "toy.return" directly to "std.return". + rewriter.replaceOpWithNewOp<ReturnOp>(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Transpose operations +//===----------------------------------------------------------------------===// + +struct TransposeOpLowering : public ConversionPattern { + TransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands, + ArrayRef<Value *> loopIvs) { + // Generate an adaptor for the remapped operands of the TransposeOp. + // This allows for using the nice named accessors that are generated + // by the ODS. + toy::TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands); + Value *input = tranposeAdaptor.input(); + + // Transpose the elements by generating a load from the reverse + // indices. + SmallVector<Value *, 2> reverseIvs(llvm::reverse(loopIvs)); + return rewriter.create<AffineLoadOp>(loc, input, reverseIvs); + }); + return matchSuccess(); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// ToyToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the toy operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Toy dialect. +namespace { +struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> { + void runOnFunction() final; +}; +} // end anonymous namespace. + +void ToyToAffineLoweringPass::runOnFunction() { + auto function = getFunction(); + + // We only lower the main function as we expect that all other functions have + // been inlined. + if (function.getName() != "main") + return; + + // Verify that the given main has no inputs and results. + if (function.getNumArguments() || function.getType().getNumResults()) { + function.emitError("expected 'main' to have 0 inputs and 0 results"); + return signalPassFailure(); + } + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect<toy::ToyDialect>(); + target.addLegalOp<toy::PrintOp>(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + OwningRewritePatternList patterns; + patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering, + ReturnOpLowering, TransposeOpLowering>(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +/// Create a pass for lowering operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() { + return std::make_unique<ToyToAffineLoweringPass>(); +} diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp new file mode 100644 index 00000000000..7e300fb702d --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -0,0 +1,213 @@ +//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops and standard operations. This lowering expects that all calls +// have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/LowerAffine.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToLLVM RewritePatterns +//===----------------------------------------------------------------------===// + +namespace { +/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual +/// elements of the array. +class PrintOpLowering : public ConversionPattern { +public: + explicit PrintOpLowering(MLIRContext *context) + : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto memRefType = (*op->operand_type_begin()).cast<MemRefType>(); + auto memRefShape = memRefType.getShape(); + auto loc = op->getLoc(); + auto *llvmDialect = + op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); + assert(llvmDialect && "expected llvm dialect to be registered"); + + ModuleOp parentModule = op->getParentOfType<ModuleOp>(); + + // Get a symbol reference to the printf function, inserting it if necessary. + auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); + Value *formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, + llvmDialect); + Value *newLineCst = getOrCreateGlobalString( + loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); + + // Create a loop for each of the dimensions within the shape. + SmallVector<Value *, 4> loopIvs; + for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { + auto lowerBound = rewriter.create<ConstantIndexOp>(loc, 0); + auto upperBound = rewriter.create<ConstantIndexOp>(loc, memRefShape[i]); + auto step = rewriter.create<ConstantIndexOp>(loc, 1); + auto loop = + rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step); + loop.getBody()->clear(); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body. + rewriter.setInsertionPointToStart(loop.getBody()); + + // Insert a newline after each of the inner dimensions of the shape. + if (i != e - 1) + rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32), + newLineCst); + rewriter.create<loop::TerminatorOp>(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to printf for the current element of the loop. + auto printOp = cast<toy::PrintOp>(op); + auto elementLoad = rewriter.create<LoadOp>(loc, printOp.input(), loopIvs); + rewriter.create<CallOp>( + loc, printfRef, rewriter.getIntegerType(32), + ArrayRef<Value *>({formatSpecifierCst, elementLoad})); + + // Notify the rewriter that this operation has been removed. + rewriter.eraseOp(op); + return matchSuccess(); + } + +private: + /// Return a symbol reference to the printf function, inserting it into the + /// module if necessary. + static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + auto *context = module.getContext(); + if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf")) + return SymbolRefAttr::get("printf", context); + + // Create a function declaration for printf, the signature is: + // * `i32 (i8*, ...)` + auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, + /*isVarArg=*/true); + + // Insert the printf function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType); + return SymbolRefAttr::get("printf", context); + } + + /// Return a value representing an access into a global string with the given + /// name, creating the string if necessary. + static Value *getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + // Create the global at the entry of the module. + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + auto type = LLVM::LLVMType::getArrayTy( + LLVM::LLVMType::getInt8Ty(llvmDialect), value.size()); + global = builder.create<LLVM::GlobalOp>( + loc, type, /*isConstant=*/true, name, builder.getStringAttr(value)); + } + + // Get the pointer to the first character in the global string. + Value *globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); + Value *cst0 = builder.create<LLVM::ConstantOp>( + loc, LLVM::LLVMType::getInt64Ty(llvmDialect), + builder.getIntegerAttr(builder.getIndexType(), 0)); + return builder.create<LLVM::GEPOp>( + loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, + ArrayRef<Value *>({cst0, cst0})); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ToyToLLVMLoweringPass +//===----------------------------------------------------------------------===// + +namespace { +struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> { + void runOnModule() final; +}; +} // end anonymous namespace + +void ToyToLLVMLoweringPass::runOnModule() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. For this lowering, we are only targeting + // the LLVM dialect. + ConversionTarget target(getContext()); + target.addLegalDialect<LLVM::LLVMDialect>(); + target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); + + // During this lowering, we will also be lowering the MemRef types, that are + // currently being operated on, to a representation in LLVM. Do perform this + // conversion we use a TypeConverter as part of the lowering. This converter + // details how one type maps to another. This is necessary now that we will be + // doing more complicated lowerings, involving loop region arguments. + LLVMTypeConverter typeConverter(&getContext()); + + // Now that the conversion target has been defined, we need to provide the + // patterns used for lowering. At this point of the compilation process, we + // have a combination of `toy`, `affine`, and `std` operations. Luckily, there + // are already exists a set of patterns to transform `affine` and `std` + // dialects. These patterns lowering in multiple stages, relying on transitive + // lowerings. Transitive lowering, or A->B->C lowering, is when multiple + // patterns must be applied to fully transform an illegal operation into a + // set of legal ones. + OwningRewritePatternList patterns; + populateAffineToStdConversionPatterns(patterns, &getContext()); + populateLoopToStdConversionPatterns(patterns, &getContext()); + populateStdToLLVMConversionPatterns(typeConverter, patterns); + + // The only remaining operation to lower from the `toy` dialect, is the + // PrintOp. + patterns.insert<PrintOpLowering>(&getContext()); + + // We want to completely lower to LLVM, so we use a `FullConversion`. This + // ensures that only legal operations will remain after the conversion. + auto module = getModule(); + if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + signalPassFailure(); +} + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr<mlir::Pass> mlir::toy::createLowerToLLVMPass() { + return std::make_unique<ToyToLLVMLoweringPass>(); +} diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp new file mode 100644 index 00000000000..227ebcd758b --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -0,0 +1,683 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include <numeric> + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::makeArrayRef; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (auto &record : moduleAST) { + if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) { + auto func = mlirGen(*funcAST); + if (!func) + return nullptr; + + theModule.push_back(func); + functionMap.insert({func.getName(), func}); + } else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) { + if (failed(mlirGen(*str))) + return nullptr; + } else { + llvm_unreachable("unknown record type"); + } + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable<StringRef, std::pair<mlir::Value *, VarDeclExprAST *>> + symbolTable; + using SymbolTableScopeT = + llvm::ScopedHashTableScope<StringRef, + std::pair<mlir::Value *, VarDeclExprAST *>>; + + /// A mapping for the functions that have been code generated to MLIR. + llvm::StringMap<mlir::FuncOp> functionMap; + + /// A mapping for named struct types to the underlying MLIR type and the + /// original AST node. + llvm::StringMap<std::pair<mlir::Type, StructAST *>> structMap; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(Location loc) { + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value *value) { + if (symbolTable.count(var.getName())) + return mlir::failure(); + symbolTable.insert(var.getName(), {value, &var}); + return mlir::success(); + } + + /// Create an MLIR type for the given struct. + mlir::LogicalResult mlirGen(StructAST &str) { + if (structMap.count(str.getName())) + return emitError(loc(str.loc())) << "error: struct type with name `" + << str.getName() << "' already exists"; + + auto variables = str.getVariables(); + std::vector<mlir::Type> elementTypes; + elementTypes.reserve(variables.size()); + for (auto &variable : variables) { + if (variable->getInitVal()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + if (!variable->getType().shape.empty()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + + mlir::Type type = getType(variable->getType(), variable->loc()); + if (!type) + return mlir::failure(); + elementTypes.push_back(type); + } + + structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + llvm::SmallVector<mlir::Type, 4> argTypes; + argTypes.reserve(proto.getArgs().size()); + for (auto &arg : proto.getArgs()) { + mlir::Type type = getType(arg->getType(), arg->loc()); + if (!type) + return nullptr; + argTypes.push_back(type); + } + auto func_type = builder.getFunctionType(argTypes, llvm::None); + return mlir::FuncOp::create(location, proto.getName(), func_type); + } + + /// Emit a new function and add it to the MLIR module. + mlir::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + SymbolTableScopeT var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + mlir::FuncOp function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + auto &entryBlock = *function.addEntryBlock(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(*std::get<0>(name_value), std::get<1>(name_value)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast<ReturnOp>(entryBlock.back()); + if (!returnOp) { + builder.create<ReturnOp>(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + *returnOp.operand_type_begin())); + } + + return function; + } + + /// Return the struct type that is the result of the given expression, or null + /// if it cannot be inferred. + StructAST *getStructFor(ExprAST *expr) { + llvm::StringRef structName; + if (auto *decl = llvm::dyn_cast<VariableExprAST>(expr)) { + auto varIt = symbolTable.lookup(decl->getName()); + if (!varIt.first) + return nullptr; + structName = varIt.second->getType().name; + } else if (auto *access = llvm::dyn_cast<BinaryExprAST>(expr)) { + if (access->getOp() != '.') + return nullptr; + // The name being accessed should be in the RHS. + auto *name = llvm::dyn_cast<VariableExprAST>(access->getRHS()); + if (!name) + return nullptr; + StructAST *parentStruct = getStructFor(access->getLHS()); + if (!parentStruct) + return nullptr; + + // Get the element within the struct corresponding to the name. + VarDeclExprAST *decl = nullptr; + for (auto &var : parentStruct->getVariables()) { + if (var->getName() == name->getName()) { + decl = var.get(); + break; + } + } + if (!decl) + return nullptr; + structName = decl->getType().name; + } + if (structName.empty()) + return nullptr; + + // If the struct name was valid, check for an entry in the struct map. + auto structIt = structMap.find(structName); + if (structIt == structMap.end()) + return nullptr; + return structIt->second.second; + } + + /// Return the numeric member index of the given struct access expression. + llvm::Optional<size_t> getMemberIndex(BinaryExprAST &accessOp) { + assert(accessOp.getOp() == '.' && "expected access operation"); + + // Lookup the struct node for the LHS. + StructAST *structAST = getStructFor(accessOp.getLHS()); + if (!structAST) + return llvm::None; + + // Get the name from the RHS. + VariableExprAST *name = llvm::dyn_cast<VariableExprAST>(accessOp.getRHS()); + if (!name) + return llvm::None; + + auto structVars = structAST->getVariables(); + auto it = llvm::find_if(structVars, [&](auto &var) { + return var->getName() == name->getName(); + }); + if (it == structVars.end()) + return llvm::None; + return it - structVars.begin(); + } + + /// Emit a binary operation + mlir::Value *mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value *lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + auto location = loc(binop.loc()); + + // If this is an access operation, handle it immediately. + if (binop.getOp() == '.') { + llvm::Optional<size_t> accessIndex = getMemberIndex(binop); + if (!accessIndex) { + emitError(location, "invalid access into struct expression"); + return nullptr; + } + return builder.create<StructAccessOp>(location, lhs, *accessIndex); + } + + // Otherwise, this is a normal binary op. + mlir::Value *rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create<AddOp>(location, lhs, rhs); + case '*': + return builder.create<MulOp>(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value *mlirGen(VariableExprAST &expr) { + if (auto *variable = symbolTable.lookup(expr.getName()).first) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value *expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create<ReturnOp>(location, expr ? makeArrayRef(expr) + : ArrayRef<mlir::Value *>()); + return mlir::success(); + } + + /// Emit a coinstant for a literal/constant array. It will be emitted as a + /// flattened array of data in an Attribute attached to a `toy.constant` + /// operation. See documentation on [Attributes](LangRef.md#attributes) for + /// more details. Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) { + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector<double> data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies<int>())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); + } + mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) { + // The type of this attribute is tensor of 64-bit floating-point with no + // shape. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get({}, elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, + llvm::makeArrayRef(lit.getValue())); + } + /// Emit a constant for a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. This function returns the generated constant, along with the + /// corresponding struct type. + std::pair<mlir::ArrayAttr, mlir::Type> + getConstantAttr(StructLiteralExprAST &lit) { + std::vector<mlir::Attribute> attrElements; + std::vector<mlir::Type> typeElements; + + for (auto &var : lit.getValues()) { + if (auto *number = llvm::dyn_cast<NumberExprAST>(var.get())) { + attrElements.push_back(getConstantAttr(*number)); + typeElements.push_back(getType(llvm::None)); + } else if (auto *lit = llvm::dyn_cast<LiteralExprAST>(var.get())) { + attrElements.push_back(getConstantAttr(*lit)); + typeElements.push_back(getType(llvm::None)); + } else { + auto *structLit = llvm::cast<StructLiteralExprAST>(var.get()); + auto attrTypePair = getConstantAttr(*structLit); + attrElements.push_back(attrTypePair.first); + typeElements.push_back(attrTypePair.second); + } + } + mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements); + mlir::Type dataType = StructType::get(typeElements); + return std::make_pair(dataAttr, dataType); + } + + /// Emit an array literal. + mlir::Value *mlirGen(LiteralExprAST &lit) { + mlir::Type type = getType(lit.getDims()); + mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute); + } + + /// Emit a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. + mlir::Value *mlirGen(StructLiteralExprAST &lit) { + mlir::ArrayAttr dataAttr; + mlir::Type dataType; + std::tie(dataAttr, dataType) = getConstantAttr(lit); + + // Build the MLIR op `toy.struct_constant`. This invokes the + // `StructConstantOp::build` method. + return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector<double> &data) { + if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa<NumberExprAST>(expr) && "expected literal or number expr"); + data.push_back(cast<NumberExprAST>(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value *mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector<mlir::Value *, 4> operands; + for (auto &expr : call.getArgs()) { + auto *arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create<TransposeOp>(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + auto calledFuncIt = functionMap.find(callee); + if (calledFuncIt == functionMap.end()) { + emitError(location) << "no defined function found for '" << callee << "'"; + return nullptr; + } + mlir::FuncOp calledFunc = calledFuncIt->second; + return builder.create<GenericCallOp>( + location, calledFunc.getType().getResult(0), + builder.getSymbolRefAttr(callee), operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto *arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create<PrintOp>(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value *mlirGen(NumberExprAST &num) { + return builder.create<ConstantOp>(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value *mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast<BinaryExprAST>(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast<VariableExprAST>(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast<LiteralExprAST>(expr)); + case toy::ExprAST::Expr_StructLiteral: + return mlirGen(cast<StructLiteralExprAST>(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast<CallExprAST>(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast<NumberExprAST>(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + auto init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value *value = mlirGen(*init); + if (!value) + return nullptr; + + // Handle the case where we are initializing a struct value. + VarType varType = vardecl.getType(); + if (!varType.name.empty()) { + // Check that the initializer type is the same as the variable + // declaration. + mlir::Type type = getType(varType, vardecl.loc()); + if (!type) + return nullptr; + if (type != value->getType()) { + emitError(loc(vardecl.loc())) + << "struct type of initializer is different than the variable " + "declaration. Got " + << value->getType() << ", but expected " << type; + return nullptr; + } + + // Otherwise, we have the initializer value, but in case the variable was + // declared with specific shape, we emit a "reshape" operation. It will + // get optimized out later as needed. + } else if (!varType.shape.empty()) { + value = builder.create<ReshapeOp>(loc(vardecl.loc()), + getType(varType.shape), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl, value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + SymbolTableScopeT var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast<PrintExprAST>(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef<int64_t> shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above for non-struct types). + mlir::Type getType(const VarType &type, const Location &location) { + if (!type.name.empty()) { + auto it = structMap.find(type.name); + if (it == structMap.end()) { + emitError(loc(location)) + << "error: unknown struct type '" << type.name << "'"; + return nullptr; + } + return it->second.first; + } + + return getType(type.shape); + } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp new file mode 100644 index 00000000000..1f572015c39 --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp @@ -0,0 +1,113 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a FunctionPass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { +public: + void runOnFunction() override { + auto f = getFunction(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast<ShapeInference>(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa<RankedTensorType>(); + }); + } +}; +} // end anonymous namespace + +/// Create a Shape Inference pass. +std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() { + return std::make_unique<ShapeInferencePass>(); +} diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp new file mode 100644 index 00000000000..ebd4f5d1103 --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -0,0 +1,101 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "toy/Dialect.h" +#include <numeric> +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace + +/// Fold simple cast operations that return the same type as the input. +OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { + return mlir::impl::foldCastOp(*this); +} + +/// Fold constants. +OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); } + +/// Fold struct constants. +OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) { + return value(); +} + +/// Fold simple struct access operations that access into a constant. +OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) { + auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>(); + if (!structAttr) + return nullptr; + + size_t elementIndex = index().getZExtValue(); + return structAttr.getValue()[elementIndex]; +} + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value *transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp()); + + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<SimplifyRedundantTranspose>(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern, + FoldConstantReshapeOptPattern>(context); +} diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.td b/mlir/examples/toy/Ch7/mlir/ToyCombine.td new file mode 100644 index 00000000000..0a63861fa96 --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.td @@ -0,0 +1,73 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +#ifndef OP_BASE +include "toy/Ops.td" +#endif // OP_BASE + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list<dag> resultPatterns, +/// list<dag> additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch7/parser/AST.cpp b/mlir/examples/toy/Ch7/parser/AST.cpp new file mode 100644 index 00000000000..74056296f2e --- /dev/null +++ b/mlir/examples/toy/Ch7/parser/AST.cpp @@ -0,0 +1,285 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(StructLiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + void dump(StructAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template <typename T> static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { +#define dispatch(CLASS) \ + if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \ + return dump(node); + dispatch(VarDeclExprAST); + dispatch(LiteralExprAST); + dispatch(StructLiteralExprAST); + dispatch(NumberExprAST); + dispatch(VariableExprAST); + dispatch(ReturnExprAST); + dispatch(BinaryExprAST); + dispatch(CallExprAST); + dispatch(PrintExprAST); + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n"; +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + if (auto *initVal = varDecl->getInitVal()) + dump(initVal); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast<LiteralExprAST>(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a struct literal. +void ASTDumper::dump(StructLiteralExprAST *node) { + INDENT(); + llvm::errs() << "Struct Literal: "; + for (auto &value : node->getValues()) + dump(value.get()); + indent(); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + if (!type.name.empty()) + llvm::errs() << type.name; + else + mlir::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a struct. +void ASTDumper::dump(StructAST *node) { + INDENT(); + llvm::errs() << "Struct: " << node->getName() << " " << loc(node) << "\n"; + + { + INDENT(); + llvm::errs() << "Variables: [\n"; + for (auto &variable : node->getVariables()) + dump(variable.get()); + indent(); + llvm::errs() << "]\n"; + } +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &record : *node) { + if (FunctionAST *function = llvm::dyn_cast<FunctionAST>(record.get())) + dump(function); + else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) + dump(str); + else + llvm::errs() << "<unknown Record, kind " << record->getKind() << ">\n"; + } +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp new file mode 100644 index 00000000000..26b684cbe2a --- /dev/null +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -0,0 +1,282 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt<std::string> inputFilename(cl::Positional, + cl::desc("<input toy file>"), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt<enum InputType> inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { + None, + DumpAST, + DumpMLIR, + DumpMLIRAffine, + DumpMLIRLLVM, + DumpLLVMIR, + RunJIT +}; +} +static cl::opt<enum Action> emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering")), + cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", + "output the MLIR dump after llvm lowering")), + cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), + cl::values( + clEnumValN(RunJIT, "jit", + "JIT the code and run it by invoking the main function"))); + +static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + if (int error = loadMLIR(context, module)) + return error; + + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; + + if (enableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect with a few cleanups afterwards. + pm.addPass(mlir::toy::createLowerToAffinePass()); + + mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (enableOpt) { + optPM.addPass(mlir::createLoopFusionPass()); + optPM.addPass(mlir::createMemRefDataFlowOptPass()); + } + } + + if (isLoweringToLLVM) { + // Finish lowering the toy IR to the LLVM dialect. + pm.addPass(mlir::toy::createLowerToLLVMPass()); + } + + if (mlir::failed(pm.run(*module))) + return 4; + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int dumpLLVMIR(mlir::ModuleOp module) { + auto llvmModule = mlir::translateModuleToLLVMIR(module); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} + +int runJit(mlir::ModuleOp module) { + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invoke("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + if (emitAction == Action::DumpAST) + return dumpAST(); + + // If we aren't dumping the AST, then we are compiling with/to MLIR. + + // Register our Dialect with MLIR. + mlir::registerDialect<mlir::toy::ToyDialect>(); + + mlir::MLIRContext context; + mlir::OwningModuleRef module; + if (int error = loadAndProcessMLIR(context, module)) + return error; + + // If we aren't exporting to non-mlir, then we are done. + bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; + if (isOutputingMLIR) { + module->dump(); + return 0; + } + + // Check to see if we are compiling to LLVM IR. + if (emitAction == Action::DumpLLVMIR) + return dumpLLVMIR(*module); + + // Otherwise, we must be running the jit. + if (emitAction == Action::RunJIT) + return runJit(*module); + + llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n"; + return -1; +} diff --git a/mlir/g3doc/Tutorials/Toy/Ch-1.md b/mlir/g3doc/Tutorials/Toy/Ch-1.md index fd08d4743c0..cec79a47866 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-1.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-1.md @@ -25,6 +25,9 @@ This tutorial is divided in the following chapters: - [Chapter #6](Ch-6.md): Lowering to LLVM and code generation. Here we'll target LLVM IR for code generation, and detail more of the lowering framework. +- [Chapter #7](Ch-7.md): Extending Toy: Adding support for a composite type. + We'll demonstrate how to add a custom type to MLIR, and how it fits in the + existing pipeline. ## The Language @@ -87,7 +90,7 @@ def main() { # Finally, calling into `multiply_transpose` with incompatible shape will # trigger a shape inference error. - var e = multiply_transpose(transpose(a), c); + var f = multiply_transpose(transpose(a), c); } ``` diff --git a/mlir/g3doc/Tutorials/Toy/Ch-7.md b/mlir/g3doc/Tutorials/Toy/Ch-7.md new file mode 100644 index 00000000000..3b9896f191d --- /dev/null +++ b/mlir/g3doc/Tutorials/Toy/Ch-7.md @@ -0,0 +1,538 @@ +# Chapter 7: Adding a Composite Type to Toy + +[TOC] + +In the [previous chapter](Ch-6.md), we demonstrated an end-to-end compilation +flow from our Toy front-end to LLVM IR. In this chapter, we will extend the Toy +language to support a new composite `struct` type. + +## Defining a `struct` in Toy + +The first thing we need to define is the interface of this type in our `toy` +source language. The general syntax of a `struct` type in toy is as follows: + +```toy +# A struct is defined by using the `struct` keyword followed by a name. +struct MyStruct { + # Inside of the struct is a list of variable declarations without initializers + # or shapes, which may also be other previously defined structs. + var a; + var b; +} +``` + +Structs may now be used in functions as variables or parameters by using the +name of the struct instead of `var`. The members of the struct are accessed via +a `.` access operator. Values of `struct` type may be initialized with a +composite initializer, or a comma-separated list of other initializers +surrounded by `{}`. An example is shown below: + +```toy +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} +``` + +## Defining a `struct` in MLIR + +In MLIR, we will also need a representation for our struct types. MLIR does not +provide a type that does exactly what we need, so we will need to define our +own. We will simply define our `struct` as an unnamed container of a set of +element types. The name of the `struct` and its elements are only useful for the +AST of our `toy` compiler, so we don't need to encode it in the MLIR +representation. + +### Defining the Type Class + +#### Reserving a Range of Type Kinds + +Types in MLIR rely on having a unique `kind` value to ensure that casting checks +remain extremely +efficient([rationale](Rationale.md#reserving-dialect-type-kinds). For `toy`, +this means we need to explicitly reserve a static range of type `kind` values in +the symbol registry file +[DialectSymbolRegistry](https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/DialectSymbolRegistry.def). + +```c++ +DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect +DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect + +// The following ranges are reserved for experimenting with MLIR dialects in a +// private context without having to register them here. +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0) +``` + +These definitions will provide a range in the Type::Kind enum to use when +defining the derived types. + +```c++ +/// Create a local enumeration with all of the types that are defined by Toy. +namespace ToyTypes { +enum Types { + Struct = mlir::Type::FIRST_TOY_TYPE, +}; +} // end namespace ToyTypes +``` + +#### Defining the Type Class + +As mentioned in [chapter 2](Ch-2.md), [`Type`](../../LangRef.md#type-system) +objects in MLIR are value-typed and rely on having an internal storage object +that holds the actual data for the type. The `Type` class in itself acts as a +simple wrapper around an internal `TypeStorage` object that is uniqued within an +instance of an `MLIRContext`. When constructing a `Type`, we are internally just +constructing and uniquing an instance of a storage class. + +When defining a new `Type` that requires additional information than just the +`kind`, like our struct type for the element types, we will need to provide a +derived storage class. The `primitive` types that don't have any additional +data, like the [`index` type](../../LangRef.md#index-type), don't require a +storage class. + +##### Defining the Storage Class + +Type storage objects contain all of the data necessary to construct and unique a +type instance. Derived storage classes must inherit from the base +`mlir::TypeStorage` and provide a set of aliases and hooks that will be used by +the `MLIRContext` for uniquing. Below is the definition of the storage instance +for our `struct` type, with each of the necessary requirements detailed inline: + +```c++ +/// This class represents the internal storage of the Toy `StructType`. +struct StructTypeStorage : public mlir::TypeStorage { + /// The `KeyTy` is a required type that provides an interface for the storage + /// instance. This type will be used when uniquing an instance of the type + /// storage. For our struct type, we will unique each instance structurally on + /// the elements that it contains. + using KeyTy = llvm::ArrayRef<mlir::Type>; + + /// A constructor for the type storage instance. + StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes) + : elementTypes(elementTypes) {} + + /// Define the comparison function for the key type with the current storage + /// instance. This is used when constructing a new instance to ensure that we + /// haven't already uniqued an instance of the given key. + bool operator==(const KeyTy &key) const { return key == elementTypes; } + + /// Define a hash function for the key type. This is used when uniquing + /// instances of the storage. + /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type + /// have hash functions available, so we could just omit this entirely. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Define a construction function for the key type from a set of parameters. + /// These parameters will be provided when constructing the storage instance + /// itself, see the `StructType::get` method further below. + /// Note: This method isn't necessary because KeyTy can be directly + /// constructed with the given parameters. + static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) { + return KeyTy(elementTypes); + } + + /// Define a construction method for creating a new instance of this storage. + /// This method takes an instance of a storage allocator, and an instance of a + /// `KeyTy`. The given allocator must be used for *all* necessary dynamic + /// allocations used to create the type storage and its internal. + static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the elements from the provided `KeyTy` into the allocator. + llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key); + + // Allocate the storage instance and construct it. + return new (allocator.allocate<StructTypeStorage>()) + StructTypeStorage(elementTypes); + } + + /// The following field contains the element types of the struct. + llvm::ArrayRef<mlir::Type> elementTypes; +}; +``` + +##### Defining the Type Class + +With the storage class defined, we can add the definition for the user visible +`StructType` class. This is the class that we will actually interface with. + +```c++ +/// This class defines the Toy struct type. It represents a collection of +/// element types. All derived types in MLIR must inherit from the CRTP class +/// 'Type::TypeBase'. It takes as template parameters the concrete type +/// (StructType), the base class to use (Type), and the storage class +/// (StructTypeStorage). +class StructType : public mlir::Type::TypeBase<StructType, mlir::Type, + StructTypeStorage> { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// This static method is used to support type inquiry through isa, cast, + /// and dyn_cast. + static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; } + + /// Create an instance of a `StructType` with the given element types. There + /// *must* be at least one element type. + static StructType get(llvm::ArrayRef<mlir::Type> elementTypes) { + assert(!elementTypes.empty() && "expected at least 1 element type"); + + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. The first two parameters are the context to unique in and + // the kind of the type. The parameters after the type kind are forwarded to + // the storage instance. + mlir::MLIRContext *ctx = elementTypes.front().getContext(); + return Base::get(ctx, ToyTypes::Struct, elementTypes); + } + + /// Returns the element types of this struct type. + llvm::ArrayRef<mlir::Type> getElementTypes() { + // 'getImpl' returns a pointer to the internal storage instance. + return getImpl()->elementTypes; + } + + /// Returns the number of element type held by this struct. + size_t getNumElementTypes() { return getElementTypes().size(); } +}; +``` + +and we register this type in the `ToyDialect` constructor in a similar way to +how we did with operations: + +```c++ +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx) { + addTypes<StructType>(); +} +``` + +With this we can now use our `StructType` when generating MLIR from Toy. See +MLIRGen.cpp for more details. + +### Parsing and Printing + +At this point we can use our `StructType` during MLIR generation and +transformation, but we can't output or parse `.mlir`. To support this we need to +add support for parsing and printing instances of the `StructType`. This support +can be added by overriding the `parseType` and `printType` methods on the +`ToyDialect`. + +```c++ +class ToyDialect : public mlir::Dialect { +public: + /// Parse an instance of a type registered to the toy dialect. + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + + /// Print an instance of a type registered to the toy dialect. + void printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const override; +}; +``` + +These methods take an instance of a high level parser or printer that allows for +easily implementing the necessary functionality. Before going into the +implementation, let's think about the syntax that we want for the `struct` type +in the printed IR. As described in the +[MLIR language reference](../../LangRef.md#dialect-types), dialect types are +generally represented as: `!` dialect-namespace `<` type-data `>`; With a pretty +form available under certain circumstances. The responsibility of our `Toy` +parser and printer is to provide the `type-data` bits. We will define our +`StructType` as having the following form: + +``` {.ebnf} + struct-type ::= `struct` `<` type (`,` type)* `>` +``` + +#### Parsing + +An implementation of the parser is shown below: + +```c++ +/// Parse an instance of a type registered to the toy dialect. +mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { + // Parse a struct type in the following form: + // struct-type ::= `struct` `<` type (`,` type)* `>` + + // NOTE: All MLIR parser function return a ParseResult. This is a + // specialization of LogicalResult that auto-converts to a `true` boolean + // value on failure to allow for chaining, but may be used with explicit + // `mlir::failed/mlir::succeeded` as desired. + + // Parse: `struct` `<` + if (parser.parseKeyword("struct") || parser.parseLess()) + return Type(); + + // Parse the element types of the struct. + SmallVector<mlir::Type, 1> elementTypes; + do { + // Parse the current element type. + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + mlir::Type elementType; + if (parser.parseType(elementType)) + return nullptr; + + // Check that the type is either a TensorType or another StructType. + if (!elementType.isa<mlir::TensorType>() && + !elementType.isa<StructType>()) { + parser.emitError(typeLoc, "element type for a struct must either " + "be a TensorType or a StructType, got: ") + << elementType; + return Type(); + } + elementTypes.push_back(elementType); + + // Parse the optional: `,` + } while (succeeded(parser.parseOptionalComma())); + + // Parse: `>` + if (parser.parseGreater()) + return Type(); + return StructType::get(elementTypes); +} +``` + +#### Printing + +As implementation of the printer is shown below: + +```c++ +/// Print an instance of a type registered to the toy dialect. +void ToyDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // Currently the only toy type is a struct type. + StructType structType = type.cast<StructType>(); + + // Print the struct type according to the parser format. + printer << "struct<"; + mlir::interleaveComma(structType.getElementTypes(), printer); + printer << '>'; +} +``` + +Before moving on, let's look at a quick of example showcasing the functionality +we have now: + +```toy +struct Struct { + var a; + var b; +} + +def multiply_transpose(Struct value) { +} +``` + +Which generates the following: + +```mlir +module { + func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) { + "toy.return"() : () -> () + } +} +``` + +### Operating on `StructType` + +Now that the `struct` type has been defined, and we can roundtrip it through the +IR. The next step is to add support for using it within our operations. + +#### Updating Existing Operations + +A few of our existing operations will need to be updated to handle `StructType`. +The first step is to make the ODS framework aware of our Type, so that we can +use it in the operation definitions. A simple example is shown below: + +```td +// Provide a definition for the Toy StructType for use in ODS. This allows for +// using StructType in a similar way to Tensor or MemRef. +def Toy_StructType : + Type<CPred<"$_self.isa<StructType>()">, "Toy struct type">; + +// Provide a definition of the types that are used within the Toy dialect. +def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; +``` + +We can then update our operations, like `ReturnOp` for example, to also accept +the `Toy_StructType`: + +```td +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + ... + let arguments = (ins Variadic<Toy_Type>:$input); + ... +} +``` + +#### Adding New `Toy` Operations + +In addition to the existing operations, we will be adding a few new operations +that will provide more specific handling of `structs`. + +##### `toy.struct_constant` + +This new operation materializes a constant value for a struct. In our current +modeling we just use an [array attribute](../../LangRef.md#array-attribute) that +contains a set of constant values for each of the `struct` elements. + +```mlir + %0 = "toy.struct_constant"() { + value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>] + } : () -> !toy.struct<tensor<*xf64>> +``` + +##### `toy.struct_access` + +This new operation materializes the Nth element of a `struct` value. + +```mlir + %0 = "toy.struct_constant"() { + value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>] + } : () -> !toy.struct<tensor<*xf64>> + %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>>) -> tensor<*xf64> +``` + +With these operations, we can revisit our original example: + +```toy +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} +``` + +and finally get a full MLIR module: + +```mlir +module { + func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> { + %0 = "toy.struct_access"(%arg0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + %1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64> + %2 = "toy.struct_access"(%arg0) {index = 1 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + %3 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64> + %4 = "toy.mul"(%1, %3) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%4) : (tensor<*xf64>) -> () + } + func @main() { + %0 = "toy.struct_constant"() {value = [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct<tensor<*xf64>, tensor<*xf64>> + %1 = "toy.generic_call"(%0) {callee = @multiply_transpose} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + "toy.print"(%1) : (tensor<*xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +#### Optimizing Operations on `StructType` + +Now that we have a few operations operating on `StructType`, we also have many +new constant folding opportunities. After inlining the MLIR module in the +previous section looks something like: + +```mlir +module { + func @main() { + %0 = "toy.struct_constant"() {value = [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct<tensor<*xf64>, tensor<*xf64>> + %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + %2 = "toy.transpose"(%1) : (tensor<*xf64>) -> tensor<*xf64> + %3 = "toy.struct_access"(%0) {index = 1 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + %4 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64> + %5 = "toy.mul"(%2, %4) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.print"(%5) : (tensor<*xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +We have several `toy.struct_access` operations that access into a +`toy.struct_constant`. As detailed in [chapter 3](Ch-3.md), we can add folders +for these `toy` operations by setting the `hasFolder` bit on the operation +definition and providing a definition of the `*Op::fold` method. + +```c++ +/// Fold constants. +OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); } + +/// Fold struct constants. +OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) { + return value(); +} + +/// Fold simple struct access operations that access into a constant. +OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) { + auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>(); + if (!structAttr) + return nullptr; + + size_t elementIndex = index().getZExtValue(); + return structAttr.getValue()[elementIndex]; +} +``` + +To ensure that MLIR generates the proper constant operations when folding our +`Toy` operations, i.e. `ConstantOp` for `TensorType` and `StructConstant` for +`StructType`, we will need to provide an override for the dialect hook +`materializeConstant`. This allows for generic MLIR operations to create +constants for the `Toy` dialect when necessary. + +```c++ +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (type.isa<StructType>()) + return builder.create<StructConstantOp>(loc, type, + value.cast<mlir::ArrayAttr>()); + return builder.create<ConstantOp>(loc, type, + value.cast<mlir::DenseElementsAttr>()); +} +``` + +With this we can now generate code that can be generated to LLVM without any +changes to our pipeline. + +```mlir +module { + func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%2) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +You can build `toyc-ch7` and try yourself: `toyc-ch7 test/struct-codegen.toy +-emit=mlir`. More details can on defining custom types can be found in +[DefiningAttributesAndTypes](../../DefiningAttributesAndTypes.md). diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 9c11adc95db..95792548221 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -50,6 +50,7 @@ if(LLVM_BUILD_EXAMPLES) toyc-ch4 toyc-ch5 toyc-ch6 + toyc-ch7 ) endif() diff --git a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir new file mode 100644 index 00000000000..3d08d0c1d80 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir @@ -0,0 +1,65 @@ +// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s +// RUN: toyc-ch7 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT + +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%3) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () +} + +// CHECK-LABEL: func @main() +// CHECK: [[VAL_0:%.*]] = constant 1.000000e+00 : f64 +// CHECK: [[VAL_1:%.*]] = constant 2.000000e+00 : f64 +// CHECK: [[VAL_2:%.*]] = constant 3.000000e+00 : f64 +// CHECK: [[VAL_3:%.*]] = constant 4.000000e+00 : f64 +// CHECK: [[VAL_4:%.*]] = constant 5.000000e+00 : f64 +// CHECK: [[VAL_5:%.*]] = constant 6.000000e+00 : f64 +// CHECK: [[VAL_6:%.*]] = alloc() : memref<3x2xf64> +// CHECK: [[VAL_7:%.*]] = alloc() : memref<3x2xf64> +// CHECK: [[VAL_8:%.*]] = alloc() : memref<2x3xf64> +// CHECK: affine.store [[VAL_0]], [[VAL_8]][0, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_1]], [[VAL_8]][0, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_2]], [[VAL_8]][0, 2] : memref<2x3xf64> +// CHECK: affine.store [[VAL_3]], [[VAL_8]][1, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_4]], [[VAL_8]][1, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_5]], [[VAL_8]][1, 2] : memref<2x3xf64> +// CHECK: affine.for [[VAL_9:%.*]] = 0 to 3 { +// CHECK: affine.for [[VAL_10:%.*]] = 0 to 2 { +// CHECK: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64> +// CHECK: affine.store [[VAL_11]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<3x2xf64> +// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { +// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { +// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: [[VAL_16:%.*]] = mulf [[VAL_14]], [[VAL_15]] : f64 +// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: "toy.print"([[VAL_6]]) : (memref<3x2xf64>) -> () +// CHECK: dealloc [[VAL_8]] : memref<2x3xf64> +// CHECK: dealloc [[VAL_7]] : memref<3x2xf64> +// CHECK: dealloc [[VAL_6]] : memref<3x2xf64> + +// OPT-LABEL: func @main() +// OPT: [[VAL_0:%.*]] = constant 1.000000e+00 : f64 +// OPT: [[VAL_1:%.*]] = constant 2.000000e+00 : f64 +// OPT: [[VAL_2:%.*]] = constant 3.000000e+00 : f64 +// OPT: [[VAL_3:%.*]] = constant 4.000000e+00 : f64 +// OPT: [[VAL_4:%.*]] = constant 5.000000e+00 : f64 +// OPT: [[VAL_5:%.*]] = constant 6.000000e+00 : f64 +// OPT: [[VAL_6:%.*]] = alloc() : memref<3x2xf64> +// OPT: [[VAL_7:%.*]] = alloc() : memref<2x3xf64> +// OPT: affine.store [[VAL_0]], [[VAL_7]][0, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_1]], [[VAL_7]][0, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_2]], [[VAL_7]][0, 2] : memref<2x3xf64> +// OPT: affine.store [[VAL_3]], [[VAL_7]][1, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_4]], [[VAL_7]][1, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_5]], [[VAL_7]][1, 2] : memref<2x3xf64> +// OPT: affine.for [[VAL_8:%.*]] = 0 to 3 { +// OPT: affine.for [[VAL_9:%.*]] = 0 to 2 { +// OPT: [[VAL_10:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_8]]] : memref<2x3xf64> +// OPT: [[VAL_11:%.*]] = mulf [[VAL_10]], [[VAL_10]] : f64 +// OPT: affine.store [[VAL_11]], [[VAL_6]]{{\[}}[[VAL_8]], [[VAL_9]]] : memref<3x2xf64> +// OPT: "toy.print"([[VAL_6]]) : (memref<3x2xf64>) -> () +// OPT: dealloc [[VAL_7]] : memref<2x3xf64> +// OPT: dealloc [[VAL_6]] : memref<3x2xf64> diff --git a/mlir/test/Examples/Toy/Ch7/ast.toy b/mlir/test/Examples/Toy/Ch7/ast.toy new file mode 100644 index 00000000000..b32410e58e7 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/ast.toy @@ -0,0 +1,76 @@ +# RUN: toyc-ch7 %s -emit=ast 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var e = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1' +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1' +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 +# CHECK-NEXT: ] +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 +# CHECK-NEXT: ] + diff --git a/mlir/test/Examples/Toy/Ch7/codegen.toy b/mlir/test/Examples/Toy/Ch7/codegen.toy new file mode 100644 index 00000000000..e19500bd9ae --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/codegen.toy @@ -0,0 +1,31 @@ +# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} + +# CHECK-LABEL: func @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64> +# CHECK: [[VAL_2:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = "toy.mul"([[VAL_2]], [[VAL_3]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.return"([[VAL_4]]) : (tensor<*xf64>) -> () + +# CHECK-LABEL: func @main() +# CHECK-NEXT: [[VAL_5:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = "toy.reshape"([[VAL_5]]) : (tensor<2x3xf64>) -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = "toy.reshape"([[VAL_7]]) : (tensor<6xf64>) -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_6]], [[VAL_8]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_10:%.*]] = "toy.generic_call"([[VAL_8]], [[VAL_6]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.print"([[VAL_10]]) : (tensor<*xf64>) -> () +# CHECK-NEXT: "toy.return"() : () -> () diff --git a/mlir/test/Examples/Toy/Ch7/invalid.mlir b/mlir/test/Examples/Toy/Ch7/invalid.mlir new file mode 100644 index 00000000000..5d35d95a5bf --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not toyc-ch7 %s -emit=mlir 2>&1 + +// The following IR is not "valid": +// - toy.print should not return a value. +// - toy.print should take an argument. +// - There should be a block terminator. +func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} diff --git a/mlir/test/Examples/Toy/Ch7/llvm-lowering.mlir b/mlir/test/Examples/Toy/Ch7/llvm-lowering.mlir new file mode 100644 index 00000000000..0009bb507eb --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/llvm-lowering.mlir @@ -0,0 +1,23 @@ +// RUN: toyc-ch7 %s -emit=llvm -opt + +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%3) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () +} + +// CHECK-LABEL: define void @main() +// CHECK: @printf +// CHECK-SAME: 1.000000e+00 +// CHECK: @printf +// CHECK-SAME: 1.600000e+01 +// CHECK: @printf +// CHECK-SAME: 4.000000e+00 +// CHECK: @printf +// CHECK-SAME: 2.500000e+01 +// CHECK: @printf +// CHECK-SAME: 9.000000e+00 +// CHECK: @printf +// CHECK-SAME: 3.000000e+01 diff --git a/mlir/test/Examples/Toy/Ch7/scalar.toy b/mlir/test/Examples/Toy/Ch7/scalar.toy new file mode 100644 index 00000000000..f917ea622e5 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/scalar.toy @@ -0,0 +1,14 @@ +# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s + +def main() { + var a<2, 2> = 5.5; + print(a); +} + +# CHECK-LABEL: func @main() { +# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<f64>} : () -> tensor<f64> +# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor<f64>) -> tensor<2x2xf64> +# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> () +# CHECK-NEXT: "toy.return"() : () -> () +# CHECK-NEXT: } + diff --git a/mlir/test/Examples/Toy/Ch7/shape_inference.mlir b/mlir/test/Examples/Toy/Ch7/shape_inference.mlir new file mode 100644 index 00000000000..b9355cf7e25 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/shape_inference.mlir @@ -0,0 +1,30 @@ +// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s + +// Check the result of inlining+shape inference on an input module. + +func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + %1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> + %2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%2) : (tensor<*xf64>) -> () +} +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> + %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> + %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64> + %4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + %5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + "toy.print"(%5) : (tensor<*xf64>) -> () + "toy.return"() : () -> () +} + +// CHECK-NOT: func @multiply_transpose +// CHECK-NOT: tensor<*xf64> + +// CHECK-LABEL: func @main() +// CHECK: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +// CHECK: [[VAL_1:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64> +// CHECK: [[VAL_2:%.*]] = "toy.mul"([[VAL_1]], [[VAL_1]]) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> +// CHECK: "toy.print"([[VAL_2]]) : (tensor<3x2xf64>) -> () +// CHECK: "toy.return"() : () -> () diff --git a/mlir/test/Examples/Toy/Ch7/struct-ast.toy b/mlir/test/Examples/Toy/Ch7/struct-ast.toy new file mode 100644 index 00000000000..dee0d5b0efd --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/struct-ast.toy @@ -0,0 +1,61 @@ +# RUN: toyc-ch7 %s -emit=ast 2>&1 | FileCheck %s + +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} + +# CHECK: Module: +# CHECK-NEXT: Struct: Struct @{{.*}}struct-ast.toy:3:1 +# CHECK-NEXT: Variables: [ +# CHECK-NEXT: VarDecl a<> @{{.*}}struct-ast.toy:4:3 +# CHECK-NEXT: VarDecl b<> @{{.*}}struct-ast.toy:5:3 +# CHECK-NEXT: ] +# CHECK-NEXT:Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}struct-ast.toy:9:1' +# CHECK-NEXT: Params: [value] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}struct-ast.toy:11:31 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}struct-ast.toy:11:10 +# CHECK-NEXT: BinOp: . @{{.*}}struct-ast.toy:11:26 +# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:11:20 +# CHECK-NEXT: var: a @{{.*}}struct-ast.toy:11:26 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}struct-ast.toy:11:31 +# CHECK-NEXT: BinOp: . @{{.*}}struct-ast.toy:11:47 +# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:11:41 +# CHECK-NEXT: var: b @{{.*}}struct-ast.toy:11:47 +# CHECK-NEXT: ] +# CHECK-NEXT: } +# CHECK-NEXT:Function +# CHECK-NEXT: Proto 'main' @{{.*}}struct-ast.toy:14:1' +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl value<Struct> @{{.*}}struct-ast.toy:16:3 +# CHECK-NEXT: Struct Literal: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}struct-ast.toy:16:19 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}struct-ast.toy:16:43 +# CHECK-NEXT: @{{.*}}struct-ast.toy:16:18 +# CHECK-NEXT: VarDecl c<> @{{.*}}struct-ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}struct-ast.toy:19:11 +# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:19:30 +# CHECK-NEXT: ] +# CHECK-NEXT: Print [ @{{.*}}struct-ast.toy:20:3 +# CHECK-NEXT: var: c @{{.*}}struct-ast.toy:20:9 +# CHECK-NEXT: ] +# CHECK-NEXT: }
\ No newline at end of file diff --git a/mlir/test/Examples/Toy/Ch7/struct-codegen.toy b/mlir/test/Examples/Toy/Ch7/struct-codegen.toy new file mode 100644 index 00000000000..66eaf8a1639 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/struct-codegen.toy @@ -0,0 +1,44 @@ +# RUN: toyc-ch7 %s -emit=mlir 2>&1 +# RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT + +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} + +# CHECK-LABEL: func @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_1:%.*]] = "toy.struct_access"([[VAL_0]]) {index = 0 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = "toy.struct_access"([[VAL_0]]) {index = 1 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = "toy.transpose"([[VAL_3]]) : (tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_5:%.*]] = "toy.mul"([[VAL_2]], [[VAL_4]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.return"([[VAL_5]]) : (tensor<*xf64>) -> () + +# CHECK-LABEL: func @main() +# CHECK-NEXT: [[VAL_6:%.*]] = "toy.struct_constant"() {value = [dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct<tensor<*xf64>, tensor<*xf64>> +# CHECK-NEXT: [[VAL_7:%.*]] = "toy.generic_call"([[VAL_6]]) {callee = @multiply_transpose} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> +# CHECK-NEXT: "toy.print"([[VAL_7]]) : (tensor<*xf64>) -> () +# CHECK-NEXT: "toy.return"() : () -> () + +# OPT-LABEL: func @main() +# OPT-NEXT: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +# OPT-NEXT: [[VAL_1:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64> +# OPT-NEXT: [[VAL_2:%.*]] = "toy.mul"([[VAL_1]], [[VAL_1]]) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> +# OPT-NEXT: "toy.print"([[VAL_2]]) : (tensor<3x2xf64>) -> () +# OPT-NEXT: "toy.return"() : () -> () diff --git a/mlir/test/Examples/Toy/Ch7/struct-opt.mlir b/mlir/test/Examples/Toy/Ch7/struct-opt.mlir new file mode 100644 index 00000000000..8c4b055b4bf --- /dev/null +++ b/mlir/test/Examples/Toy/Ch7/struct-opt.mlir @@ -0,0 +1,16 @@ +// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s + +func @main() { + %0 = "toy.struct_constant"() { + value = [[dense<4.000000e+00> : tensor<2x2xf64>], dense<4.000000e+00> : tensor<2x2xf64>] + } : () -> !toy.struct<!toy.struct<tensor<*xf64>>, tensor<*xf64>> + %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct<!toy.struct<tensor<*xf64>>, tensor<*xf64>>) -> !toy.struct<tensor<*xf64>> + %2 = "toy.struct_access"(%1) {index = 0 : i64} : (!toy.struct<tensor<*xf64>>) -> tensor<*xf64> + "toy.print"(%2) : (tensor<*xf64>) -> () + "toy.return"() : () -> () +} + +// CHECK-LABEL: func @main +// CHECK-NEXT: %[[CST:.*]] = "toy.constant" +// CHECK-SAME: dense<4.0 +// CHECK-NEXT: "toy.print"(%[[CST]]) |

