summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-11-07 09:53:27 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-07 09:54:04 -0800
commit6b4e30b7c80782a2e1616c739b8a598ed72b725d (patch)
tree373bd95995aebd74b65fed0eb71bb158a60ee794
parent5fbdb67b0aa7f01b17dcca62e08e3db38d021fce (diff)
downloadbcm5719-llvm-6b4e30b7c80782a2e1616c739b8a598ed72b725d.tar.gz
bcm5719-llvm-6b4e30b7c80782a2e1616c739b8a598ed72b725d.zip
Add Ch-7 of the toy tutorial detailing how to define new types.
This chapter adds a new composite type to Toy, and shows the process of adding a new type to the IR, adding and updating operations to use it, and constant folding operations producing it. PiperOrigin-RevId: 279107885
-rw-r--r--mlir/examples/toy/CMakeLists.txt1
-rw-r--r--mlir/examples/toy/Ch7/CMakeLists.txt51
-rw-r--r--mlir/examples/toy/Ch7/include/CMakeLists.txt1
-rw-r--r--mlir/examples/toy/Ch7/include/toy/AST.h317
-rw-r--r--mlir/examples/toy/Ch7/include/toy/CMakeLists.txt9
-rw-r--r--mlir/examples/toy/Ch7/include/toy/Dialect.h109
-rw-r--r--mlir/examples/toy/Ch7/include/toy/Lexer.h244
-rw-r--r--mlir/examples/toy/Ch7/include/toy/MLIRGen.h41
-rw-r--r--mlir/examples/toy/Ch7/include/toy/Ops.td314
-rw-r--r--mlir/examples/toy/Ch7/include/toy/Parser.h687
-rw-r--r--mlir/examples/toy/Ch7/include/toy/Passes.h45
-rw-r--r--mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h37
-rw-r--r--mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td41
-rw-r--r--mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp68
-rw-r--r--mlir/examples/toy/Ch7/mlir/Dialect.cpp483
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp318
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp213
-rw-r--r--mlir/examples/toy/Ch7/mlir/MLIRGen.cpp683
-rw-r--r--mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp113
-rw-r--r--mlir/examples/toy/Ch7/mlir/ToyCombine.cpp101
-rw-r--r--mlir/examples/toy/Ch7/mlir/ToyCombine.td73
-rw-r--r--mlir/examples/toy/Ch7/parser/AST.cpp285
-rw-r--r--mlir/examples/toy/Ch7/toyc.cpp282
-rw-r--r--mlir/g3doc/Tutorials/Toy/Ch-1.md5
-rw-r--r--mlir/g3doc/Tutorials/Toy/Ch-7.md538
-rw-r--r--mlir/test/CMakeLists.txt1
-rw-r--r--mlir/test/Examples/Toy/Ch7/affine-lowering.mlir65
-rw-r--r--mlir/test/Examples/Toy/Ch7/ast.toy76
-rw-r--r--mlir/test/Examples/Toy/Ch7/codegen.toy31
-rw-r--r--mlir/test/Examples/Toy/Ch7/invalid.mlir9
-rw-r--r--mlir/test/Examples/Toy/Ch7/llvm-lowering.mlir23
-rw-r--r--mlir/test/Examples/Toy/Ch7/scalar.toy14
-rw-r--r--mlir/test/Examples/Toy/Ch7/shape_inference.mlir30
-rw-r--r--mlir/test/Examples/Toy/Ch7/struct-ast.toy61
-rw-r--r--mlir/test/Examples/Toy/Ch7/struct-codegen.toy44
-rw-r--r--mlir/test/Examples/Toy/Ch7/struct-opt.mlir16
36 files changed, 5428 insertions, 1 deletions
diff --git a/mlir/examples/toy/CMakeLists.txt b/mlir/examples/toy/CMakeLists.txt
index 52a22014b4e..56002b1ad2e 100644
--- a/mlir/examples/toy/CMakeLists.txt
+++ b/mlir/examples/toy/CMakeLists.txt
@@ -12,3 +12,4 @@ add_subdirectory(Ch3)
add_subdirectory(Ch4)
add_subdirectory(Ch5)
add_subdirectory(Ch6)
+add_subdirectory(Ch7)
diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt
new file mode 100644
index 00000000000..fc26425f038
--- /dev/null
+++ b/mlir/examples/toy/Ch7/CMakeLists.txt
@@ -0,0 +1,51 @@
+add_subdirectory(include)
+
+set(LLVM_LINK_COMPONENTS
+ Core
+ Support
+ )
+
+set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
+mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
+add_public_tablegen_target(ToyCh7CombineIncGen)
+
+add_toy_chapter(toyc-ch7
+ toyc.cpp
+ parser/AST.cpp
+ mlir/MLIRGen.cpp
+ mlir/Dialect.cpp
+ mlir/DeadFunctionEliminationPass.cpp
+ mlir/LowerToAffineLoops.cpp
+ mlir/LowerToLLVM.cpp
+ mlir/ShapeInferencePass.cpp
+ mlir/ToyCombine.cpp
+ )
+
+add_dependencies(toyc-ch7 ToyCh7ShapeInferenceInterfaceIncGen)
+add_dependencies(toyc-ch7 ToyCh7OpsIncGen)
+add_dependencies(toyc-ch7 ToyCh7CombineIncGen)
+add_dependencies(toyc-ch7 MLIRCallOpInterfacesIncGen)
+include_directories(include/)
+include_directories(${CMAKE_CURRENT_BINARY_DIR})
+include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
+target_link_libraries(toyc-ch7
+ PRIVATE
+ MLIRAffineOps
+ MLIRAnalysis
+ MLIRExecutionEngine
+ MLIRIR
+ MLIRLLVMIR
+ MLIRLoopToStandard
+ MLIRParser
+ MLIRPass
+ MLIRStandardOps
+ MLIRStandardToLLVM
+ MLIRTargetLLVMIR
+ MLIRTransforms
+ )
+
+whole_archive_link(toyc-ch7
+ MLIRAffineOps
+ MLIRLLVMIR
+ MLIRStandardOps
+ )
diff --git a/mlir/examples/toy/Ch7/include/CMakeLists.txt b/mlir/examples/toy/Ch7/include/CMakeLists.txt
new file mode 100644
index 00000000000..37c89d0bae9
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(toy)
diff --git a/mlir/examples/toy/Ch7/include/toy/AST.h b/mlir/examples/toy/Ch7/include/toy/AST.h
new file mode 100644
index 00000000000..558d9deab8e
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/AST.h
@@ -0,0 +1,317 @@
+//===- AST.h - Node definition for the Toy AST ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST for the Toy language. It is optimized for
+// simplicity, not efficiency. The AST forms a tree structure where each node
+// references its children using std::unique_ptr<>.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_AST_H_
+#define MLIR_TUTORIAL_TOY_AST_H_
+
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <vector>
+
+namespace toy {
+
+/// A variable type with either name or shape information.
+struct VarType {
+ std::string name;
+ std::vector<int64_t> shape;
+};
+
+/// Base class for all expression nodes.
+class ExprAST {
+public:
+ enum ExprASTKind {
+ Expr_VarDecl,
+ Expr_Return,
+ Expr_Num,
+ Expr_Literal,
+ Expr_StructLiteral,
+ Expr_Var,
+ Expr_BinOp,
+ Expr_Call,
+ Expr_Print,
+ };
+
+ ExprAST(ExprASTKind kind, Location location)
+ : kind(kind), location(location) {}
+ virtual ~ExprAST() = default;
+
+ ExprASTKind getKind() const { return kind; }
+
+ const Location &loc() { return location; }
+
+private:
+ const ExprASTKind kind;
+ Location location;
+};
+
+/// A block-list of expressions.
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+/// Expression class for numeric literals like "1.0".
+class NumberExprAST : public ExprAST {
+ double Val;
+
+public:
+ NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
+
+ double getValue() { return Val; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
+};
+
+/// Expression class for a literal value.
+class LiteralExprAST : public ExprAST {
+ std::vector<std::unique_ptr<ExprAST>> values;
+ std::vector<int64_t> dims;
+
+public:
+ LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
+ std::vector<int64_t> dims)
+ : ExprAST(Expr_Literal, loc), values(std::move(values)),
+ dims(std::move(dims)) {}
+
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
+ llvm::ArrayRef<int64_t> getDims() { return dims; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
+};
+
+/// Expression class for a literal struct value.
+class StructLiteralExprAST : public ExprAST {
+ std::vector<std::unique_ptr<ExprAST>> values;
+
+public:
+ StructLiteralExprAST(Location loc,
+ std::vector<std::unique_ptr<ExprAST>> values)
+ : ExprAST(Expr_StructLiteral, loc), values(std::move(values)) {}
+
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) {
+ return c->getKind() == Expr_StructLiteral;
+ }
+};
+
+/// Expression class for referencing a variable, like "a".
+class VariableExprAST : public ExprAST {
+ std::string name;
+
+public:
+ VariableExprAST(Location loc, llvm::StringRef name)
+ : ExprAST(Expr_Var, loc), name(name) {}
+
+ llvm::StringRef getName() { return name; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
+};
+
+/// Expression class for defining a variable.
+class VarDeclExprAST : public ExprAST {
+ std::string name;
+ VarType type;
+ std::unique_ptr<ExprAST> initVal;
+
+public:
+ VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
+ std::unique_ptr<ExprAST> initVal = nullptr)
+ : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
+ initVal(std::move(initVal)) {}
+
+ llvm::StringRef getName() { return name; }
+ ExprAST *getInitVal() { return initVal.get(); }
+ const VarType &getType() { return type; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
+};
+
+/// Expression class for a return operator.
+class ReturnExprAST : public ExprAST {
+ llvm::Optional<std::unique_ptr<ExprAST>> expr;
+
+public:
+ ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
+ : ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+ llvm::Optional<ExprAST *> getExpr() {
+ if (expr.hasValue())
+ return expr->get();
+ return llvm::None;
+ }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
+};
+
+/// Expression class for a binary operator.
+class BinaryExprAST : public ExprAST {
+ char op;
+ std::unique_ptr<ExprAST> lhs, rhs;
+
+public:
+ char getOp() { return op; }
+ ExprAST *getLHS() { return lhs.get(); }
+ ExprAST *getRHS() { return rhs.get(); }
+
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
+ std::unique_ptr<ExprAST> rhs)
+ : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
+ rhs(std::move(rhs)) {}
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
+};
+
+/// Expression class for function calls.
+class CallExprAST : public ExprAST {
+ std::string callee;
+ std::vector<std::unique_ptr<ExprAST>> args;
+
+public:
+ CallExprAST(Location loc, const std::string &callee,
+ std::vector<std::unique_ptr<ExprAST>> args)
+ : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
+
+ llvm::StringRef getCallee() { return callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
+};
+
+/// Expression class for builtin print calls.
+class PrintExprAST : public ExprAST {
+ std::unique_ptr<ExprAST> arg;
+
+public:
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+ : ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
+
+ ExprAST *getArg() { return arg.get(); }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
+};
+
+/// This class represents the "prototype" for a function, which captures its
+/// name, and its argument names (thus implicitly the number of arguments the
+/// function takes).
+class PrototypeAST {
+ Location location;
+ std::string name;
+ std::vector<std::unique_ptr<VarDeclExprAST>> args;
+
+public:
+ PrototypeAST(Location location, const std::string &name,
+ std::vector<std::unique_ptr<VarDeclExprAST>> args)
+ : location(location), name(name), args(std::move(args)) {}
+
+ const Location &loc() { return location; }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getArgs() { return args; }
+};
+
+/// This class represents a top level record in a module.
+class RecordAST {
+public:
+ enum RecordASTKind {
+ Record_Function,
+ Record_Struct,
+ };
+
+ RecordAST(RecordASTKind kind) : kind(kind) {}
+ virtual ~RecordAST() = default;
+
+ RecordASTKind getKind() const { return kind; }
+
+private:
+ const RecordASTKind kind;
+};
+
+/// This class represents a function definition itself.
+class FunctionAST : public RecordAST {
+ std::unique_ptr<PrototypeAST> proto;
+ std::unique_ptr<ExprASTList> body;
+
+public:
+ FunctionAST(std::unique_ptr<PrototypeAST> proto,
+ std::unique_ptr<ExprASTList> body)
+ : RecordAST(Record_Function), proto(std::move(proto)),
+ body(std::move(body)) {}
+ PrototypeAST *getProto() { return proto.get(); }
+ ExprASTList *getBody() { return body.get(); }
+
+ /// LLVM style RTTI
+ static bool classof(const RecordAST *R) {
+ return R->getKind() == Record_Function;
+ }
+};
+
+/// This class represents a struct definition.
+class StructAST : public RecordAST {
+ Location location;
+ std::string name;
+ std::vector<std::unique_ptr<VarDeclExprAST>> variables;
+
+public:
+ StructAST(Location location, const std::string &name,
+ std::vector<std::unique_ptr<VarDeclExprAST>> variables)
+ : RecordAST(Record_Struct), location(location), name(name),
+ variables(std::move(variables)) {}
+
+ const Location &loc() { return location; }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getVariables() {
+ return variables;
+ }
+
+ /// LLVM style RTTI
+ static bool classof(const RecordAST *R) {
+ return R->getKind() == Record_Struct;
+ }
+};
+
+/// This class represents a list of functions to be processed together
+class ModuleAST {
+ std::vector<std::unique_ptr<RecordAST>> records;
+
+public:
+ ModuleAST(std::vector<std::unique_ptr<RecordAST>> records)
+ : records(std::move(records)) {}
+
+ auto begin() -> decltype(records.begin()) { return records.begin(); }
+ auto end() -> decltype(records.end()) { return records.end(); }
+};
+
+void dump(ModuleAST &);
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_AST_H_
diff --git a/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt
new file mode 100644
index 00000000000..fa30bd2e8e0
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt
@@ -0,0 +1,9 @@
+set(LLVM_TARGET_DEFINITIONS Ops.td)
+mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
+mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
+add_public_tablegen_target(ToyCh7OpsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
+mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(ToyCh7ShapeInferenceInterfaceIncGen)
diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h
new file mode 100644
index 00000000000..82b677a3c2a
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h
@@ -0,0 +1,109 @@
+//===- Dialect.h - Dialect definition for the Toy IR ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the IR Dialect for the Toy language.
+// See g3doc/Tutorials/Toy/Ch-2.md for more information.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
+#define MLIR_TUTORIAL_TOY_DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/StandardTypes.h"
+#include "toy/ShapeInferenceInterface.h"
+
+namespace mlir {
+namespace toy {
+namespace detail {
+class StructTypeStorage;
+} // end namespace detail
+
+/// This is the definition of the Toy dialect. A dialect inherits from
+/// mlir::Dialect and registers custom attributes, operations, and types (in its
+/// constructor). It can also override some general behavior exposed via virtual
+/// methods.
+class ToyDialect : public mlir::Dialect {
+public:
+ explicit ToyDialect(mlir::MLIRContext *ctx);
+
+ /// A hook used to materialize constant values with the given type.
+ Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+ Location loc) override;
+
+ /// Parse an instance of a type registered to the toy dialect.
+ mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
+
+ /// Print an instance of a type registered to the toy dialect.
+ void printType(mlir::Type type,
+ mlir::DialectAsmPrinter &printer) const override;
+
+ /// Provide a utility accessor to the dialect namespace. This is used by
+ /// several utilities for casting between dialects.
+ static llvm::StringRef getDialectNamespace() { return "toy"; }
+};
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+/// Include the auto-generated header file containing the declarations of the
+/// toy operations.
+#define GET_OP_CLASSES
+#include "toy/Ops.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Toy Types
+//===----------------------------------------------------------------------===//
+
+/// Create a local enumeration with all of the types that are defined by Toy.
+namespace ToyTypes {
+enum Types {
+ Struct = mlir::Type::FIRST_TOY_TYPE,
+};
+} // end namespace ToyTypes
+
+/// This class defines the Toy struct type. It represents a collection of
+/// element types. All derived types in MLIR must inherit from the CRTP class
+/// 'Type::TypeBase'. It takes as template parameters the concrete type
+/// (StructType), the base class to use (Type), and the storage class
+/// (StructTypeStorage).
+class StructType : public mlir::Type::TypeBase<StructType, mlir::Type,
+ detail::StructTypeStorage> {
+public:
+ /// Inherit some necessary constructors from 'TypeBase'.
+ using Base::Base;
+
+ /// This static method is used to support type inquiry through isa, cast,
+ /// and dyn_cast.
+ static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; }
+
+ /// Create an instance of a `StructType` with the given element types. There
+ /// *must* be atleast one element type.
+ static StructType get(llvm::ArrayRef<mlir::Type> elementTypes);
+
+ /// Returns the element types of this struct type.
+ llvm::ArrayRef<mlir::Type> getElementTypes();
+
+ /// Returns the number of element type held by this struct.
+ size_t getNumElementTypes() { return getElementTypes().size(); }
+};
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
diff --git a/mlir/examples/toy/Ch7/include/toy/Lexer.h b/mlir/examples/toy/Ch7/include/toy/Lexer.h
new file mode 100644
index 00000000000..89dc6cba9ff
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/Lexer.h
@@ -0,0 +1,244 @@
+//===- Lexer.h - Lexer for the Toy language -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple Lexer for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_LEXER_H_
+#define MLIR_TUTORIAL_TOY_LEXER_H_
+
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+#include <string>
+
+namespace toy {
+
+/// Structure definition a location in a file.
+struct Location {
+ std::shared_ptr<std::string> file; ///< filename.
+ int line; ///< line number.
+ int col; ///< column number.
+};
+
+// List of Token returned by the lexer.
+enum Token : int {
+ tok_semicolon = ';',
+ tok_parenthese_open = '(',
+ tok_parenthese_close = ')',
+ tok_bracket_open = '{',
+ tok_bracket_close = '}',
+ tok_sbracket_open = '[',
+ tok_sbracket_close = ']',
+
+ tok_eof = -1,
+
+ // commands
+ tok_return = -2,
+ tok_var = -3,
+ tok_def = -4,
+ tok_struct = -5,
+
+ // primary
+ tok_identifier = -6,
+ tok_number = -7,
+};
+
+/// The Lexer is an abstract base class providing all the facilities that the
+/// Parser expects. It goes through the stream one token at a time and keeps
+/// track of the location in the file for debugging purpose.
+/// It relies on a subclass to provide a `readNextLine()` method. The subclass
+/// can proceed by reading the next line from the standard input or from a
+/// memory mapped file.
+class Lexer {
+public:
+ /// Create a lexer for the given filename. The filename is kept only for
+ /// debugging purpose (attaching a location to a Token).
+ Lexer(std::string filename)
+ : lastLocation(
+ {std::make_shared<std::string>(std::move(filename)), 0, 0}) {}
+ virtual ~Lexer() = default;
+
+ /// Look at the current token in the stream.
+ Token getCurToken() { return curTok; }
+
+ /// Move to the next token in the stream and return it.
+ Token getNextToken() { return curTok = getTok(); }
+
+ /// Move to the next token in the stream, asserting on the current token
+ /// matching the expectation.
+ void consume(Token tok) {
+ assert(tok == curTok && "consume Token mismatch expectation");
+ getNextToken();
+ }
+
+ /// Return the current identifier (prereq: getCurToken() == tok_identifier)
+ llvm::StringRef getId() {
+ assert(curTok == tok_identifier);
+ return identifierStr;
+ }
+
+ /// Return the current number (prereq: getCurToken() == tok_number)
+ double getValue() {
+ assert(curTok == tok_number);
+ return numVal;
+ }
+
+ /// Return the location for the beginning of the current token.
+ Location getLastLocation() { return lastLocation; }
+
+ // Return the current line in the file.
+ int getLine() { return curLineNum; }
+
+ // Return the current column in the file.
+ int getCol() { return curCol; }
+
+private:
+ /// Delegate to a derived class fetching the next line. Returns an empty
+ /// string to signal end of file (EOF). Lines are expected to always finish
+ /// with "\n"
+ virtual llvm::StringRef readNextLine() = 0;
+
+ /// Return the next character from the stream. This manages the buffer for the
+ /// current line and request the next line buffer to the derived class as
+ /// needed.
+ int getNextChar() {
+ // The current line buffer should not be empty unless it is the end of file.
+ if (curLineBuffer.empty())
+ return EOF;
+ ++curCol;
+ auto nextchar = curLineBuffer.front();
+ curLineBuffer = curLineBuffer.drop_front();
+ if (curLineBuffer.empty())
+ curLineBuffer = readNextLine();
+ if (nextchar == '\n') {
+ ++curLineNum;
+ curCol = 0;
+ }
+ return nextchar;
+ }
+
+ /// Return the next token from standard input.
+ Token getTok() {
+ // Skip any whitespace.
+ while (isspace(lastChar))
+ lastChar = Token(getNextChar());
+
+ // Save the current location before reading the token characters.
+ lastLocation.line = curLineNum;
+ lastLocation.col = curCol;
+
+ // Identifier: [a-zA-Z][a-zA-Z0-9_]*
+ if (isalpha(lastChar)) {
+ identifierStr = (char)lastChar;
+ while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
+ identifierStr += (char)lastChar;
+
+ if (identifierStr == "return")
+ return tok_return;
+ if (identifierStr == "def")
+ return tok_def;
+ if (identifierStr == "struct")
+ return tok_struct;
+ if (identifierStr == "var")
+ return tok_var;
+ return tok_identifier;
+ }
+
+ // Number: [0-9] ([0-9.])*
+ if (isdigit(lastChar)) {
+ std::string numStr;
+ do {
+ numStr += lastChar;
+ lastChar = Token(getNextChar());
+ } while (isdigit(lastChar) || lastChar == '.');
+
+ numVal = strtod(numStr.c_str(), nullptr);
+ return tok_number;
+ }
+
+ if (lastChar == '#') {
+ // Comment until end of line.
+ do {
+ lastChar = Token(getNextChar());
+ } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
+
+ if (lastChar != EOF)
+ return getTok();
+ }
+
+ // Check for end of file. Don't eat the EOF.
+ if (lastChar == EOF)
+ return tok_eof;
+
+ // Otherwise, just return the character as its ascii value.
+ Token thisChar = Token(lastChar);
+ lastChar = Token(getNextChar());
+ return thisChar;
+ }
+
+ /// The last token read from the input.
+ Token curTok = tok_eof;
+
+ /// Location for `curTok`.
+ Location lastLocation;
+
+ /// If the current Token is an identifier, this string contains the value.
+ std::string identifierStr;
+
+ /// If the current Token is a number, this contains the value.
+ double numVal = 0;
+
+ /// The last value returned by getNextChar(). We need to keep it around as we
+ /// always need to read ahead one character to decide when to end a token and
+ /// we can't put it back in the stream after reading from it.
+ Token lastChar = Token(' ');
+
+ /// Keep track of the current line number in the input stream
+ int curLineNum = 0;
+
+ /// Keep track of the current column number in the input stream
+ int curCol = 0;
+
+ /// Buffer supplied by the derived class on calls to `readNextLine()`
+ llvm::StringRef curLineBuffer = "\n";
+};
+
+/// A lexer implementation operating on a buffer in memory.
+class LexerBuffer final : public Lexer {
+public:
+ LexerBuffer(const char *begin, const char *end, std::string filename)
+ : Lexer(std::move(filename)), current(begin), end(end) {}
+
+private:
+ /// Provide one line at a time to the Lexer, return an empty string when
+ /// reaching the end of the buffer.
+ llvm::StringRef readNextLine() override {
+ auto *begin = current;
+ while (current <= end && *current && *current != '\n')
+ ++current;
+ if (current <= end && *current)
+ ++current;
+ llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
+ return result;
+ }
+ const char *current, *end;
+};
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_LEXER_H_
diff --git a/mlir/examples/toy/Ch7/include/toy/MLIRGen.h b/mlir/examples/toy/Ch7/include/toy/MLIRGen.h
new file mode 100644
index 00000000000..287f432c847
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/MLIRGen.h
@@ -0,0 +1,41 @@
+//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares a simple interface to perform IR generation targeting MLIR
+// from a Module AST for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_
+#define MLIR_TUTORIAL_TOY_MLIRGEN_H_
+
+#include <memory>
+
+namespace mlir {
+class MLIRContext;
+class OwningModuleRef;
+} // namespace mlir
+
+namespace toy {
+class ModuleAST;
+
+/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
+/// or nullptr on failure.
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
new file mode 100644
index 00000000000..5e932bb0a7e
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -0,0 +1,314 @@
+//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines the operations of the Toy dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOY_OPS
+#define TOY_OPS
+
+#ifndef MLIR_CALLINTERFACES
+include "mlir/Analysis/CallInterfaces.td"
+#endif // MLIR_CALLINTERFACES
+
+#ifndef SHAPE_INFERENCE_INTERFACE
+include "toy/ShapeInferenceInterface.td"
+#endif // SHAPE_INFERENCE_INTERFACE
+
+// Provide a definition of the 'toy' dialect in the ODS framework so that we
+// can define our operations.
+def Toy_Dialect : Dialect {
+ let name = "toy";
+ let cppNamespace = "toy";
+}
+
+// Base class for toy dialect operations. This operation inherits from the base
+// `Op` class in OpBase.td, and provides:
+// * The parent dialect of the operation.
+// * The mnemonic for the operation, or the name without the dialect prefix.
+// * A list of traits for the operation.
+class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<Toy_Dialect, mnemonic, traits>;
+
+// Provide a definition for the Toy StructType for use in ODS. This allows for
+// using StructType in a similar way to Tensor or MemRef.
+def Toy_StructType :
+ Type<CPred<"$_self.isa<StructType>()">, "Toy struct type">;
+
+// Provide a definition of the types that are used within the Toy dialect.
+def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>;
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+// We define a toy operation by inherting from our base 'Toy_Op' class above.
+// Here we provide the mnemonic and a list of traits for the operation. The
+// constant operation is marked as 'NoSideEffect' as it is a pure operation
+// and may be removed if dead.
+def ConstantOp : Toy_Op<"constant",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+ // Provide a summary and description for this operation. This can be used to
+ // auto-generate documenatation of the operations within our dialect.
+ let summary = "constant";
+ let description = [{
+ Constant operation turns a literal into an SSA value. The data is attached
+ to the operation as an attribute. For example:
+
+ ```mlir
+ %0 = "toy.constant"()
+ { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
+ : () -> tensor<2x3xf64>
+ ```
+ }];
+
+ // The constant operation takes an attribute as the only input.
+ let arguments = (ins F64ElementsAttr:$value);
+
+ // The constant operation returns a single value of TensorType.
+ let results = (outs F64Tensor);
+
+ // Add custom build methods for the constant operation. These method populates
+ // the `state` that MLIR uses to create operations, i.e. these are used when
+ // using `builder.create<ConstantOp>(...)`.
+ let builders = [
+ // Build a constant with a given constant tensor value.
+ OpBuilder<"Builder *builder, OperationState &state, "
+ "DenseElementsAttr value", [{
+ build(builder, state, value.getType(), value);
+ }]>,
+
+ // Build a constant with a given constant floating-point value.
+ OpBuilder<"Builder *builder, OperationState &state, double value">
+ ];
+
+ // Invoke a static verify method to verify this constant operation.
+ let verifier = [{ return ::verify(*this); }];
+
+ // Set the folder bit so that we can implement constant folders.
+ let hasFolder = 1;
+}
+
+def AddOp : Toy_Op<"add",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+ let summary = "element-wise addition operation";
+ let description = [{
+ The "add" operation performs element-wise addition between two tensors.
+ The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ // Allow building an AddOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
+ ];
+}
+
+def CastOp : Toy_Op<"cast",
+ [DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
+ SameOperandsAndResultShape]> {
+ let summary = "shape cast operation";
+ let description = [{
+ The "cast" operation converts a tensor from one type to an equivalent type
+ without changing any data elements. The source and destination types
+ must both be tensor types with the same element type. If both are ranked
+ then the rank should be the same and static dimensions should match. The
+ operation is invalid if converting to a mismatching constant dimension.
+ }];
+
+ let arguments = (ins F64Tensor:$input);
+ let results = (outs F64Tensor:$output);
+
+ // Set the folder bit so that we can fold redundant cast operations.
+ let hasFolder = 1;
+}
+
+def GenericCallOp : Toy_Op<"generic_call",
+ [DeclareOpInterfaceMethods<CallOpInterface>]> {
+ let summary = "generic call operation";
+ let description = [{
+ Generic calls represent calls to a user defined function that needs to
+ be specialized for the shape of its arguments. The callee name is attached
+ as a symbol reference via an attribute. The arguments list must match the
+ arguments expected by the callee. For example:
+
+ ```mlir
+ %4 = "toy.generic_call"(%1, %3) {callee = @my_func}
+ : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
+ ```
+
+ This is only valid if a function named "my_func" exists and takes two
+ arguments.
+ }];
+
+ // The generic call operation takes a symbol reference attribute as the
+ // callee, and inputs for the call.
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);
+
+ // The generic call operation returns a single value of TensorType or
+ // StructType.
+ let results = (outs Toy_Type);
+
+ // Add custom build methods for the generic call operation.
+ let builders = [
+ OpBuilder<"Builder *builder, OperationState &state, "
+ "StringRef callee, ArrayRef<Value *> arguments">
+ ];
+}
+
+def MulOp : Toy_Op<"mul",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+ let summary = "element-wise multiplication operation";
+ let description = [{
+ The "mul" operation performs element-wise multiplication between two
+ tensors. The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ // Allow building a MulOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
+ ];
+}
+
+def PrintOp : Toy_Op<"print"> {
+ let summary = "print operation";
+ let description = [{
+ The "print" builtin operation prints a given input tensor, and produces
+ no results.
+ }];
+
+ // The print operation takes an input tensor to print.
+ // We also allow a F64MemRef to enable interop during partial lowering.
+ let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
+}
+
+def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
+ let summary = "tensor reshape operation";
+ let description = [{
+ Reshape operation is transforming its input tensor into a new tensor with
+ the same number of elements but different shapes. For example:
+
+ ```mlir
+ %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
+ ```
+ }];
+
+ let arguments = (ins F64Tensor:$input);
+ let hasCanonicalizer = 1;
+
+ // We expect that the reshape operation returns a statically shaped tensor.
+ let results = (outs StaticShapeTensorOf<[F64]>);
+}
+
+def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
+ let summary = "return operation";
+ let description = [{
+ The "return" operation represents a return operation within a function.
+ The operation takes an optional operand and produces no results.
+ The operand type must match the signature of the function that contains
+ the operation. For example:
+
+ ```mlir
+ func @foo() -> tensor<2xf64> {
+ ...
+ toy.return %0 : tensor<2xf64>
+ }
+ ```
+ }];
+
+ // The return operation takes an optional input operand to return. This
+ // value must match the return type of the enclosing function.
+ let arguments = (ins Variadic<Toy_Type>:$input);
+
+ // Allow building a ReturnOp with no return operand.
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
+ >];
+
+ // Provide extra utility definitions on the c++ operation class definition.
+ let extraClassDeclaration = [{
+ bool hasOperand() { return getNumOperands() != 0; }
+ }];
+
+ // Invoke a static verify method to verify this return operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> {
+ let summary = "struct access";
+ let description = [{
+ Access the Nth element of a value returning a struct type.
+ }];
+
+ let arguments = (ins Toy_StructType:$input, I64Attr:$index);
+ let results = (outs Toy_Type);
+
+ // Allow building a StructAccessOp with just a struct value and an index.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &state, Value *input, size_t index">
+ ];
+
+ let verifier = [{ return ::verify(*this); }];
+
+ // Set the folder bit so that we can fold constant accesses.
+ let hasFolder = 1;
+}
+
+def StructConstantOp : Toy_Op<"struct_constant", [NoSideEffect]> {
+ let summary = "struct constant";
+ let description = [{
+ Constant operation turns a literal struct value into an SSA value. The data
+ is attached to the operation as an attribute. The struct constant is encoded
+ as an array of other constant values. For example:
+
+ ```mlir
+ %0 = "toy.struct_constant"() {
+ value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>]
+ } : () -> !toy.struct<tensor<*xf64>>
+ ```
+ }];
+
+ let hasFolder = 1;
+ let arguments = (ins ArrayAttr:$value);
+ let results = (outs Toy_StructType);
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def TransposeOp : Toy_Op<"transpose",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
+ let summary = "transpose operation";
+
+ let arguments = (ins F64Tensor:$input);
+ let results = (outs F64Tensor);
+ let hasCanonicalizer = 1;
+
+ // Allow building a TransposeOp with from the input operand.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &state, Value *input">
+ ];
+
+ // Invoke a static verify method to verify this transpose operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+#endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch7/include/toy/Parser.h b/mlir/examples/toy/Ch7/include/toy/Parser.h
new file mode 100644
index 00000000000..e0f36028c59
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/Parser.h
@@ -0,0 +1,687 @@
+//===- Parser.h - Toy Language Parser -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the parser for the Toy language. It processes the Token
+// provided by the Lexer and returns an AST.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PARSER_H
+#define MLIR_TUTORIAL_TOY_PARSER_H
+
+#include "toy/AST.h"
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <utility>
+#include <vector>
+
+namespace toy {
+
+/// This is a simple recursive parser for the Toy language. It produces a well
+/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
+/// or symbol resolution is performed. For example, variables are referenced by
+/// string and the code could reference an undeclared variable and the parsing
+/// succeeds.
+class Parser {
+public:
+ /// Create a Parser for the supplied lexer.
+ Parser(Lexer &lexer) : lexer(lexer) {}
+
+ /// Parse a full Module. A module is a list of function definitions.
+ std::unique_ptr<ModuleAST> parseModule() {
+ lexer.getNextToken(); // prime the lexer
+
+ // Parse functions and structs one at a time and accumulate in this vector.
+ std::vector<std::unique_ptr<RecordAST>> records;
+ while (true) {
+ std::unique_ptr<RecordAST> record;
+ switch (lexer.getCurToken()) {
+ case tok_eof:
+ break;
+ case tok_def:
+ record = parseDefinition();
+ break;
+ case tok_struct:
+ record = parseStruct();
+ break;
+ default:
+ return parseError<ModuleAST>("'def' or 'struct'",
+ "when parsing top level module records");
+ }
+ if (!record)
+ break;
+ records.push_back(std::move(record));
+ }
+
+ // If we didn't reach EOF, there was an error during parsing
+ if (lexer.getCurToken() != tok_eof)
+ return parseError<ModuleAST>("nothing", "at end of module");
+
+ return std::make_unique<ModuleAST>(std::move(records));
+ }
+
+private:
+ Lexer &lexer;
+
+ /// Parse a return statement.
+ /// return :== return ; | return expr ;
+ std::unique_ptr<ReturnExprAST> parseReturn() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_return);
+
+ // return takes an optional argument
+ llvm::Optional<std::unique_ptr<ExprAST>> expr;
+ if (lexer.getCurToken() != ';') {
+ expr = parseExpression();
+ if (!expr)
+ return nullptr;
+ }
+ return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
+ }
+
+ /// Parse a literal number.
+ /// numberexpr ::= number
+ std::unique_ptr<ExprAST> parseNumberExpr() {
+ auto loc = lexer.getLastLocation();
+ auto result =
+ std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
+ lexer.consume(tok_number);
+ return std::move(result);
+ }
+
+ /// Parse a literal array expression.
+ /// tensorLiteral ::= [ literalList ] | number
+ /// literalList ::= tensorLiteral | tensorLiteral, literalList
+ std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(Token('['));
+
+ // Hold the list of values at this nesting level.
+ std::vector<std::unique_ptr<ExprAST>> values;
+ // Hold the dimensions for all the nesting inside this level.
+ std::vector<int64_t> dims;
+ do {
+ // We can have either another nested array or a number literal.
+ if (lexer.getCurToken() == '[') {
+ values.push_back(parseTensorLiteralExpr());
+ if (!values.back())
+ return nullptr; // parse error in the nested array.
+ } else {
+ if (lexer.getCurToken() != tok_number)
+ return parseError<ExprAST>("<num> or [", "in literal expression");
+ values.push_back(parseNumberExpr());
+ }
+
+ // End of this list on ']'
+ if (lexer.getCurToken() == ']')
+ break;
+
+ // Elements are separated by a comma.
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>("] or ,", "in literal expression");
+
+ lexer.getNextToken(); // eat ,
+ } while (true);
+ if (values.empty())
+ return parseError<ExprAST>("<something>", "to fill literal expression");
+ lexer.getNextToken(); // eat ]
+
+ /// Fill in the dimensions now. First the current nesting level:
+ dims.push_back(values.size());
+
+ /// If there is any nested array, process all of them and ensure that
+ /// dimensions are uniform.
+ if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
+ return llvm::isa<LiteralExprAST>(expr.get());
+ })) {
+ auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
+ if (!firstLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+
+ // Append the nested dimensions to the current level
+ auto firstDims = firstLiteral->getDims();
+ dims.insert(dims.end(), firstDims.begin(), firstDims.end());
+
+ // Sanity check that shape is uniform across all elements of the list.
+ for (auto &expr : values) {
+ auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
+ if (!exprLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+ if (exprLiteral->getDims() != firstDims)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+ }
+ }
+ return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
+ std::move(dims));
+ }
+
+ /// Parse a literal struct expression.
+ /// structLiteral ::= { (structLiteral | tensorLiteral)+ }
+ std::unique_ptr<ExprAST> parseStructLiteralExpr() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(Token('{'));
+
+ // Hold the list of values.
+ std::vector<std::unique_ptr<ExprAST>> values;
+ do {
+ // We can have either another nested array or a number literal.
+ if (lexer.getCurToken() == '[') {
+ values.push_back(parseTensorLiteralExpr());
+ if (!values.back())
+ return nullptr;
+ } else if (lexer.getCurToken() == tok_number) {
+ values.push_back(parseNumberExpr());
+ if (!values.back())
+ return nullptr;
+ } else {
+ if (lexer.getCurToken() != '{')
+ return parseError<ExprAST>("{, [, or number",
+ "in struct literal expression");
+ values.push_back(parseStructLiteralExpr());
+ }
+
+ // End of this list on '}'
+ if (lexer.getCurToken() == '}')
+ break;
+
+ // Elements are separated by a comma.
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>("} or ,", "in struct literal expression");
+
+ lexer.getNextToken(); // eat ,
+ } while (true);
+ if (values.empty())
+ return parseError<ExprAST>("<something>",
+ "to fill struct literal expression");
+ lexer.getNextToken(); // eat }
+
+ return std::make_unique<StructLiteralExprAST>(std::move(loc),
+ std::move(values));
+ }
+
+ /// parenexpr ::= '(' expression ')'
+ std::unique_ptr<ExprAST> parseParenExpr() {
+ lexer.getNextToken(); // eat (.
+ auto v = parseExpression();
+ if (!v)
+ return nullptr;
+
+ if (lexer.getCurToken() != ')')
+ return parseError<ExprAST>(")", "to close expression with parentheses");
+ lexer.consume(Token(')'));
+ return v;
+ }
+
+ /// Parse a call expression.
+ std::unique_ptr<ExprAST> parseCallExpr(llvm::StringRef name,
+ const Location &loc) {
+ lexer.consume(Token('('));
+ std::vector<std::unique_ptr<ExprAST>> args;
+ if (lexer.getCurToken() != ')') {
+ while (true) {
+ if (auto arg = parseExpression())
+ args.push_back(std::move(arg));
+ else
+ return nullptr;
+
+ if (lexer.getCurToken() == ')')
+ break;
+
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>(", or )", "in argument list");
+ lexer.getNextToken();
+ }
+ }
+ lexer.consume(Token(')'));
+
+ // It can be a builtin call to print
+ if (name == "print") {
+ if (args.size() != 1)
+ return parseError<ExprAST>("<single arg>", "as argument to print()");
+
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
+ }
+
+ // Call to a user-defined function
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
+ }
+
+ /// identifierexpr
+ /// ::= identifier
+ /// ::= identifier '(' expression ')'
+ std::unique_ptr<ExprAST> parseIdentifierExpr() {
+ std::string name = lexer.getId();
+
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat identifier.
+
+ if (lexer.getCurToken() != '(') // Simple variable ref.
+ return std::make_unique<VariableExprAST>(std::move(loc), name);
+
+ // This is a function call.
+ return parseCallExpr(name, loc);
+ }
+
+ /// primary
+ /// ::= identifierexpr
+ /// ::= numberexpr
+ /// ::= parenexpr
+ /// ::= tensorliteral
+ std::unique_ptr<ExprAST> parsePrimary() {
+ switch (lexer.getCurToken()) {
+ default:
+ llvm::errs() << "unknown token '" << lexer.getCurToken()
+ << "' when expecting an expression\n";
+ return nullptr;
+ case tok_identifier:
+ return parseIdentifierExpr();
+ case tok_number:
+ return parseNumberExpr();
+ case '(':
+ return parseParenExpr();
+ case '[':
+ return parseTensorLiteralExpr();
+ case '{':
+ return parseStructLiteralExpr();
+ case ';':
+ return nullptr;
+ case '}':
+ return nullptr;
+ }
+ }
+
+ /// Recursively parse the right hand side of a binary expression, the ExprPrec
+ /// argument indicates the precedence of the current binary operator.
+ ///
+ /// binoprhs ::= ('+' primary)*
+ std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+ std::unique_ptr<ExprAST> lhs) {
+ // If this is a binop, find its precedence.
+ while (true) {
+ int tokPrec = getTokPrecedence();
+
+ // If this is a binop that binds at least as tightly as the current binop,
+ // consume it, otherwise we are done.
+ if (tokPrec < exprPrec)
+ return lhs;
+
+ // Okay, we know this is a binop.
+ int binOp = lexer.getCurToken();
+ lexer.consume(Token(binOp));
+ auto loc = lexer.getLastLocation();
+
+ // Parse the primary expression after the binary operator.
+ auto rhs = parsePrimary();
+ if (!rhs)
+ return parseError<ExprAST>("expression", "to complete binary operator");
+
+ // If BinOp binds less tightly with rhs than the operator after rhs, let
+ // the pending operator take rhs as its lhs.
+ int nextPrec = getTokPrecedence();
+ if (tokPrec < nextPrec) {
+ rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+ if (!rhs)
+ return nullptr;
+ }
+
+ // Merge lhs/RHS.
+ lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+ std::move(lhs), std::move(rhs));
+ }
+ }
+
+ /// expression::= primary binop rhs
+ std::unique_ptr<ExprAST> parseExpression() {
+ auto lhs = parsePrimary();
+ if (!lhs)
+ return nullptr;
+
+ return parseBinOpRHS(0, std::move(lhs));
+ }
+
+ /// type ::= < shape_list >
+ /// shape_list ::= num | num , shape_list
+ std::unique_ptr<VarType> parseType() {
+ if (lexer.getCurToken() != '<')
+ return parseError<VarType>("<", "to begin type");
+ lexer.getNextToken(); // eat <
+
+ auto type = std::make_unique<VarType>();
+
+ while (lexer.getCurToken() == tok_number) {
+ type->shape.push_back(lexer.getValue());
+ lexer.getNextToken();
+ if (lexer.getCurToken() == ',')
+ lexer.getNextToken();
+ }
+
+ if (lexer.getCurToken() != '>')
+ return parseError<VarType>(">", "to end type");
+ lexer.getNextToken(); // eat >
+ return type;
+ }
+
+ /// Parse either a variable declaration or a call expression.
+ std::unique_ptr<ExprAST> parseDeclarationOrCallExpr() {
+ auto loc = lexer.getLastLocation();
+ std::string id = lexer.getId();
+ lexer.consume(tok_identifier);
+
+ // Check for a call expression.
+ if (lexer.getCurToken() == '(')
+ return parseCallExpr(id, loc);
+
+ // Otherwise, this is a variable declaration.
+ return parseTypedDeclaration(id, /*requiresInitializer=*/true, loc);
+ }
+
+ /// Parse a typed variable declaration.
+ std::unique_ptr<VarDeclExprAST>
+ parseTypedDeclaration(llvm::StringRef typeName, bool requiresInitializer,
+ const Location &loc) {
+ // Parse the variable name.
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<VarDeclExprAST>("name", "in variable declaration");
+ std::string id = lexer.getId();
+ lexer.getNextToken(); // eat id
+
+ // Parse the initializer.
+ std::unique_ptr<ExprAST> expr;
+ if (requiresInitializer) {
+ if (lexer.getCurToken() != '=')
+ return parseError<VarDeclExprAST>("initializer",
+ "in variable declaration");
+ lexer.consume(Token('='));
+ expr = parseExpression();
+ }
+
+ VarType type;
+ type.name = typeName;
+ return std::make_unique<VarDeclExprAST>(loc, std::move(id), std::move(type),
+ std::move(expr));
+ }
+
+ /// Parse a variable declaration, for either a tensor value or a struct value,
+ /// with an optionally required initializer.
+ /// decl ::= var identifier [ type ] (= expr)?
+ /// decl ::= identifier identifier (= expr)?
+ std::unique_ptr<VarDeclExprAST> parseDeclaration(bool requiresInitializer) {
+ // Check to see if this is a 'var' declaration.
+ if (lexer.getCurToken() == tok_var)
+ return parseVarDeclaration(requiresInitializer);
+
+ // Parse the type name.
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<VarDeclExprAST>("type name", "in variable declaration");
+ auto loc = lexer.getLastLocation();
+ std::string typeName = lexer.getId();
+ lexer.getNextToken(); // eat id
+
+ // Parse the rest of the declaration.
+ return parseTypedDeclaration(typeName, requiresInitializer, loc);
+ }
+
+ /// Parse a variable declaration, it starts with a `var` keyword followed by
+ /// and identifier and an optional type (shape specification) before the
+ /// optionally required initializer.
+ /// decl ::= var identifier [ type ] (= expr)?
+ std::unique_ptr<VarDeclExprAST>
+ parseVarDeclaration(bool requiresInitializer) {
+ if (lexer.getCurToken() != tok_var)
+ return parseError<VarDeclExprAST>("var", "to begin declaration");
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat var
+
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<VarDeclExprAST>("identified",
+ "after 'var' declaration");
+ std::string id = lexer.getId();
+ lexer.getNextToken(); // eat id
+
+ std::unique_ptr<VarType> type; // Type is optional, it can be inferred
+ if (lexer.getCurToken() == '<') {
+ type = parseType();
+ if (!type)
+ return nullptr;
+ }
+ if (!type)
+ type = std::make_unique<VarType>();
+
+ std::unique_ptr<ExprAST> expr;
+ if (requiresInitializer) {
+ lexer.consume(Token('='));
+ expr = parseExpression();
+ }
+ return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
+ std::move(*type), std::move(expr));
+ }
+
+ /// Parse a block: a list of expression separated by semicolons and wrapped in
+ /// curly braces.
+ ///
+ /// block ::= { expression_list }
+ /// expression_list ::= block_expr ; expression_list
+ /// block_expr ::= decl | "return" | expr
+ std::unique_ptr<ExprASTList> parseBlock() {
+ if (lexer.getCurToken() != '{')
+ return parseError<ExprASTList>("{", "to begin block");
+ lexer.consume(Token('{'));
+
+ auto exprList = std::make_unique<ExprASTList>();
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+
+ while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
+ if (lexer.getCurToken() == tok_identifier) {
+ // Variable declaration or call
+ auto expr = parseDeclarationOrCallExpr();
+ if (!expr)
+ return nullptr;
+ exprList->push_back(std::move(expr));
+ } else if (lexer.getCurToken() == tok_var) {
+ // Variable declaration
+ auto varDecl = parseDeclaration(/*requiresInitializer=*/true);
+ if (!varDecl)
+ return nullptr;
+ exprList->push_back(std::move(varDecl));
+ } else if (lexer.getCurToken() == tok_return) {
+ // Return statement
+ auto ret = parseReturn();
+ if (!ret)
+ return nullptr;
+ exprList->push_back(std::move(ret));
+ } else {
+ // General expression
+ auto expr = parseExpression();
+ if (!expr)
+ return nullptr;
+ exprList->push_back(std::move(expr));
+ }
+ // Ensure that elements are separated by a semicolon.
+ if (lexer.getCurToken() != ';')
+ return parseError<ExprASTList>(";", "after expression");
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+ }
+
+ if (lexer.getCurToken() != '}')
+ return parseError<ExprASTList>("}", "to close block");
+
+ lexer.consume(Token('}'));
+ return exprList;
+ }
+
+ /// prototype ::= def id '(' decl_list ')'
+ /// decl_list ::= identifier | identifier, decl_list
+ std::unique_ptr<PrototypeAST> parsePrototype() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_def);
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>("function name", "in prototype");
+
+ std::string fnName = lexer.getId();
+ lexer.consume(tok_identifier);
+
+ if (lexer.getCurToken() != '(')
+ return parseError<PrototypeAST>("(", "in prototype");
+ lexer.consume(Token('('));
+
+ std::vector<std::unique_ptr<VarDeclExprAST>> args;
+ if (lexer.getCurToken() != ')') {
+ do {
+ VarType type;
+ std::string name;
+
+ // Parse either the name of the variable, or its type.
+ std::string nameOrType = lexer.getId();
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_identifier);
+
+ // If the next token is an identifier, we just parsed the type.
+ if (lexer.getCurToken() == tok_identifier) {
+ type.name = std::move(nameOrType);
+
+ // Parse the name.
+ name = lexer.getId();
+ lexer.consume(tok_identifier);
+ } else {
+ // Otherwise, we just parsed the name.
+ name = std::move(nameOrType);
+ }
+
+ args.push_back(
+ std::make_unique<VarDeclExprAST>(std::move(loc), name, type));
+ if (lexer.getCurToken() != ',')
+ break;
+ lexer.consume(Token(','));
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>(
+ "identifier", "after ',' in function parameter list");
+ } while (true);
+ }
+ if (lexer.getCurToken() != ')')
+ return parseError<PrototypeAST>("}", "to end function prototype");
+
+ // success.
+ lexer.consume(Token(')'));
+ return std::make_unique<PrototypeAST>(std::move(loc), fnName,
+ std::move(args));
+ }
+
+ /// Parse a function definition, we expect a prototype initiated with the
+ /// `def` keyword, followed by a block containing a list of expressions.
+ ///
+ /// definition ::= prototype block
+ std::unique_ptr<FunctionAST> parseDefinition() {
+ auto proto = parsePrototype();
+ if (!proto)
+ return nullptr;
+
+ if (auto block = parseBlock())
+ return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
+ return nullptr;
+ }
+
+ /// Parse a struct definition, we expect a struct initiated with the
+ /// `struct` keyword, followed by a block containing a list of variable
+ /// declarations.
+ ///
+ /// definition ::= `struct` identifer `{` decl+ `}`
+ std::unique_ptr<StructAST> parseStruct() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_struct);
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<StructAST>("name", "in struct definition");
+ std::string name = lexer.getId();
+ lexer.consume(tok_identifier);
+
+ // Parse: '{'
+ if (lexer.getCurToken() != '{')
+ return parseError<StructAST>("{", "in struct definition");
+ lexer.consume(Token('{'));
+
+ // Parse: decl+
+ std::vector<std::unique_ptr<VarDeclExprAST>> decls;
+ do {
+ auto decl = parseDeclaration(/*requiresInitializer=*/false);
+ if (!decl)
+ return nullptr;
+ decls.push_back(std::move(decl));
+
+ if (lexer.getCurToken() != ';')
+ return parseError<StructAST>(";",
+ "after variable in struct definition");
+ lexer.consume(Token(';'));
+ } while (lexer.getCurToken() != '}');
+
+ // Parse: '}'
+ lexer.consume(Token('}'));
+ return std::make_unique<StructAST>(loc, name, std::move(decls));
+ }
+
+ /// Get the precedence of the pending binary operator token.
+ int getTokPrecedence() {
+ if (!isascii(lexer.getCurToken()))
+ return -1;
+
+ // 1 is lowest precedence.
+ switch (static_cast<char>(lexer.getCurToken())) {
+ case '-':
+ return 20;
+ case '+':
+ return 20;
+ case '*':
+ return 40;
+ case '.':
+ return 60;
+ default:
+ return -1;
+ }
+ }
+
+ /// Helper function to signal errors while parsing, it takes an argument
+ /// indicating the expected token and another argument giving more context.
+ /// Location is retrieved from the lexer to enrich the error message.
+ template <typename R, typename T, typename U = const char *>
+ std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
+ auto curToken = lexer.getCurToken();
+ llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
+ << lexer.getLastLocation().col << "): expected '" << expected
+ << "' " << context << " but has Token " << curToken;
+ if (isprint(curToken))
+ llvm::errs() << " '" << (char)curToken << "'";
+ llvm::errs() << "\n";
+ return nullptr;
+ }
+};
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_PARSER_H
diff --git a/mlir/examples/toy/Ch7/include/toy/Passes.h b/mlir/examples/toy/Ch7/include/toy/Passes.h
new file mode 100644
index 00000000000..00fe4ffe49b
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/Passes.h
@@ -0,0 +1,45 @@
+//===- Passes.h - Toy Passes Definition -----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file exposes the entry points to create compiler passes for Toy.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PASSES_H
+#define MLIR_TUTORIAL_TOY_PASSES_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+namespace toy {
+std::unique_ptr<Pass> createDeadFunctionEliminationPass();
+std::unique_ptr<Pass> createShapeInferencePass();
+
+/// Create a pass for lowering to operations in the `Affine` and `Std` dialects,
+/// for a subset of the Toy IR (e.g. matmul).
+std::unique_ptr<mlir::Pass> createLowerToAffinePass();
+
+/// Create a pass for lowering operations the remaining `Toy` operations, as
+/// well as `Affine` and `Std`, to the LLVM dialect for codegen.
+std::unique_ptr<mlir::Pass> createLowerToLLVMPass();
+
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_PASSES_H
diff --git a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h
new file mode 100644
index 00000000000..fc36b5b100d
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h
@@ -0,0 +1,37 @@
+//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the declarations of the shape inference interfaces defined
+// in ShapeInferenceInterface.td.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
+#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace toy {
+
+/// Include the auto-generated declarations.
+#include "toy/ShapeInferenceOpInterfaces.h.inc"
+
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
diff --git a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td
new file mode 100644
index 00000000000..dc345907ac7
--- /dev/null
+++ b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td
@@ -0,0 +1,41 @@
+//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines the operations of the Shape Inference Op Interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SHAPE_INFERENCE_INTERFACE
+#define SHAPE_INFERENCE_INTERFACE
+
+#ifndef OP_BASE
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
+ let description = [{
+ Interface to access a registered method to infer the return types for an
+ operation that can be used during type inference.
+ }];
+
+ let methods = [
+ InterfaceMethod<"Infer and set the output shape for the current operation.",
+ "void", "inferShapes">
+ ];
+}
+
+#endif // SHAPE_INFERENCE_INTERFACE
diff --git a/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp
new file mode 100644
index 00000000000..b58adb5d52f
--- /dev/null
+++ b/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp
@@ -0,0 +1,68 @@
+//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a Module level pass performing dead function
+// elimination. This is required as a post-processing step after function
+// inlining.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "toy/Passes.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+
+namespace {
+/// This is a simple function DCE pass that deletes all non-main functions after
+/// inlining.
+/// TODO(riverriddle) This is only necessary because MLIR currently does not
+/// have generic DCE support for functions.
+class DeadFunctionEliminationPass
+ : public mlir::ModulePass<DeadFunctionEliminationPass> {
+public:
+ void runOnModule() override {
+ mlir::ModuleOp module = getModule();
+ mlir::SymbolTable moduleSymTable(module);
+
+ // Eliminate non-main functions.
+ auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main");
+ for (mlir::FuncOp func :
+ llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
+ if (func != mainFn)
+ func.erase();
+ }
+ }
+};
+} // end anonymous namespace
+
+/// Create a pass that eliminates inlined functions in toy.
+std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
+ return std::make_unique<DeadFunctionEliminationPass>();
+}
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
new file mode 100644
index 00000000000..2beaa870a89
--- /dev/null
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -0,0 +1,483 @@
+//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the dialect for the Toy IR: custom type parsing and
+// operation verification.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/InliningUtils.h"
+
+using namespace mlir;
+using namespace mlir::toy;
+
+//===----------------------------------------------------------------------===//
+// ToyInlinerInterface
+//===----------------------------------------------------------------------===//
+
+/// This class defines the interface for handling inlining with Toy
+/// operations.
+struct ToyInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// All operations within toy can be inlined.
+ bool isLegalToInline(Operation *, Region *,
+ BlockAndValueMapping &) const final {
+ return true;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Transformation Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Handle the given inlined terminator(toy.return) by replacing it with a new
+ /// operation as necessary.
+ void handleTerminator(Operation *op,
+ ArrayRef<Value *> valuesToRepl) const final {
+ // Only "toy.return" needs to be handled here.
+ auto returnOp = cast<ReturnOp>(op);
+
+ // Replace the values directly with the return operands.
+ assert(returnOp.getNumOperands() == valuesToRepl.size());
+ for (const auto &it : llvm::enumerate(returnOp.getOperands()))
+ valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
+ }
+
+ /// Attempts to materialize a conversion for a type mismatch between a call
+ /// from this dialect, and a callable region. This method should generate an
+ /// operation that takes 'input' as the only operand, and produces a single
+ /// result of 'resultType'. If a conversion can not be generated, nullptr
+ /// should be returned.
+ Operation *materializeCallConversion(OpBuilder &builder, Value *input,
+ Type resultType,
+ Location conversionLoc) const final {
+ return builder.create<CastOp>(conversionLoc, resultType, input);
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ToyDialect
+//===----------------------------------------------------------------------===//
+
+/// Dialect creation, the instance will be owned by the context. This is the
+/// point of registration of custom types and operations for the dialect.
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ addOperations<
+#define GET_OP_LIST
+#include "toy/Ops.cpp.inc"
+ >();
+ addInterfaces<ToyInlinerInterface>();
+ addTypes<StructType>();
+}
+
+mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
+ mlir::Attribute value,
+ mlir::Type type,
+ mlir::Location loc) {
+ if (type.isa<StructType>())
+ return builder.create<StructConstantOp>(loc, type,
+ value.cast<mlir::ArrayAttr>());
+ return builder.create<ConstantOp>(loc, type,
+ value.cast<mlir::DenseElementsAttr>());
+}
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+
+/// Build a constant operation.
+/// The builder is passed as an argument, so is the state that this method is
+/// expected to fill in order to build the operation.
+void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
+ double value) {
+ auto dataType = RankedTensorType::get({}, builder->getF64Type());
+ auto dataAttribute = DenseElementsAttr::get(dataType, value);
+ ConstantOp::build(builder, state, dataType, dataAttribute);
+}
+
+/// Verify that the given attribute value is valid for the given type.
+static mlir::LogicalResult verifyConstantForType(mlir::Type type,
+ mlir::Attribute opaqueValue,
+ mlir::Operation *op) {
+ if (type.isa<mlir::TensorType>()) {
+ // Check that the value is a elements attribute.
+ auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
+ if (!attrValue)
+ return op->emitError("constant of TensorType must be initialized by "
+ "a DenseFPElementsAttr, got ")
+ << opaqueValue;
+
+ // If the return type of the constant is not an unranked tensor, the shape
+ // must match the shape of the attribute holding the data.
+ auto resultType = type.dyn_cast<mlir::RankedTensorType>();
+ if (!resultType)
+ return success();
+
+ // Check that the rank of the attribute type matches the rank of the
+ // constant result type.
+ auto attrType = attrValue.getType().cast<mlir::TensorType>();
+ if (attrType.getRank() != resultType.getRank()) {
+ return op->emitOpError("return type must match the one of the attached "
+ "value attribute: ")
+ << attrType.getRank() << " != " << resultType.getRank();
+ }
+
+ // Check that each of the dimensions match between the two types.
+ for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
+ if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
+ return op->emitOpError(
+ "return type shape mismatches its attribute at dimension ")
+ << dim << ": " << attrType.getShape()[dim]
+ << " != " << resultType.getShape()[dim];
+ }
+ }
+ return mlir::success();
+ }
+ auto resultType = type.cast<StructType>();
+ llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
+
+ // Verify that the initializer is an Array.
+ auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
+ if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
+ return op->emitError("constant of StructType must be initialized by an "
+ "ArrayAttr with the same number of elements, got ")
+ << opaqueValue;
+
+ // Check that each of the elements are valid.
+ llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue();
+ for (const auto &it : llvm::zip(resultElementTypes, attrElementValues))
+ if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+/// Verifier for the constant operation. This corresponds to the `::verify(...)`
+/// in the op definition.
+static mlir::LogicalResult verify(ConstantOp op) {
+ return verifyConstantForType(op.getResult()->getType(), op.value(), op);
+}
+
+static mlir::LogicalResult verify(StructConstantOp op) {
+ return verifyConstantForType(op.getResult()->getType(), op.value(), op);
+}
+
+/// Infer the output shape of the ConstantOp, this is required by the shape
+/// inference interface.
+void ConstantOp::inferShapes() { getResult()->setType(value().getType()); }
+
+//===----------------------------------------------------------------------===//
+// AddOp
+
+void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
+ state.addOperands({lhs, rhs});
+}
+
+/// Infer the output shape of the AddOp, this is required by the shape inference
+/// interface.
+void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
+
+//===----------------------------------------------------------------------===//
+// CastOp
+
+/// Infer the output shape of the CastOp, this is required by the shape
+/// inference interface.
+void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
+
+//===----------------------------------------------------------------------===//
+// GenericCallOp
+
+void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
+ StringRef callee, ArrayRef<mlir::Value *> arguments) {
+ // Generic call always returns an unranked Tensor initially.
+ state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
+ state.addOperands(arguments);
+ state.addAttribute("callee", builder->getSymbolRefAttr(callee));
+}
+
+/// Return the callee of the generic call operation, this is required by the
+/// call interface.
+CallInterfaceCallable GenericCallOp::getCallableForCallee() {
+ return getAttrOfType<SymbolRefAttr>("callee");
+}
+
+/// Get the argument operands to the called function, this is required by the
+/// call interface.
+Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
+
+//===----------------------------------------------------------------------===//
+// MulOp
+
+void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
+ state.addOperands({lhs, rhs});
+}
+
+/// Infer the output shape of the MulOp, this is required by the shape inference
+/// interface.
+void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+
+static mlir::LogicalResult verify(ReturnOp op) {
+ // We know that the parent operation is a function, because of the 'HasParent'
+ // trait attached to the operation definition.
+ auto function = cast<FuncOp>(op.getParentOp());
+
+ /// ReturnOps can only have a single optional operand.
+ if (op.getNumOperands() > 1)
+ return op.emitOpError() << "expects at most 1 return operand";
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getType().getResults();
+ if (op.getNumOperands() != results.size())
+ return op.emitOpError()
+ << "does not return the same number of values ("
+ << op.getNumOperands() << ") as the enclosing function ("
+ << results.size() << ")";
+
+ // If the operation does not have an input, we are done.
+ if (!op.hasOperand())
+ return mlir::success();
+
+ auto inputType = *op.operand_type_begin();
+ auto resultType = results.front();
+
+ // Check that the result type of the function matches the operand type.
+ if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
+ resultType.isa<mlir::UnrankedTensorType>())
+ return mlir::success();
+
+ return op.emitError() << "type of return operand ("
+ << *op.operand_type_begin()
+ << ") doesn't match function result type ("
+ << results.front() << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// StructAccessOp
+
+void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state,
+ mlir::Value *input, size_t index) {
+ // Extract the result type from the input type.
+ StructType structTy = input->getType().cast<StructType>();
+ assert(index < structTy.getNumElementTypes());
+ mlir::Type resultType = structTy.getElementTypes()[index];
+
+ // Call into the auto-generated build method.
+ build(b, state, resultType, input, b->getI64IntegerAttr(index));
+}
+
+static mlir::LogicalResult verify(StructAccessOp op) {
+ StructType structTy = op.input()->getType().cast<StructType>();
+ size_t index = op.index().getZExtValue();
+ if (index >= structTy.getNumElementTypes())
+ return op.emitOpError()
+ << "index should be within the range of the input struct type";
+ mlir::Type resultType = op.getResult()->getType();
+ if (resultType != structTy.getElementTypes()[index])
+ return op.emitOpError() << "must have the same result type as the struct "
+ "element referred to by the index";
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// TransposeOp
+
+void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
+ mlir::Value *value) {
+ state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
+ state.addOperands(value);
+}
+
+void TransposeOp::inferShapes() {
+ auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
+ SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
+ getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
+}
+
+static mlir::LogicalResult verify(TransposeOp op) {
+ auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>();
+ auto resultType = op.getType().dyn_cast<RankedTensorType>();
+ if (!inputType || !resultType)
+ return mlir::success();
+
+ auto inputShape = inputType.getShape();
+ if (!std::equal(inputShape.begin(), inputShape.end(),
+ resultType.getShape().rbegin())) {
+ return op.emitError()
+ << "expected result shape to be a transpose of the input";
+ }
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// Toy Types
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace toy {
+namespace detail {
+/// This class represents the internal storage of the Toy `StructType`.
+struct StructTypeStorage : public mlir::TypeStorage {
+ /// The `KeyTy` is a required type that provides an interface for the storage
+ /// instance. This type will be used when uniquing an instance of the type
+ /// storage. For our struct type, we will unique each instance structurally on
+ /// the elements that it contains.
+ using KeyTy = llvm::ArrayRef<mlir::Type>;
+
+ /// A constructor for the type storage instance.
+ StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
+ : elementTypes(elementTypes) {}
+
+ /// Define the comparison function for the key type with the current storage
+ /// instance. This is used when constructing a new instance to ensure that we
+ /// haven't already uniqued an instance of the given key.
+ bool operator==(const KeyTy &key) const { return key == elementTypes; }
+
+ /// Define a hash function for the key type. This is used when uniquing
+ /// instances of the storage, see the `StructType::get` method.
+ /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
+ /// have hash functions available, so we could just omit this entirely.
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_value(key);
+ }
+
+ /// Define a construction function for the key type from a set of parameters.
+ /// These parameters will be provided when constructing the storage instance
+ /// itself.
+ /// Note: This method isn't necessary because KeyTy can be directly
+ /// constructed with the given parameters.
+ static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
+ return KeyTy(elementTypes);
+ }
+
+ /// Define a construction method for creating a new instance of this storage.
+ /// This method takes an instance of a storage allocator, and an instance of a
+ /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
+ /// allocations used to create the type storage and its internal.
+ static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ // Copy the elements from the provided `KeyTy` into the allocator.
+ llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
+
+ // Allocate the storage instance and construct it.
+ return new (allocator.allocate<StructTypeStorage>())
+ StructTypeStorage(elementTypes);
+ }
+
+ /// The following field contains the element types of the struct.
+ llvm::ArrayRef<mlir::Type> elementTypes;
+};
+} // end namespace detail
+} // end namespace toy
+} // end namespace mlir
+
+/// Create an instance of a `StructType` with the given element types. There
+/// *must* be at least one element type.
+StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
+ assert(!elementTypes.empty() && "expected at least 1 element type");
+
+ // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
+ // of this type. The first two parameters are the context to unique in and the
+ // kind of the type. The parameters after the type kind are forwarded to the
+ // storage instance.
+ mlir::MLIRContext *ctx = elementTypes.front().getContext();
+ return Base::get(ctx, ToyTypes::Struct, elementTypes);
+}
+
+/// Returns the element types of this struct type.
+llvm::ArrayRef<mlir::Type> StructType::getElementTypes() {
+ // 'getImpl' returns a pointer to the internal storage instance.
+ return getImpl()->elementTypes;
+}
+
+/// Parse an instance of a type registered to the toy dialect.
+mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
+ // Parse a struct type in the following form:
+ // struct-type ::= `struct` `<` type (`,` type)* `>`
+
+ // NOTE: All MLIR parser function return a ParseResult. This is a
+ // specialization of LogicalResult that auto-converts to a `true` boolean
+ // value on failure to allow for chaining, but may be used with explicit
+ // `mlir::failed/mlir::succeeded` as desired.
+
+ // Parse: `struct` `<`
+ if (parser.parseKeyword("struct") || parser.parseLess())
+ return Type();
+
+ // Parse the element types of the struct.
+ SmallVector<mlir::Type, 1> elementTypes;
+ do {
+ // Parse the current element type.
+ llvm::SMLoc typeLoc = parser.getCurrentLocation();
+ mlir::Type elementType;
+ if (parser.parseType(elementType))
+ return nullptr;
+
+ // Check that the type is either a TensorType or another StructType.
+ if (!elementType.isa<mlir::TensorType>() &&
+ !elementType.isa<StructType>()) {
+ parser.emitError(typeLoc, "element type for a struct must either "
+ "be a TensorType or a StructType, got: ")
+ << elementType;
+ return Type();
+ }
+ elementTypes.push_back(elementType);
+
+ // Parse the optional: `,`
+ } while (succeeded(parser.parseOptionalComma()));
+
+ // Parse: `>`
+ if (parser.parseGreater())
+ return Type();
+ return StructType::get(elementTypes);
+}
+
+/// Print an instance of a type registered to the toy dialect.
+void ToyDialect::printType(mlir::Type type,
+ mlir::DialectAsmPrinter &printer) const {
+ // Currently the only toy type is a struct type.
+ StructType structType = type.cast<StructType>();
+
+ // Print the struct type according to the parser format.
+ printer << "struct<";
+ mlir::interleaveComma(structType.getElementTypes(), printer);
+ printer << '>';
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "toy/Ops.cpp.inc"
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
new file mode 100644
index 00000000000..a8e38aef7ad
--- /dev/null
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -0,0 +1,318 @@
+//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a partial lowering of Toy operations to a combination of
+// affine loops and standard operations. This lowering expects that all calls
+// have been inlined, and all shapes have been resolved.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+#include "toy/Passes.h"
+
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns
+//===----------------------------------------------------------------------===//
+
+/// Convert the given TensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(TensorType type) {
+ assert(type.hasRank() && "expected only ranked shapes");
+ return MemRefType::get(type.getShape(), type.getElementType());
+}
+
+/// Insert an allocation and deallocation for the given MemRefType.
+static Value *insertAllocAndDealloc(MemRefType type, Location loc,
+ PatternRewriter &rewriter) {
+ auto alloc = rewriter.create<AllocOp>(loc, type);
+
+ // Make sure to allocate at the beginning of the block.
+ auto *parentBlock = alloc.getOperation()->getBlock();
+ alloc.getOperation()->moveBefore(&parentBlock->front());
+
+ // Make sure to deallocate this alloc at the end of the block. This is fine
+ // as toy functions have no control flow.
+ auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
+ dealloc.getOperation()->moveBefore(&parentBlock->back());
+ return alloc;
+}
+
+/// This defines the function type used to process an iteration of a lowered
+/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
+/// to the operands of the input operation, and the set of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn = function_ref<Value *(PatternRewriter &rewriter,
+ ArrayRef<Value *> memRefOperands,
+ ArrayRef<Value *> loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, ArrayRef<Value *> operands,
+ PatternRewriter &rewriter,
+ LoopIterationFn processIteration) {
+ auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto loc = op->getLoc();
+
+ // Insert an allocation and deallocation for the result of this operation.
+ auto memRefType = convertTensorToMemRef(tensorType);
+ auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
+
+ // Create an empty affine loop for each of the dimensions within the shape.
+ SmallVector<Value *, 4> loopIvs;
+ for (auto dim : tensorType.getShape()) {
+ auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
+ loop.getBody()->clear();
+ loopIvs.push_back(loop.getInductionVar());
+
+ // Terminate the loop body and update the rewriter insertion point to the
+ // beginning of the loop.
+ rewriter.setInsertionPointToStart(loop.getBody());
+ rewriter.create<AffineTerminatorOp>(loc);
+ rewriter.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Generate a call to the processing function with the rewriter, the memref
+ // operands, and the loop induction variables. This function will return the
+ // value to store at the current index.
+ Value *valueToStore = processIteration(rewriter, operands, loopIvs);
+ rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
+ llvm::makeArrayRef(loopIvs));
+
+ // Replace this operation with the generated alloc.
+ rewriter.replaceOp(op, alloc);
+}
+
+namespace {
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Binary operations
+//===----------------------------------------------------------------------===//
+
+template <typename BinaryOp, typename LoweredBinaryOp>
+struct BinaryOpLowering : public ConversionPattern {
+ BinaryOpLowering(MLIRContext *ctx)
+ : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ lowerOpToLoops(
+ op, operands, rewriter,
+ [loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands,
+ ArrayRef<Value *> loopIvs) {
+ // Generate an adaptor for the remapped operands of the BinaryOp. This
+ // allows for using the nice named accessors that are generated by the
+ // ODS.
+ typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
+
+ // Generate loads for the element of 'lhs' and 'rhs' at the inner
+ // loop.
+ auto loadedLhs =
+ rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
+ auto loadedRhs =
+ rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded values.
+ return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
+ });
+ return matchSuccess();
+ }
+};
+using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
+using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
+
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Constant operations
+//===----------------------------------------------------------------------===//
+
+struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
+ using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(toy::ConstantOp op,
+ PatternRewriter &rewriter) const final {
+ DenseElementsAttr constantValue = op.value();
+ Location loc = op.getLoc();
+
+ // When lowering the constant operation, we allocate and assign the constant
+ // values to a corresponding memref allocation.
+ auto tensorType = op.getType().cast<TensorType>();
+ auto memRefType = convertTensorToMemRef(tensorType);
+ auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
+
+ // We will be generating constant indices up-to the largest dimension.
+ // Create these constants up-front to avoid large amounts of redundant
+ // operations.
+ auto valueShape = memRefType.getShape();
+ SmallVector<Value *, 8> constantIndices;
+ for (auto i : llvm::seq<int64_t>(
+ 0, *std::max_element(valueShape.begin(), valueShape.end())))
+ constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
+
+ // The constant operation represents a multi-dimensional constant, so we
+ // will need to generate a store for each of the elements. The following
+ // functor recursively walks the dimensions of the constant shape,
+ // generating a store when the recursion hits the base case.
+ SmallVector<Value *, 2> indices;
+ auto valueIt = constantValue.getValues<FloatAttr>().begin();
+ std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
+ // The last dimension is the base case of the recursion, at this point
+ // we store the element at the given index.
+ if (dimension == valueShape.size()) {
+ rewriter.create<AffineStoreOp>(
+ loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
+ llvm::makeArrayRef(indices));
+ return;
+ }
+
+ // Otherwise, iterate over the current dimension and add the indices to
+ // the list.
+ for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
+ indices.push_back(constantIndices[i]);
+ storeElements(dimension + 1);
+ indices.pop_back();
+ }
+ };
+
+ // Start the element storing recursion from the first dimension.
+ storeElements(/*dimension=*/0);
+
+ // Replace this operation with the generated alloc.
+ rewriter.replaceOp(op, alloc);
+ return matchSuccess();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Return operations
+//===----------------------------------------------------------------------===//
+
+struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
+ using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(toy::ReturnOp op,
+ PatternRewriter &rewriter) const final {
+ // During this lowering, we expect that all function calls have been
+ // inlined.
+ if (op.hasOperand())
+ return matchFailure();
+
+ // We lower "toy.return" directly to "std.return".
+ rewriter.replaceOpWithNewOp<ReturnOp>(op);
+ return matchSuccess();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Transpose operations
+//===----------------------------------------------------------------------===//
+
+struct TransposeOpLowering : public ConversionPattern {
+ TransposeOpLowering(MLIRContext *ctx)
+ : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ lowerOpToLoops(
+ op, operands, rewriter,
+ [loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands,
+ ArrayRef<Value *> loopIvs) {
+ // Generate an adaptor for the remapped operands of the TransposeOp.
+ // This allows for using the nice named accessors that are generated
+ // by the ODS.
+ toy::TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands);
+ Value *input = tranposeAdaptor.input();
+
+ // Transpose the elements by generating a load from the reverse
+ // indices.
+ SmallVector<Value *, 2> reverseIvs(llvm::reverse(loopIvs));
+ return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
+ });
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace.
+
+//===----------------------------------------------------------------------===//
+// ToyToAffineLoweringPass
+//===----------------------------------------------------------------------===//
+
+/// This is a partial lowering to affine loops of the toy operations that are
+/// computationally intensive (like matmul for example...) while keeping the
+/// rest of the code in the Toy dialect.
+namespace {
+struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> {
+ void runOnFunction() final;
+};
+} // end anonymous namespace.
+
+void ToyToAffineLoweringPass::runOnFunction() {
+ auto function = getFunction();
+
+ // We only lower the main function as we expect that all other functions have
+ // been inlined.
+ if (function.getName() != "main")
+ return;
+
+ // Verify that the given main has no inputs and results.
+ if (function.getNumArguments() || function.getType().getNumResults()) {
+ function.emitError("expected 'main' to have 0 inputs and 0 results");
+ return signalPassFailure();
+ }
+
+ // The first thing to define is the conversion target. This will define the
+ // final target for this lowering.
+ ConversionTarget target(getContext());
+
+ // We define the specific operations, or dialects, that are legal targets for
+ // this lowering. In our case, we are lowering to a combination of the
+ // `Affine` and `Standard` dialects.
+ target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
+
+ // We also define the Toy dialect as Illegal so that the conversion will fail
+ // if any of these operations are *not* converted. Given that we actually want
+ // a partial lowering, we explicitly mark the Toy operations that don't want
+ // to lower, `toy.print`, as `legal`.
+ target.addIllegalDialect<toy::ToyDialect>();
+ target.addLegalOp<toy::PrintOp>();
+
+ // Now that the conversion target has been defined, we just need to provide
+ // the set of patterns that will lower the Toy operations.
+ OwningRewritePatternList patterns;
+ patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
+ ReturnOpLowering, TransposeOpLowering>(&getContext());
+
+ // With the target and rewrite patterns defined, we can now attempt the
+ // conversion. The conversion will signal failure if any of our `illegal`
+ // operations were not converted successfully.
+ if (failed(applyPartialConversion(getFunction(), target, patterns)))
+ signalPassFailure();
+}
+
+/// Create a pass for lowering operations in the `Affine` and `Std` dialects,
+/// for a subset of the Toy IR (e.g. matmul).
+std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
+ return std::make_unique<ToyToAffineLoweringPass>();
+}
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
new file mode 100644
index 00000000000..7e300fb702d
--- /dev/null
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -0,0 +1,213 @@
+//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a partial lowering of Toy operations to a combination of
+// affine loops and standard operations. This lowering expects that all calls
+// have been inlined, and all shapes have been resolved.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+#include "toy/Passes.h"
+
+#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/LowerAffine.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ToyToLLVM RewritePatterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
+/// elements of the array.
+class PrintOpLowering : public ConversionPattern {
+public:
+ explicit PrintOpLowering(MLIRContext *context)
+ : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
+ auto memRefShape = memRefType.getShape();
+ auto loc = op->getLoc();
+ auto *llvmDialect =
+ op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+ assert(llvmDialect && "expected llvm dialect to be registered");
+
+ ModuleOp parentModule = op->getParentOfType<ModuleOp>();
+
+ // Get a symbol reference to the printf function, inserting it if necessary.
+ auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
+ Value *formatSpecifierCst = getOrCreateGlobalString(
+ loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
+ llvmDialect);
+ Value *newLineCst = getOrCreateGlobalString(
+ loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
+
+ // Create a loop for each of the dimensions within the shape.
+ SmallVector<Value *, 4> loopIvs;
+ for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) {
+ auto lowerBound = rewriter.create<ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<ConstantIndexOp>(loc, memRefShape[i]);
+ auto step = rewriter.create<ConstantIndexOp>(loc, 1);
+ auto loop =
+ rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step);
+ loop.getBody()->clear();
+ loopIvs.push_back(loop.getInductionVar());
+
+ // Terminate the loop body.
+ rewriter.setInsertionPointToStart(loop.getBody());
+
+ // Insert a newline after each of the inner dimensions of the shape.
+ if (i != e - 1)
+ rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
+ newLineCst);
+ rewriter.create<loop::TerminatorOp>(loc);
+ rewriter.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Generate a call to printf for the current element of the loop.
+ auto printOp = cast<toy::PrintOp>(op);
+ auto elementLoad = rewriter.create<LoadOp>(loc, printOp.input(), loopIvs);
+ rewriter.create<CallOp>(
+ loc, printfRef, rewriter.getIntegerType(32),
+ ArrayRef<Value *>({formatSpecifierCst, elementLoad}));
+
+ // Notify the rewriter that this operation has been removed.
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+
+private:
+ /// Return a symbol reference to the printf function, inserting it into the
+ /// module if necessary.
+ static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
+ ModuleOp module,
+ LLVM::LLVMDialect *llvmDialect) {
+ auto *context = module.getContext();
+ if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
+ return SymbolRefAttr::get("printf", context);
+
+ // Create a function declaration for printf, the signature is:
+ // * `i32 (i8*, ...)`
+ auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
+ auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
+ /*isVarArg=*/true);
+
+ // Insert the printf function into the body of the parent module.
+ PatternRewriter::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPointToStart(module.getBody());
+ rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
+ return SymbolRefAttr::get("printf", context);
+ }
+
+ /// Return a value representing an access into a global string with the given
+ /// name, creating the string if necessary.
+ static Value *getOrCreateGlobalString(Location loc, OpBuilder &builder,
+ StringRef name, StringRef value,
+ ModuleOp module,
+ LLVM::LLVMDialect *llvmDialect) {
+ // Create the global at the entry of the module.
+ LLVM::GlobalOp global;
+ if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
+ OpBuilder::InsertionGuard insertGuard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+ auto type = LLVM::LLVMType::getArrayTy(
+ LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
+ global = builder.create<LLVM::GlobalOp>(
+ loc, type, /*isConstant=*/true, name, builder.getStringAttr(value));
+ }
+
+ // Get the pointer to the first character in the global string.
+ Value *globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
+ Value *cst0 = builder.create<LLVM::ConstantOp>(
+ loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ builder.getIntegerAttr(builder.getIndexType(), 0));
+ return builder.create<LLVM::GEPOp>(
+ loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+ ArrayRef<Value *>({cst0, cst0}));
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// ToyToLLVMLoweringPass
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
+ void runOnModule() final;
+};
+} // end anonymous namespace
+
+void ToyToLLVMLoweringPass::runOnModule() {
+ // The first thing to define is the conversion target. This will define the
+ // final target for this lowering. For this lowering, we are only targeting
+ // the LLVM dialect.
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+
+ // During this lowering, we will also be lowering the MemRef types, that are
+ // currently being operated on, to a representation in LLVM. Do perform this
+ // conversion we use a TypeConverter as part of the lowering. This converter
+ // details how one type maps to another. This is necessary now that we will be
+ // doing more complicated lowerings, involving loop region arguments.
+ LLVMTypeConverter typeConverter(&getContext());
+
+ // Now that the conversion target has been defined, we need to provide the
+ // patterns used for lowering. At this point of the compilation process, we
+ // have a combination of `toy`, `affine`, and `std` operations. Luckily, there
+ // are already exists a set of patterns to transform `affine` and `std`
+ // dialects. These patterns lowering in multiple stages, relying on transitive
+ // lowerings. Transitive lowering, or A->B->C lowering, is when multiple
+ // patterns must be applied to fully transform an illegal operation into a
+ // set of legal ones.
+ OwningRewritePatternList patterns;
+ populateAffineToStdConversionPatterns(patterns, &getContext());
+ populateLoopToStdConversionPatterns(patterns, &getContext());
+ populateStdToLLVMConversionPatterns(typeConverter, patterns);
+
+ // The only remaining operation to lower from the `toy` dialect, is the
+ // PrintOp.
+ patterns.insert<PrintOpLowering>(&getContext());
+
+ // We want to completely lower to LLVM, so we use a `FullConversion`. This
+ // ensures that only legal operations will remain after the conversion.
+ auto module = getModule();
+ if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
+ signalPassFailure();
+}
+
+/// Create a pass for lowering operations the remaining `Toy` operations, as
+/// well as `Affine` and `Std`, to the LLVM dialect for codegen.
+std::unique_ptr<mlir::Pass> mlir::toy::createLowerToLLVMPass() {
+ return std::make_unique<ToyToLLVMLoweringPass>();
+}
diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
new file mode 100644
index 00000000000..227ebcd758b
--- /dev/null
+++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
@@ -0,0 +1,683 @@
+//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR generation targeting MLIR from a Module AST
+// for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/MLIRGen.h"
+#include "toy/AST.h"
+#include "toy/Dialect.h"
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+using namespace mlir::toy;
+using namespace toy;
+
+using llvm::ArrayRef;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::makeArrayRef;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+/// Implementation of a simple MLIR emission from the Toy AST.
+///
+/// This will emit operations that are specific to the Toy language, preserving
+/// the semantics of the language and (hopefully) allow to perform accurate
+/// analysis and transformation based on these high level semantics.
+class MLIRGenImpl {
+public:
+ MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
+
+ /// Public API: convert the AST for a Toy module (source file) to an MLIR
+ /// Module operation.
+ mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
+ // We create an empty MLIR module and codegen functions one at a time and
+ // add them to the module.
+ theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
+
+ for (auto &record : moduleAST) {
+ if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
+ auto func = mlirGen(*funcAST);
+ if (!func)
+ return nullptr;
+
+ theModule.push_back(func);
+ functionMap.insert({func.getName(), func});
+ } else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
+ if (failed(mlirGen(*str)))
+ return nullptr;
+ } else {
+ llvm_unreachable("unknown record type");
+ }
+ }
+
+ // Verify the module after we have finished constructing it, this will check
+ // the structural properties of the IR and invoke any specific verifiers we
+ // have on the Toy operations.
+ if (failed(mlir::verify(theModule))) {
+ theModule.emitError("module verification error");
+ return nullptr;
+ }
+
+ return theModule;
+ }
+
+private:
+ /// A "module" matches a Toy source file: containing a list of functions.
+ mlir::ModuleOp theModule;
+
+ /// The builder is a helper class to create IR inside a function. The builder
+ /// is stateful, in particular it keeps an "insertion point": this is where
+ /// the next operations will be introduced.
+ mlir::OpBuilder builder;
+
+ /// The symbol table maps a variable name to a value in the current scope.
+ /// Entering a function creates a new scope, and the function arguments are
+ /// added to the mapping. When the processing of a function is terminated, the
+ /// scope is destroyed and the mappings created in this scope are dropped.
+ llvm::ScopedHashTable<StringRef, std::pair<mlir::Value *, VarDeclExprAST *>>
+ symbolTable;
+ using SymbolTableScopeT =
+ llvm::ScopedHashTableScope<StringRef,
+ std::pair<mlir::Value *, VarDeclExprAST *>>;
+
+ /// A mapping for the functions that have been code generated to MLIR.
+ llvm::StringMap<mlir::FuncOp> functionMap;
+
+ /// A mapping for named struct types to the underlying MLIR type and the
+ /// original AST node.
+ llvm::StringMap<std::pair<mlir::Type, StructAST *>> structMap;
+
+ /// Helper conversion for a Toy AST location to an MLIR location.
+ mlir::Location loc(Location loc) {
+ return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+ loc.col);
+ }
+
+ /// Declare a variable in the current scope, return success if the variable
+ /// wasn't declared yet.
+ mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value *value) {
+ if (symbolTable.count(var.getName()))
+ return mlir::failure();
+ symbolTable.insert(var.getName(), {value, &var});
+ return mlir::success();
+ }
+
+ /// Create an MLIR type for the given struct.
+ mlir::LogicalResult mlirGen(StructAST &str) {
+ if (structMap.count(str.getName()))
+ return emitError(loc(str.loc())) << "error: struct type with name `"
+ << str.getName() << "' already exists";
+
+ auto variables = str.getVariables();
+ std::vector<mlir::Type> elementTypes;
+ elementTypes.reserve(variables.size());
+ for (auto &variable : variables) {
+ if (variable->getInitVal())
+ return emitError(loc(variable->loc()))
+ << "error: variables within a struct definition must not have "
+ "initializers";
+ if (!variable->getType().shape.empty())
+ return emitError(loc(variable->loc()))
+ << "error: variables within a struct definition must not have "
+ "initializers";
+
+ mlir::Type type = getType(variable->getType(), variable->loc());
+ if (!type)
+ return mlir::failure();
+ elementTypes.push_back(type);
+ }
+
+ structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str);
+ return mlir::success();
+ }
+
+ /// Create the prototype for an MLIR function with as many arguments as the
+ /// provided Toy AST prototype.
+ mlir::FuncOp mlirGen(PrototypeAST &proto) {
+ auto location = loc(proto.loc());
+
+ // This is a generic function, the return type will be inferred later.
+ llvm::SmallVector<mlir::Type, 4> argTypes;
+ argTypes.reserve(proto.getArgs().size());
+ for (auto &arg : proto.getArgs()) {
+ mlir::Type type = getType(arg->getType(), arg->loc());
+ if (!type)
+ return nullptr;
+ argTypes.push_back(type);
+ }
+ auto func_type = builder.getFunctionType(argTypes, llvm::None);
+ return mlir::FuncOp::create(location, proto.getName(), func_type);
+ }
+
+ /// Emit a new function and add it to the MLIR module.
+ mlir::FuncOp mlirGen(FunctionAST &funcAST) {
+ // Create a scope in the symbol table to hold variable declarations.
+ SymbolTableScopeT var_scope(symbolTable);
+
+ // Create an MLIR function for the given prototype.
+ mlir::FuncOp function(mlirGen(*funcAST.getProto()));
+ if (!function)
+ return nullptr;
+
+ // Let's start the body of the function now!
+ // In MLIR the entry block of the function is special: it must have the same
+ // argument list as the function itself.
+ auto &entryBlock = *function.addEntryBlock();
+ auto protoArgs = funcAST.getProto()->getArgs();
+
+ // Declare all the function arguments in the symbol table.
+ for (const auto &name_value :
+ llvm::zip(protoArgs, entryBlock.getArguments())) {
+ if (failed(declare(*std::get<0>(name_value), std::get<1>(name_value))))
+ return nullptr;
+ }
+
+ // Set the insertion point in the builder to the beginning of the function
+ // body, it will be used throughout the codegen to create operations in this
+ // function.
+ builder.setInsertionPointToStart(&entryBlock);
+
+ // Emit the body of the function.
+ if (mlir::failed(mlirGen(*funcAST.getBody()))) {
+ function.erase();
+ return nullptr;
+ }
+
+ // Implicitly return void if no return statement was emitted.
+ // FIXME: we may fix the parser instead to always return the last expression
+ // (this would possibly help the REPL case later)
+ ReturnOp returnOp;
+ if (!entryBlock.empty())
+ returnOp = dyn_cast<ReturnOp>(entryBlock.back());
+ if (!returnOp) {
+ builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
+ } else if (returnOp.hasOperand()) {
+ // Otherwise, if this return operation has an operand then add a result to
+ // the function.
+ function.setType(builder.getFunctionType(function.getType().getInputs(),
+ *returnOp.operand_type_begin()));
+ }
+
+ return function;
+ }
+
+ /// Return the struct type that is the result of the given expression, or null
+ /// if it cannot be inferred.
+ StructAST *getStructFor(ExprAST *expr) {
+ llvm::StringRef structName;
+ if (auto *decl = llvm::dyn_cast<VariableExprAST>(expr)) {
+ auto varIt = symbolTable.lookup(decl->getName());
+ if (!varIt.first)
+ return nullptr;
+ structName = varIt.second->getType().name;
+ } else if (auto *access = llvm::dyn_cast<BinaryExprAST>(expr)) {
+ if (access->getOp() != '.')
+ return nullptr;
+ // The name being accessed should be in the RHS.
+ auto *name = llvm::dyn_cast<VariableExprAST>(access->getRHS());
+ if (!name)
+ return nullptr;
+ StructAST *parentStruct = getStructFor(access->getLHS());
+ if (!parentStruct)
+ return nullptr;
+
+ // Get the element within the struct corresponding to the name.
+ VarDeclExprAST *decl = nullptr;
+ for (auto &var : parentStruct->getVariables()) {
+ if (var->getName() == name->getName()) {
+ decl = var.get();
+ break;
+ }
+ }
+ if (!decl)
+ return nullptr;
+ structName = decl->getType().name;
+ }
+ if (structName.empty())
+ return nullptr;
+
+ // If the struct name was valid, check for an entry in the struct map.
+ auto structIt = structMap.find(structName);
+ if (structIt == structMap.end())
+ return nullptr;
+ return structIt->second.second;
+ }
+
+ /// Return the numeric member index of the given struct access expression.
+ llvm::Optional<size_t> getMemberIndex(BinaryExprAST &accessOp) {
+ assert(accessOp.getOp() == '.' && "expected access operation");
+
+ // Lookup the struct node for the LHS.
+ StructAST *structAST = getStructFor(accessOp.getLHS());
+ if (!structAST)
+ return llvm::None;
+
+ // Get the name from the RHS.
+ VariableExprAST *name = llvm::dyn_cast<VariableExprAST>(accessOp.getRHS());
+ if (!name)
+ return llvm::None;
+
+ auto structVars = structAST->getVariables();
+ auto it = llvm::find_if(structVars, [&](auto &var) {
+ return var->getName() == name->getName();
+ });
+ if (it == structVars.end())
+ return llvm::None;
+ return it - structVars.begin();
+ }
+
+ /// Emit a binary operation
+ mlir::Value *mlirGen(BinaryExprAST &binop) {
+ // First emit the operations for each side of the operation before emitting
+ // the operation itself. For example if the expression is `a + foo(a)`
+ // 1) First it will visiting the LHS, which will return a reference to the
+ // value holding `a`. This value should have been emitted at declaration
+ // time and registered in the symbol table, so nothing would be
+ // codegen'd. If the value is not in the symbol table, an error has been
+ // emitted and nullptr is returned.
+ // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
+ // and the result value is returned. If an error occurs we get a nullptr
+ // and propagate.
+ //
+ mlir::Value *lhs = mlirGen(*binop.getLHS());
+ if (!lhs)
+ return nullptr;
+ auto location = loc(binop.loc());
+
+ // If this is an access operation, handle it immediately.
+ if (binop.getOp() == '.') {
+ llvm::Optional<size_t> accessIndex = getMemberIndex(binop);
+ if (!accessIndex) {
+ emitError(location, "invalid access into struct expression");
+ return nullptr;
+ }
+ return builder.create<StructAccessOp>(location, lhs, *accessIndex);
+ }
+
+ // Otherwise, this is a normal binary op.
+ mlir::Value *rhs = mlirGen(*binop.getRHS());
+ if (!rhs)
+ return nullptr;
+
+ // Derive the operation name from the binary operator. At the moment we only
+ // support '+' and '*'.
+ switch (binop.getOp()) {
+ case '+':
+ return builder.create<AddOp>(location, lhs, rhs);
+ case '*':
+ return builder.create<MulOp>(location, lhs, rhs);
+ }
+
+ emitError(location, "invalid binary operator '") << binop.getOp() << "'";
+ return nullptr;
+ }
+
+ /// This is a reference to a variable in an expression. The variable is
+ /// expected to have been declared and so should have a value in the symbol
+ /// table, otherwise emit an error and return nullptr.
+ mlir::Value *mlirGen(VariableExprAST &expr) {
+ if (auto *variable = symbolTable.lookup(expr.getName()).first)
+ return variable;
+
+ emitError(loc(expr.loc()), "error: unknown variable '")
+ << expr.getName() << "'";
+ return nullptr;
+ }
+
+ /// Emit a return operation. This will return failure if any generation fails.
+ mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
+ auto location = loc(ret.loc());
+
+ // 'return' takes an optional expression, handle that case here.
+ mlir::Value *expr = nullptr;
+ if (ret.getExpr().hasValue()) {
+ if (!(expr = mlirGen(*ret.getExpr().getValue())))
+ return mlir::failure();
+ }
+
+ // Otherwise, this return operation has zero operands.
+ builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
+ : ArrayRef<mlir::Value *>());
+ return mlir::success();
+ }
+
+ /// Emit a coinstant for a literal/constant array. It will be emitted as a
+ /// flattened array of data in an Attribute attached to a `toy.constant`
+ /// operation. See documentation on [Attributes](LangRef.md#attributes) for
+ /// more details. Here is an excerpt:
+ ///
+ /// Attributes are the mechanism for specifying constant data in MLIR in
+ /// places where a variable is never allowed [...]. They consist of a name
+ /// and a concrete attribute value. The set of expected attributes, their
+ /// structure, and their interpretation are all contextually dependent on
+ /// what they are attached to.
+ ///
+ /// Example, the source level statement:
+ /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+ /// will be converted to:
+ /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
+ /// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
+ /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
+ ///
+ mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) {
+ // The attribute is a vector with a floating point value per element
+ // (number) in the array, see `collectData()` below for more details.
+ std::vector<double> data;
+ data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
+ std::multiplies<int>()));
+ collectData(lit, data);
+
+ // The type of this attribute is tensor of 64-bit floating-point with the
+ // shape of the literal.
+ mlir::Type elementType = builder.getF64Type();
+ auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
+
+ // This is the actual attribute that holds the list of values for this
+ // tensor literal.
+ return mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
+ }
+ mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) {
+ // The type of this attribute is tensor of 64-bit floating-point with no
+ // shape.
+ mlir::Type elementType = builder.getF64Type();
+ auto dataType = mlir::RankedTensorType::get({}, elementType);
+
+ // This is the actual attribute that holds the list of values for this
+ // tensor literal.
+ return mlir::DenseElementsAttr::get(dataType,
+ llvm::makeArrayRef(lit.getValue()));
+ }
+ /// Emit a constant for a struct literal. It will be emitted as an array of
+ /// other literals in an Attribute attached to a `toy.struct_constant`
+ /// operation. This function returns the generated constant, along with the
+ /// corresponding struct type.
+ std::pair<mlir::ArrayAttr, mlir::Type>
+ getConstantAttr(StructLiteralExprAST &lit) {
+ std::vector<mlir::Attribute> attrElements;
+ std::vector<mlir::Type> typeElements;
+
+ for (auto &var : lit.getValues()) {
+ if (auto *number = llvm::dyn_cast<NumberExprAST>(var.get())) {
+ attrElements.push_back(getConstantAttr(*number));
+ typeElements.push_back(getType(llvm::None));
+ } else if (auto *lit = llvm::dyn_cast<LiteralExprAST>(var.get())) {
+ attrElements.push_back(getConstantAttr(*lit));
+ typeElements.push_back(getType(llvm::None));
+ } else {
+ auto *structLit = llvm::cast<StructLiteralExprAST>(var.get());
+ auto attrTypePair = getConstantAttr(*structLit);
+ attrElements.push_back(attrTypePair.first);
+ typeElements.push_back(attrTypePair.second);
+ }
+ }
+ mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements);
+ mlir::Type dataType = StructType::get(typeElements);
+ return std::make_pair(dataAttr, dataType);
+ }
+
+ /// Emit an array literal.
+ mlir::Value *mlirGen(LiteralExprAST &lit) {
+ mlir::Type type = getType(lit.getDims());
+ mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit);
+
+ // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
+ // method.
+ return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
+ }
+
+ /// Emit a struct literal. It will be emitted as an array of
+ /// other literals in an Attribute attached to a `toy.struct_constant`
+ /// operation.
+ mlir::Value *mlirGen(StructLiteralExprAST &lit) {
+ mlir::ArrayAttr dataAttr;
+ mlir::Type dataType;
+ std::tie(dataAttr, dataType) = getConstantAttr(lit);
+
+ // Build the MLIR op `toy.struct_constant`. This invokes the
+ // `StructConstantOp::build` method.
+ return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr);
+ }
+
+ /// Recursive helper function to accumulate the data that compose an array
+ /// literal. It flattens the nested structure in the supplied vector. For
+ /// example with this array:
+ /// [[1, 2], [3, 4]]
+ /// we will generate:
+ /// [ 1, 2, 3, 4 ]
+ /// Individual numbers are represented as doubles.
+ /// Attributes are the way MLIR attaches constant to operations.
+ void collectData(ExprAST &expr, std::vector<double> &data) {
+ if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
+ for (auto &value : lit->getValues())
+ collectData(*value, data);
+ return;
+ }
+
+ assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
+ data.push_back(cast<NumberExprAST>(expr).getValue());
+ }
+
+ /// Emit a call expression. It emits specific operations for the `transpose`
+ /// builtin. Other identifiers are assumed to be user-defined functions.
+ mlir::Value *mlirGen(CallExprAST &call) {
+ llvm::StringRef callee = call.getCallee();
+ auto location = loc(call.loc());
+
+ // Codegen the operands first.
+ SmallVector<mlir::Value *, 4> operands;
+ for (auto &expr : call.getArgs()) {
+ auto *arg = mlirGen(*expr);
+ if (!arg)
+ return nullptr;
+ operands.push_back(arg);
+ }
+
+ // Builting calls have their custom operation, meaning this is a
+ // straightforward emission.
+ if (callee == "transpose") {
+ if (call.getArgs().size() != 1) {
+ emitError(location, "MLIR codegen encountered an error: toy.transpose "
+ "does not accept multiple arguments");
+ return nullptr;
+ }
+ return builder.create<TransposeOp>(location, operands[0]);
+ }
+
+ // Otherwise this is a call to a user-defined function. Calls to ser-defined
+ // functions are mapped to a custom call that takes the callee name as an
+ // attribute.
+ auto calledFuncIt = functionMap.find(callee);
+ if (calledFuncIt == functionMap.end()) {
+ emitError(location) << "no defined function found for '" << callee << "'";
+ return nullptr;
+ }
+ mlir::FuncOp calledFunc = calledFuncIt->second;
+ return builder.create<GenericCallOp>(
+ location, calledFunc.getType().getResult(0),
+ builder.getSymbolRefAttr(callee), operands);
+ }
+
+ /// Emit a print expression. It emits specific operations for two builtins:
+ /// transpose(x) and print(x).
+ mlir::LogicalResult mlirGen(PrintExprAST &call) {
+ auto *arg = mlirGen(*call.getArg());
+ if (!arg)
+ return mlir::failure();
+
+ builder.create<PrintOp>(loc(call.loc()), arg);
+ return mlir::success();
+ }
+
+ /// Emit a constant for a single number (FIXME: semantic? broadcast?)
+ mlir::Value *mlirGen(NumberExprAST &num) {
+ return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
+ }
+
+ /// Dispatch codegen for the right expression subclass using RTTI.
+ mlir::Value *mlirGen(ExprAST &expr) {
+ switch (expr.getKind()) {
+ case toy::ExprAST::Expr_BinOp:
+ return mlirGen(cast<BinaryExprAST>(expr));
+ case toy::ExprAST::Expr_Var:
+ return mlirGen(cast<VariableExprAST>(expr));
+ case toy::ExprAST::Expr_Literal:
+ return mlirGen(cast<LiteralExprAST>(expr));
+ case toy::ExprAST::Expr_StructLiteral:
+ return mlirGen(cast<StructLiteralExprAST>(expr));
+ case toy::ExprAST::Expr_Call:
+ return mlirGen(cast<CallExprAST>(expr));
+ case toy::ExprAST::Expr_Num:
+ return mlirGen(cast<NumberExprAST>(expr));
+ default:
+ emitError(loc(expr.loc()))
+ << "MLIR codegen encountered an unhandled expr kind '"
+ << Twine(expr.getKind()) << "'";
+ return nullptr;
+ }
+ }
+
+ /// Handle a variable declaration, we'll codegen the expression that forms the
+ /// initializer and record the value in the symbol table before returning it.
+ /// Future expressions will be able to reference this variable through symbol
+ /// table lookup.
+ mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
+ auto init = vardecl.getInitVal();
+ if (!init) {
+ emitError(loc(vardecl.loc()),
+ "missing initializer in variable declaration");
+ return nullptr;
+ }
+
+ mlir::Value *value = mlirGen(*init);
+ if (!value)
+ return nullptr;
+
+ // Handle the case where we are initializing a struct value.
+ VarType varType = vardecl.getType();
+ if (!varType.name.empty()) {
+ // Check that the initializer type is the same as the variable
+ // declaration.
+ mlir::Type type = getType(varType, vardecl.loc());
+ if (!type)
+ return nullptr;
+ if (type != value->getType()) {
+ emitError(loc(vardecl.loc()))
+ << "struct type of initializer is different than the variable "
+ "declaration. Got "
+ << value->getType() << ", but expected " << type;
+ return nullptr;
+ }
+
+ // Otherwise, we have the initializer value, but in case the variable was
+ // declared with specific shape, we emit a "reshape" operation. It will
+ // get optimized out later as needed.
+ } else if (!varType.shape.empty()) {
+ value = builder.create<ReshapeOp>(loc(vardecl.loc()),
+ getType(varType.shape), value);
+ }
+
+ // Register the value in the symbol table.
+ if (failed(declare(vardecl, value)))
+ return nullptr;
+ return value;
+ }
+
+ /// Codegen a list of expression, return failure if one of them hit an error.
+ mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
+ SymbolTableScopeT var_scope(symbolTable);
+ for (auto &expr : blockAST) {
+ // Specific handling for variable declarations, return statement, and
+ // print. These can only appear in block list and not in nested
+ // expressions.
+ if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
+ if (!mlirGen(*vardecl))
+ return mlir::failure();
+ continue;
+ }
+ if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
+ return mlirGen(*ret);
+ if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
+ if (mlir::failed(mlirGen(*print)))
+ return mlir::success();
+ continue;
+ }
+
+ // Generic expression dispatch codegen.
+ if (!mlirGen(*expr))
+ return mlir::failure();
+ }
+ return mlir::success();
+ }
+
+ /// Build a tensor type from a list of shape dimensions.
+ mlir::Type getType(ArrayRef<int64_t> shape) {
+ // If the shape is empty, then this type is unranked.
+ if (shape.empty())
+ return mlir::UnrankedTensorType::get(builder.getF64Type());
+
+ // Otherwise, we use the given shape.
+ return mlir::RankedTensorType::get(shape, builder.getF64Type());
+ }
+
+ /// Build an MLIR type from a Toy AST variable type (forward to the generic
+ /// getType above for non-struct types).
+ mlir::Type getType(const VarType &type, const Location &location) {
+ if (!type.name.empty()) {
+ auto it = structMap.find(type.name);
+ if (it == structMap.end()) {
+ emitError(loc(location))
+ << "error: unknown struct type '" << type.name << "'";
+ return nullptr;
+ }
+ return it->second.first;
+ }
+
+ return getType(type.shape);
+ }
+};
+
+} // namespace
+
+namespace toy {
+
+// The public API for codegen.
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST) {
+ return MLIRGenImpl(context).mlirGen(moduleAST);
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
new file mode 100644
index 00000000000..1f572015c39
--- /dev/null
+++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
@@ -0,0 +1,113 @@
+//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a Function level pass performing interprocedural
+// propagation of array shapes through function specialization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "toy/Dialect.h"
+#include "toy/Passes.h"
+#include "toy/ShapeInferenceInterface.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "shape-inference"
+
+using namespace mlir;
+using namespace toy;
+
+/// Include the auto-generated definitions for the shape inference interfaces.
+#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
+
+namespace {
+/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
+/// shape inference.
+///
+/// Algorithm:
+///
+/// 1) Build a worklist containing all the operations that return a
+/// dynamically shaped tensor: these are the operations that need shape
+/// inference.
+/// 2) Iterate on the worklist:
+/// a) find an operation to process: the next ready operation in the
+/// worklist has all of its arguments non-generic,
+/// b) if no operation is found, break out of the loop,
+/// c) remove the operation from the worklist,
+/// d) infer the shape of its output from the argument types.
+/// 3) If the worklist is empty, the algorithm succeeded.
+///
+class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
+public:
+ void runOnFunction() override {
+ auto f = getFunction();
+
+ // Populate the worklist with the operations that need shape inference:
+ // these are operations that return a dynamic shape.
+ llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
+ f.walk([&](mlir::Operation *op) {
+ if (returnsDynamicShape(op))
+ opWorklist.insert(op);
+ });
+
+ // Iterate on the operations in the worklist until all operations have been
+ // inferred or no change happened (fix point).
+ while (!opWorklist.empty()) {
+ // Find the next operation ready for inference, that is an operation
+ // with all operands already resolved (non-generic).
+ auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+ if (nextop == opWorklist.end())
+ break;
+
+ Operation *op = *nextop;
+ opWorklist.erase(op);
+
+ // Ask the operation to infer its output shapes.
+ LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
+ if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
+ shapeOp.inferShapes();
+ } else {
+ op->emitError("unable to infer shape of operation without shape "
+ "inference interface");
+ return signalPassFailure();
+ }
+ }
+
+ // If the operation worklist isn't empty, this indicates a failure.
+ if (!opWorklist.empty()) {
+ f.emitError("Shape inference failed, ")
+ << opWorklist.size() << " operations couldn't be inferred\n";
+ signalPassFailure();
+ }
+ }
+
+ /// A utility method that returns if the given operation has a dynamically
+ /// shaped result.
+ static bool returnsDynamicShape(Operation *op) {
+ return llvm::any_of(op->getResultTypes(), [](Type resultType) {
+ return !resultType.isa<RankedTensorType>();
+ });
+ }
+};
+} // end anonymous namespace
+
+/// Create a Shape Inference pass.
+std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
+ return std::make_unique<ShapeInferencePass>();
+}
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
new file mode 100644
index 00000000000..ebd4f5d1103
--- /dev/null
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -0,0 +1,101 @@
+//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a set of simple combiners for optimizing operations in
+// the Toy dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "toy/Dialect.h"
+#include <numeric>
+using namespace mlir;
+using namespace toy;
+
+namespace {
+/// Include the patterns defined in the Declarative Rewrite framework.
+#include "ToyCombine.inc"
+} // end anonymous namespace
+
+/// Fold simple cast operations that return the same type as the input.
+OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+ return mlir::impl::foldCastOp(*this);
+}
+
+/// Fold constants.
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
+
+/// Fold struct constants.
+OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
+ return value();
+}
+
+/// Fold simple struct access operations that access into a constant.
+OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
+ auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
+ if (!structAttr)
+ return nullptr;
+
+ size_t elementIndex = index().getZExtValue();
+ return structAttr.getValue()[elementIndex];
+}
+
+/// This is an example of a c++ rewrite pattern for the TransposeOp. It
+/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
+struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
+ /// We register this pattern to match every toy.transpose in the IR.
+ /// The "benefit" is used by the framework to order the patterns and process
+ /// them in order of profitability.
+ SimplifyRedundantTranspose(mlir::MLIRContext *context)
+ : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
+
+ /// This method attempts to match a pattern and rewrite it. The rewriter
+ /// argument is the orchestrator of the sequence of rewrites. The pattern is
+ /// expected to interact with it to perform any changes to the IR from here.
+ mlir::PatternMatchResult
+ matchAndRewrite(TransposeOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // Look through the input of the current transpose.
+ mlir::Value *transposeInput = op.getOperand();
+ TransposeOp transposeInputOp =
+ llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
+
+ // If the input is defined by another Transpose, bingo!
+ if (!transposeInputOp)
+ return matchFailure();
+
+ // Use the rewriter to perform the replacement.
+ rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
+ return matchSuccess();
+ }
+};
+
+/// Register our patterns as "canonicalization" patterns on the TransposeOp so
+/// that they can be picked up by the Canonicalization framework.
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SimplifyRedundantTranspose>(context);
+}
+
+/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
+/// that they can be picked up by the Canonicalization framework.
+void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
+ FoldConstantReshapeOptPattern>(context);
+}
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.td b/mlir/examples/toy/Ch7/mlir/ToyCombine.td
new file mode 100644
index 00000000000..0a63861fa96
--- /dev/null
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.td
@@ -0,0 +1,73 @@
+//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines language-specific pattern match optimizations for Toy using
+// Declarative Rewrite Rules (DRR) specified using TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOY_COMBINE
+#define TOY_COMBINE
+
+#ifndef OP_BASE
+include "toy/Ops.td"
+#endif // OP_BASE
+
+/// Note: The DRR definition used for defining patterns is shown below:
+///
+/// class Pattern<
+/// dag sourcePattern, list<dag> resultPatterns,
+/// list<dag> additionalConstraints = [],
+/// dag benefitsAdded = (addBenefit 0)
+/// >;
+
+//===----------------------------------------------------------------------===//
+// Basic Pattern-Match and Rewrite
+//===----------------------------------------------------------------------===//
+
+// Reshape(Reshape(x)) = Reshape(x)
+def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
+ (ReshapeOp $arg)>;
+
+//===----------------------------------------------------------------------===//
+// Pattern-Match and Rewrite using Native Code Call
+//===----------------------------------------------------------------------===//
+
+// Native Code Calls may be used for more complex transformations using inline
+// C++ and C++ helper functions.
+
+// Reshape(Constant(x)) = x'
+def ReshapeConstant :
+ NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
+def FoldConstantReshapeOptPattern : Pat<
+ (ReshapeOp:$res (ConstantOp $arg)),
+ (ConstantOp (ReshapeConstant $arg, $res))>;
+
+//===----------------------------------------------------------------------===//
+// Pattern-Match and Rewrite with Constraints
+//===----------------------------------------------------------------------===//
+
+// DRR allows for constraint checking when the transformation is conditional
+// on operand properties.
+
+// Reshape(x) = x, where input and output shapes are identical
+def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
+def RedundantReshapeOptPattern : Pat<
+ (ReshapeOp:$res $arg), (replaceWithValue $arg),
+ [(TypesAreIdentical $res, $arg)]>;
+
+#endif // TOY_COMBINE
diff --git a/mlir/examples/toy/Ch7/parser/AST.cpp b/mlir/examples/toy/Ch7/parser/AST.cpp
new file mode 100644
index 00000000000..74056296f2e
--- /dev/null
+++ b/mlir/examples/toy/Ch7/parser/AST.cpp
@@ -0,0 +1,285 @@
+//===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST dump for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/AST.h"
+
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+
+namespace {
+
+// RAII helper to manage increasing/decreasing the indentation as we traverse
+// the AST
+struct Indent {
+ Indent(int &level) : level(level) { ++level; }
+ ~Indent() { --level; }
+ int &level;
+};
+
+/// Helper class that implement the AST tree traversal and print the nodes along
+/// the way. The only data member is the current indentation level.
+class ASTDumper {
+public:
+ void dump(ModuleAST *node);
+
+private:
+ void dump(const VarType &type);
+ void dump(VarDeclExprAST *varDecl);
+ void dump(ExprAST *expr);
+ void dump(ExprASTList *exprList);
+ void dump(NumberExprAST *num);
+ void dump(LiteralExprAST *node);
+ void dump(StructLiteralExprAST *node);
+ void dump(VariableExprAST *node);
+ void dump(ReturnExprAST *node);
+ void dump(BinaryExprAST *node);
+ void dump(CallExprAST *node);
+ void dump(PrintExprAST *node);
+ void dump(PrototypeAST *node);
+ void dump(FunctionAST *node);
+ void dump(StructAST *node);
+
+ // Actually print spaces matching the current indentation level
+ void indent() {
+ for (int i = 0; i < curIndent; i++)
+ llvm::errs() << " ";
+ }
+ int curIndent = 0;
+};
+
+} // namespace
+
+/// Return a formatted string for the location of any node
+template <typename T> static std::string loc(T *node) {
+ const auto &loc = node->loc();
+ return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
+ llvm::Twine(loc.col))
+ .str();
+}
+
+// Helper Macro to bump the indentation level and print the leading spaces for
+// the current indentations
+#define INDENT() \
+ Indent level_(curIndent); \
+ indent();
+
+/// Dispatch to a generic expressions to the appropriate subclass using RTTI
+void ASTDumper::dump(ExprAST *expr) {
+#define dispatch(CLASS) \
+ if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
+ return dump(node);
+ dispatch(VarDeclExprAST);
+ dispatch(LiteralExprAST);
+ dispatch(StructLiteralExprAST);
+ dispatch(NumberExprAST);
+ dispatch(VariableExprAST);
+ dispatch(ReturnExprAST);
+ dispatch(BinaryExprAST);
+ dispatch(CallExprAST);
+ dispatch(PrintExprAST);
+ // No match, fallback to a generic message
+ INDENT();
+ llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
+}
+
+/// A variable declaration is printing the variable name, the type, and then
+/// recurse in the initializer value.
+void ASTDumper::dump(VarDeclExprAST *varDecl) {
+ INDENT();
+ llvm::errs() << "VarDecl " << varDecl->getName();
+ dump(varDecl->getType());
+ llvm::errs() << " " << loc(varDecl) << "\n";
+ if (auto *initVal = varDecl->getInitVal())
+ dump(initVal);
+}
+
+/// A "block", or a list of expression
+void ASTDumper::dump(ExprASTList *exprList) {
+ INDENT();
+ llvm::errs() << "Block {\n";
+ for (auto &expr : *exprList)
+ dump(expr.get());
+ indent();
+ llvm::errs() << "} // Block\n";
+}
+
+/// A literal number, just print the value.
+void ASTDumper::dump(NumberExprAST *num) {
+ INDENT();
+ llvm::errs() << num->getValue() << " " << loc(num) << "\n";
+}
+
+/// Helper to print recursively a literal. This handles nested array like:
+/// [ [ 1, 2 ], [ 3, 4 ] ]
+/// We print out such array with the dimensions spelled out at every level:
+/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
+void printLitHelper(ExprAST *litOrNum) {
+ // Inside a literal expression we can have either a number or another literal
+ if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
+ llvm::errs() << num->getValue();
+ return;
+ }
+ auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
+
+ // Print the dimension for this literal first
+ llvm::errs() << "<";
+ mlir::interleaveComma(literal->getDims(), llvm::errs());
+ llvm::errs() << ">";
+
+ // Now print the content, recursing on every element of the list
+ llvm::errs() << "[ ";
+ mlir::interleaveComma(literal->getValues(), llvm::errs(),
+ [&](auto &elt) { printLitHelper(elt.get()); });
+ llvm::errs() << "]";
+}
+
+/// Print a literal, see the recursive helper above for the implementation.
+void ASTDumper::dump(LiteralExprAST *node) {
+ INDENT();
+ llvm::errs() << "Literal: ";
+ printLitHelper(node);
+ llvm::errs() << " " << loc(node) << "\n";
+}
+
+/// Print a struct literal.
+void ASTDumper::dump(StructLiteralExprAST *node) {
+ INDENT();
+ llvm::errs() << "Struct Literal: ";
+ for (auto &value : node->getValues())
+ dump(value.get());
+ indent();
+ llvm::errs() << " " << loc(node) << "\n";
+}
+
+/// Print a variable reference (just a name).
+void ASTDumper::dump(VariableExprAST *node) {
+ INDENT();
+ llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
+}
+
+/// Return statement print the return and its (optional) argument.
+void ASTDumper::dump(ReturnExprAST *node) {
+ INDENT();
+ llvm::errs() << "Return\n";
+ if (node->getExpr().hasValue())
+ return dump(*node->getExpr());
+ {
+ INDENT();
+ llvm::errs() << "(void)\n";
+ }
+}
+
+/// Print a binary operation, first the operator, then recurse into LHS and RHS.
+void ASTDumper::dump(BinaryExprAST *node) {
+ INDENT();
+ llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+ dump(node->getLHS());
+ dump(node->getRHS());
+}
+
+/// Print a call expression, first the callee name and the list of args by
+/// recursing into each individual argument.
+void ASTDumper::dump(CallExprAST *node) {
+ INDENT();
+ llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+ for (auto &arg : node->getArgs())
+ dump(arg.get());
+ indent();
+ llvm::errs() << "]\n";
+}
+
+/// Print a builtin print call, first the builtin name and then the argument.
+void ASTDumper::dump(PrintExprAST *node) {
+ INDENT();
+ llvm::errs() << "Print [ " << loc(node) << "\n";
+ dump(node->getArg());
+ indent();
+ llvm::errs() << "]\n";
+}
+
+/// Print type: only the shape is printed in between '<' and '>'
+void ASTDumper::dump(const VarType &type) {
+ llvm::errs() << "<";
+ if (!type.name.empty())
+ llvm::errs() << type.name;
+ else
+ mlir::interleaveComma(type.shape, llvm::errs());
+ llvm::errs() << ">";
+}
+
+/// Print a function prototype, first the function name, and then the list of
+/// parameters names.
+void ASTDumper::dump(PrototypeAST *node) {
+ INDENT();
+ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
+ indent();
+ llvm::errs() << "Params: [";
+ mlir::interleaveComma(node->getArgs(), llvm::errs(),
+ [](auto &arg) { llvm::errs() << arg->getName(); });
+ llvm::errs() << "]\n";
+}
+
+/// Print a function, first the prototype and then the body.
+void ASTDumper::dump(FunctionAST *node) {
+ INDENT();
+ llvm::errs() << "Function \n";
+ dump(node->getProto());
+ dump(node->getBody());
+}
+
+/// Print a struct.
+void ASTDumper::dump(StructAST *node) {
+ INDENT();
+ llvm::errs() << "Struct: " << node->getName() << " " << loc(node) << "\n";
+
+ {
+ INDENT();
+ llvm::errs() << "Variables: [\n";
+ for (auto &variable : node->getVariables())
+ dump(variable.get());
+ indent();
+ llvm::errs() << "]\n";
+ }
+}
+
+/// Print a module, actually loop over the functions and print them in sequence.
+void ASTDumper::dump(ModuleAST *node) {
+ INDENT();
+ llvm::errs() << "Module:\n";
+ for (auto &record : *node) {
+ if (FunctionAST *function = llvm::dyn_cast<FunctionAST>(record.get()))
+ dump(function);
+ else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get()))
+ dump(str);
+ else
+ llvm::errs() << "<unknown Record, kind " << record->getKind() << ">\n";
+ }
+}
+
+namespace toy {
+
+// Public API
+void dump(ModuleAST &module) { ASTDumper().dump(&module); }
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp
new file mode 100644
index 00000000000..26b684cbe2a
--- /dev/null
+++ b/mlir/examples/toy/Ch7/toyc.cpp
@@ -0,0 +1,282 @@
+//===- toyc.cpp - The Toy Compiler ----------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the entry point for the Toy compiler.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+#include "toy/MLIRGen.h"
+#include "toy/Parser.h"
+#include "toy/Passes.h"
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+namespace cl = llvm::cl;
+
+static cl::opt<std::string> inputFilename(cl::Positional,
+ cl::desc("<input toy file>"),
+ cl::init("-"),
+ cl::value_desc("filename"));
+
+namespace {
+enum InputType { Toy, MLIR };
+}
+static cl::opt<enum InputType> inputType(
+ "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
+ cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
+ cl::values(clEnumValN(MLIR, "mlir",
+ "load the input file as an MLIR file")));
+
+namespace {
+enum Action {
+ None,
+ DumpAST,
+ DumpMLIR,
+ DumpMLIRAffine,
+ DumpMLIRLLVM,
+ DumpLLVMIR,
+ RunJIT
+};
+}
+static cl::opt<enum Action> emitAction(
+ "emit", cl::desc("Select the kind of output desired"),
+ cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
+ cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")),
+ cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine",
+ "output the MLIR dump after affine lowering")),
+ cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm",
+ "output the MLIR dump after llvm lowering")),
+ cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")),
+ cl::values(
+ clEnumValN(RunJIT, "jit",
+ "JIT the code and run it by invoking the main function")));
+
+static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
+
+/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
+std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
+ llvm::MemoryBuffer::getFileOrSTDIN(filename);
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
+ return nullptr;
+ }
+ auto buffer = fileOrErr.get()->getBuffer();
+ LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
+ Parser parser(lexer);
+ return parser.parseModule();
+}
+
+int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
+ // Handle '.toy' input to the compiler.
+ if (inputType != InputType::MLIR &&
+ !llvm::StringRef(inputFilename).endswith(".mlir")) {
+ auto moduleAST = parseInputFile(inputFilename);
+ module = mlirGen(context, *moduleAST);
+ return !module ? 1 : 0;
+ }
+
+ // Otherwise, the input is '.mlir'.
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
+ llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
+ if (std::error_code EC = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ return -1;
+ }
+
+ // Parse the input mlir.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
+ module = mlir::parseSourceFile(sourceMgr, &context);
+ if (!module) {
+ llvm::errs() << "Error can't load file " << inputFilename << "\n";
+ return 3;
+ }
+ return 0;
+}
+
+int loadAndProcessMLIR(mlir::MLIRContext &context,
+ mlir::OwningModuleRef &module) {
+ if (int error = loadMLIR(context, module))
+ return error;
+
+ mlir::PassManager pm(&context);
+ // Apply any generic pass manager command line options and run the pipeline.
+ applyPassManagerCLOptions(pm);
+
+ // Check to see what granularity of MLIR we are compiling to.
+ bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
+ bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM;
+
+ if (enableOpt || isLoweringToAffine) {
+ // Inline all functions into main and then delete them.
+ pm.addPass(mlir::createInlinerPass());
+ pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
+
+ // Now that there is only one function, we can infer the shapes of each of
+ // the operations.
+ mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
+ optPM.addPass(mlir::createCanonicalizerPass());
+ optPM.addPass(mlir::toy::createShapeInferencePass());
+ optPM.addPass(mlir::createCanonicalizerPass());
+ optPM.addPass(mlir::createCSEPass());
+ }
+
+ if (isLoweringToAffine) {
+ // Partially lower the toy dialect with a few cleanups afterwards.
+ pm.addPass(mlir::toy::createLowerToAffinePass());
+
+ mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
+ optPM.addPass(mlir::createCanonicalizerPass());
+ optPM.addPass(mlir::createCSEPass());
+
+ // Add optimizations if enabled.
+ if (enableOpt) {
+ optPM.addPass(mlir::createLoopFusionPass());
+ optPM.addPass(mlir::createMemRefDataFlowOptPass());
+ }
+ }
+
+ if (isLoweringToLLVM) {
+ // Finish lowering the toy IR to the LLVM dialect.
+ pm.addPass(mlir::toy::createLowerToLLVMPass());
+ }
+
+ if (mlir::failed(pm.run(*module)))
+ return 4;
+ return 0;
+}
+
+int dumpAST() {
+ if (inputType == InputType::MLIR) {
+ llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
+ return 5;
+ }
+
+ auto moduleAST = parseInputFile(inputFilename);
+ if (!moduleAST)
+ return 1;
+
+ dump(*moduleAST);
+ return 0;
+}
+
+int dumpLLVMIR(mlir::ModuleOp module) {
+ auto llvmModule = mlir::translateModuleToLLVMIR(module);
+ if (!llvmModule) {
+ llvm::errs() << "Failed to emit LLVM IR\n";
+ return -1;
+ }
+
+ // Initialize LLVM targets.
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+ mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
+
+ /// Optionally run an optimization pipeline over the llvm module.
+ auto optPipeline = mlir::makeOptimizingTransformer(
+ /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
+ /*targetMachine=*/nullptr);
+ if (auto err = optPipeline(llvmModule.get())) {
+ llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
+ return -1;
+ }
+ llvm::errs() << *llvmModule << "\n";
+ return 0;
+}
+
+int runJit(mlir::ModuleOp module) {
+ // Initialize LLVM targets.
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+
+ // An optimization pipeline to use within the execution engine.
+ auto optPipeline = mlir::makeOptimizingTransformer(
+ /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
+ /*targetMachine=*/nullptr);
+
+ // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
+ // the module.
+ auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
+ assert(maybeEngine && "failed to construct an execution engine");
+ auto &engine = maybeEngine.get();
+
+ // Invoke the JIT-compiled function.
+ auto invocationResult = engine->invoke("main");
+ if (invocationResult) {
+ llvm::errs() << "JIT invocation failed\n";
+ return -1;
+ }
+
+ return 0;
+}
+
+int main(int argc, char **argv) {
+ mlir::registerPassManagerCLOptions();
+ cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+
+ if (emitAction == Action::DumpAST)
+ return dumpAST();
+
+ // If we aren't dumping the AST, then we are compiling with/to MLIR.
+
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+
+ mlir::MLIRContext context;
+ mlir::OwningModuleRef module;
+ if (int error = loadAndProcessMLIR(context, module))
+ return error;
+
+ // If we aren't exporting to non-mlir, then we are done.
+ bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM;
+ if (isOutputingMLIR) {
+ module->dump();
+ return 0;
+ }
+
+ // Check to see if we are compiling to LLVM IR.
+ if (emitAction == Action::DumpLLVMIR)
+ return dumpLLVMIR(*module);
+
+ // Otherwise, we must be running the jit.
+ if (emitAction == Action::RunJIT)
+ return runJit(*module);
+
+ llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+ return -1;
+}
diff --git a/mlir/g3doc/Tutorials/Toy/Ch-1.md b/mlir/g3doc/Tutorials/Toy/Ch-1.md
index fd08d4743c0..cec79a47866 100644
--- a/mlir/g3doc/Tutorials/Toy/Ch-1.md
+++ b/mlir/g3doc/Tutorials/Toy/Ch-1.md
@@ -25,6 +25,9 @@ This tutorial is divided in the following chapters:
- [Chapter #6](Ch-6.md): Lowering to LLVM and code generation. Here we'll
target LLVM IR for code generation, and detail more of the lowering
framework.
+- [Chapter #7](Ch-7.md): Extending Toy: Adding support for a composite type.
+ We'll demonstrate how to add a custom type to MLIR, and how it fits in the
+ existing pipeline.
## The Language
@@ -87,7 +90,7 @@ def main() {
# Finally, calling into `multiply_transpose` with incompatible shape will
# trigger a shape inference error.
- var e = multiply_transpose(transpose(a), c);
+ var f = multiply_transpose(transpose(a), c);
}
```
diff --git a/mlir/g3doc/Tutorials/Toy/Ch-7.md b/mlir/g3doc/Tutorials/Toy/Ch-7.md
new file mode 100644
index 00000000000..3b9896f191d
--- /dev/null
+++ b/mlir/g3doc/Tutorials/Toy/Ch-7.md
@@ -0,0 +1,538 @@
+# Chapter 7: Adding a Composite Type to Toy
+
+[TOC]
+
+In the [previous chapter](Ch-6.md), we demonstrated an end-to-end compilation
+flow from our Toy front-end to LLVM IR. In this chapter, we will extend the Toy
+language to support a new composite `struct` type.
+
+## Defining a `struct` in Toy
+
+The first thing we need to define is the interface of this type in our `toy`
+source language. The general syntax of a `struct` type in toy is as follows:
+
+```toy
+# A struct is defined by using the `struct` keyword followed by a name.
+struct MyStruct {
+ # Inside of the struct is a list of variable declarations without initializers
+ # or shapes, which may also be other previously defined structs.
+ var a;
+ var b;
+}
+```
+
+Structs may now be used in functions as variables or parameters by using the
+name of the struct instead of `var`. The members of the struct are accessed via
+a `.` access operator. Values of `struct` type may be initialized with a
+composite initializer, or a comma-separated list of other initializers
+surrounded by `{}`. An example is shown below:
+
+```toy
+struct Struct {
+ var a;
+ var b;
+}
+
+# User defined generic function may operate on struct types as well.
+def multiply_transpose(Struct value) {
+ # We can access the elements of a struct via the '.' operator.
+ return transpose(value.a) * transpose(value.b);
+}
+
+def main() {
+ # We initialize struct values using a composite initializer.
+ Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]};
+
+ # We pass these arguments to functions like we do with variables.
+ var c = multiply_transpose(value);
+ print(c);
+}
+```
+
+## Defining a `struct` in MLIR
+
+In MLIR, we will also need a representation for our struct types. MLIR does not
+provide a type that does exactly what we need, so we will need to define our
+own. We will simply define our `struct` as an unnamed container of a set of
+element types. The name of the `struct` and its elements are only useful for the
+AST of our `toy` compiler, so we don't need to encode it in the MLIR
+representation.
+
+### Defining the Type Class
+
+#### Reserving a Range of Type Kinds
+
+Types in MLIR rely on having a unique `kind` value to ensure that casting checks
+remain extremely
+efficient([rationale](Rationale.md#reserving-dialect-type-kinds). For `toy`,
+this means we need to explicitly reserve a static range of type `kind` values in
+the symbol registry file
+[DialectSymbolRegistry](https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/DialectSymbolRegistry.def).
+
+```c++
+DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
+DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
+
+// The following ranges are reserved for experimenting with MLIR dialects in a
+// private context without having to register them here.
+DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0)
+```
+
+These definitions will provide a range in the Type::Kind enum to use when
+defining the derived types.
+
+```c++
+/// Create a local enumeration with all of the types that are defined by Toy.
+namespace ToyTypes {
+enum Types {
+ Struct = mlir::Type::FIRST_TOY_TYPE,
+};
+} // end namespace ToyTypes
+```
+
+#### Defining the Type Class
+
+As mentioned in [chapter 2](Ch-2.md), [`Type`](../../LangRef.md#type-system)
+objects in MLIR are value-typed and rely on having an internal storage object
+that holds the actual data for the type. The `Type` class in itself acts as a
+simple wrapper around an internal `TypeStorage` object that is uniqued within an
+instance of an `MLIRContext`. When constructing a `Type`, we are internally just
+constructing and uniquing an instance of a storage class.
+
+When defining a new `Type` that requires additional information than just the
+`kind`, like our struct type for the element types, we will need to provide a
+derived storage class. The `primitive` types that don't have any additional
+data, like the [`index` type](../../LangRef.md#index-type), don't require a
+storage class.
+
+##### Defining the Storage Class
+
+Type storage objects contain all of the data necessary to construct and unique a
+type instance. Derived storage classes must inherit from the base
+`mlir::TypeStorage` and provide a set of aliases and hooks that will be used by
+the `MLIRContext` for uniquing. Below is the definition of the storage instance
+for our `struct` type, with each of the necessary requirements detailed inline:
+
+```c++
+/// This class represents the internal storage of the Toy `StructType`.
+struct StructTypeStorage : public mlir::TypeStorage {
+ /// The `KeyTy` is a required type that provides an interface for the storage
+ /// instance. This type will be used when uniquing an instance of the type
+ /// storage. For our struct type, we will unique each instance structurally on
+ /// the elements that it contains.
+ using KeyTy = llvm::ArrayRef<mlir::Type>;
+
+ /// A constructor for the type storage instance.
+ StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
+ : elementTypes(elementTypes) {}
+
+ /// Define the comparison function for the key type with the current storage
+ /// instance. This is used when constructing a new instance to ensure that we
+ /// haven't already uniqued an instance of the given key.
+ bool operator==(const KeyTy &key) const { return key == elementTypes; }
+
+ /// Define a hash function for the key type. This is used when uniquing
+ /// instances of the storage.
+ /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
+ /// have hash functions available, so we could just omit this entirely.
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_value(key);
+ }
+
+ /// Define a construction function for the key type from a set of parameters.
+ /// These parameters will be provided when constructing the storage instance
+ /// itself, see the `StructType::get` method further below.
+ /// Note: This method isn't necessary because KeyTy can be directly
+ /// constructed with the given parameters.
+ static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
+ return KeyTy(elementTypes);
+ }
+
+ /// Define a construction method for creating a new instance of this storage.
+ /// This method takes an instance of a storage allocator, and an instance of a
+ /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
+ /// allocations used to create the type storage and its internal.
+ static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ // Copy the elements from the provided `KeyTy` into the allocator.
+ llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
+
+ // Allocate the storage instance and construct it.
+ return new (allocator.allocate<StructTypeStorage>())
+ StructTypeStorage(elementTypes);
+ }
+
+ /// The following field contains the element types of the struct.
+ llvm::ArrayRef<mlir::Type> elementTypes;
+};
+```
+
+##### Defining the Type Class
+
+With the storage class defined, we can add the definition for the user visible
+`StructType` class. This is the class that we will actually interface with.
+
+```c++
+/// This class defines the Toy struct type. It represents a collection of
+/// element types. All derived types in MLIR must inherit from the CRTP class
+/// 'Type::TypeBase'. It takes as template parameters the concrete type
+/// (StructType), the base class to use (Type), and the storage class
+/// (StructTypeStorage).
+class StructType : public mlir::Type::TypeBase<StructType, mlir::Type,
+ StructTypeStorage> {
+public:
+ /// Inherit some necessary constructors from 'TypeBase'.
+ using Base::Base;
+
+ /// This static method is used to support type inquiry through isa, cast,
+ /// and dyn_cast.
+ static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; }
+
+ /// Create an instance of a `StructType` with the given element types. There
+ /// *must* be at least one element type.
+ static StructType get(llvm::ArrayRef<mlir::Type> elementTypes) {
+ assert(!elementTypes.empty() && "expected at least 1 element type");
+
+ // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
+ // of this type. The first two parameters are the context to unique in and
+ // the kind of the type. The parameters after the type kind are forwarded to
+ // the storage instance.
+ mlir::MLIRContext *ctx = elementTypes.front().getContext();
+ return Base::get(ctx, ToyTypes::Struct, elementTypes);
+ }
+
+ /// Returns the element types of this struct type.
+ llvm::ArrayRef<mlir::Type> getElementTypes() {
+ // 'getImpl' returns a pointer to the internal storage instance.
+ return getImpl()->elementTypes;
+ }
+
+ /// Returns the number of element type held by this struct.
+ size_t getNumElementTypes() { return getElementTypes().size(); }
+};
+```
+
+and we register this type in the `ToyDialect` constructor in a similar way to
+how we did with operations:
+
+```c++
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ : mlir::Dialect(getDialectNamespace(), ctx) {
+ addTypes<StructType>();
+}
+```
+
+With this we can now use our `StructType` when generating MLIR from Toy. See
+MLIRGen.cpp for more details.
+
+### Parsing and Printing
+
+At this point we can use our `StructType` during MLIR generation and
+transformation, but we can't output or parse `.mlir`. To support this we need to
+add support for parsing and printing instances of the `StructType`. This support
+can be added by overriding the `parseType` and `printType` methods on the
+`ToyDialect`.
+
+```c++
+class ToyDialect : public mlir::Dialect {
+public:
+ /// Parse an instance of a type registered to the toy dialect.
+ mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
+
+ /// Print an instance of a type registered to the toy dialect.
+ void printType(mlir::Type type,
+ mlir::DialectAsmPrinter &printer) const override;
+};
+```
+
+These methods take an instance of a high level parser or printer that allows for
+easily implementing the necessary functionality. Before going into the
+implementation, let's think about the syntax that we want for the `struct` type
+in the printed IR. As described in the
+[MLIR language reference](../../LangRef.md#dialect-types), dialect types are
+generally represented as: `!` dialect-namespace `<` type-data `>`; With a pretty
+form available under certain circumstances. The responsibility of our `Toy`
+parser and printer is to provide the `type-data` bits. We will define our
+`StructType` as having the following form:
+
+``` {.ebnf}
+ struct-type ::= `struct` `<` type (`,` type)* `>`
+```
+
+#### Parsing
+
+An implementation of the parser is shown below:
+
+```c++
+/// Parse an instance of a type registered to the toy dialect.
+mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
+ // Parse a struct type in the following form:
+ // struct-type ::= `struct` `<` type (`,` type)* `>`
+
+ // NOTE: All MLIR parser function return a ParseResult. This is a
+ // specialization of LogicalResult that auto-converts to a `true` boolean
+ // value on failure to allow for chaining, but may be used with explicit
+ // `mlir::failed/mlir::succeeded` as desired.
+
+ // Parse: `struct` `<`
+ if (parser.parseKeyword("struct") || parser.parseLess())
+ return Type();
+
+ // Parse the element types of the struct.
+ SmallVector<mlir::Type, 1> elementTypes;
+ do {
+ // Parse the current element type.
+ llvm::SMLoc typeLoc = parser.getCurrentLocation();
+ mlir::Type elementType;
+ if (parser.parseType(elementType))
+ return nullptr;
+
+ // Check that the type is either a TensorType or another StructType.
+ if (!elementType.isa<mlir::TensorType>() &&
+ !elementType.isa<StructType>()) {
+ parser.emitError(typeLoc, "element type for a struct must either "
+ "be a TensorType or a StructType, got: ")
+ << elementType;
+ return Type();
+ }
+ elementTypes.push_back(elementType);
+
+ // Parse the optional: `,`
+ } while (succeeded(parser.parseOptionalComma()));
+
+ // Parse: `>`
+ if (parser.parseGreater())
+ return Type();
+ return StructType::get(elementTypes);
+}
+```
+
+#### Printing
+
+As implementation of the printer is shown below:
+
+```c++
+/// Print an instance of a type registered to the toy dialect.
+void ToyDialect::printType(mlir::Type type,
+ mlir::DialectAsmPrinter &printer) const {
+ // Currently the only toy type is a struct type.
+ StructType structType = type.cast<StructType>();
+
+ // Print the struct type according to the parser format.
+ printer << "struct<";
+ mlir::interleaveComma(structType.getElementTypes(), printer);
+ printer << '>';
+}
+```
+
+Before moving on, let's look at a quick of example showcasing the functionality
+we have now:
+
+```toy
+struct Struct {
+ var a;
+ var b;
+}
+
+def multiply_transpose(Struct value) {
+}
+```
+
+Which generates the following:
+
+```mlir
+module {
+ func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) {
+ "toy.return"() : () -> ()
+ }
+}
+```
+
+### Operating on `StructType`
+
+Now that the `struct` type has been defined, and we can roundtrip it through the
+IR. The next step is to add support for using it within our operations.
+
+#### Updating Existing Operations
+
+A few of our existing operations will need to be updated to handle `StructType`.
+The first step is to make the ODS framework aware of our Type, so that we can
+use it in the operation definitions. A simple example is shown below:
+
+```td
+// Provide a definition for the Toy StructType for use in ODS. This allows for
+// using StructType in a similar way to Tensor or MemRef.
+def Toy_StructType :
+ Type<CPred<"$_self.isa<StructType>()">, "Toy struct type">;
+
+// Provide a definition of the types that are used within the Toy dialect.
+def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>;
+```
+
+We can then update our operations, like `ReturnOp` for example, to also accept
+the `Toy_StructType`:
+
+```td
+def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
+ ...
+ let arguments = (ins Variadic<Toy_Type>:$input);
+ ...
+}
+```
+
+#### Adding New `Toy` Operations
+
+In addition to the existing operations, we will be adding a few new operations
+that will provide more specific handling of `structs`.
+
+##### `toy.struct_constant`
+
+This new operation materializes a constant value for a struct. In our current
+modeling we just use an [array attribute](../../LangRef.md#array-attribute) that
+contains a set of constant values for each of the `struct` elements.
+
+```mlir
+ %0 = "toy.struct_constant"() {
+ value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>]
+ } : () -> !toy.struct<tensor<*xf64>>
+```
+
+##### `toy.struct_access`
+
+This new operation materializes the Nth element of a `struct` value.
+
+```mlir
+ %0 = "toy.struct_constant"() {
+ value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>]
+ } : () -> !toy.struct<tensor<*xf64>>
+ %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>>) -> tensor<*xf64>
+```
+
+With these operations, we can revisit our original example:
+
+```toy
+struct Struct {
+ var a;
+ var b;
+}
+
+# User defined generic function may operate on struct types as well.
+def multiply_transpose(Struct value) {
+ # We can access the elements of a struct via the '.' operator.
+ return transpose(value.a) * transpose(value.b);
+}
+
+def main() {
+ # We initialize struct values using a composite initializer.
+ Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]};
+
+ # We pass these arguments to functions like we do with variables.
+ var c = multiply_transpose(value);
+ print(c);
+}
+```
+
+and finally get a full MLIR module:
+
+```mlir
+module {
+ func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> {
+ %0 = "toy.struct_access"(%arg0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+ %1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64>
+ %2 = "toy.struct_access"(%arg0) {index = 1 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+ %3 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64>
+ %4 = "toy.mul"(%1, %3) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
+ "toy.return"(%4) : (tensor<*xf64>) -> ()
+ }
+ func @main() {
+ %0 = "toy.struct_constant"() {value = [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct<tensor<*xf64>, tensor<*xf64>>
+ %1 = "toy.generic_call"(%0) {callee = @multiply_transpose} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+ "toy.print"(%1) : (tensor<*xf64>) -> ()
+ "toy.return"() : () -> ()
+ }
+}
+```
+
+#### Optimizing Operations on `StructType`
+
+Now that we have a few operations operating on `StructType`, we also have many
+new constant folding opportunities. After inlining the MLIR module in the
+previous section looks something like:
+
+```mlir
+module {
+ func @main() {
+ %0 = "toy.struct_constant"() {value = [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct<tensor<*xf64>, tensor<*xf64>>
+ %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+ %2 = "toy.transpose"(%1) : (tensor<*xf64>) -> tensor<*xf64>
+ %3 = "toy.struct_access"(%0) {index = 1 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+ %4 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64>
+ %5 = "toy.mul"(%2, %4) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
+ "toy.print"(%5) : (tensor<*xf64>) -> ()
+ "toy.return"() : () -> ()
+ }
+}
+```
+
+We have several `toy.struct_access` operations that access into a
+`toy.struct_constant`. As detailed in [chapter 3](Ch-3.md), we can add folders
+for these `toy` operations by setting the `hasFolder` bit on the operation
+definition and providing a definition of the `*Op::fold` method.
+
+```c++
+/// Fold constants.
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
+
+/// Fold struct constants.
+OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
+ return value();
+}
+
+/// Fold simple struct access operations that access into a constant.
+OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
+ auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
+ if (!structAttr)
+ return nullptr;
+
+ size_t elementIndex = index().getZExtValue();
+ return structAttr.getValue()[elementIndex];
+}
+```
+
+To ensure that MLIR generates the proper constant operations when folding our
+`Toy` operations, i.e. `ConstantOp` for `TensorType` and `StructConstant` for
+`StructType`, we will need to provide an override for the dialect hook
+`materializeConstant`. This allows for generic MLIR operations to create
+constants for the `Toy` dialect when necessary.
+
+```c++
+mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
+ mlir::Attribute value,
+ mlir::Type type,
+ mlir::Location loc) {
+ if (type.isa<StructType>())
+ return builder.create<StructConstantOp>(loc, type,
+ value.cast<mlir::ArrayAttr>());
+ return builder.create<ConstantOp>(loc, type,
+ value.cast<mlir::DenseElementsAttr>());
+}
+```
+
+With this we can now generate code that can be generated to LLVM without any
+changes to our pipeline.
+
+```mlir
+module {
+ func @main() {
+ %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
+ %1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64>
+ %2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
+ "toy.print"(%2) : (tensor<3x2xf64>) -> ()
+ "toy.return"() : () -> ()
+ }
+}
+```
+
+You can build `toyc-ch7` and try yourself: `toyc-ch7 test/struct-codegen.toy
+-emit=mlir`. More details can on defining custom types can be found in
+[DefiningAttributesAndTypes](../../DefiningAttributesAndTypes.md).
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 9c11adc95db..95792548221 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -50,6 +50,7 @@ if(LLVM_BUILD_EXAMPLES)
toyc-ch4
toyc-ch5
toyc-ch6
+ toyc-ch7
)
endif()
diff --git a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
new file mode 100644
index 00000000000..3d08d0c1d80
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
@@ -0,0 +1,65 @@
+// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
+// RUN: toyc-ch7 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
+
+func @main() {
+ %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
+ %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64>
+ %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
+ "toy.print"(%3) : (tensor<3x2xf64>) -> ()
+ "toy.return"() : () -> ()
+}
+
+// CHECK-LABEL: func @main()
+// CHECK: [[VAL_0:%.*]] = constant 1.000000e+00 : f64
+// CHECK: [[VAL_1:%.*]] = constant 2.000000e+00 : f64
+// CHECK: [[VAL_2:%.*]] = constant 3.000000e+00 : f64
+// CHECK: [[VAL_3:%.*]] = constant 4.000000e+00 : f64
+// CHECK: [[VAL_4:%.*]] = constant 5.000000e+00 : f64
+// CHECK: [[VAL_5:%.*]] = constant 6.000000e+00 : f64
+// CHECK: [[VAL_6:%.*]] = alloc() : memref<3x2xf64>
+// CHECK: [[VAL_7:%.*]] = alloc() : memref<3x2xf64>
+// CHECK: [[VAL_8:%.*]] = alloc() : memref<2x3xf64>
+// CHECK: affine.store [[VAL_0]], [[VAL_8]][0, 0] : memref<2x3xf64>
+// CHECK: affine.store [[VAL_1]], [[VAL_8]][0, 1] : memref<2x3xf64>
+// CHECK: affine.store [[VAL_2]], [[VAL_8]][0, 2] : memref<2x3xf64>
+// CHECK: affine.store [[VAL_3]], [[VAL_8]][1, 0] : memref<2x3xf64>
+// CHECK: affine.store [[VAL_4]], [[VAL_8]][1, 1] : memref<2x3xf64>
+// CHECK: affine.store [[VAL_5]], [[VAL_8]][1, 2] : memref<2x3xf64>
+// CHECK: affine.for [[VAL_9:%.*]] = 0 to 3 {
+// CHECK: affine.for [[VAL_10:%.*]] = 0 to 2 {
+// CHECK: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64>
+// CHECK: affine.store [[VAL_11]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<3x2xf64>
+// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
+// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
+// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
+// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
+// CHECK: [[VAL_16:%.*]] = mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
+// CHECK: "toy.print"([[VAL_6]]) : (memref<3x2xf64>) -> ()
+// CHECK: dealloc [[VAL_8]] : memref<2x3xf64>
+// CHECK: dealloc [[VAL_7]] : memref<3x2xf64>
+// CHECK: dealloc [[VAL_6]] : memref<3x2xf64>
+
+// OPT-LABEL: func @main()
+// OPT: [[VAL_0:%.*]] = constant 1.000000e+00 : f64
+// OPT: [[VAL_1:%.*]] = constant 2.000000e+00 : f64
+// OPT: [[VAL_2:%.*]] = constant 3.000000e+00 : f64
+// OPT: [[VAL_3:%.*]] = constant 4.000000e+00 : f64
+// OPT: [[VAL_4:%.*]] = constant 5.000000e+00 : f64
+// OPT: [[VAL_5:%.*]] = constant 6.000000e+00 : f64
+// OPT: [[VAL_6:%.*]] = alloc() : memref<3x2xf64>
+// OPT: [[VAL_7:%.*]] = alloc() : memref<2x3xf64>
+// OPT: affine.store [[VAL_0]], [[VAL_7]][0, 0] : memref<2x3xf64>
+// OPT: affine.store [[VAL_1]], [[VAL_7]][0, 1] : memref<2x3xf64>
+// OPT: affine.store [[VAL_2]], [[VAL_7]][0, 2] : memref<2x3xf64>
+// OPT: affine.store [[VAL_3]], [[VAL_7]][1, 0] : memref<2x3xf64>
+// OPT: affine.store [[VAL_4]], [[VAL_7]][1, 1] : memref<2x3xf64>
+// OPT: affine.store [[VAL_5]], [[VAL_7]][1, 2] : memref<2x3xf64>
+// OPT: affine.for [[VAL_8:%.*]] = 0 to 3 {
+// OPT: affine.for [[VAL_9:%.*]] = 0 to 2 {
+// OPT: [[VAL_10:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_8]]] : memref<2x3xf64>
+// OPT: [[VAL_11:%.*]] = mulf [[VAL_10]], [[VAL_10]] : f64
+// OPT: affine.store [[VAL_11]], [[VAL_6]]{{\[}}[[VAL_8]], [[VAL_9]]] : memref<3x2xf64>
+// OPT: "toy.print"([[VAL_6]]) : (memref<3x2xf64>) -> ()
+// OPT: dealloc [[VAL_7]] : memref<2x3xf64>
+// OPT: dealloc [[VAL_6]] : memref<3x2xf64>
diff --git a/mlir/test/Examples/Toy/Ch7/ast.toy b/mlir/test/Examples/Toy/Ch7/ast.toy
new file mode 100644
index 00000000000..b32410e58e7
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/ast.toy
@@ -0,0 +1,76 @@
+# RUN: toyc-ch7 %s -emit=ast 2>&1 | FileCheck %s
+
+# User defined generic function that operates on unknown shaped arguments.
+def multiply_transpose(a, b) {
+ return transpose(a) * transpose(b);
+}
+
+def main() {
+ # Define a variable `a` with shape <2, 3>, initialized with the literal value.
+ # The shape is inferred from the supplied literal.
+ var a = [[1, 2, 3], [4, 5, 6]];
+ # b is identical to a, the literal array is implicitly reshaped: defining new
+ # variables is the way to reshape arrays (element count in literal must match
+ # the size of specified shape).
+ var b<2, 3> = [1, 2, 3, 4, 5, 6];
+
+ # This call will specialize `multiply_transpose` with <2, 3> for both
+ # arguments and deduce a return type of <2, 2> in initialization of `c`.
+ var c = multiply_transpose(a, b);
+ # A second call to `multiply_transpose` with <2, 3> for both arguments will
+ # reuse the previously specialized and inferred version and return `<2, 2>`
+ var d = multiply_transpose(b, a);
+ # A new call with `<2, 2>` for both dimension will trigger another
+ # specialization of `multiply_transpose`.
+ var e = multiply_transpose(b, c);
+ # Finally, calling into `multiply_transpose` with incompatible shape will
+ # trigger a shape inference error.
+ var e = multiply_transpose(transpose(a), c);
+}
+
+
+# CHECK: Module:
+# CHECK-NEXT: Function
+# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1'
+# CHECK-NEXT: Params: [a, b]
+# CHECK-NEXT: Block {
+# CHECK-NEXT: Return
+# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25
+# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10
+# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20
+# CHECK-NEXT: ]
+# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25
+# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35
+# CHECK-NEXT: ]
+# CHECK-NEXT: } // Block
+# CHECK-NEXT: Function
+# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1'
+# CHECK-NEXT: Params: []
+# CHECK-NEXT: Block {
+# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3
+# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11
+# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3
+# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17
+# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11
+# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30
+# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33
+# CHECK-NEXT: ]
+# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11
+# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30
+# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33
+# CHECK-NEXT: ]
+# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11
+# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30
+# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33
+# CHECK-NEXT: ]
+# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:28:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11
+# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30
+# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40
+# CHECK-NEXT: ]
+# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44
+# CHECK-NEXT: ]
+
diff --git a/mlir/test/Examples/Toy/Ch7/codegen.toy b/mlir/test/Examples/Toy/Ch7/codegen.toy
new file mode 100644
index 00000000000..e19500bd9ae
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/codegen.toy
@@ -0,0 +1,31 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s
+
+# User defined generic function that operates on unknown shaped arguments
+def multiply_transpose(a, b) {
+ return transpose(a) * transpose(b);
+}
+
+def main() {
+ var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+ var b<2, 3> = [1, 2, 3, 4, 5, 6];
+ var c = multiply_transpose(a, b);
+ var d = multiply_transpose(b, a);
+ print(d);
+}
+
+# CHECK-LABEL: func @multiply_transpose(
+# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
+# CHECK: [[VAL_2:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<*xf64>) -> tensor<*xf64>
+# CHECK-NEXT: [[VAL_3:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64>
+# CHECK-NEXT: [[VAL_4:%.*]] = "toy.mul"([[VAL_2]], [[VAL_3]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
+# CHECK-NEXT: "toy.return"([[VAL_4]]) : (tensor<*xf64>) -> ()
+
+# CHECK-LABEL: func @main()
+# CHECK-NEXT: [[VAL_5:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
+# CHECK-NEXT: [[VAL_6:%.*]] = "toy.reshape"([[VAL_5]]) : (tensor<2x3xf64>) -> tensor<2x3xf64>
+# CHECK-NEXT: [[VAL_7:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
+# CHECK-NEXT: [[VAL_8:%.*]] = "toy.reshape"([[VAL_7]]) : (tensor<6xf64>) -> tensor<2x3xf64>
+# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_6]], [[VAL_8]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
+# CHECK-NEXT: [[VAL_10:%.*]] = "toy.generic_call"([[VAL_8]], [[VAL_6]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
+# CHECK-NEXT: "toy.print"([[VAL_10]]) : (tensor<*xf64>) -> ()
+# CHECK-NEXT: "toy.return"() : () -> ()
diff --git a/mlir/test/Examples/Toy/Ch7/invalid.mlir b/mlir/test/Examples/Toy/Ch7/invalid.mlir
new file mode 100644
index 00000000000..5d35d95a5bf
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/invalid.mlir
@@ -0,0 +1,9 @@
+// RUN: not toyc-ch7 %s -emit=mlir 2>&1
+
+// The following IR is not "valid":
+// - toy.print should not return a value.
+// - toy.print should take an argument.
+// - There should be a block terminator.
+func @main() {
+ %0 = "toy.print"() : () -> tensor<2x3xf64>
+}
diff --git a/mlir/test/Examples/Toy/Ch7/llvm-lowering.mlir b/mlir/test/Examples/Toy/Ch7/llvm-lowering.mlir
new file mode 100644
index 00000000000..0009bb507eb
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/llvm-lowering.mlir
@@ -0,0 +1,23 @@
+// RUN: toyc-ch7 %s -emit=llvm -opt
+
+func @main() {
+ %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
+ %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64>
+ %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
+ "toy.print"(%3) : (tensor<3x2xf64>) -> ()
+ "toy.return"() : () -> ()
+}
+
+// CHECK-LABEL: define void @main()
+// CHECK: @printf
+// CHECK-SAME: 1.000000e+00
+// CHECK: @printf
+// CHECK-SAME: 1.600000e+01
+// CHECK: @printf
+// CHECK-SAME: 4.000000e+00
+// CHECK: @printf
+// CHECK-SAME: 2.500000e+01
+// CHECK: @printf
+// CHECK-SAME: 9.000000e+00
+// CHECK: @printf
+// CHECK-SAME: 3.000000e+01
diff --git a/mlir/test/Examples/Toy/Ch7/scalar.toy b/mlir/test/Examples/Toy/Ch7/scalar.toy
new file mode 100644
index 00000000000..f917ea622e5
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/scalar.toy
@@ -0,0 +1,14 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s
+
+def main() {
+ var a<2, 2> = 5.5;
+ print(a);
+}
+
+# CHECK-LABEL: func @main() {
+# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<f64>} : () -> tensor<f64>
+# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor<f64>) -> tensor<2x2xf64>
+# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> ()
+# CHECK-NEXT: "toy.return"() : () -> ()
+# CHECK-NEXT: }
+
diff --git a/mlir/test/Examples/Toy/Ch7/shape_inference.mlir b/mlir/test/Examples/Toy/Ch7/shape_inference.mlir
new file mode 100644
index 00000000000..b9355cf7e25
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/shape_inference.mlir
@@ -0,0 +1,30 @@
+// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s
+
+// Check the result of inlining+shape inference on an input module.
+
+func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
+ %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64>
+ %1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64>
+ %2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
+ "toy.return"(%2) : (tensor<*xf64>) -> ()
+}
+func @main() {
+ %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
+ %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
+ %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
+ %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>
+ %4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
+ %5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
+ "toy.print"(%5) : (tensor<*xf64>) -> ()
+ "toy.return"() : () -> ()
+}
+
+// CHECK-NOT: func @multiply_transpose
+// CHECK-NOT: tensor<*xf64>
+
+// CHECK-LABEL: func @main()
+// CHECK: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
+// CHECK: [[VAL_1:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64>
+// CHECK: [[VAL_2:%.*]] = "toy.mul"([[VAL_1]], [[VAL_1]]) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
+// CHECK: "toy.print"([[VAL_2]]) : (tensor<3x2xf64>) -> ()
+// CHECK: "toy.return"() : () -> ()
diff --git a/mlir/test/Examples/Toy/Ch7/struct-ast.toy b/mlir/test/Examples/Toy/Ch7/struct-ast.toy
new file mode 100644
index 00000000000..dee0d5b0efd
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/struct-ast.toy
@@ -0,0 +1,61 @@
+# RUN: toyc-ch7 %s -emit=ast 2>&1 | FileCheck %s
+
+struct Struct {
+ var a;
+ var b;
+}
+
+# User defined generic function may operate on struct types as well.
+def multiply_transpose(Struct value) {
+ # We can access the elements of a struct via the '.' operator.
+ return transpose(value.a) * transpose(value.b);
+}
+
+def main() {
+ # We initialize struct values using a composite initializer.
+ Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]};
+
+ # We pass these arguments to functions like we do with variables.
+ var c = multiply_transpose(value);
+ print(c);
+}
+
+# CHECK: Module:
+# CHECK-NEXT: Struct: Struct @{{.*}}struct-ast.toy:3:1
+# CHECK-NEXT: Variables: [
+# CHECK-NEXT: VarDecl a<> @{{.*}}struct-ast.toy:4:3
+# CHECK-NEXT: VarDecl b<> @{{.*}}struct-ast.toy:5:3
+# CHECK-NEXT: ]
+# CHECK-NEXT:Function
+# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}struct-ast.toy:9:1'
+# CHECK-NEXT: Params: [value]
+# CHECK-NEXT: Block {
+# CHECK-NEXT: Return
+# CHECK-NEXT: BinOp: * @{{.*}}struct-ast.toy:11:31
+# CHECK-NEXT: Call 'transpose' [ @{{.*}}struct-ast.toy:11:10
+# CHECK-NEXT: BinOp: . @{{.*}}struct-ast.toy:11:26
+# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:11:20
+# CHECK-NEXT: var: a @{{.*}}struct-ast.toy:11:26
+# CHECK-NEXT: ]
+# CHECK-NEXT: Call 'transpose' [ @{{.*}}struct-ast.toy:11:31
+# CHECK-NEXT: BinOp: . @{{.*}}struct-ast.toy:11:47
+# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:11:41
+# CHECK-NEXT: var: b @{{.*}}struct-ast.toy:11:47
+# CHECK-NEXT: ]
+# CHECK-NEXT: }
+# CHECK-NEXT:Function
+# CHECK-NEXT: Proto 'main' @{{.*}}struct-ast.toy:14:1'
+# CHECK-NEXT: Params: []
+# CHECK-NEXT: Block {
+# CHECK-NEXT: VarDecl value<Struct> @{{.*}}struct-ast.toy:16:3
+# CHECK-NEXT: Struct Literal: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}struct-ast.toy:16:19
+# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}struct-ast.toy:16:43
+# CHECK-NEXT: @{{.*}}struct-ast.toy:16:18
+# CHECK-NEXT: VarDecl c<> @{{.*}}struct-ast.toy:19:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}struct-ast.toy:19:11
+# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:19:30
+# CHECK-NEXT: ]
+# CHECK-NEXT: Print [ @{{.*}}struct-ast.toy:20:3
+# CHECK-NEXT: var: c @{{.*}}struct-ast.toy:20:9
+# CHECK-NEXT: ]
+# CHECK-NEXT: } \ No newline at end of file
diff --git a/mlir/test/Examples/Toy/Ch7/struct-codegen.toy b/mlir/test/Examples/Toy/Ch7/struct-codegen.toy
new file mode 100644
index 00000000000..66eaf8a1639
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/struct-codegen.toy
@@ -0,0 +1,44 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1
+# RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
+
+struct Struct {
+ var a;
+ var b;
+}
+
+# User defined generic function may operate on struct types as well.
+def multiply_transpose(Struct value) {
+ # We can access the elements of a struct via the '.' operator.
+ return transpose(value.a) * transpose(value.b);
+}
+
+def main() {
+ # We initialize struct values using a composite initializer.
+ Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]};
+
+ # We pass these arguments to functions like we do with variables.
+ var c = multiply_transpose(value);
+ print(c);
+}
+
+# CHECK-LABEL: func @multiply_transpose(
+# CHECK-SAME: [[VAL_0:%.*]]: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+# CHECK-NEXT: [[VAL_1:%.*]] = "toy.struct_access"([[VAL_0]]) {index = 0 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64>
+# CHECK-NEXT: [[VAL_3:%.*]] = "toy.struct_access"([[VAL_0]]) {index = 1 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+# CHECK-NEXT: [[VAL_4:%.*]] = "toy.transpose"([[VAL_3]]) : (tensor<*xf64>) -> tensor<*xf64>
+# CHECK-NEXT: [[VAL_5:%.*]] = "toy.mul"([[VAL_2]], [[VAL_4]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
+# CHECK-NEXT: "toy.return"([[VAL_5]]) : (tensor<*xf64>) -> ()
+
+# CHECK-LABEL: func @main()
+# CHECK-NEXT: [[VAL_6:%.*]] = "toy.struct_constant"() {value = [dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct<tensor<*xf64>, tensor<*xf64>>
+# CHECK-NEXT: [[VAL_7:%.*]] = "toy.generic_call"([[VAL_6]]) {callee = @multiply_transpose} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
+# CHECK-NEXT: "toy.print"([[VAL_7]]) : (tensor<*xf64>) -> ()
+# CHECK-NEXT: "toy.return"() : () -> ()
+
+# OPT-LABEL: func @main()
+# OPT-NEXT: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
+# OPT-NEXT: [[VAL_1:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64>
+# OPT-NEXT: [[VAL_2:%.*]] = "toy.mul"([[VAL_1]], [[VAL_1]]) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
+# OPT-NEXT: "toy.print"([[VAL_2]]) : (tensor<3x2xf64>) -> ()
+# OPT-NEXT: "toy.return"() : () -> ()
diff --git a/mlir/test/Examples/Toy/Ch7/struct-opt.mlir b/mlir/test/Examples/Toy/Ch7/struct-opt.mlir
new file mode 100644
index 00000000000..8c4b055b4bf
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch7/struct-opt.mlir
@@ -0,0 +1,16 @@
+// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s
+
+func @main() {
+ %0 = "toy.struct_constant"() {
+ value = [[dense<4.000000e+00> : tensor<2x2xf64>], dense<4.000000e+00> : tensor<2x2xf64>]
+ } : () -> !toy.struct<!toy.struct<tensor<*xf64>>, tensor<*xf64>>
+ %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct<!toy.struct<tensor<*xf64>>, tensor<*xf64>>) -> !toy.struct<tensor<*xf64>>
+ %2 = "toy.struct_access"(%1) {index = 0 : i64} : (!toy.struct<tensor<*xf64>>) -> tensor<*xf64>
+ "toy.print"(%2) : (tensor<*xf64>) -> ()
+ "toy.return"() : () -> ()
+}
+
+// CHECK-LABEL: func @main
+// CHECK-NEXT: %[[CST:.*]] = "toy.constant"
+// CHECK-SAME: dense<4.0
+// CHECK-NEXT: "toy.print"(%[[CST]])
OpenPOWER on IntegriCloud