summaryrefslogtreecommitdiffstats
path: root/mlir/examples
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-10-17 14:21:44 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-10-17 14:22:13 -0700
commit0372eb413f1cf917106562be35c633ae428f2973 (patch)
tree92b99a9250941a36c22f80c3dd2f81062764c59f /mlir/examples
parent5b03e692f6716ff4fcc4ae0887d1792562456f4b (diff)
downloadbcm5719-llvm-0372eb413f1cf917106562be35c633ae428f2973.tar.gz
bcm5719-llvm-0372eb413f1cf917106562be35c633ae428f2973.zip
Add Ch.6 of the Toy tutorial.
This chapters introduces the notion of a full conversion, and adds support for lowering down to the LLVM dialect, LLVM IR, and thus code generation. PiperOrigin-RevId: 275337786
Diffstat (limited to 'mlir/examples')
-rw-r--r--mlir/examples/toy/CMakeLists.txt1
-rw-r--r--mlir/examples/toy/Ch6/CMakeLists.txt51
-rw-r--r--mlir/examples/toy/Ch6/include/CMakeLists.txt1
-rw-r--r--mlir/examples/toy/Ch6/include/toy/AST.h253
-rw-r--r--mlir/examples/toy/Ch6/include/toy/CMakeLists.txt9
-rw-r--r--mlir/examples/toy/Ch6/include/toy/Dialect.h55
-rw-r--r--mlir/examples/toy/Ch6/include/toy/Lexer.h239
-rw-r--r--mlir/examples/toy/Ch6/include/toy/MLIRGen.h41
-rw-r--r--mlir/examples/toy/Ch6/include/toy/Ops.td275
-rw-r--r--mlir/examples/toy/Ch6/include/toy/Parser.h492
-rw-r--r--mlir/examples/toy/Ch6/include/toy/Passes.h45
-rw-r--r--mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h37
-rw-r--r--mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td38
-rw-r--r--mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp68
-rw-r--r--mlir/examples/toy/Ch6/mlir/Dialect.cpp268
-rw-r--r--mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp318
-rw-r--r--mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp213
-rw-r--r--mlir/examples/toy/Ch6/mlir/MLIRGen.cpp467
-rw-r--r--mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp113
-rw-r--r--mlir/examples/toy/Ch6/mlir/ToyCombine.cpp83
-rw-r--r--mlir/examples/toy/Ch6/mlir/ToyCombine.td73
-rw-r--r--mlir/examples/toy/Ch6/parser/AST.cpp263
-rw-r--r--mlir/examples/toy/Ch6/toyc.cpp277
23 files changed, 3680 insertions, 0 deletions
diff --git a/mlir/examples/toy/CMakeLists.txt b/mlir/examples/toy/CMakeLists.txt
index 73d5caf3792..52a22014b4e 100644
--- a/mlir/examples/toy/CMakeLists.txt
+++ b/mlir/examples/toy/CMakeLists.txt
@@ -11,3 +11,4 @@ add_subdirectory(Ch2)
add_subdirectory(Ch3)
add_subdirectory(Ch4)
add_subdirectory(Ch5)
+add_subdirectory(Ch6)
diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt
new file mode 100644
index 00000000000..d66c921b15e
--- /dev/null
+++ b/mlir/examples/toy/Ch6/CMakeLists.txt
@@ -0,0 +1,51 @@
+add_subdirectory(include)
+
+set(LLVM_LINK_COMPONENTS
+ Core
+ Support
+ )
+
+set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
+mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
+add_public_tablegen_target(ToyCh6CombineIncGen)
+
+add_toy_chapter(toyc-ch6
+ toyc.cpp
+ parser/AST.cpp
+ mlir/MLIRGen.cpp
+ mlir/Dialect.cpp
+ mlir/DeadFunctionEliminationPass.cpp
+ mlir/LowerToAffineLoops.cpp
+ mlir/LowerToLLVM.cpp
+ mlir/ShapeInferencePass.cpp
+ mlir/ToyCombine.cpp
+ )
+
+add_dependencies(toyc-ch6 ToyCh6ShapeInferenceInterfaceIncGen)
+add_dependencies(toyc-ch6 ToyCh6OpsIncGen)
+add_dependencies(toyc-ch6 ToyCh6CombineIncGen)
+add_dependencies(toyc-ch6 MLIRCallOpInterfacesIncGen)
+include_directories(include/)
+include_directories(${CMAKE_CURRENT_BINARY_DIR})
+include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
+target_link_libraries(toyc-ch6
+ PRIVATE
+ MLIRAffineOps
+ MLIRAnalysis
+ MLIRExecutionEngine
+ MLIRIR
+ MLIRLLVMIR
+ MLIRLoopToStandard
+ MLIRParser
+ MLIRPass
+ MLIRStandardOps
+ MLIRStandardToLLVM
+ MLIRTargetLLVMIR
+ MLIRTransforms
+ )
+
+whole_archive_link(toyc-ch6
+ MLIRAffineOps
+ MLIRLLVMIR
+ MLIRStandardOps
+ )
diff --git a/mlir/examples/toy/Ch6/include/CMakeLists.txt b/mlir/examples/toy/Ch6/include/CMakeLists.txt
new file mode 100644
index 00000000000..37c89d0bae9
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(toy)
diff --git a/mlir/examples/toy/Ch6/include/toy/AST.h b/mlir/examples/toy/Ch6/include/toy/AST.h
new file mode 100644
index 00000000000..2ad3392c11a
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/AST.h
@@ -0,0 +1,253 @@
+//===- AST.h - Node definition for the Toy AST ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST for the Toy language. It is optimized for
+// simplicity, not efficiency. The AST forms a tree structure where each node
+// references its children using std::unique_ptr<>.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_AST_H_
+#define MLIR_TUTORIAL_TOY_AST_H_
+
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <vector>
+
+namespace toy {
+
+/// A variable type with shape information.
+struct VarType {
+ std::vector<int64_t> shape;
+};
+
+/// Base class for all expression nodes.
+class ExprAST {
+public:
+ enum ExprASTKind {
+ Expr_VarDecl,
+ Expr_Return,
+ Expr_Num,
+ Expr_Literal,
+ Expr_Var,
+ Expr_BinOp,
+ Expr_Call,
+ Expr_Print,
+ };
+
+ ExprAST(ExprASTKind kind, Location location)
+ : kind(kind), location(location) {}
+
+ virtual ~ExprAST() = default;
+
+ ExprASTKind getKind() const { return kind; }
+
+ const Location &loc() { return location; }
+
+private:
+ const ExprASTKind kind;
+ Location location;
+};
+
+/// A block-list of expressions.
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+/// Expression class for numeric literals like "1.0".
+class NumberExprAST : public ExprAST {
+ double Val;
+
+public:
+ NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+
+ double getValue() { return Val; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+};
+
+/// Expression class for a literal value.
+class LiteralExprAST : public ExprAST {
+ std::vector<std::unique_ptr<ExprAST>> values;
+ std::vector<int64_t> dims;
+
+public:
+ LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
+ std::vector<int64_t> dims)
+ : ExprAST(Expr_Literal, loc), values(std::move(values)),
+ dims(std::move(dims)) {}
+
+ std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
+ std::vector<int64_t> &getDims() { return dims; }
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+};
+
+/// Expression class for referencing a variable, like "a".
+class VariableExprAST : public ExprAST {
+ std::string name;
+
+public:
+ VariableExprAST(Location loc, const std::string &name)
+ : ExprAST(Expr_Var, loc), name(name) {}
+
+ llvm::StringRef getName() { return name; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+};
+
+/// Expression class for defining a variable.
+class VarDeclExprAST : public ExprAST {
+ std::string name;
+ VarType type;
+ std::unique_ptr<ExprAST> initVal;
+
+public:
+ VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ std::unique_ptr<ExprAST> initVal)
+ : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
+ initVal(std::move(initVal)) {}
+
+ llvm::StringRef getName() { return name; }
+ ExprAST *getInitVal() { return initVal.get(); }
+ VarType &getType() { return type; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+};
+
+/// Expression class for a return operator.
+class ReturnExprAST : public ExprAST {
+ llvm::Optional<std::unique_ptr<ExprAST>> expr;
+
+public:
+ ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
+ : ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+ llvm::Optional<ExprAST *> getExpr() {
+ if (expr.hasValue())
+ return expr->get();
+ return llvm::NoneType();
+ }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+};
+
+/// Expression class for a binary operator.
+class BinaryExprAST : public ExprAST {
+ char Op;
+ std::unique_ptr<ExprAST> LHS, RHS;
+
+public:
+ char getOp() { return Op; }
+ ExprAST *getLHS() { return LHS.get(); }
+ ExprAST *getRHS() { return RHS.get(); }
+
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
+ std::unique_ptr<ExprAST> RHS)
+ : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
+ RHS(std::move(RHS)) {}
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+};
+
+/// Expression class for function calls.
+class CallExprAST : public ExprAST {
+ std::string Callee;
+ std::vector<std::unique_ptr<ExprAST>> Args;
+
+public:
+ CallExprAST(Location loc, const std::string &Callee,
+ std::vector<std::unique_ptr<ExprAST>> Args)
+ : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+
+ llvm::StringRef getCallee() { return Callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+};
+
+/// Expression class for builtin print calls.
+class PrintExprAST : public ExprAST {
+ std::unique_ptr<ExprAST> Arg;
+
+public:
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
+ : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+
+ ExprAST *getArg() { return Arg.get(); }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+};
+
+/// This class represents the "prototype" for a function, which captures its
+/// name, and its argument names (thus implicitly the number of arguments the
+/// function takes).
+class PrototypeAST {
+ Location location;
+ std::string name;
+ std::vector<std::unique_ptr<VariableExprAST>> args;
+
+public:
+ PrototypeAST(Location location, const std::string &name,
+ std::vector<std::unique_ptr<VariableExprAST>> args)
+ : location(location), name(name), args(std::move(args)) {}
+
+ const Location &loc() { return location; }
+ const std::string &getName() const { return name; }
+ const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
+ return args;
+ }
+};
+
+/// This class represents a function definition itself.
+class FunctionAST {
+ std::unique_ptr<PrototypeAST> Proto;
+ std::unique_ptr<ExprASTList> Body;
+
+public:
+ FunctionAST(std::unique_ptr<PrototypeAST> Proto,
+ std::unique_ptr<ExprASTList> Body)
+ : Proto(std::move(Proto)), Body(std::move(Body)) {}
+ PrototypeAST *getProto() { return Proto.get(); }
+ ExprASTList *getBody() { return Body.get(); }
+};
+
+/// This class represents a list of functions to be processed together
+class ModuleAST {
+ std::vector<FunctionAST> functions;
+
+public:
+ ModuleAST(std::vector<FunctionAST> functions)
+ : functions(std::move(functions)) {}
+
+ auto begin() -> decltype(functions.begin()) { return functions.begin(); }
+ auto end() -> decltype(functions.end()) { return functions.end(); }
+};
+
+void dump(ModuleAST &);
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_AST_H_
diff --git a/mlir/examples/toy/Ch6/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch6/include/toy/CMakeLists.txt
new file mode 100644
index 00000000000..aecf11fab6c
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/CMakeLists.txt
@@ -0,0 +1,9 @@
+set(LLVM_TARGET_DEFINITIONS Ops.td)
+mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
+mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
+add_public_tablegen_target(ToyCh6OpsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
+mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(ToyCh6ShapeInferenceInterfaceIncGen)
diff --git a/mlir/examples/toy/Ch6/include/toy/Dialect.h b/mlir/examples/toy/Ch6/include/toy/Dialect.h
new file mode 100644
index 00000000000..556ae972b84
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/Dialect.h
@@ -0,0 +1,55 @@
+//===- Dialect.h - Dialect definition for the Toy IR ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the IR Dialect for the Toy language.
+// See g3doc/Tutorials/Toy/Ch-2.md for more information.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
+#define MLIR_TUTORIAL_TOY_DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/StandardTypes.h"
+#include "toy/ShapeInferenceInterface.h"
+
+namespace mlir {
+namespace toy {
+
+/// This is the definition of the Toy dialect. A dialect inherits from
+/// mlir::Dialect and registers custom attributes, operations, and types (in its
+/// constructor). It can also override some general behavior exposed via virtual
+/// methods.
+class ToyDialect : public mlir::Dialect {
+public:
+ explicit ToyDialect(mlir::MLIRContext *ctx);
+
+ /// Provide a utility accessor to the dialect namespace. This is used by
+ /// several utilities for casting between dialects.
+ static llvm::StringRef getDialectNamespace() { return "toy"; }
+};
+
+/// Include the auto-generated header file containing the declarations of the
+/// toy operations.
+#define GET_OP_CLASSES
+#include "toy/Ops.h.inc"
+
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
diff --git a/mlir/examples/toy/Ch6/include/toy/Lexer.h b/mlir/examples/toy/Ch6/include/toy/Lexer.h
new file mode 100644
index 00000000000..21f92614912
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/Lexer.h
@@ -0,0 +1,239 @@
+//===- Lexer.h - Lexer for the Toy language -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple Lexer for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_LEXER_H_
+#define MLIR_TUTORIAL_TOY_LEXER_H_
+
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+#include <string>
+
+namespace toy {
+
+/// Structure definition a location in a file.
+struct Location {
+ std::shared_ptr<std::string> file; ///< filename.
+ int line; ///< line number.
+ int col; ///< column number.
+};
+
+// List of Token returned by the lexer.
+enum Token : int {
+ tok_semicolon = ';',
+ tok_parenthese_open = '(',
+ tok_parenthese_close = ')',
+ tok_bracket_open = '{',
+ tok_bracket_close = '}',
+ tok_sbracket_open = '[',
+ tok_sbracket_close = ']',
+
+ tok_eof = -1,
+
+ // commands
+ tok_return = -2,
+ tok_var = -3,
+ tok_def = -4,
+
+ // primary
+ tok_identifier = -5,
+ tok_number = -6,
+};
+
+/// The Lexer is an abstract base class providing all the facilities that the
+/// Parser expects. It goes through the stream one token at a time and keeps
+/// track of the location in the file for debugging purpose.
+/// It relies on a subclass to provide a `readNextLine()` method. The subclass
+/// can proceed by reading the next line from the standard input or from a
+/// memory mapped file.
+class Lexer {
+public:
+ /// Create a lexer for the given filename. The filename is kept only for
+ /// debugging purpose (attaching a location to a Token).
+ Lexer(std::string filename)
+ : lastLocation(
+ {std::make_shared<std::string>(std::move(filename)), 0, 0}) {}
+ virtual ~Lexer() = default;
+
+ /// Look at the current token in the stream.
+ Token getCurToken() { return curTok; }
+
+ /// Move to the next token in the stream and return it.
+ Token getNextToken() { return curTok = getTok(); }
+
+ /// Move to the next token in the stream, asserting on the current token
+ /// matching the expectation.
+ void consume(Token tok) {
+ assert(tok == curTok && "consume Token mismatch expectation");
+ getNextToken();
+ }
+
+ /// Return the current identifier (prereq: getCurToken() == tok_identifier)
+ llvm::StringRef getId() {
+ assert(curTok == tok_identifier);
+ return IdentifierStr;
+ }
+
+ /// Return the current number (prereq: getCurToken() == tok_number)
+ double getValue() {
+ assert(curTok == tok_number);
+ return NumVal;
+ }
+
+ /// Return the location for the beginning of the current token.
+ Location getLastLocation() { return lastLocation; }
+
+ // Return the current line in the file.
+ int getLine() { return curLineNum; }
+
+ // Return the current column in the file.
+ int getCol() { return curCol; }
+
+private:
+ /// Delegate to a derived class fetching the next line. Returns an empty
+ /// string to signal end of file (EOF). Lines are expected to always finish
+ /// with "\n"
+ virtual llvm::StringRef readNextLine() = 0;
+
+ /// Return the next character from the stream. This manages the buffer for the
+ /// current line and request the next line buffer to the derived class as
+ /// needed.
+ int getNextChar() {
+ // The current line buffer should not be empty unless it is the end of file.
+ if (curLineBuffer.empty())
+ return EOF;
+ ++curCol;
+ auto nextchar = curLineBuffer.front();
+ curLineBuffer = curLineBuffer.drop_front();
+ if (curLineBuffer.empty())
+ curLineBuffer = readNextLine();
+ if (nextchar == '\n') {
+ ++curLineNum;
+ curCol = 0;
+ }
+ return nextchar;
+ }
+
+ /// Return the next token from standard input.
+ Token getTok() {
+ // Skip any whitespace.
+ while (isspace(LastChar))
+ LastChar = Token(getNextChar());
+
+ // Save the current location before reading the token characters.
+ lastLocation.line = curLineNum;
+ lastLocation.col = curCol;
+
+ if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
+ IdentifierStr = (char)LastChar;
+ while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
+ IdentifierStr += (char)LastChar;
+
+ if (IdentifierStr == "return")
+ return tok_return;
+ if (IdentifierStr == "def")
+ return tok_def;
+ if (IdentifierStr == "var")
+ return tok_var;
+ return tok_identifier;
+ }
+
+ if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
+ std::string NumStr;
+ do {
+ NumStr += LastChar;
+ LastChar = Token(getNextChar());
+ } while (isdigit(LastChar) || LastChar == '.');
+
+ NumVal = strtod(NumStr.c_str(), nullptr);
+ return tok_number;
+ }
+
+ if (LastChar == '#') {
+ // Comment until end of line.
+ do
+ LastChar = Token(getNextChar());
+ while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+
+ if (LastChar != EOF)
+ return getTok();
+ }
+
+ // Check for end of file. Don't eat the EOF.
+ if (LastChar == EOF)
+ return tok_eof;
+
+ // Otherwise, just return the character as its ascii value.
+ Token ThisChar = Token(LastChar);
+ LastChar = Token(getNextChar());
+ return ThisChar;
+ }
+
+ /// The last token read from the input.
+ Token curTok = tok_eof;
+
+ /// Location for `curTok`.
+ Location lastLocation;
+
+ /// If the current Token is an identifier, this string contains the value.
+ std::string IdentifierStr;
+
+ /// If the current Token is a number, this contains the value.
+ double NumVal = 0;
+
+ /// The last value returned by getNextChar(). We need to keep it around as we
+ /// always need to read ahead one character to decide when to end a token and
+ /// we can't put it back in the stream after reading from it.
+ Token LastChar = Token(' ');
+
+ /// Keep track of the current line number in the input stream
+ int curLineNum = 0;
+
+ /// Keep track of the current column number in the input stream
+ int curCol = 0;
+
+ /// Buffer supplied by the derived class on calls to `readNextLine()`
+ llvm::StringRef curLineBuffer = "\n";
+};
+
+/// A lexer implementation operating on a buffer in memory.
+class LexerBuffer final : public Lexer {
+public:
+ LexerBuffer(const char *begin, const char *end, std::string filename)
+ : Lexer(std::move(filename)), current(begin), end(end) {}
+
+private:
+ /// Provide one line at a time to the Lexer, return an empty string when
+ /// reaching the end of the buffer.
+ llvm::StringRef readNextLine() override {
+ auto *begin = current;
+ while (current <= end && *current && *current != '\n')
+ ++current;
+ if (current <= end && *current)
+ ++current;
+ llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
+ return result;
+ }
+ const char *current, *end;
+};
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_LEXER_H_
diff --git a/mlir/examples/toy/Ch6/include/toy/MLIRGen.h b/mlir/examples/toy/Ch6/include/toy/MLIRGen.h
new file mode 100644
index 00000000000..287f432c847
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/MLIRGen.h
@@ -0,0 +1,41 @@
+//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares a simple interface to perform IR generation targeting MLIR
+// from a Module AST for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_
+#define MLIR_TUTORIAL_TOY_MLIRGEN_H_
+
+#include <memory>
+
+namespace mlir {
+class MLIRContext;
+class OwningModuleRef;
+} // namespace mlir
+
+namespace toy {
+class ModuleAST;
+
+/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
+/// or nullptr on failure.
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td
new file mode 100644
index 00000000000..0eb30a7e022
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/Ops.td
@@ -0,0 +1,275 @@
+//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines the operations of the Toy dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef TOY_OPS
+#else
+#define TOY_OPS
+
+#ifdef MLIR_CALLINTERFACES
+#else
+include "mlir/Analysis/CallInterfaces.td"
+#endif // MLIR_CALLINTERFACES
+
+#ifdef SHAPE_INFERENCE_INTERFACE
+#else
+include "toy/ShapeInferenceInterface.td"
+#endif // SHAPE_INFERENCE_INTERFACE
+
+// Provide a definition of the 'toy' dialect in the ODS framework so that we
+// can define our operations.
+def Toy_Dialect : Dialect {
+ let name = "toy";
+ let cppNamespace = "toy";
+}
+
+// Base class for toy dialect operations. This operation inherits from the base
+// `Op` class in OpBase.td, and provides:
+// * The parent dialect of the operation.
+// * The mnemonic for the operation, or the name without the dialect prefix.
+// * A list of traits for the operation.
+class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<Toy_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+// We define a toy operation by inherting from our base 'Toy_Op' class above.
+// Here we provide the mnemonic and a list of traits for the operation. The
+// constant operation is marked as 'NoSideEffect' as it is a pure operation
+// and may be removed if dead.
+def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
+ // Provide a summary and description for this operation. This can be used to
+ // auto-generate documenatation of the operations within our dialect.
+ let summary = "constant";
+ let description = [{
+ Constant operation turns a literal into an SSA value. The data is attached
+ to the operation as an attribute. For example:
+
+ ```mlir
+ %0 = "toy.constant"()
+ { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
+ : () -> tensor<2x3xf64>
+ ```
+ }];
+
+ // The constant operation takes an attribute as the only input.
+ let arguments = (ins F64ElementsAttr:$value);
+
+ // The constant operation returns a single value of TensorType.
+ let results = (outs F64Tensor);
+
+ // Add custom build methods for the constant operation. These method populates
+ // the `state` that MLIR uses to create operations, i.e. these are used when
+ // using `builder.create<ConstantOp>(...)`.
+ let builders = [
+ // Build a constant with a given constant tensor value.
+ OpBuilder<"Builder *builder, OperationState &result, "
+ "DenseElementsAttr value", [{
+ build(builder, result, value.getType(), value);
+ }]>,
+
+ // Build a constant with a given constant floating-point value.
+ OpBuilder<"Builder *builder, OperationState &result, double value", [{
+ buildConstantOp(builder, result, value);
+ }]>
+ ];
+
+ // Invoke a static verify method to verify this constant operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def AddOp : Toy_Op<"add",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+ let summary = "element-wise addition operation";
+ let description = [{
+ The "add" operation performs element-wise addition between two tensors.
+ The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ // Allow building an AddOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
+ buildAddOp(b, result, lhs, rhs);
+ }]
+ >];
+}
+
+def CastOp : Toy_Op<"cast",
+ [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
+ SameOperandsAndResultShape]> {
+ let summary = "shape cast operation";
+ let description = [{
+ The "cast" operation converts a tensor from one type to an equivalent type
+ without changing any data elements. The source and destination types
+ must both be tensor types with the same element type. If both are ranked
+ then the rank should be the same and static dimensions should match. The
+ operation is invalid if converting to a mismatching constant dimension.
+ }];
+
+ let arguments = (ins F64Tensor:$input);
+ let results = (outs F64Tensor:$output);
+
+ // Set the folder bit so that we can fold redundant cast operations.
+ let hasFolder = 1;
+}
+
+def GenericCallOp : Toy_Op<"generic_call",
+ [DeclareOpInterfaceMethods<CallOpInterface>]> {
+ let summary = "generic call operation";
+ let description = [{
+ Generic calls represent calls to a user defined function that needs to
+ be specialized for the shape of its arguments. The callee name is attached
+ as a symbol reference via an attribute. The arguments list must match the
+ arguments expected by the callee. For example:
+
+ ```mlir
+ %4 = "toy.generic_call"(%1, %3) {callee = @my_func}
+ : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
+ ```
+
+ This is only valid if a function named "my_func" exists and takes two
+ arguments.
+ }];
+
+ // The generic call operation takes a symbol reference attribute as the
+ // callee, and inputs for the call.
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+
+ // The generic call operation returns a single value of TensorType.
+ let results = (outs F64Tensor);
+
+ // Add custom build methods for the generic call operation.
+ let builders = [
+ // Build a constant with a given constant tensor value.
+ OpBuilder<"Builder *builder, OperationState &result, "
+ "StringRef callee, ArrayRef<Value *> arguments", [{
+ buildGenericCallOp(builder, result, callee, arguments);
+ }]>
+ ];
+}
+
+def MulOp : Toy_Op<"mul",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+ let summary = "element-wise multiplication operation";
+ let description = [{
+ The "mul" operation performs element-wise multiplication between two
+ tensors. The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ // Allow building a MulOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
+ buildMulOp(b, result, lhs, rhs);
+ }]
+ >];
+}
+
+def PrintOp : Toy_Op<"print"> {
+ let summary = "print operation";
+ let description = [{
+ The "print" builtin operation prints a given input tensor, and produces
+ no results.
+ }];
+
+ // The print operation takes an input tensor to print.
+ // We also allow a F64MemRef to enable interop during partial lowering.
+ let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
+}
+
+def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
+ let summary = "tensor reshape operation";
+ let description = [{
+ Reshape operation is transforming its input tensor into a new tensor with
+ the same number of elements but different shapes. For example:
+
+ ```mlir
+ %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
+ ```
+ }];
+
+ let arguments = (ins F64Tensor:$input);
+ let hasCanonicalizer = 1;
+
+ // We expect that the reshape operation returns a statically shaped tensor.
+ let results = (outs StaticShapeTensorOf<[F64]>);
+}
+
+def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
+ let summary = "return operation";
+ let description = [{
+ The "return" operation represents a return operation within a function.
+ The operation takes an optional tensor operand and produces no results.
+ The operand type must match the signature of the function that contains
+ the operation. For example:
+
+ ```mlir
+ func @foo() -> tensor<2xf64> {
+ ...
+ toy.return %0 : tensor<2xf64>
+ }
+ ```
+ }];
+
+ // The return operation takes an optional input operand to return. This
+ // value must match the return type of the enclosing function.
+ let arguments = (ins Variadic<F64Tensor>:$input);
+
+ // Allow building a ReturnOp with no return operand.
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+
+ // Provide extra utility definitions on the c++ operation class definition.
+ let extraClassDeclaration = [{
+ bool hasOperand() { return getNumOperands() != 0; }
+ }];
+
+ // Invoke a static verify method to verify this return operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def TransposeOp : Toy_Op<"transpose",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+ let summary = "transpose operation";
+
+ let arguments = (ins F64Tensor:$input);
+ let results = (outs F64Tensor);
+ let hasCanonicalizer = 1;
+
+ // Allow building a TransposeOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *input", [{
+ buildTransposeOp(b, result, input);
+ }]
+ >];
+
+ // Invoke a static verify method to verify this transpose operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+#endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch6/include/toy/Parser.h b/mlir/examples/toy/Ch6/include/toy/Parser.h
new file mode 100644
index 00000000000..5962a982ff0
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/Parser.h
@@ -0,0 +1,492 @@
+//===- Parser.h - Toy Language Parser -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the parser for the Toy language. It processes the Token
+// provided by the Lexer and returns an AST.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PARSER_H
+#define MLIR_TUTORIAL_TOY_PARSER_H
+
+#include "toy/AST.h"
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <utility>
+#include <vector>
+
+namespace toy {
+
+/// This is a simple recursive parser for the Toy language. It produces a well
+/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
+/// or symbol resolution is performed. For example, variables are referenced by
+/// string and the code could reference an undeclared variable and the parsing
+/// succeeds.
+class Parser {
+public:
+ /// Create a Parser for the supplied lexer.
+ Parser(Lexer &lexer) : lexer(lexer) {}
+
+ /// Parse a full Module. A module is a list of function definitions.
+ std::unique_ptr<ModuleAST> ParseModule() {
+ lexer.getNextToken(); // prime the lexer
+
+ // Parse functions one at a time and accumulate in this vector.
+ std::vector<FunctionAST> functions;
+ while (auto F = ParseDefinition()) {
+ functions.push_back(std::move(*F));
+ if (lexer.getCurToken() == tok_eof)
+ break;
+ }
+ // If we didn't reach EOF, there was an error during parsing
+ if (lexer.getCurToken() != tok_eof)
+ return parseError<ModuleAST>("nothing", "at end of module");
+
+ return std::make_unique<ModuleAST>(std::move(functions));
+ }
+
+private:
+ Lexer &lexer;
+
+ /// Parse a return statement.
+ /// return :== return ; | return expr ;
+ std::unique_ptr<ReturnExprAST> ParseReturn() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_return);
+
+ // return takes an optional argument
+ llvm::Optional<std::unique_ptr<ExprAST>> expr;
+ if (lexer.getCurToken() != ';') {
+ expr = ParseExpression();
+ if (!expr)
+ return nullptr;
+ }
+ return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
+ }
+
+ /// Parse a literal number.
+ /// numberexpr ::= number
+ std::unique_ptr<ExprAST> ParseNumberExpr() {
+ auto loc = lexer.getLastLocation();
+ auto Result =
+ std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
+ lexer.consume(tok_number);
+ return std::move(Result);
+ }
+
+ /// Parse a literal array expression.
+ /// tensorLiteral ::= [ literalList ] | number
+ /// literalList ::= tensorLiteral | tensorLiteral, literalList
+ std::unique_ptr<ExprAST> ParseTensorLitteralExpr() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(Token('['));
+
+ // Hold the list of values at this nesting level.
+ std::vector<std::unique_ptr<ExprAST>> values;
+ // Hold the dimensions for all the nesting inside this level.
+ std::vector<int64_t> dims;
+ do {
+ // We can have either another nested array or a number literal.
+ if (lexer.getCurToken() == '[') {
+ values.push_back(ParseTensorLitteralExpr());
+ if (!values.back())
+ return nullptr; // parse error in the nested array.
+ } else {
+ if (lexer.getCurToken() != tok_number)
+ return parseError<ExprAST>("<num> or [", "in literal expression");
+ values.push_back(ParseNumberExpr());
+ }
+
+ // End of this list on ']'
+ if (lexer.getCurToken() == ']')
+ break;
+
+ // Elements are separated by a comma.
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>("] or ,", "in literal expression");
+
+ lexer.getNextToken(); // eat ,
+ } while (true);
+ if (values.empty())
+ return parseError<ExprAST>("<something>", "to fill literal expression");
+ lexer.getNextToken(); // eat ]
+ /// Fill in the dimensions now. First the current nesting level:
+ dims.push_back(values.size());
+ /// If there is any nested array, process all of them and ensure that
+ /// dimensions are uniform.
+ if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
+ return llvm::isa<LiteralExprAST>(expr.get());
+ })) {
+ auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
+ if (!firstLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+
+ // Append the nested dimensions to the current level
+ auto &firstDims = firstLiteral->getDims();
+ dims.insert(dims.end(), firstDims.begin(), firstDims.end());
+
+ // Sanity check that shape is uniform across all elements of the list.
+ for (auto &expr : values) {
+ auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
+ if (!exprLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+ if (exprLiteral->getDims() != firstDims)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+ }
+ }
+ return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
+ std::move(dims));
+ }
+
+ /// parenexpr ::= '(' expression ')'
+ std::unique_ptr<ExprAST> ParseParenExpr() {
+ lexer.getNextToken(); // eat (.
+ auto V = ParseExpression();
+ if (!V)
+ return nullptr;
+
+ if (lexer.getCurToken() != ')')
+ return parseError<ExprAST>(")", "to close expression with parentheses");
+ lexer.consume(Token(')'));
+ return V;
+ }
+
+ /// identifierexpr
+ /// ::= identifier
+ /// ::= identifier '(' expression ')'
+ std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::string name = lexer.getId();
+
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat identifier.
+
+ if (lexer.getCurToken() != '(') // Simple variable ref.
+ return std::make_unique<VariableExprAST>(std::move(loc), name);
+
+ // This is a function call.
+ lexer.consume(Token('('));
+ std::vector<std::unique_ptr<ExprAST>> Args;
+ if (lexer.getCurToken() != ')') {
+ while (true) {
+ if (auto Arg = ParseExpression())
+ Args.push_back(std::move(Arg));
+ else
+ return nullptr;
+
+ if (lexer.getCurToken() == ')')
+ break;
+
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>(", or )", "in argument list");
+ lexer.getNextToken();
+ }
+ }
+ lexer.consume(Token(')'));
+
+ // It can be a builtin call to print
+ if (name == "print") {
+ if (Args.size() != 1)
+ return parseError<ExprAST>("<single arg>", "as argument to print()");
+
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
+ }
+
+ // Call to a user-defined function
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
+ }
+
+ /// primary
+ /// ::= identifierexpr
+ /// ::= numberexpr
+ /// ::= parenexpr
+ /// ::= tensorliteral
+ std::unique_ptr<ExprAST> ParsePrimary() {
+ switch (lexer.getCurToken()) {
+ default:
+ llvm::errs() << "unknown token '" << lexer.getCurToken()
+ << "' when expecting an expression\n";
+ return nullptr;
+ case tok_identifier:
+ return ParseIdentifierExpr();
+ case tok_number:
+ return ParseNumberExpr();
+ case '(':
+ return ParseParenExpr();
+ case '[':
+ return ParseTensorLitteralExpr();
+ case ';':
+ return nullptr;
+ case '}':
+ return nullptr;
+ }
+ }
+
+ /// Recursively parse the right hand side of a binary expression, the ExprPrec
+ /// argument indicates the precedence of the current binary operator.
+ ///
+ /// binoprhs ::= ('+' primary)*
+ std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
+ std::unique_ptr<ExprAST> LHS) {
+ // If this is a binop, find its precedence.
+ while (true) {
+ int TokPrec = GetTokPrecedence();
+
+ // If this is a binop that binds at least as tightly as the current binop,
+ // consume it, otherwise we are done.
+ if (TokPrec < ExprPrec)
+ return LHS;
+
+ // Okay, we know this is a binop.
+ int BinOp = lexer.getCurToken();
+ lexer.consume(Token(BinOp));
+ auto loc = lexer.getLastLocation();
+
+ // Parse the primary expression after the binary operator.
+ auto RHS = ParsePrimary();
+ if (!RHS)
+ return parseError<ExprAST>("expression", "to complete binary operator");
+
+ // If BinOp binds less tightly with RHS than the operator after RHS, let
+ // the pending operator take RHS as its LHS.
+ int NextPrec = GetTokPrecedence();
+ if (TokPrec < NextPrec) {
+ RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
+ if (!RHS)
+ return nullptr;
+ }
+
+ // Merge LHS/RHS.
+ LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
+ std::move(LHS), std::move(RHS));
+ }
+ }
+
+ /// expression::= primary binoprhs
+ std::unique_ptr<ExprAST> ParseExpression() {
+ auto LHS = ParsePrimary();
+ if (!LHS)
+ return nullptr;
+
+ return ParseBinOpRHS(0, std::move(LHS));
+ }
+
+ /// type ::= < shape_list >
+ /// shape_list ::= num | num , shape_list
+ std::unique_ptr<VarType> ParseType() {
+ if (lexer.getCurToken() != '<')
+ return parseError<VarType>("<", "to begin type");
+ lexer.getNextToken(); // eat <
+
+ auto type = std::make_unique<VarType>();
+
+ while (lexer.getCurToken() == tok_number) {
+ type->shape.push_back(lexer.getValue());
+ lexer.getNextToken();
+ if (lexer.getCurToken() == ',')
+ lexer.getNextToken();
+ }
+
+ if (lexer.getCurToken() != '>')
+ return parseError<VarType>(">", "to end type");
+ lexer.getNextToken(); // eat >
+ return type;
+ }
+
+ /// Parse a variable declaration, it starts with a `var` keyword followed by
+ /// and identifier and an optional type (shape specification) before the
+ /// initializer.
+ /// decl ::= var identifier [ type ] = expr
+ std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ if (lexer.getCurToken() != tok_var)
+ return parseError<VarDeclExprAST>("var", "to begin declaration");
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat var
+
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<VarDeclExprAST>("identified",
+ "after 'var' declaration");
+ std::string id = lexer.getId();
+ lexer.getNextToken(); // eat id
+
+ std::unique_ptr<VarType> type; // Type is optional, it can be inferred
+ if (lexer.getCurToken() == '<') {
+ type = ParseType();
+ if (!type)
+ return nullptr;
+ }
+
+ if (!type)
+ type = std::make_unique<VarType>();
+ lexer.consume(Token('='));
+ auto expr = ParseExpression();
+ return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
+ std::move(*type), std::move(expr));
+ }
+
+ /// Parse a block: a list of expression separated by semicolons and wrapped in
+ /// curly braces.
+ ///
+ /// block ::= { expression_list }
+ /// expression_list ::= block_expr ; expression_list
+ /// block_expr ::= decl | "return" | expr
+ std::unique_ptr<ExprASTList> ParseBlock() {
+ if (lexer.getCurToken() != '{')
+ return parseError<ExprASTList>("{", "to begin block");
+ lexer.consume(Token('{'));
+
+ auto exprList = std::make_unique<ExprASTList>();
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+
+ while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
+ if (lexer.getCurToken() == tok_var) {
+ // Variable declaration
+ auto varDecl = ParseDeclaration();
+ if (!varDecl)
+ return nullptr;
+ exprList->push_back(std::move(varDecl));
+ } else if (lexer.getCurToken() == tok_return) {
+ // Return statement
+ auto ret = ParseReturn();
+ if (!ret)
+ return nullptr;
+ exprList->push_back(std::move(ret));
+ } else {
+ // General expression
+ auto expr = ParseExpression();
+ if (!expr)
+ return nullptr;
+ exprList->push_back(std::move(expr));
+ }
+ // Ensure that elements are separated by a semicolon.
+ if (lexer.getCurToken() != ';')
+ return parseError<ExprASTList>(";", "after expression");
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+ }
+
+ if (lexer.getCurToken() != '}')
+ return parseError<ExprASTList>("}", "to close block");
+
+ lexer.consume(Token('}'));
+ return exprList;
+ }
+
+ /// prototype ::= def id '(' decl_list ')'
+ /// decl_list ::= identifier | identifier, decl_list
+ std::unique_ptr<PrototypeAST> ParsePrototype() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_def);
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>("function name", "in prototype");
+
+ std::string FnName = lexer.getId();
+ lexer.consume(tok_identifier);
+
+ if (lexer.getCurToken() != '(')
+ return parseError<PrototypeAST>("(", "in prototype");
+ lexer.consume(Token('('));
+
+ std::vector<std::unique_ptr<VariableExprAST>> args;
+ if (lexer.getCurToken() != ')') {
+ do {
+ std::string name = lexer.getId();
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_identifier);
+ auto decl = std::make_unique<VariableExprAST>(std::move(loc), name);
+ args.push_back(std::move(decl));
+ if (lexer.getCurToken() != ',')
+ break;
+ lexer.consume(Token(','));
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>(
+ "identifier", "after ',' in function parameter list");
+ } while (true);
+ }
+ if (lexer.getCurToken() != ')')
+ return parseError<PrototypeAST>("}", "to end function prototype");
+
+ // success.
+ lexer.consume(Token(')'));
+ return std::make_unique<PrototypeAST>(std::move(loc), FnName,
+ std::move(args));
+ }
+
+ /// Parse a function definition, we expect a prototype initiated with the
+ /// `def` keyword, followed by a block containing a list of expressions.
+ ///
+ /// definition ::= prototype block
+ std::unique_ptr<FunctionAST> ParseDefinition() {
+ auto Proto = ParsePrototype();
+ if (!Proto)
+ return nullptr;
+
+ if (auto block = ParseBlock())
+ return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ return nullptr;
+ }
+
+ /// Get the precedence of the pending binary operator token.
+ int GetTokPrecedence() {
+ if (!isascii(lexer.getCurToken()))
+ return -1;
+
+ // 1 is lowest precedence.
+ switch (static_cast<char>(lexer.getCurToken())) {
+ case '-':
+ return 20;
+ case '+':
+ return 20;
+ case '*':
+ return 40;
+ default:
+ return -1;
+ }
+ }
+
+ /// Helper function to signal errors while parsing, it takes an argument
+ /// indicating the expected token and another argument giving more context.
+ /// Location is retrieved from the lexer to enrich the error message.
+ template <typename R, typename T, typename U = const char *>
+ std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
+ auto curToken = lexer.getCurToken();
+ llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
+ << lexer.getLastLocation().col << "): expected '" << expected
+ << "' " << context << " but has Token " << curToken;
+ if (isprint(curToken))
+ llvm::errs() << " '" << (char)curToken << "'";
+ llvm::errs() << "\n";
+ return nullptr;
+ }
+};
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_PARSER_H
diff --git a/mlir/examples/toy/Ch6/include/toy/Passes.h b/mlir/examples/toy/Ch6/include/toy/Passes.h
new file mode 100644
index 00000000000..00fe4ffe49b
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/Passes.h
@@ -0,0 +1,45 @@
+//===- Passes.h - Toy Passes Definition -----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file exposes the entry points to create compiler passes for Toy.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PASSES_H
+#define MLIR_TUTORIAL_TOY_PASSES_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+namespace toy {
+std::unique_ptr<Pass> createDeadFunctionEliminationPass();
+std::unique_ptr<Pass> createShapeInferencePass();
+
+/// Create a pass for lowering to operations in the `Affine` and `Std` dialects,
+/// for a subset of the Toy IR (e.g. matmul).
+std::unique_ptr<mlir::Pass> createLowerToAffinePass();
+
+/// Create a pass for lowering operations the remaining `Toy` operations, as
+/// well as `Affine` and `Std`, to the LLVM dialect for codegen.
+std::unique_ptr<mlir::Pass> createLowerToLLVMPass();
+
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_PASSES_H
diff --git a/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h
new file mode 100644
index 00000000000..fc36b5b100d
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h
@@ -0,0 +1,37 @@
+//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the declarations of the shape inference interfaces defined
+// in ShapeInferenceInterface.td.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
+#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace toy {
+
+/// Include the auto-generated declarations.
+#include "toy/ShapeInferenceOpInterfaces.h.inc"
+
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
diff --git a/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td
new file mode 100644
index 00000000000..19e70e60327
--- /dev/null
+++ b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td
@@ -0,0 +1,38 @@
+//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines the operations of the Shape Inference Op Interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef SHAPE_INFERENCE_INTERFACE
+#else
+#define SHAPE_INFERENCE_INTERFACE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
+ let methods = [
+ InterfaceMethod<"Infer and set the output shape for the current operation.",
+ "void", "inferShapes">
+ ];
+}
+
+#endif // SHAPE_INFERENCE_INTERFACE
diff --git a/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp
new file mode 100644
index 00000000000..b58adb5d52f
--- /dev/null
+++ b/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp
@@ -0,0 +1,68 @@
+//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a Module level pass performing dead function
+// elimination. This is required as a post-processing step after function
+// inlining.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "toy/Passes.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+
+namespace {
+/// This is a simple function DCE pass that deletes all non-main functions after
+/// inlining.
+/// TODO(riverriddle) This is only necessary because MLIR currently does not
+/// have generic DCE support for functions.
+class DeadFunctionEliminationPass
+ : public mlir::ModulePass<DeadFunctionEliminationPass> {
+public:
+ void runOnModule() override {
+ mlir::ModuleOp module = getModule();
+ mlir::SymbolTable moduleSymTable(module);
+
+ // Eliminate non-main functions.
+ auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main");
+ for (mlir::FuncOp func :
+ llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
+ if (func != mainFn)
+ func.erase();
+ }
+ }
+};
+} // end anonymous namespace
+
+/// Create a pass that eliminates inlined functions in toy.
+std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
+ return std::make_unique<DeadFunctionEliminationPass>();
+}
diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
new file mode 100644
index 00000000000..e31cb917d89
--- /dev/null
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -0,0 +1,268 @@
+//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the dialect for the Toy IR: custom type parsing and
+// operation verification.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/InliningUtils.h"
+
+using namespace mlir;
+using namespace mlir::toy;
+
+//===----------------------------------------------------------------------===//
+// ToyInlinerInterface
+//===----------------------------------------------------------------------===//
+
+/// This class defines the interface for handling inlining with Toy
+/// operations.
+struct ToyInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// All operations within toy can be inlined.
+ bool isLegalToInline(Operation *, Region *,
+ BlockAndValueMapping &) const final {
+ return true;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Transformation Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Handle the given inlined terminator(toy.return) by replacing it with a new
+ /// operation as necessary.
+ void handleTerminator(Operation *op,
+ ArrayRef<Value *> valuesToRepl) const final {
+ // Only "toy.return" needs to be handled here.
+ auto returnOp = cast<ReturnOp>(op);
+
+ // Replace the values directly with the return operands.
+ assert(returnOp.getNumOperands() == valuesToRepl.size());
+ for (const auto &it : llvm::enumerate(returnOp.getOperands()))
+ valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
+ }
+
+ /// Attempts to materialize a conversion for a type mismatch between a call
+ /// from this dialect, and a callable region. This method should generate an
+ /// operation that takes 'input' as the only operand, and produces a single
+ /// result of 'resultType'. If a conversion can not be generated, nullptr
+ /// should be returned.
+ Operation *materializeCallConversion(OpBuilder &builder, Value *input,
+ Type resultType,
+ Location conversionLoc) const final {
+ return builder.create<CastOp>(conversionLoc, resultType, input);
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ToyDialect
+//===----------------------------------------------------------------------===//
+
+/// Dialect creation, the instance will be owned by the context. This is the
+/// point of registration of custom types and operations for the dialect.
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ addOperations<
+#define GET_OP_LIST
+#include "toy/Ops.cpp.inc"
+ >();
+ addInterfaces<ToyInlinerInterface>();
+}
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+/// Build a constant operation.
+/// The builder is passed as an argument, so is the state that this method is
+/// expected to fill in order to build the operation.
+static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
+ double value) {
+ auto dataType = builder->getTensorType({}, builder->getF64Type());
+ auto dataAttribute = DenseElementsAttr::get(dataType, value);
+ ConstantOp::build(builder, state, dataType, dataAttribute);
+}
+
+/// Infer the output shape of the CastOp, this is required by the shape
+/// inference interface.
+void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
+
+/// Verifier for the constant operation. This corresponds to the `::verify(...)`
+/// in the op definition.
+static mlir::LogicalResult verify(ConstantOp op) {
+ // If the return type of the constant is not an unranked tensor, the shape
+ // must match the shape of the attribute holding the data.
+ auto resultType = op.getResult()->getType().cast<RankedTensorType>();
+ if (!resultType)
+ return success();
+
+ // Check that the rank of the attribute type matches the rank of the constant
+ // result type.
+ auto attrType = op.value().getType().cast<mlir::TensorType>();
+ if (attrType.getRank() != resultType.getRank()) {
+ return op.emitOpError(
+ "return type must match the one of the attached value "
+ "attribute: ")
+ << attrType.getRank() << " != " << resultType.getRank();
+ }
+
+ // Check that each of the dimensions match between the two types.
+ for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
+ if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
+ return op.emitOpError(
+ "return type shape mismatches its attribute at dimension ")
+ << dim << ": " << attrType.getShape()[dim]
+ << " != " << resultType.getShape()[dim];
+ }
+ }
+ return mlir::success();
+}
+
+static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ state.addTypes(builder->getTensorType(builder->getF64Type()));
+ state.addOperands({lhs, rhs});
+}
+
+/// Infer the output shape of the AddOp, this is required by the shape inference
+/// interface.
+void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
+
+static void buildGenericCallOp(mlir::Builder *builder,
+ mlir::OperationState &state, StringRef callee,
+ ArrayRef<mlir::Value *> arguments) {
+ // Generic call always returns an unranked Tensor initially.
+ state.addTypes(builder->getTensorType(builder->getF64Type()));
+ state.addOperands(arguments);
+ state.addAttribute("callee", builder->getSymbolRefAttr(callee));
+}
+
+/// Return the callee of the generic call operation, this is required by the
+/// call interface.
+CallInterfaceCallable GenericCallOp::getCallableForCallee() {
+ return getAttrOfType<SymbolRefAttr>("callee");
+}
+
+/// Get the argument operands to the called function, this is required by the
+/// call interface.
+Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
+
+static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ state.addTypes(builder->getTensorType(builder->getF64Type()));
+ state.addOperands({lhs, rhs});
+}
+
+/// Infer the output shape of the MulOp, this is required by the shape inference
+/// interface.
+void MulOp::inferShapes() {
+ auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
+ auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
+ auto lhsRank = lhs.getShape().size();
+ auto rhsRank = rhs.getShape().size();
+ if (lhsRank != rhsRank)
+ return;
+
+ SmallVector<int64_t, 2> dims;
+ if (lhsRank == 1) {
+ // dot product, result shape is <1>
+ dims.push_back(1);
+ } else if (lhsRank == 2) {
+ dims.push_back(lhs.getShape()[0]);
+ dims.push_back(rhs.getShape()[1]);
+ } else {
+ return;
+ }
+ getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
+}
+
+static mlir::LogicalResult verify(ReturnOp op) {
+ // We know that the parent operation is a function, because of the 'HasParent'
+ // trait attached to the operation definition.
+ auto function = cast<FuncOp>(op.getParentOp());
+
+ /// ReturnOps can only have a single optional operand.
+ if (op.getNumOperands() > 1)
+ return op.emitOpError() << "expects at most 1 return operand";
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getType().getResults();
+ if (op.getNumOperands() != results.size())
+ return op.emitOpError()
+ << "does not return the same number of values ("
+ << op.getNumOperands() << ") as the enclosing function ("
+ << results.size() << ")";
+
+ // If the operation does not have an input, we are done.
+ if (!op.hasOperand())
+ return mlir::success();
+
+ auto inputType = *op.operand_type_begin();
+ auto resultType = results.front();
+
+ // Check that the result type of the function matches the operand type.
+ if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
+ resultType.isa<mlir::UnrankedTensorType>())
+ return mlir::success();
+
+ return op.emitError() << "type of return operand ("
+ << *op.operand_type_begin()
+ << ") doesn't match function result type ("
+ << results.front() << ")";
+}
+
+static void buildTransposeOp(mlir::Builder *builder,
+ mlir::OperationState &state, mlir::Value *value) {
+ state.addTypes(builder->getTensorType(builder->getF64Type()));
+ state.addOperands(value);
+}
+
+void TransposeOp::inferShapes() {
+ auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
+ SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
+ getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
+}
+
+static mlir::LogicalResult verify(TransposeOp op) {
+ auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>();
+ auto resultType = op.getType().dyn_cast<RankedTensorType>();
+ if (!inputType || !resultType)
+ return mlir::success();
+
+ auto inputShape = inputType.getShape();
+ if (!std::equal(inputShape.begin(), inputShape.end(),
+ resultType.getShape().rbegin())) {
+ return op.emitError()
+ << "expected result shape to be a transpose of the input";
+ }
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "toy/Ops.cpp.inc"
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
new file mode 100644
index 00000000000..a8e38aef7ad
--- /dev/null
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -0,0 +1,318 @@
+//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a partial lowering of Toy operations to a combination of
+// affine loops and standard operations. This lowering expects that all calls
+// have been inlined, and all shapes have been resolved.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+#include "toy/Passes.h"
+
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns
+//===----------------------------------------------------------------------===//
+
+/// Convert the given TensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(TensorType type) {
+ assert(type.hasRank() && "expected only ranked shapes");
+ return MemRefType::get(type.getShape(), type.getElementType());
+}
+
+/// Insert an allocation and deallocation for the given MemRefType.
+static Value *insertAllocAndDealloc(MemRefType type, Location loc,
+ PatternRewriter &rewriter) {
+ auto alloc = rewriter.create<AllocOp>(loc, type);
+
+ // Make sure to allocate at the beginning of the block.
+ auto *parentBlock = alloc.getOperation()->getBlock();
+ alloc.getOperation()->moveBefore(&parentBlock->front());
+
+ // Make sure to deallocate this alloc at the end of the block. This is fine
+ // as toy functions have no control flow.
+ auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
+ dealloc.getOperation()->moveBefore(&parentBlock->back());
+ return alloc;
+}
+
+/// This defines the function type used to process an iteration of a lowered
+/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
+/// to the operands of the input operation, and the set of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn = function_ref<Value *(PatternRewriter &rewriter,
+ ArrayRef<Value *> memRefOperands,
+ ArrayRef<Value *> loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, ArrayRef<Value *> operands,
+ PatternRewriter &rewriter,
+ LoopIterationFn processIteration) {
+ auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto loc = op->getLoc();
+
+ // Insert an allocation and deallocation for the result of this operation.
+ auto memRefType = convertTensorToMemRef(tensorType);
+ auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
+
+ // Create an empty affine loop for each of the dimensions within the shape.
+ SmallVector<Value *, 4> loopIvs;
+ for (auto dim : tensorType.getShape()) {
+ auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
+ loop.getBody()->clear();
+ loopIvs.push_back(loop.getInductionVar());
+
+ // Terminate the loop body and update the rewriter insertion point to the
+ // beginning of the loop.
+ rewriter.setInsertionPointToStart(loop.getBody());
+ rewriter.create<AffineTerminatorOp>(loc);
+ rewriter.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Generate a call to the processing function with the rewriter, the memref
+ // operands, and the loop induction variables. This function will return the
+ // value to store at the current index.
+ Value *valueToStore = processIteration(rewriter, operands, loopIvs);
+ rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
+ llvm::makeArrayRef(loopIvs));
+
+ // Replace this operation with the generated alloc.
+ rewriter.replaceOp(op, alloc);
+}
+
+namespace {
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Binary operations
+//===----------------------------------------------------------------------===//
+
+template <typename BinaryOp, typename LoweredBinaryOp>
+struct BinaryOpLowering : public ConversionPattern {
+ BinaryOpLowering(MLIRContext *ctx)
+ : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ lowerOpToLoops(
+ op, operands, rewriter,
+ [loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands,
+ ArrayRef<Value *> loopIvs) {
+ // Generate an adaptor for the remapped operands of the BinaryOp. This
+ // allows for using the nice named accessors that are generated by the
+ // ODS.
+ typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
+
+ // Generate loads for the element of 'lhs' and 'rhs' at the inner
+ // loop.
+ auto loadedLhs =
+ rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
+ auto loadedRhs =
+ rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded values.
+ return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
+ });
+ return matchSuccess();
+ }
+};
+using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
+using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
+
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Constant operations
+//===----------------------------------------------------------------------===//
+
+struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
+ using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(toy::ConstantOp op,
+ PatternRewriter &rewriter) const final {
+ DenseElementsAttr constantValue = op.value();
+ Location loc = op.getLoc();
+
+ // When lowering the constant operation, we allocate and assign the constant
+ // values to a corresponding memref allocation.
+ auto tensorType = op.getType().cast<TensorType>();
+ auto memRefType = convertTensorToMemRef(tensorType);
+ auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
+
+ // We will be generating constant indices up-to the largest dimension.
+ // Create these constants up-front to avoid large amounts of redundant
+ // operations.
+ auto valueShape = memRefType.getShape();
+ SmallVector<Value *, 8> constantIndices;
+ for (auto i : llvm::seq<int64_t>(
+ 0, *std::max_element(valueShape.begin(), valueShape.end())))
+ constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
+
+ // The constant operation represents a multi-dimensional constant, so we
+ // will need to generate a store for each of the elements. The following
+ // functor recursively walks the dimensions of the constant shape,
+ // generating a store when the recursion hits the base case.
+ SmallVector<Value *, 2> indices;
+ auto valueIt = constantValue.getValues<FloatAttr>().begin();
+ std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
+ // The last dimension is the base case of the recursion, at this point
+ // we store the element at the given index.
+ if (dimension == valueShape.size()) {
+ rewriter.create<AffineStoreOp>(
+ loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
+ llvm::makeArrayRef(indices));
+ return;
+ }
+
+ // Otherwise, iterate over the current dimension and add the indices to
+ // the list.
+ for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
+ indices.push_back(constantIndices[i]);
+ storeElements(dimension + 1);
+ indices.pop_back();
+ }
+ };
+
+ // Start the element storing recursion from the first dimension.
+ storeElements(/*dimension=*/0);
+
+ // Replace this operation with the generated alloc.
+ rewriter.replaceOp(op, alloc);
+ return matchSuccess();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Return operations
+//===----------------------------------------------------------------------===//
+
+struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
+ using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(toy::ReturnOp op,
+ PatternRewriter &rewriter) const final {
+ // During this lowering, we expect that all function calls have been
+ // inlined.
+ if (op.hasOperand())
+ return matchFailure();
+
+ // We lower "toy.return" directly to "std.return".
+ rewriter.replaceOpWithNewOp<ReturnOp>(op);
+ return matchSuccess();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Transpose operations
+//===----------------------------------------------------------------------===//
+
+struct TransposeOpLowering : public ConversionPattern {
+ TransposeOpLowering(MLIRContext *ctx)
+ : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ lowerOpToLoops(
+ op, operands, rewriter,
+ [loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands,
+ ArrayRef<Value *> loopIvs) {
+ // Generate an adaptor for the remapped operands of the TransposeOp.
+ // This allows for using the nice named accessors that are generated
+ // by the ODS.
+ toy::TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands);
+ Value *input = tranposeAdaptor.input();
+
+ // Transpose the elements by generating a load from the reverse
+ // indices.
+ SmallVector<Value *, 2> reverseIvs(llvm::reverse(loopIvs));
+ return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
+ });
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace.
+
+//===----------------------------------------------------------------------===//
+// ToyToAffineLoweringPass
+//===----------------------------------------------------------------------===//
+
+/// This is a partial lowering to affine loops of the toy operations that are
+/// computationally intensive (like matmul for example...) while keeping the
+/// rest of the code in the Toy dialect.
+namespace {
+struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> {
+ void runOnFunction() final;
+};
+} // end anonymous namespace.
+
+void ToyToAffineLoweringPass::runOnFunction() {
+ auto function = getFunction();
+
+ // We only lower the main function as we expect that all other functions have
+ // been inlined.
+ if (function.getName() != "main")
+ return;
+
+ // Verify that the given main has no inputs and results.
+ if (function.getNumArguments() || function.getType().getNumResults()) {
+ function.emitError("expected 'main' to have 0 inputs and 0 results");
+ return signalPassFailure();
+ }
+
+ // The first thing to define is the conversion target. This will define the
+ // final target for this lowering.
+ ConversionTarget target(getContext());
+
+ // We define the specific operations, or dialects, that are legal targets for
+ // this lowering. In our case, we are lowering to a combination of the
+ // `Affine` and `Standard` dialects.
+ target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
+
+ // We also define the Toy dialect as Illegal so that the conversion will fail
+ // if any of these operations are *not* converted. Given that we actually want
+ // a partial lowering, we explicitly mark the Toy operations that don't want
+ // to lower, `toy.print`, as `legal`.
+ target.addIllegalDialect<toy::ToyDialect>();
+ target.addLegalOp<toy::PrintOp>();
+
+ // Now that the conversion target has been defined, we just need to provide
+ // the set of patterns that will lower the Toy operations.
+ OwningRewritePatternList patterns;
+ patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
+ ReturnOpLowering, TransposeOpLowering>(&getContext());
+
+ // With the target and rewrite patterns defined, we can now attempt the
+ // conversion. The conversion will signal failure if any of our `illegal`
+ // operations were not converted successfully.
+ if (failed(applyPartialConversion(getFunction(), target, patterns)))
+ signalPassFailure();
+}
+
+/// Create a pass for lowering operations in the `Affine` and `Std` dialects,
+/// for a subset of the Toy IR (e.g. matmul).
+std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
+ return std::make_unique<ToyToAffineLoweringPass>();
+}
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
new file mode 100644
index 00000000000..127349f6c9d
--- /dev/null
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -0,0 +1,213 @@
+//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a partial lowering of Toy operations to a combination of
+// affine loops and standard operations. This lowering expects that all calls
+// have been inlined, and all shapes have been resolved.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+#include "toy/Passes.h"
+
+#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/LowerAffine.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ToyToLLVM RewritePatterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
+/// elements of the array.
+class PrintOpLowering : public ConversionPattern {
+public:
+ explicit PrintOpLowering(MLIRContext *context)
+ : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
+ auto memRefShape = memRefType.getShape();
+ auto loc = op->getLoc();
+ auto *llvmDialect =
+ op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+ assert(llvmDialect && "expected llvm dialect to be registered");
+
+ ModuleOp parentModule = op->getParentOfType<ModuleOp>();
+
+ // Get a symbol reference to the printf function, inserting it if necessary.
+ auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
+ Value *formatSpecifierCst = getOrCreateGlobalString(
+ loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
+ llvmDialect);
+ Value *newLineCst = getOrCreateGlobalString(
+ loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
+
+ // Create a loop for each of the dimensions within the shape.
+ SmallVector<Value *, 4> loopIvs;
+ for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) {
+ auto lowerBound = rewriter.create<ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<ConstantIndexOp>(loc, memRefShape[i]);
+ auto step = rewriter.create<ConstantIndexOp>(loc, 1);
+ auto loop =
+ rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step);
+ loop.getBody()->clear();
+ loopIvs.push_back(loop.getInductionVar());
+
+ // Terminate the loop body.
+ rewriter.setInsertionPointToStart(loop.getBody());
+
+ // Insert a newline after each of the inner dimensions of the shape.
+ if (i != e - 1)
+ rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
+ newLineCst);
+ rewriter.create<loop::TerminatorOp>(loc);
+ rewriter.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Generate a call to printf for the current element of the loop.
+ auto printOp = cast<toy::PrintOp>(op);
+ auto elementLoad = rewriter.create<LoadOp>(loc, printOp.input(), loopIvs);
+ rewriter.create<CallOp>(
+ loc, printfRef, rewriter.getIntegerType(32),
+ ArrayRef<Value *>({formatSpecifierCst, elementLoad}));
+
+ // Notify the rewriter that this operation has been removed.
+ rewriter.replaceOp(op, llvm::None);
+ return matchSuccess();
+ }
+
+private:
+ /// Return a symbol reference to the printf function, inserting it into the
+ /// module if necessary.
+ static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
+ ModuleOp module,
+ LLVM::LLVMDialect *llvmDialect) {
+ auto *context = module.getContext();
+ if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
+ return SymbolRefAttr::get("printf", context);
+
+ // Create a function declaration for printf, the signature is:
+ // * `i32 (i8*, ...)`
+ auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
+ auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
+ /*isVarArg=*/true);
+
+ // Insert the printf function into the body of the parent module.
+ PatternRewriter::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPointToStart(module.getBody());
+ rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
+ return SymbolRefAttr::get("printf", context);
+ }
+
+ /// Return a value representing an access into a global string with the given
+ /// name, creating the string if necessary.
+ static Value *getOrCreateGlobalString(Location loc, OpBuilder &builder,
+ StringRef name, StringRef value,
+ ModuleOp module,
+ LLVM::LLVMDialect *llvmDialect) {
+ // Create the global at the entry of the module.
+ LLVM::GlobalOp global;
+ if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
+ OpBuilder::InsertionGuard insertGuard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+ auto type = LLVM::LLVMType::getArrayTy(
+ LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
+ global = builder.create<LLVM::GlobalOp>(
+ loc, type, /*isConstant=*/true, name, builder.getStringAttr(value));
+ }
+
+ // Get the pointer to the first character in the global string.
+ Value *globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
+ Value *cst0 = builder.create<LLVM::ConstantOp>(
+ loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ builder.getIntegerAttr(builder.getIndexType(), 0));
+ return builder.create<LLVM::GEPOp>(
+ loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+ ArrayRef<Value *>({cst0, cst0}));
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// ToyToLLVMLoweringPass
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
+ void runOnModule() final;
+};
+} // end anonymous namespace
+
+void ToyToLLVMLoweringPass::runOnModule() {
+ // The first thing to define is the conversion target. This will define the
+ // final target for this lowering. For this lowering, we are only targeting
+ // the LLVM dialect.
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+
+ // During this lowering, we will also be lowering the MemRef types, that are
+ // currently being operated on, to a representation in LLVM. Do perform this
+ // conversion we use a TypeConverter as part of the lowering. This converter
+ // details how one type maps to another. This is necessary now that we will be
+ // doing more complicated lowerings, involving loop region arguments.
+ LLVMTypeConverter typeConverter(&getContext());
+
+ // Now that the conversion target has been defined, we need to provide the
+ // patterns used for lowering. At this point of the compilation process, we
+ // have a combination of `toy`, `affine`, and `std` operations. Luckily, there
+ // are already exists a set of patterns to transform `affine` and `std`
+ // dialects. These patterns lowering in multiple stages, relying on transitive
+ // lowerings. Transitive lowering, or A->B->C lowering, is when multiple
+ // patterns must be applied to fully transform an illegal operation into a
+ // set of legal ones.
+ OwningRewritePatternList patterns;
+ populateAffineToStdConversionPatterns(patterns, &getContext());
+ populateLoopToStdConversionPatterns(patterns, &getContext());
+ populateStdToLLVMConversionPatterns(typeConverter, patterns);
+
+ // The only remaining operation to lower from the `toy` dialect, is the
+ // PrintOp.
+ patterns.insert<PrintOpLowering>(&getContext());
+
+ // We want to completely lower to LLVM, so we use a `FullConversion`. This
+ // ensures that only legal operations will remain after the conversion.
+ auto module = getModule();
+ if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
+ signalPassFailure();
+}
+
+/// Create a pass for lowering operations the remaining `Toy` operations, as
+/// well as `Affine` and `Std`, to the LLVM dialect for codegen.
+std::unique_ptr<mlir::Pass> mlir::toy::createLowerToLLVMPass() {
+ return std::make_unique<ToyToLLVMLoweringPass>();
+}
diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp
new file mode 100644
index 00000000000..5f12d0a8798
--- /dev/null
+++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp
@@ -0,0 +1,467 @@
+//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR generation targeting MLIR from a Module AST
+// for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/MLIRGen.h"
+#include "toy/AST.h"
+#include "toy/Dialect.h"
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+using namespace mlir::toy;
+using namespace toy;
+
+using llvm::ArrayRef;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::makeArrayRef;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+/// Implementation of a simple MLIR emission from the Toy AST.
+///
+/// This will emit operations that are specific to the Toy language, preserving
+/// the semantics of the language and (hopefully) allow to perform accurate
+/// analysis and transformation based on these high level semantics.
+class MLIRGenImpl {
+public:
+ MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
+
+ /// Public API: convert the AST for a Toy module (source file) to an MLIR
+ /// Module operation.
+ mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
+ // We create an empty MLIR module and codegen functions one at a time and
+ // add them to the module.
+ theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
+
+ for (FunctionAST &F : moduleAST) {
+ auto func = mlirGen(F);
+ if (!func)
+ return nullptr;
+ theModule.push_back(func);
+ }
+
+ // Verify the module after we have finished constructing it, this will check
+ // the structural properties of the IR and invoke any specific verifiers we
+ // have on the Toy operations.
+ if (failed(mlir::verify(theModule))) {
+ theModule.emitError("module verification error");
+ return nullptr;
+ }
+
+ return theModule;
+ }
+
+private:
+ /// A "module" matches a Toy source file: containing a list of functions.
+ mlir::ModuleOp theModule;
+
+ /// The builder is a helper class to create IR inside a function. The builder
+ /// is stateful, in particular it keeeps an "insertion point": this is where
+ /// the next operations will be introduced.
+ mlir::OpBuilder builder;
+
+ /// The symbol table maps a variable name to a value in the current scope.
+ /// Entering a function creates a new scope, and the function arguments are
+ /// added to the mapping. When the processing of a function is terminated, the
+ /// scope is destroyed and the mappings created in this scope are dropped.
+ llvm::ScopedHashTable<StringRef, mlir::Value *> symbolTable;
+
+ /// Helper conversion for a Toy AST location to an MLIR location.
+ mlir::Location loc(Location loc) {
+ return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+ loc.col);
+ }
+
+ /// Declare a variable in the current scope, return success if the variable
+ /// wasn't declared yet.
+ mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
+ if (symbolTable.count(var))
+ return mlir::failure();
+ symbolTable.insert(var, value);
+ return mlir::success();
+ }
+
+ /// Create the prototype for an MLIR function with as many arguments as the
+ /// provided Toy AST prototype.
+ mlir::FuncOp mlirGen(PrototypeAST &proto) {
+ auto location = loc(proto.loc());
+
+ // This is a generic function, the return type will be inferred later.
+ // Arguments type are uniformly unranked tensors.
+ llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
+ getType(VarType{}));
+ auto func_type = builder.getFunctionType(arg_types, llvm::None);
+ auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
+
+ // Mark the function as generic: it'll require type specialization for every
+ // call site.
+ if (function.getNumArguments())
+ function.setAttr("toy.generic", builder.getUnitAttr());
+ return function;
+ }
+
+ /// Emit a new function and add it to the MLIR module.
+ mlir::FuncOp mlirGen(FunctionAST &funcAST) {
+ // Create a scope in the symbol table to hold variable declarations.
+ ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+
+ // Create an MLIR function for the given prototype.
+ mlir::FuncOp function(mlirGen(*funcAST.getProto()));
+ if (!function)
+ return nullptr;
+
+ // Let's start the body of the function now!
+ // In MLIR the entry block of the function is special: it must have the same
+ // argument list as the function itself.
+ auto &entryBlock = *function.addEntryBlock();
+ auto &protoArgs = funcAST.getProto()->getArgs();
+
+ // Declare all the function arguments in the symbol table.
+ for (const auto &name_value :
+ llvm::zip(protoArgs, entryBlock.getArguments())) {
+ if (failed(declare(std::get<0>(name_value)->getName(),
+ std::get<1>(name_value))))
+ return nullptr;
+ }
+
+ // Set the insertion point in the builder to the beginning of the function
+ // body, it will be used throughout the codegen to create operations in this
+ // function.
+ builder.setInsertionPointToStart(&entryBlock);
+
+ // Emit the body of the function.
+ if (mlir::failed(mlirGen(*funcAST.getBody()))) {
+ function.erase();
+ return nullptr;
+ }
+
+ // Implicitly return void if no return statement was emitted.
+ // FIXME: we may fix the parser instead to always return the last expression
+ // (this would possibly help the REPL case later)
+ ReturnOp returnOp;
+ if (!entryBlock.empty())
+ returnOp = dyn_cast<ReturnOp>(entryBlock.back());
+ if (!returnOp) {
+ builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
+ } else if (returnOp.hasOperand()) {
+ // Otherwise, if this return operation has an operand then add a result to
+ // the function.
+ function.setType(builder.getFunctionType(function.getType().getInputs(),
+ getType(VarType{})));
+ }
+
+ return function;
+ }
+
+ /// Emit a binary operation
+ mlir::Value *mlirGen(BinaryExprAST &binop) {
+ // First emit the operations for each side of the operation before emitting
+ // the operation itself. For example if the expression is `a + foo(a)`
+ // 1) First it will visiting the LHS, which will return a reference to the
+ // value holding `a`. This value should have been emitted at declaration
+ // time and registered in the symbol table, so nothing would be
+ // codegen'd. If the value is not in the symbol table, an error has been
+ // emitted and nullptr is returned.
+ // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
+ // and the result value is returned. If an error occurs we get a nullptr
+ // and propagate.
+ //
+ mlir::Value *lhs = mlirGen(*binop.getLHS());
+ if (!lhs)
+ return nullptr;
+ mlir::Value *rhs = mlirGen(*binop.getRHS());
+ if (!rhs)
+ return nullptr;
+ auto location = loc(binop.loc());
+
+ // Derive the operation name from the binary operator. At the moment we only
+ // support '+' and '*'.
+ switch (binop.getOp()) {
+ case '+':
+ return builder.create<AddOp>(location, lhs, rhs);
+ case '*':
+ return builder.create<MulOp>(location, lhs, rhs);
+ }
+
+ emitError(location, "invalid binary operator '") << binop.getOp() << "'";
+ return nullptr;
+ }
+
+ /// This is a reference to a variable in an expression. The variable is
+ /// expected to have been declared and so should have a value in the symbol
+ /// table, otherwise emit an error and return nullptr.
+ mlir::Value *mlirGen(VariableExprAST &expr) {
+ if (auto *variable = symbolTable.lookup(expr.getName()))
+ return variable;
+
+ emitError(loc(expr.loc()), "error: unknown variable '")
+ << expr.getName() << "'";
+ return nullptr;
+ }
+
+ /// Emit a return operation. This will return failure if any generation fails.
+ mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
+ auto location = loc(ret.loc());
+
+ // 'return' takes an optional expression, handle that case here.
+ mlir::Value *expr = nullptr;
+ if (ret.getExpr().hasValue()) {
+ if (!(expr = mlirGen(*ret.getExpr().getValue())))
+ return mlir::failure();
+ }
+
+ // Otherwise, this return operation has zero operands.
+ builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
+ : ArrayRef<mlir::Value *>());
+ return mlir::success();
+ }
+
+ /// Emit a literal/constant array. It will be emitted as a flattened array of
+ /// data in an Attribute attached to a `toy.constant` operation.
+ /// See documentation on [Attributes](LangRef.md#attributes) for more details.
+ /// Here is an excerpt:
+ ///
+ /// Attributes are the mechanism for specifying constant data in MLIR in
+ /// places where a variable is never allowed [...]. They consist of a name
+ /// and a concrete attribute value. The set of expected attributes, their
+ /// structure, and their interpretation are all contextually dependent on
+ /// what they are attached to.
+ ///
+ /// Example, the source level statement:
+ /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+ /// will be converted to:
+ /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
+ /// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
+ /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
+ ///
+ mlir::Value *mlirGen(LiteralExprAST &lit) {
+ auto type = getType(lit.getDims());
+
+ // The attribute is a vector with a floating point value per element
+ // (number) in the array, see `collectData()` below for more details.
+ std::vector<double> data;
+ data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
+ std::multiplies<int>()));
+ collectData(lit, data);
+
+ // The type of this attribute is tensor of 64-bit floating-point with the
+ // shape of the literal.
+ mlir::Type elementType = builder.getF64Type();
+ auto dataType = builder.getTensorType(lit.getDims(), elementType);
+
+ // This is the actual attribute that holds the list of values for this
+ // tensor literal.
+ auto dataAttribute =
+ mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
+
+ // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
+ // method.
+ return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
+ }
+
+ /// Recursive helper function to accumulate the data that compose an array
+ /// literal. It flattens the nested structure in the supplied vector. For
+ /// example with this array:
+ /// [[1, 2], [3, 4]]
+ /// we will generate:
+ /// [ 1, 2, 3, 4 ]
+ /// Individual numbers are represented as doubles.
+ /// Attributes are the way MLIR attaches constant to operations.
+ void collectData(ExprAST &expr, std::vector<double> &data) {
+ if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
+ for (auto &value : lit->getValues())
+ collectData(*value, data);
+ return;
+ }
+
+ assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
+ data.push_back(cast<NumberExprAST>(expr).getValue());
+ }
+
+ /// Emit a call expression. It emits specific operations for the `transpose`
+ /// builtin. Other identifiers are assumed to be user-defined functions.
+ mlir::Value *mlirGen(CallExprAST &call) {
+ llvm::StringRef callee = call.getCallee();
+ auto location = loc(call.loc());
+
+ // Codegen the operands first.
+ SmallVector<mlir::Value *, 4> operands;
+ for (auto &expr : call.getArgs()) {
+ auto *arg = mlirGen(*expr);
+ if (!arg)
+ return nullptr;
+ operands.push_back(arg);
+ }
+
+ // Builting calls have their custom operation, meaning this is a
+ // straightforward emission.
+ if (callee == "transpose") {
+ if (call.getArgs().size() != 1) {
+ emitError(location, "MLIR codegen encountered an error: toy.transpose "
+ "does not accept multiple arguments");
+ return nullptr;
+ }
+ return builder.create<TransposeOp>(location, operands[0]);
+ }
+
+ // Otherwise this is a call to a user-defined function. Calls to ser-defined
+ // functions are mapped to a custom call that takes the callee name as an
+ // attribute.
+ return builder.create<GenericCallOp>(location, callee, operands);
+ }
+
+ /// Emit a print expression. It emits specific operations for two builtins:
+ /// transpose(x) and print(x).
+ mlir::LogicalResult mlirGen(PrintExprAST &call) {
+ auto *arg = mlirGen(*call.getArg());
+ if (!arg)
+ return mlir::failure();
+
+ builder.create<PrintOp>(loc(call.loc()), arg);
+ return mlir::success();
+ }
+
+ /// Emit a constant for a single number (FIXME: semantic? broadcast?)
+ mlir::Value *mlirGen(NumberExprAST &num) {
+ return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
+ }
+
+ /// Dispatch codegen for the right expression subclass using RTTI.
+ mlir::Value *mlirGen(ExprAST &expr) {
+ switch (expr.getKind()) {
+ case toy::ExprAST::Expr_BinOp:
+ return mlirGen(cast<BinaryExprAST>(expr));
+ case toy::ExprAST::Expr_Var:
+ return mlirGen(cast<VariableExprAST>(expr));
+ case toy::ExprAST::Expr_Literal:
+ return mlirGen(cast<LiteralExprAST>(expr));
+ case toy::ExprAST::Expr_Call:
+ return mlirGen(cast<CallExprAST>(expr));
+ case toy::ExprAST::Expr_Num:
+ return mlirGen(cast<NumberExprAST>(expr));
+ default:
+ emitError(loc(expr.loc()))
+ << "MLIR codegen encountered an unhandled expr kind '"
+ << Twine(expr.getKind()) << "'";
+ return nullptr;
+ }
+ }
+
+ /// Handle a variable declaration, we'll codegen the expression that forms the
+ /// initializer and record the value in the symbol table before returning it.
+ /// Future expressions will be able to reference this variable through symbol
+ /// table lookup.
+ mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
+ auto init = vardecl.getInitVal();
+ if (!init) {
+ emitError(loc(vardecl.loc()),
+ "missing initializer in variable declaration");
+ return nullptr;
+ }
+
+ mlir::Value *value = mlirGen(*init);
+ if (!value)
+ return nullptr;
+
+ // We have the initializer value, but in case the variable was declared
+ // with specific shape, we emit a "reshape" operation. It will get
+ // optimized out later as needed.
+ if (!vardecl.getType().shape.empty()) {
+ value = builder.create<ReshapeOp>(loc(vardecl.loc()),
+ getType(vardecl.getType()), value);
+ }
+
+ // Register the value in the symbol table.
+ if (failed(declare(vardecl.getName(), value)))
+ return nullptr;
+ return value;
+ }
+
+ /// Codegen a list of expression, return failure if one of them hit an error.
+ mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
+ ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
+ for (auto &expr : blockAST) {
+ // Specific handling for variable declarations, return statement, and
+ // print. These can only appear in block list and not in nested
+ // expressions.
+ if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
+ if (!mlirGen(*vardecl))
+ return mlir::failure();
+ continue;
+ }
+ if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
+ return mlirGen(*ret);
+ if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
+ if (mlir::failed(mlirGen(*print)))
+ return mlir::success();
+ continue;
+ }
+
+ // Generic expression dispatch codegen.
+ if (!mlirGen(*expr))
+ return mlir::failure();
+ }
+ return mlir::success();
+ }
+
+ /// Build a tensor type from a list of shape dimensions.
+ mlir::Type getType(ArrayRef<int64_t> shape) {
+ // If the shape is empty, then this type is unranked.
+ if (shape.empty())
+ return builder.getTensorType(builder.getF64Type());
+
+ // Otherwise, we use the given shape.
+ return builder.getTensorType(shape, builder.getF64Type());
+ }
+
+ /// Build an MLIR type from a Toy AST variable type (forward to the generic
+ /// getType above).
+ mlir::Type getType(const VarType &type) { return getType(type.shape); }
+};
+
+} // namespace
+
+namespace toy {
+
+// The public API for codegen.
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST) {
+ return MLIRGenImpl(context).mlirGen(moduleAST);
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
new file mode 100644
index 00000000000..1f572015c39
--- /dev/null
+++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
@@ -0,0 +1,113 @@
+//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a Function level pass performing interprocedural
+// propagation of array shapes through function specialization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "toy/Dialect.h"
+#include "toy/Passes.h"
+#include "toy/ShapeInferenceInterface.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "shape-inference"
+
+using namespace mlir;
+using namespace toy;
+
+/// Include the auto-generated definitions for the shape inference interfaces.
+#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
+
+namespace {
+/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
+/// shape inference.
+///
+/// Algorithm:
+///
+/// 1) Build a worklist containing all the operations that return a
+/// dynamically shaped tensor: these are the operations that need shape
+/// inference.
+/// 2) Iterate on the worklist:
+/// a) find an operation to process: the next ready operation in the
+/// worklist has all of its arguments non-generic,
+/// b) if no operation is found, break out of the loop,
+/// c) remove the operation from the worklist,
+/// d) infer the shape of its output from the argument types.
+/// 3) If the worklist is empty, the algorithm succeeded.
+///
+class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
+public:
+ void runOnFunction() override {
+ auto f = getFunction();
+
+ // Populate the worklist with the operations that need shape inference:
+ // these are operations that return a dynamic shape.
+ llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
+ f.walk([&](mlir::Operation *op) {
+ if (returnsDynamicShape(op))
+ opWorklist.insert(op);
+ });
+
+ // Iterate on the operations in the worklist until all operations have been
+ // inferred or no change happened (fix point).
+ while (!opWorklist.empty()) {
+ // Find the next operation ready for inference, that is an operation
+ // with all operands already resolved (non-generic).
+ auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+ if (nextop == opWorklist.end())
+ break;
+
+ Operation *op = *nextop;
+ opWorklist.erase(op);
+
+ // Ask the operation to infer its output shapes.
+ LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
+ if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
+ shapeOp.inferShapes();
+ } else {
+ op->emitError("unable to infer shape of operation without shape "
+ "inference interface");
+ return signalPassFailure();
+ }
+ }
+
+ // If the operation worklist isn't empty, this indicates a failure.
+ if (!opWorklist.empty()) {
+ f.emitError("Shape inference failed, ")
+ << opWorklist.size() << " operations couldn't be inferred\n";
+ signalPassFailure();
+ }
+ }
+
+ /// A utility method that returns if the given operation has a dynamically
+ /// shaped result.
+ static bool returnsDynamicShape(Operation *op) {
+ return llvm::any_of(op->getResultTypes(), [](Type resultType) {
+ return !resultType.isa<RankedTensorType>();
+ });
+ }
+};
+} // end anonymous namespace
+
+/// Create a Shape Inference pass.
+std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
+ return std::make_unique<ShapeInferencePass>();
+}
diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
new file mode 100644
index 00000000000..47e1abc6c74
--- /dev/null
+++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
@@ -0,0 +1,83 @@
+//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a set of simple combiners for optimizing operations in
+// the Toy dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "toy/Dialect.h"
+#include <numeric>
+using namespace mlir;
+using namespace toy;
+
+namespace {
+/// Include the patterns defined in the Declarative Rewrite framework.
+#include "ToyCombine.inc"
+} // end anonymous namespace
+
+/// Fold simple cast operations that return the same type as the input.
+OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+ return mlir::impl::foldCastOp(*this);
+}
+
+/// This is an example of a c++ rewrite pattern for the TransposeOp. It
+/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
+struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
+ /// We register this pattern to match every toy.transpose in the IR.
+ /// The "benefit" is used by the framework to order the patterns and process
+ /// them in order of profitability.
+ SimplifyRedundantTranspose(mlir::MLIRContext *context)
+ : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
+
+ /// This method attempts to match a pattern and rewrite it. The rewriter
+ /// argument is the orchestrator of the sequence of rewrites. The pattern is
+ /// expected to interact with it to perform any changes to the IR from here.
+ mlir::PatternMatchResult
+ matchAndRewrite(TransposeOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // Look through the input of the current transpose.
+ mlir::Value *transposeInput = op.getOperand();
+ TransposeOp transposeInputOp =
+ llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
+
+ // If the input is defined by another Transpose, bingo!
+ if (!transposeInputOp)
+ return matchFailure();
+
+ // Use the rewriter to perform the replacement.
+ rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
+ return matchSuccess();
+ }
+};
+
+/// Register our patterns as "canonicalization" patterns on the TransposeOp so
+/// that they can be picked up by the Canonicalization framework.
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SimplifyRedundantTranspose>(context);
+}
+
+/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
+/// that they can be picked up by the Canonicalization framework.
+void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
+ FoldConstantReshapeOptPattern>(context);
+}
diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.td b/mlir/examples/toy/Ch6/mlir/ToyCombine.td
new file mode 100644
index 00000000000..0a63861fa96
--- /dev/null
+++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.td
@@ -0,0 +1,73 @@
+//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines language-specific pattern match optimizations for Toy using
+// Declarative Rewrite Rules (DRR) specified using TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOY_COMBINE
+#define TOY_COMBINE
+
+#ifndef OP_BASE
+include "toy/Ops.td"
+#endif // OP_BASE
+
+/// Note: The DRR definition used for defining patterns is shown below:
+///
+/// class Pattern<
+/// dag sourcePattern, list<dag> resultPatterns,
+/// list<dag> additionalConstraints = [],
+/// dag benefitsAdded = (addBenefit 0)
+/// >;
+
+//===----------------------------------------------------------------------===//
+// Basic Pattern-Match and Rewrite
+//===----------------------------------------------------------------------===//
+
+// Reshape(Reshape(x)) = Reshape(x)
+def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
+ (ReshapeOp $arg)>;
+
+//===----------------------------------------------------------------------===//
+// Pattern-Match and Rewrite using Native Code Call
+//===----------------------------------------------------------------------===//
+
+// Native Code Calls may be used for more complex transformations using inline
+// C++ and C++ helper functions.
+
+// Reshape(Constant(x)) = x'
+def ReshapeConstant :
+ NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
+def FoldConstantReshapeOptPattern : Pat<
+ (ReshapeOp:$res (ConstantOp $arg)),
+ (ConstantOp (ReshapeConstant $arg, $res))>;
+
+//===----------------------------------------------------------------------===//
+// Pattern-Match and Rewrite with Constraints
+//===----------------------------------------------------------------------===//
+
+// DRR allows for constraint checking when the transformation is conditional
+// on operand properties.
+
+// Reshape(x) = x, where input and output shapes are identical
+def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
+def RedundantReshapeOptPattern : Pat<
+ (ReshapeOp:$res $arg), (replaceWithValue $arg),
+ [(TypesAreIdentical $res, $arg)]>;
+
+#endif // TOY_COMBINE
diff --git a/mlir/examples/toy/Ch6/parser/AST.cpp b/mlir/examples/toy/Ch6/parser/AST.cpp
new file mode 100644
index 00000000000..869f2ef2013
--- /dev/null
+++ b/mlir/examples/toy/Ch6/parser/AST.cpp
@@ -0,0 +1,263 @@
+//===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST dump for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/AST.h"
+
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+
+namespace {
+
+// RAII helper to manage increasing/decreasing the indentation as we traverse
+// the AST
+struct Indent {
+ Indent(int &level) : level(level) { ++level; }
+ ~Indent() { --level; }
+ int &level;
+};
+
+/// Helper class that implement the AST tree traversal and print the nodes along
+/// the way. The only data member is the current indentation level.
+class ASTDumper {
+public:
+ void dump(ModuleAST *Node);
+
+private:
+ void dump(VarType &type);
+ void dump(VarDeclExprAST *varDecl);
+ void dump(ExprAST *expr);
+ void dump(ExprASTList *exprList);
+ void dump(NumberExprAST *num);
+ void dump(LiteralExprAST *Node);
+ void dump(VariableExprAST *Node);
+ void dump(ReturnExprAST *Node);
+ void dump(BinaryExprAST *Node);
+ void dump(CallExprAST *Node);
+ void dump(PrintExprAST *Node);
+ void dump(PrototypeAST *Node);
+ void dump(FunctionAST *Node);
+
+ // Actually print spaces matching the current indentation level
+ void indent() {
+ for (int i = 0; i < curIndent; i++)
+ llvm::errs() << " ";
+ }
+ int curIndent = 0;
+};
+
+} // namespace
+
+/// Return a formatted string for the location of any node
+template <typename T> static std::string loc(T *Node) {
+ const auto &loc = Node->loc();
+ return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
+ llvm::Twine(loc.col))
+ .str();
+}
+
+// Helper Macro to bump the indentation level and print the leading spaces for
+// the current indentations
+#define INDENT() \
+ Indent level_(curIndent); \
+ indent();
+
+/// Dispatch to a generic expressions to the appropriate subclass using RTTI
+void ASTDumper::dump(ExprAST *expr) {
+#define dispatch(CLASS) \
+ if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
+ return dump(node);
+ dispatch(VarDeclExprAST);
+ dispatch(LiteralExprAST);
+ dispatch(NumberExprAST);
+ dispatch(VariableExprAST);
+ dispatch(ReturnExprAST);
+ dispatch(BinaryExprAST);
+ dispatch(CallExprAST);
+ dispatch(PrintExprAST);
+ // No match, fallback to a generic message
+ INDENT();
+ llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
+}
+
+/// A variable declaration is printing the variable name, the type, and then
+/// recurse in the initializer value.
+void ASTDumper::dump(VarDeclExprAST *varDecl) {
+ INDENT();
+ llvm::errs() << "VarDecl " << varDecl->getName();
+ dump(varDecl->getType());
+ llvm::errs() << " " << loc(varDecl) << "\n";
+ dump(varDecl->getInitVal());
+}
+
+/// A "block", or a list of expression
+void ASTDumper::dump(ExprASTList *exprList) {
+ INDENT();
+ llvm::errs() << "Block {\n";
+ for (auto &expr : *exprList)
+ dump(expr.get());
+ indent();
+ llvm::errs() << "} // Block\n";
+}
+
+/// A literal number, just print the value.
+void ASTDumper::dump(NumberExprAST *num) {
+ INDENT();
+ llvm::errs() << num->getValue() << " " << loc(num) << "\n";
+}
+
+/// Helper to print recurisvely a literal. This handles nested array like:
+/// [ [ 1, 2 ], [ 3, 4 ] ]
+/// We print out such array with the dimensions spelled out at every level:
+/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
+void printLitHelper(ExprAST *lit_or_num) {
+ // Inside a literal expression we can have either a number or another literal
+ if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+ llvm::errs() << num->getValue();
+ return;
+ }
+ auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+
+ // Print the dimension for this literal first
+ llvm::errs() << "<";
+ {
+ const char *sep = "";
+ for (auto dim : literal->getDims()) {
+ llvm::errs() << sep << dim;
+ sep = ", ";
+ }
+ }
+ llvm::errs() << ">";
+
+ // Now print the content, recursing on every element of the list
+ llvm::errs() << "[ ";
+ const char *sep = "";
+ for (auto &elt : literal->getValues()) {
+ llvm::errs() << sep;
+ printLitHelper(elt.get());
+ sep = ", ";
+ }
+ llvm::errs() << "]";
+}
+
+/// Print a literal, see the recursive helper above for the implementation.
+void ASTDumper::dump(LiteralExprAST *Node) {
+ INDENT();
+ llvm::errs() << "Literal: ";
+ printLitHelper(Node);
+ llvm::errs() << " " << loc(Node) << "\n";
+}
+
+/// Print a variable reference (just a name).
+void ASTDumper::dump(VariableExprAST *Node) {
+ INDENT();
+ llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+}
+
+/// Return statement print the return and its (optional) argument.
+void ASTDumper::dump(ReturnExprAST *Node) {
+ INDENT();
+ llvm::errs() << "Return\n";
+ if (Node->getExpr().hasValue())
+ return dump(*Node->getExpr());
+ {
+ INDENT();
+ llvm::errs() << "(void)\n";
+ }
+}
+
+/// Print a binary operation, first the operator, then recurse into LHS and RHS.
+void ASTDumper::dump(BinaryExprAST *Node) {
+ INDENT();
+ llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
+ dump(Node->getLHS());
+ dump(Node->getRHS());
+}
+
+/// Print a call expression, first the callee name and the list of args by
+/// recursing into each individual argument.
+void ASTDumper::dump(CallExprAST *Node) {
+ INDENT();
+ llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
+ for (auto &arg : Node->getArgs())
+ dump(arg.get());
+ indent();
+ llvm::errs() << "]\n";
+}
+
+/// Print a builtin print call, first the builtin name and then the argument.
+void ASTDumper::dump(PrintExprAST *Node) {
+ INDENT();
+ llvm::errs() << "Print [ " << loc(Node) << "\n";
+ dump(Node->getArg());
+ indent();
+ llvm::errs() << "]\n";
+}
+
+/// Print type: only the shape is printed in between '<' and '>'
+void ASTDumper::dump(VarType &type) {
+ llvm::errs() << "<";
+ const char *sep = "";
+ for (auto shape : type.shape) {
+ llvm::errs() << sep << shape;
+ sep = ", ";
+ }
+ llvm::errs() << ">";
+}
+
+/// Print a function prototype, first the function name, and then the list of
+/// parameters names.
+void ASTDumper::dump(PrototypeAST *Node) {
+ INDENT();
+ llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+ indent();
+ llvm::errs() << "Params: [";
+ const char *sep = "";
+ for (auto &arg : Node->getArgs()) {
+ llvm::errs() << sep << arg->getName();
+ sep = ", ";
+ }
+ llvm::errs() << "]\n";
+}
+
+/// Print a function, first the prototype and then the body.
+void ASTDumper::dump(FunctionAST *Node) {
+ INDENT();
+ llvm::errs() << "Function \n";
+ dump(Node->getProto());
+ dump(Node->getBody());
+}
+
+/// Print a module, actually loop over the functions and print them in sequence.
+void ASTDumper::dump(ModuleAST *Node) {
+ INDENT();
+ llvm::errs() << "Module:\n";
+ for (auto &F : *Node)
+ dump(&F);
+}
+
+namespace toy {
+
+// Public API
+void dump(ModuleAST &module) { ASTDumper().dump(&module); }
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp
new file mode 100644
index 00000000000..a40056bb646
--- /dev/null
+++ b/mlir/examples/toy/Ch6/toyc.cpp
@@ -0,0 +1,277 @@
+//===- toyc.cpp - The Toy Compiler ----------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the entry point for the Toy compiler.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+#include "toy/MLIRGen.h"
+#include "toy/Parser.h"
+#include "toy/Passes.h"
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+namespace cl = llvm::cl;
+
+static cl::opt<std::string> inputFilename(cl::Positional,
+ cl::desc("<input toy file>"),
+ cl::init("-"),
+ cl::value_desc("filename"));
+
+namespace {
+enum InputType { Toy, MLIR };
+}
+static cl::opt<enum InputType> inputType(
+ "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
+ cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
+ cl::values(clEnumValN(MLIR, "mlir",
+ "load the input file as an MLIR file")));
+
+namespace {
+enum Action {
+ None,
+ DumpAST,
+ DumpMLIR,
+ DumpMLIRAffine,
+ DumpMLIRLLVM,
+ DumpLLVMIR,
+ RunJIT
+};
+}
+static cl::opt<enum Action> emitAction(
+ "emit", cl::desc("Select the kind of output desired"),
+ cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
+ cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")),
+ cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine",
+ "output the MLIR dump after affine lowering")),
+ cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm",
+ "output the MLIR dump after llvm lowering")),
+ cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")),
+ cl::values(
+ clEnumValN(RunJIT, "jit",
+ "JIT the code and run it by invoking the main function")));
+
+static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
+
+/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
+std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::MemoryBuffer::getFileOrSTDIN(filename);
+ if (std::error_code EC = FileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ return nullptr;
+ }
+ auto buffer = FileOrErr.get()->getBuffer();
+ LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
+ Parser parser(lexer);
+ return parser.ParseModule();
+}
+
+int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
+ // Handle '.toy' input to the compiler.
+ if (inputType != InputType::MLIR &&
+ !llvm::StringRef(inputFilename).endswith(".mlir")) {
+ auto moduleAST = parseInputFile(inputFilename);
+ module = mlirGen(context, *moduleAST);
+ return !module ? 1 : 0;
+ }
+
+ // Otherwise, the input is '.mlir'.
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
+ llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
+ if (std::error_code EC = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ return -1;
+ }
+
+ // Parse the input mlir.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
+ module = mlir::parseSourceFile(sourceMgr, &context);
+ if (!module) {
+ llvm::errs() << "Error can't load file " << inputFilename << "\n";
+ return 3;
+ }
+ return 0;
+}
+
+int loadAndProcessMLIR(mlir::MLIRContext &context,
+ mlir::OwningModuleRef &module) {
+ if (int error = loadMLIR(context, module))
+ return error;
+
+ mlir::PassManager pm(&context);
+ // Apply any generic pass manager command line options and run the pipeline.
+ applyPassManagerCLOptions(pm);
+
+ // Check to see what granularity of MLIR we are compiling to.
+ bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
+ bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM;
+
+ if (EnableOpt || isLoweringToAffine) {
+ // Inline all functions into main and then delete them.
+ pm.addPass(mlir::createInlinerPass());
+ pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
+
+ // Now that there is only one function, we can infer the shapes of each of
+ // the operations.
+ pm.addPass(mlir::toy::createShapeInferencePass());
+ pm.addPass(mlir::createCanonicalizerPass());
+ }
+
+ if (isLoweringToAffine) {
+ // Partially lower the toy dialect with a few cleanups afterwards.
+ pm.addPass(mlir::toy::createLowerToAffinePass());
+ pm.addPass(mlir::createCanonicalizerPass());
+ pm.addPass(mlir::createCSEPass());
+
+ // Add optimizations if enabled.
+ if (EnableOpt) {
+ pm.addPass(mlir::createLoopFusionPass());
+ pm.addPass(mlir::createMemRefDataFlowOptPass());
+ }
+ }
+
+ if (isLoweringToLLVM) {
+ // Finish lowering the toy IR to the LLVM dialect.
+ pm.addPass(mlir::toy::createLowerToLLVMPass());
+ }
+
+ if (mlir::failed(pm.run(*module)))
+ return 4;
+ return 0;
+}
+
+int dumpAST() {
+ if (inputType == InputType::MLIR) {
+ llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
+ return 5;
+ }
+
+ auto moduleAST = parseInputFile(inputFilename);
+ if (!moduleAST)
+ return 1;
+
+ dump(*moduleAST);
+ return 0;
+}
+
+int dumpLLVMIR(mlir::ModuleOp module) {
+ auto llvmModule = mlir::translateModuleToLLVMIR(module);
+ if (!llvmModule) {
+ llvm::errs() << "Failed to emit LLVM IR\n";
+ return -1;
+ }
+
+ // Initialize LLVM targets.
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+ mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
+
+ /// Optionally run an optimization pipeline over the llvm module.
+ auto optPipeline = mlir::makeOptimizingTransformer(
+ /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
+ /*targetMachine=*/nullptr);
+ if (auto err = optPipeline(llvmModule.get())) {
+ llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
+ return -1;
+ }
+ llvm::errs() << *llvmModule << "\n";
+ return 0;
+}
+
+int runJit(mlir::ModuleOp module) {
+ // Initialize LLVM targets.
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+
+ // An optimization pipeline to use within the execution engine.
+ auto optPipeline = mlir::makeOptimizingTransformer(
+ /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
+ /*targetMachine=*/nullptr);
+
+ // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
+ // the module.
+ auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
+ assert(maybeEngine && "failed to construct an execution engine");
+ auto &engine = maybeEngine.get();
+
+ // Invoke the JIT-compiled function.
+ auto invocationResult = engine->invoke("main");
+ if (invocationResult) {
+ llvm::errs() << "JIT invocation failed\n";
+ return -1;
+ }
+
+ return 0;
+}
+
+int main(int argc, char **argv) {
+ mlir::registerPassManagerCLOptions();
+ cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+
+ if (emitAction == Action::DumpAST)
+ return dumpAST();
+
+ // If we aren't dumping the AST, then we are compiling with/to MLIR.
+
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+
+ mlir::MLIRContext context;
+ mlir::OwningModuleRef module;
+ if (int error = loadAndProcessMLIR(context, module))
+ return error;
+
+ // If we aren't exporting to non-mlir, then we are done.
+ bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM;
+ if (isOutputingMLIR) {
+ module->dump();
+ return 0;
+ }
+
+ // Check to see if we are compiling to LLVM IR.
+ if (emitAction == Action::DumpLLVMIR)
+ return dumpLLVMIR(*module);
+
+ // Otherwise, we must be running the jit.
+ if (emitAction == Action::RunJIT)
+ return runJit(*module);
+
+ llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+ return -1;
+}
OpenPOWER on IntegriCloud