summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMehdi Amini <aminim@google.com>2019-04-03 18:45:01 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-04-03 19:22:32 -0700
commitf0a328b6d5e0b1e4679352c7b3c37d3fe7de80e7 (patch)
tree39c3aba8ce1e22ed086d23fe47133a599c959cb9
parent393c77c5da883e9cc53d43e0e9abc5db78bcbf69 (diff)
downloadbcm5719-llvm-f0a328b6d5e0b1e4679352c7b3c37d3fe7de80e7.tar.gz
bcm5719-llvm-f0a328b6d5e0b1e4679352c7b3c37d3fe7de80e7.zip
Chapter 3 for Toy tutorial: introduction of a dialect
-- PiperOrigin-RevId: 241849162
-rw-r--r--mlir/examples/toy/CMakeLists.txt1
-rw-r--r--mlir/examples/toy/Ch2/mlir/MLIRGen.cpp6
-rw-r--r--mlir/examples/toy/Ch3/CMakeLists.txt17
-rw-r--r--mlir/examples/toy/Ch3/include/toy/AST.h256
-rw-r--r--mlir/examples/toy/Ch3/include/toy/Dialect.h324
-rw-r--r--mlir/examples/toy/Ch3/include/toy/Lexer.h239
-rw-r--r--mlir/examples/toy/Ch3/include/toy/MLIRGen.h42
-rw-r--r--mlir/examples/toy/Ch3/include/toy/Parser.h494
-rw-r--r--mlir/examples/toy/Ch3/mlir/MLIRGen.cpp480
-rw-r--r--mlir/examples/toy/Ch3/mlir/ToyDialect.cpp393
-rw-r--r--mlir/examples/toy/Ch3/parser/AST.cpp263
-rw-r--r--mlir/examples/toy/Ch3/toyc.cpp139
-rw-r--r--mlir/g3doc/Tutorials/Toy/Ch-3.md297
-rw-r--r--mlir/include/mlir/IR/DialectTypeRegistry.def1
-rw-r--r--mlir/test/Examples/Toy/Ch2/codegen.toy4
-rw-r--r--mlir/test/Examples/Toy/Ch2/invalid.mlir2
-rw-r--r--mlir/test/Examples/Toy/Ch3/ast.toy73
-rw-r--r--mlir/test/Examples/Toy/Ch3/codegen.toy32
-rw-r--r--mlir/test/Examples/Toy/Ch3/invalid.mlir11
-rw-r--r--mlir/test/Examples/Toy/Ch3/scalar.toy12
20 files changed, 3077 insertions, 9 deletions
diff --git a/mlir/examples/toy/CMakeLists.txt b/mlir/examples/toy/CMakeLists.txt
index b70c371ad53..d50aa185d86 100644
--- a/mlir/examples/toy/CMakeLists.txt
+++ b/mlir/examples/toy/CMakeLists.txt
@@ -8,3 +8,4 @@ endmacro(add_toy_chapter name)
add_subdirectory(Ch1)
add_subdirectory(Ch2)
+add_subdirectory(Ch3)
diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
index 5bd80738a22..062f88aa34a 100644
--- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
@@ -360,12 +360,6 @@ private:
mlir::OperationState result(&context, location, "toy.generic_call");
result.types.push_back(getType(VarType{}));
result.operands = std::move(operands);
- for (auto &expr : call.getArgs()) {
- auto *arg = mlirGen(*expr);
- if (!arg)
- return nullptr;
- result.operands.push_back(arg);
- }
auto calleeAttr = builder->getStringAttr(call.getCallee());
result.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
return builder->createOperation(result)->getResult(0);
diff --git a/mlir/examples/toy/Ch3/CMakeLists.txt b/mlir/examples/toy/Ch3/CMakeLists.txt
new file mode 100644
index 00000000000..060f3dd26ec
--- /dev/null
+++ b/mlir/examples/toy/Ch3/CMakeLists.txt
@@ -0,0 +1,17 @@
+set(LLVM_LINK_COMPONENTS
+ Support
+ )
+
+add_toy_chapter(toyc-ch3
+ toyc.cpp
+ parser/AST.cpp
+ mlir/MLIRGen.cpp
+ mlir/ToyDialect.cpp
+ )
+include_directories(include/)
+target_link_libraries(toyc-ch3
+ PRIVATE
+ MLIRAnalysis
+ MLIRIR
+ MLIRParser
+ MLIRTransforms)
diff --git a/mlir/examples/toy/Ch3/include/toy/AST.h b/mlir/examples/toy/Ch3/include/toy/AST.h
new file mode 100644
index 00000000000..456a32309c4
--- /dev/null
+++ b/mlir/examples/toy/Ch3/include/toy/AST.h
@@ -0,0 +1,256 @@
+//===- AST.h - Node definition for the Toy AST ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST for the Toy language. It is optimized for
+// simplicity, not efficiency. The AST forms a tree structure where each node
+// references its children using std::unique_ptr<>.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_AST_H_
+#define MLIR_TUTORIAL_TOY_AST_H_
+
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <vector>
+
+namespace toy {
+
+/// A variable
+struct VarType {
+ enum { TY_FLOAT, TY_INT } elt_ty;
+ std::vector<int> shape;
+};
+
+/// Base class for all expression nodes.
+class ExprAST {
+public:
+ enum ExprASTKind {
+ Expr_VarDecl,
+ Expr_Return,
+ Expr_Num,
+ Expr_Literal,
+ Expr_Var,
+ Expr_BinOp,
+ Expr_Call,
+ Expr_Print, // builtin
+ Expr_If,
+ Expr_For,
+ };
+
+ ExprAST(ExprASTKind kind, Location location)
+ : kind(kind), location(location) {}
+
+ virtual ~ExprAST() = default;
+
+ ExprASTKind getKind() const { return kind; }
+
+ const Location &loc() { return location; }
+
+private:
+ const ExprASTKind kind;
+ Location location;
+};
+
+/// A block-list of expressions.
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+/// Expression class for numeric literals like "1.0".
+class NumberExprAST : public ExprAST {
+ double Val;
+
+public:
+ NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+
+ double getValue() { return Val; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+};
+
+///
+class LiteralExprAST : public ExprAST {
+ std::vector<std::unique_ptr<ExprAST>> values;
+ std::vector<int64_t> dims;
+
+public:
+ LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
+ std::vector<int64_t> dims)
+ : ExprAST(Expr_Literal, loc), values(std::move(values)),
+ dims(std::move(dims)) {}
+
+ std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
+ std::vector<int64_t> &getDims() { return dims; }
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+};
+
+/// Expression class for referencing a variable, like "a".
+class VariableExprAST : public ExprAST {
+ std::string name;
+
+public:
+ VariableExprAST(Location loc, const std::string &name)
+ : ExprAST(Expr_Var, loc), name(name) {}
+
+ llvm::StringRef getName() { return name; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+};
+
+///
+class VarDeclExprAST : public ExprAST {
+ std::string name;
+ VarType type;
+ std::unique_ptr<ExprAST> initVal;
+
+public:
+ VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ std::unique_ptr<ExprAST> initVal)
+ : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
+ initVal(std::move(initVal)) {}
+
+ llvm::StringRef getName() { return name; }
+ ExprAST *getInitVal() { return initVal.get(); }
+ VarType &getType() { return type; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+};
+
+///
+class ReturnExprAST : public ExprAST {
+ llvm::Optional<std::unique_ptr<ExprAST>> expr;
+
+public:
+ ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
+ : ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+ llvm::Optional<ExprAST *> getExpr() {
+ if (expr.hasValue())
+ return expr->get();
+ return llvm::NoneType();
+ }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+};
+
+/// Expression class for a binary operator.
+class BinaryExprAST : public ExprAST {
+ char Op;
+ std::unique_ptr<ExprAST> LHS, RHS;
+
+public:
+ char getOp() { return Op; }
+ ExprAST *getLHS() { return LHS.get(); }
+ ExprAST *getRHS() { return RHS.get(); }
+
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
+ std::unique_ptr<ExprAST> RHS)
+ : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
+ RHS(std::move(RHS)) {}
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+};
+
+/// Expression class for function calls.
+class CallExprAST : public ExprAST {
+ std::string Callee;
+ std::vector<std::unique_ptr<ExprAST>> Args;
+
+public:
+ CallExprAST(Location loc, const std::string &Callee,
+ std::vector<std::unique_ptr<ExprAST>> Args)
+ : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+
+ llvm::StringRef getCallee() { return Callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+};
+
+/// Expression class for builtin print calls.
+class PrintExprAST : public ExprAST {
+ std::unique_ptr<ExprAST> Arg;
+
+public:
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
+ : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+
+ ExprAST *getArg() { return Arg.get(); }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+};
+
+/// This class represents the "prototype" for a function, which captures its
+/// name, and its argument names (thus implicitly the number of arguments the
+/// function takes).
+class PrototypeAST {
+ Location location;
+ std::string name;
+ std::vector<std::unique_ptr<VariableExprAST>> args;
+
+public:
+ PrototypeAST(Location location, const std::string &name,
+ std::vector<std::unique_ptr<VariableExprAST>> args)
+ : location(location), name(name), args(std::move(args)) {}
+
+ const Location &loc() { return location; }
+ const std::string &getName() const { return name; }
+ const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
+ return args;
+ }
+};
+
+/// This class represents a function definition itself.
+class FunctionAST {
+ std::unique_ptr<PrototypeAST> Proto;
+ std::unique_ptr<ExprASTList> Body;
+
+public:
+ FunctionAST(std::unique_ptr<PrototypeAST> Proto,
+ std::unique_ptr<ExprASTList> Body)
+ : Proto(std::move(Proto)), Body(std::move(Body)) {}
+ PrototypeAST *getProto() { return Proto.get(); }
+ ExprASTList *getBody() { return Body.get(); }
+};
+
+/// This class represents a list of functions to be processed together
+class ModuleAST {
+ std::vector<FunctionAST> functions;
+
+public:
+ ModuleAST(std::vector<FunctionAST> functions)
+ : functions(std::move(functions)) {}
+
+ auto begin() -> decltype(functions.begin()) { return functions.begin(); }
+ auto end() -> decltype(functions.end()) { return functions.end(); }
+};
+
+void dump(ModuleAST &);
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_AST_H_
diff --git a/mlir/examples/toy/Ch3/include/toy/Dialect.h b/mlir/examples/toy/Ch3/include/toy/Dialect.h
new file mode 100644
index 00000000000..cc867700b68
--- /dev/null
+++ b/mlir/examples/toy/Ch3/include/toy/Dialect.h
@@ -0,0 +1,324 @@
+//===- AST.h - Node definition for the Toy AST ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST for the Toy language. It is optimized for
+// simplicity, not efficiency. The AST forms a tree structure where each node
+// references its children using std::unique_ptr<>.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
+#define MLIR_TUTORIAL_TOY_DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+class FuncBuilder;
+}
+
+namespace toy {
+
+/// This is the definition of the Toy dialect. A dialect inherits from
+/// mlir::Dialect and register custom operations and types (in its constructor).
+/// It can also overridding general behavior of dialects exposed as virtual
+/// method, for example regarding verification and parsing/printing.
+class ToyDialect : public mlir::Dialect {
+public:
+ explicit ToyDialect(mlir::MLIRContext *ctx);
+
+ /// Parse a type registered to this dialect. Overridding this method is
+ /// required for dialects that have custom types.
+ /// Technically this is only needed to be able to round-trip to textual IR.
+ mlir::Type parseType(llvm::StringRef tyData,
+ mlir::Location loc) const override;
+
+ /// Print a type registered to this dialect. Overridding this method is
+ /// only required for dialects that have custom types.
+ /// Technically this is only needed to be able to round-trip to textual IR.
+ void printType(mlir::Type type, llvm::raw_ostream &os) const override;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+/////////////////////// Custom Types for the Dialect ///////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+struct ToyArrayTypeStorage;
+}
+
+/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
+enum ToyTypeKind {
+ // The enum starts at the range reserved for this dialect.
+ TOY_TYPE = mlir::OpaqueType::FIRST_TOY_TYPE,
+ TOY_ARRAY,
+};
+
+/// Type for Toy arrays.
+/// In MLIR Types are reference to immutable and uniqued objects owned by the
+/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
+/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
+/// provide the public facade API to interact with the type.
+class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
+ detail::ToyArrayTypeStorage> {
+public:
+ using Base::Base;
+
+ /// Returns the dimensions for this array, or and empty range for a generic
+ /// array.
+ llvm::ArrayRef<int64_t> getShape();
+
+ /// Predicate to test if this array is generic (shape haven't been inferred
+ /// yet).
+ bool isGeneric() { return getShape().empty(); }
+
+ /// Return the rank of this array (0 if it is generic)
+ int getRank() { return getShape().size(); }
+
+ /// Return the type of individual elements in the array.
+ mlir::Type getElementType();
+
+ /// Get the unique instance of this Type from the context.
+ /// A ToyArrayType is only defined by the shape of the array.
+ static ToyArrayType get(mlir::MLIRContext *context,
+ llvm::ArrayRef<int64_t> shape = {});
+
+ /// Support method to enable LLVM-style RTTI type casting.
+ static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+//////////////////// Custom Operations for the Dialect /////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+
+/// Constant operation turns a literal into an SSA value. The data is attached
+/// to the operation as an attribute. For example:
+///
+/// %0 = "toy.constant"()
+/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
+/// : () -> !toy<"array<2, 3>">
+///
+/// An operation inherit from `class Op` and specify optional traits. Here we
+/// indicate that `toy.constant` does not have any operand and return a single
+/// result. The traits are making some methods available on the operation, for
+/// instance we will be able to use `getResult()` but `getOperand()` won't be
+/// available.
+class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ /// This is the name used by MLIR to match an operation to this class during
+ /// parsing.
+ static llvm::StringRef getOperationName() { return "toy.constant"; }
+
+ /// Operation can have extra verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::FuncBuilder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.constant` operation does not have arguments but attaches a
+ /// constant array as attribute and returns it as an SSA value.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ llvm::ArrayRef<int64_t> shape,
+ mlir::DenseElementsAttr value);
+
+ /// Similar to the one above, but takes a single float and returns a
+ /// !toy<"array<1>">.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::FloatAttr value);
+
+ /// Inherit Constructor
+ using Op::Op;
+};
+
+/// Generic calls are representing calls to a user defined function that need to
+/// be specialized for the shape of its arguments. The callee name is attached
+/// as a literal string as an attribute. The arguments list must match the
+/// arguments expected by the callee. For example:
+///
+/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
+/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+///
+/// This is only valid if a function named "my_func" exists and takes two
+/// arguments.
+class GenericCallOp
+ : public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::OneResult> {
+public:
+ /// MLIR will use this to register the operation with the parser/printer.
+ static llvm::StringRef getOperationName() { return "toy.generic_call"; }
+
+ /// Operations can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to the builder to allow:
+ /// mlir::FuncBuilder::create<GenericCallOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.generic_call` operation accepts a callee name and a list of
+ /// arguments for the call.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ llvm::StringRef callee,
+ llvm::ArrayRef<mlir::Value *> arguments);
+
+ /// Return the name of the callee.
+ llvm::StringRef getCalleeName();
+
+ /// Inherit Constructor
+ using Op::Op;
+};
+
+/// Return operation terminates blocks (and function as well). They take a
+/// single argument and the type must match the function return type.
+class ReturnOp
+ : public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.return"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::FuncBuilder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.return` operation accepts an optional single array as argument
+ /// and does not have any returned value.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *value = nullptr);
+
+ /// Return true if there is a returned value.
+ bool hasOperand() { return 0 != getNumOperands(); }
+
+ /// Helper to return the optional operand. Caller must check if the operand
+ /// is present before calling this.
+ mlir::Value *getOperand() { return getOperation()->getOperand(0); }
+
+ /// Inherit Constructor
+ using Op::Op;
+};
+
+/// The print builtin takes a single array argument and does not return any.
+class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
+ mlir::OpTrait::ZeroResult> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.print"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::FuncBuilder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.print` operation accepts a single array as argument and does
+ /// not have any returned value.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *value);
+
+ /// Inherit Constructor
+ using Op::Op;
+};
+
+class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
+ mlir::OpTrait::OneResult> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.transpose"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::FuncBuilder::create<TransposeOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.transpose` operation accepts a single array as argument and
+ /// returns the transposed array as its only result.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *value);
+
+ /// Inherit Constructor
+ using Op::Op;
+};
+
+/// Reshape operation is transforming its input array into a new array with the
+/// same number of elements but different shapes. For example:
+///
+/// %0 = "toy.transpose"(%arg1) : (!toy<"array<10>">) -> !toy<"array<5, 2>">
+///
+class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
+ mlir::OpTrait::OneResult> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.reshape"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::FuncBuilder::create<ReshapeOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.reshape` operation accepts a single array as argument and
+ /// returns the array with the specified reshapedType as its only result.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *value, ToyArrayType reshapedType);
+
+ /// Inherit Constructor
+ using Op::Op;
+};
+
+/// Binary operation implementing a multiplication. For two-dimensional array
+/// a matrix multiplication is implemented, while for one dimensional array a
+/// dot product is performed.
+class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
+ mlir::OpTrait::OneResult> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.mul"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::FuncBuilder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.mul` operation accepts two operands as argument and returns
+ /// a single value.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *lhs, mlir::Value *rhs);
+
+ /// Inherit Constructor
+ using Op::Op;
+};
+
+/// Element wise addition of two arrays. The shape must match.
+class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<3>::Impl,
+ mlir::OpTrait::OneResult> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.add"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::FuncBuilder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.mul` operation accepts two operands as argument and returns
+ /// a single value.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *lhs, mlir::Value *rhs);
+
+ /// Inherit Constructor
+ using Op::Op;
+};
+
+} // end namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
diff --git a/mlir/examples/toy/Ch3/include/toy/Lexer.h b/mlir/examples/toy/Ch3/include/toy/Lexer.h
new file mode 100644
index 00000000000..d73adb9706b
--- /dev/null
+++ b/mlir/examples/toy/Ch3/include/toy/Lexer.h
@@ -0,0 +1,239 @@
+//===- Lexer.h - Lexer for the Toy language -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple Lexer for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_LEXER_H_
+#define MLIR_TUTORIAL_TOY_LEXER_H_
+
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+#include <string>
+
+namespace toy {
+
+/// Structure definition a location in a file.
+struct Location {
+ std::shared_ptr<std::string> file; ///< filename
+ int line; ///< line number.
+ int col; ///< column number.
+};
+
+// List of Token returned by the lexer.
+enum Token : int {
+ tok_semicolon = ';',
+ tok_parenthese_open = '(',
+ tok_parenthese_close = ')',
+ tok_bracket_open = '{',
+ tok_bracket_close = '}',
+ tok_sbracket_open = '[',
+ tok_sbracket_close = ']',
+
+ tok_eof = -1,
+
+ // commands
+ tok_return = -2,
+ tok_var = -3,
+ tok_def = -4,
+
+ // primary
+ tok_identifier = -5,
+ tok_number = -6,
+};
+
+/// The Lexer is an abstract base class providing all the facilities that the
+/// Parser expects. It goes through the stream one token at a time and keeps
+/// track of the location in the file for debugging purpose.
+/// It relies on a subclass to provide a `readNextLine()` method. The subclass
+/// can proceed by reading the next line from the standard input or from a
+/// memory mapped file.
+class Lexer {
+public:
+ /// Create a lexer for the given filename. The filename is kept only for
+ /// debugging purpose (attaching a location to a Token).
+ Lexer(std::string filename)
+ : lastLocation(
+ {std::make_shared<std::string>(std::move(filename)), 0, 0}) {}
+ virtual ~Lexer() = default;
+
+ /// Look at the current token in the stream.
+ Token getCurToken() { return curTok; }
+
+ /// Move to the next token in the stream and return it.
+ Token getNextToken() { return curTok = getTok(); }
+
+ /// Move to the next token in the stream, asserting on the current token
+ /// matching the expectation.
+ void consume(Token tok) {
+ assert(tok == curTok && "consume Token mismatch expectation");
+ getNextToken();
+ }
+
+ /// Return the current identifier (prereq: getCurToken() == tok_identifier)
+ llvm::StringRef getId() {
+ assert(curTok == tok_identifier);
+ return IdentifierStr;
+ }
+
+ /// Return the current number (prereq: getCurToken() == tok_number)
+ double getValue() {
+ assert(curTok == tok_number);
+ return NumVal;
+ }
+
+ /// Return the location for the beginning of the current token.
+ Location getLastLocation() { return lastLocation; }
+
+ // Return the current line in the file.
+ int getLine() { return curLineNum; }
+
+ // Return the current column in the file.
+ int getCol() { return curCol; }
+
+private:
+ /// Delegate to a derived class fetching the next line. Returns an empty
+ /// string to signal end of file (EOF). Lines are expected to always finish
+ /// with "\n"
+ virtual llvm::StringRef readNextLine() = 0;
+
+ /// Return the next character from the stream. This manages the buffer for the
+ /// current line and request the next line buffer to the derived class as
+ /// needed.
+ int getNextChar() {
+ // The current line buffer should not be empty unless it is the end of file.
+ if (curLineBuffer.empty())
+ return EOF;
+ ++curCol;
+ auto nextchar = curLineBuffer.front();
+ curLineBuffer = curLineBuffer.drop_front();
+ if (curLineBuffer.empty())
+ curLineBuffer = readNextLine();
+ if (nextchar == '\n') {
+ ++curLineNum;
+ curCol = 0;
+ }
+ return nextchar;
+ }
+
+ /// Return the next token from standard input.
+ Token getTok() {
+ // Skip any whitespace.
+ while (isspace(LastChar))
+ LastChar = Token(getNextChar());
+
+ // Save the current location before reading the token characters.
+ lastLocation.line = curLineNum;
+ lastLocation.col = curCol;
+
+ if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
+ IdentifierStr = (char)LastChar;
+ while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
+ IdentifierStr += (char)LastChar;
+
+ if (IdentifierStr == "return")
+ return tok_return;
+ if (IdentifierStr == "def")
+ return tok_def;
+ if (IdentifierStr == "var")
+ return tok_var;
+ return tok_identifier;
+ }
+
+ if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
+ std::string NumStr;
+ do {
+ NumStr += LastChar;
+ LastChar = Token(getNextChar());
+ } while (isdigit(LastChar) || LastChar == '.');
+
+ NumVal = strtod(NumStr.c_str(), nullptr);
+ return tok_number;
+ }
+
+ if (LastChar == '#') {
+ // Comment until end of line.
+ do
+ LastChar = Token(getNextChar());
+ while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+
+ if (LastChar != EOF)
+ return getTok();
+ }
+
+ // Check for end of file. Don't eat the EOF.
+ if (LastChar == EOF)
+ return tok_eof;
+
+ // Otherwise, just return the character as its ascii value.
+ Token ThisChar = Token(LastChar);
+ LastChar = Token(getNextChar());
+ return ThisChar;
+ }
+
+ /// The last token read from the input.
+ Token curTok = tok_eof;
+
+ /// Location for `curTok`.
+ Location lastLocation;
+
+ /// If the current Token is an identifier, this string contains the value.
+ std::string IdentifierStr;
+
+ /// If the current Token is a number, this contains the value.
+ double NumVal = 0;
+
+ /// The last value returned by getNextChar(). We need to keep it around as we
+ /// always need to read ahead one character to decide when to end a token and
+ /// we can't put it back in the stream after reading from it.
+ Token LastChar = Token(' ');
+
+ /// Keep track of the current line number in the input stream
+ int curLineNum = 0;
+
+ /// Keep track of the current column number in the input stream
+ int curCol = 0;
+
+ /// Buffer supplied by the derived class on calls to `readNextLine()`
+ llvm::StringRef curLineBuffer = "\n";
+};
+
+/// A lexer implementation operating on a buffer in memory.
+class LexerBuffer final : public Lexer {
+public:
+ LexerBuffer(const char *begin, const char *end, std::string filename)
+ : Lexer(std::move(filename)), current(begin), end(end) {}
+
+private:
+ /// Provide one line at a time to the Lexer, return an empty string when
+ /// reaching the end of the buffer.
+ llvm::StringRef readNextLine() override {
+ auto *begin = current;
+ while (current <= end && *current && *current != '\n')
+ ++current;
+ if (current <= end && *current)
+ ++current;
+ llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
+ return result;
+ }
+ const char *current, *end;
+};
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_LEXER_H_
diff --git a/mlir/examples/toy/Ch3/include/toy/MLIRGen.h b/mlir/examples/toy/Ch3/include/toy/MLIRGen.h
new file mode 100644
index 00000000000..21637bc19af
--- /dev/null
+++ b/mlir/examples/toy/Ch3/include/toy/MLIRGen.h
@@ -0,0 +1,42 @@
+//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares a simple interface to perform IR generation targeting MLIR
+// from a Module AST for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_
+#define MLIR_TUTORIAL_TOY_MLIRGEN_H_
+
+#include <memory>
+
+namespace mlir {
+class MLIRContext;
+class Module;
+} // namespace mlir
+
+namespace toy {
+class ModuleAST;
+
+/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
+/// or nullptr on failure.
+std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST);
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
diff --git a/mlir/examples/toy/Ch3/include/toy/Parser.h b/mlir/examples/toy/Ch3/include/toy/Parser.h
new file mode 100644
index 00000000000..bc7aa520624
--- /dev/null
+++ b/mlir/examples/toy/Ch3/include/toy/Parser.h
@@ -0,0 +1,494 @@
+//===- Parser.h - Toy Language Parser -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the parser for the Toy language. It processes the Token
+// provided by the Lexer and returns an AST.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PARSER_H
+#define MLIR_TUTORIAL_TOY_PARSER_H
+
+#include "toy/AST.h"
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <utility>
+#include <vector>
+
+namespace toy {
+
+/// This is a simple recursive parser for the Toy language. It produces a well
+/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
+/// or symbol resolution is performed. For example, variables are referenced by
+/// string and the code could reference an undeclared variable and the parsing
+/// succeeds.
+class Parser {
+public:
+ /// Create a Parser for the supplied lexer.
+ Parser(Lexer &lexer) : lexer(lexer) {}
+
+ /// Parse a full Module. A module is a list of function definitions.
+ std::unique_ptr<ModuleAST> ParseModule() {
+ lexer.getNextToken(); // prime the lexer
+
+ // Parse functions one at a time and accumulate in this vector.
+ std::vector<FunctionAST> functions;
+ while (auto F = ParseDefinition()) {
+ functions.push_back(std::move(*F));
+ if (lexer.getCurToken() == tok_eof)
+ break;
+ }
+ // If we didn't reach EOF, there was an error during parsing
+ if (lexer.getCurToken() != tok_eof)
+ return parseError<ModuleAST>("nothing", "at end of module");
+
+ return llvm::make_unique<ModuleAST>(std::move(functions));
+ }
+
+private:
+ Lexer &lexer;
+
+ /// Parse a return statement.
+ /// return :== return ; | return expr ;
+ std::unique_ptr<ReturnExprAST> ParseReturn() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_return);
+
+ // return takes an optional argument
+ llvm::Optional<std::unique_ptr<ExprAST>> expr;
+ if (lexer.getCurToken() != ';') {
+ expr = ParseExpression();
+ if (!expr)
+ return nullptr;
+ }
+ return llvm::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
+ }
+
+ /// Parse a literal number.
+ /// numberexpr ::= number
+ std::unique_ptr<ExprAST> ParseNumberExpr() {
+ auto loc = lexer.getLastLocation();
+ auto Result =
+ llvm::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
+ lexer.consume(tok_number);
+ return std::move(Result);
+ }
+
+ /// Parse a literal array expression.
+ /// tensorLiteral ::= [ literalList ] | number
+ /// literalList ::= tensorLiteral | tensorLiteral, literalList
+ std::unique_ptr<ExprAST> ParseTensorLitteralExpr() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(Token('['));
+
+ // Hold the list of values at this nesting level.
+ std::vector<std::unique_ptr<ExprAST>> values;
+ // Hold the dimensions for all the nesting inside this level.
+ std::vector<int64_t> dims;
+ do {
+ // We can have either another nested array or a number literal.
+ if (lexer.getCurToken() == '[') {
+ values.push_back(ParseTensorLitteralExpr());
+ if (!values.back())
+ return nullptr; // parse error in the nested array.
+ } else {
+ if (lexer.getCurToken() != tok_number)
+ return parseError<ExprAST>("<num> or [", "in literal expression");
+ values.push_back(ParseNumberExpr());
+ }
+
+ // End of this list on ']'
+ if (lexer.getCurToken() == ']')
+ break;
+
+ // Elements are separated by a comma.
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>("] or ,", "in literal expression");
+
+ lexer.getNextToken(); // eat ,
+ } while (true);
+ if (values.empty())
+ return parseError<ExprAST>("<something>", "to fill literal expression");
+ lexer.getNextToken(); // eat ]
+ /// Fill in the dimensions now. First the current nesting level:
+ dims.push_back(values.size());
+ /// If there is any nested array, process all of them and ensure that
+ /// dimensions are uniform.
+ if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
+ return llvm::isa<LiteralExprAST>(expr.get());
+ })) {
+ auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
+ if (!firstLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expession");
+
+ // Append the nested dimensions to the current level
+ auto &firstDims = firstLiteral->getDims();
+ dims.insert(dims.end(), firstDims.begin(), firstDims.end());
+
+ // Sanity check that shape is uniform across all elements of the list.
+ for (auto &expr : values) {
+ auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
+ if (!exprLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expession");
+ if (exprLiteral->getDims() != firstDims)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expession");
+ }
+ }
+ return llvm::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
+ std::move(dims));
+ }
+
+ /// parenexpr ::= '(' expression ')'
+ std::unique_ptr<ExprAST> ParseParenExpr() {
+ lexer.getNextToken(); // eat (.
+ auto V = ParseExpression();
+ if (!V)
+ return nullptr;
+
+ if (lexer.getCurToken() != ')')
+ return parseError<ExprAST>(")", "to close expression with parentheses");
+ lexer.consume(Token(')'));
+ return V;
+ }
+
+ /// identifierexpr
+ /// ::= identifier
+ /// ::= identifier '(' expression ')'
+ std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::string name = lexer.getId();
+
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat identifier.
+
+ if (lexer.getCurToken() != '(') // Simple variable ref.
+ return llvm::make_unique<VariableExprAST>(std::move(loc), name);
+
+ // This is a function call.
+ lexer.consume(Token('('));
+ std::vector<std::unique_ptr<ExprAST>> Args;
+ if (lexer.getCurToken() != ')') {
+ while (true) {
+ if (auto Arg = ParseExpression())
+ Args.push_back(std::move(Arg));
+ else
+ return nullptr;
+
+ if (lexer.getCurToken() == ')')
+ break;
+
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>(", or )", "in argument list");
+ lexer.getNextToken();
+ }
+ }
+ lexer.consume(Token(')'));
+
+ // It can be a builtin call to print
+ if (name == "print") {
+ if (Args.size() != 1)
+ return parseError<ExprAST>("<single arg>", "as argument to print()");
+
+ return llvm::make_unique<PrintExprAST>(std::move(loc),
+ std::move(Args[0]));
+ }
+
+ // Call to a user-defined function
+ return llvm::make_unique<CallExprAST>(std::move(loc), name,
+ std::move(Args));
+ }
+
+ /// primary
+ /// ::= identifierexpr
+ /// ::= numberexpr
+ /// ::= parenexpr
+ /// ::= tensorliteral
+ std::unique_ptr<ExprAST> ParsePrimary() {
+ switch (lexer.getCurToken()) {
+ default:
+ llvm::errs() << "unknown token '" << lexer.getCurToken()
+ << "' when expecting an expression\n";
+ return nullptr;
+ case tok_identifier:
+ return ParseIdentifierExpr();
+ case tok_number:
+ return ParseNumberExpr();
+ case '(':
+ return ParseParenExpr();
+ case '[':
+ return ParseTensorLitteralExpr();
+ case ';':
+ return nullptr;
+ case '}':
+ return nullptr;
+ }
+ }
+
+ /// Recursively parse the right hand side of a binary expression, the ExprPrec
+ /// argument indicates the precedence of the current binary operator.
+ ///
+ /// binoprhs ::= ('+' primary)*
+ std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
+ std::unique_ptr<ExprAST> LHS) {
+ // If this is a binop, find its precedence.
+ while (true) {
+ int TokPrec = GetTokPrecedence();
+
+ // If this is a binop that binds at least as tightly as the current binop,
+ // consume it, otherwise we are done.
+ if (TokPrec < ExprPrec)
+ return LHS;
+
+ // Okay, we know this is a binop.
+ int BinOp = lexer.getCurToken();
+ lexer.consume(Token(BinOp));
+ auto loc = lexer.getLastLocation();
+
+ // Parse the primary expression after the binary operator.
+ auto RHS = ParsePrimary();
+ if (!RHS)
+ return parseError<ExprAST>("expression", "to complete binary operator");
+
+ // If BinOp binds less tightly with RHS than the operator after RHS, let
+ // the pending operator take RHS as its LHS.
+ int NextPrec = GetTokPrecedence();
+ if (TokPrec < NextPrec) {
+ RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
+ if (!RHS)
+ return nullptr;
+ }
+
+ // Merge LHS/RHS.
+ LHS = llvm::make_unique<BinaryExprAST>(std::move(loc), BinOp,
+ std::move(LHS), std::move(RHS));
+ }
+ }
+
+ /// expression::= primary binoprhs
+ std::unique_ptr<ExprAST> ParseExpression() {
+ auto LHS = ParsePrimary();
+ if (!LHS)
+ return nullptr;
+
+ return ParseBinOpRHS(0, std::move(LHS));
+ }
+
+ /// type ::= < shape_list >
+ /// shape_list ::= num | num , shape_list
+ std::unique_ptr<VarType> ParseType() {
+ if (lexer.getCurToken() != '<')
+ return parseError<VarType>("<", "to begin type");
+ lexer.getNextToken(); // eat <
+
+ auto type = llvm::make_unique<VarType>();
+
+ while (lexer.getCurToken() == tok_number) {
+ type->shape.push_back(lexer.getValue());
+ lexer.getNextToken();
+ if (lexer.getCurToken() == ',')
+ lexer.getNextToken();
+ }
+
+ if (lexer.getCurToken() != '>')
+ return parseError<VarType>(">", "to end type");
+ lexer.getNextToken(); // eat >
+ return type;
+ }
+
+ /// Parse a variable declaration, it starts with a `var` keyword followed by
+ /// and identifier and an optional type (shape specification) before the
+ /// initializer.
+ /// decl ::= var identifier [ type ] = expr
+ std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ if (lexer.getCurToken() != tok_var)
+ return parseError<VarDeclExprAST>("var", "to begin declaration");
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat var
+
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<VarDeclExprAST>("identified",
+ "after 'var' declaration");
+ std::string id = lexer.getId();
+ lexer.getNextToken(); // eat id
+
+ std::unique_ptr<VarType> type; // Type is optional, it can be inferred
+ if (lexer.getCurToken() == '<') {
+ type = ParseType();
+ if (!type)
+ return nullptr;
+ }
+
+ if (!type)
+ type = llvm::make_unique<VarType>();
+ lexer.consume(Token('='));
+ auto expr = ParseExpression();
+ return llvm::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
+ std::move(*type), std::move(expr));
+ }
+
+ /// Parse a block: a list of expression separated by semicolons and wrapped in
+ /// curly braces.
+ ///
+ /// block ::= { expression_list }
+ /// expression_list ::= block_expr ; expression_list
+ /// block_expr ::= decl | "return" | expr
+ std::unique_ptr<ExprASTList> ParseBlock() {
+ if (lexer.getCurToken() != '{')
+ return parseError<ExprASTList>("{", "to begin block");
+ lexer.consume(Token('{'));
+
+ auto exprList = llvm::make_unique<ExprASTList>();
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+
+ while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
+ if (lexer.getCurToken() == tok_var) {
+ // Variable declaration
+ auto varDecl = ParseDeclaration();
+ if (!varDecl)
+ return nullptr;
+ exprList->push_back(std::move(varDecl));
+ } else if (lexer.getCurToken() == tok_return) {
+ // Return statement
+ auto ret = ParseReturn();
+ if (!ret)
+ return nullptr;
+ exprList->push_back(std::move(ret));
+ } else {
+ // General expression
+ auto expr = ParseExpression();
+ if (!expr)
+ return nullptr;
+ exprList->push_back(std::move(expr));
+ }
+ // Ensure that elements are separated by a semicolon.
+ if (lexer.getCurToken() != ';')
+ return parseError<ExprASTList>(";", "after expression");
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+ }
+
+ if (lexer.getCurToken() != '}')
+ return parseError<ExprASTList>("}", "to close block");
+
+ lexer.consume(Token('}'));
+ return exprList;
+ }
+
+ /// prototype ::= def id '(' decl_list ')'
+ /// decl_list ::= identifier | identifier, decl_list
+ std::unique_ptr<PrototypeAST> ParsePrototype() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_def);
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>("function name", "in prototype");
+
+ std::string FnName = lexer.getId();
+ lexer.consume(tok_identifier);
+
+ if (lexer.getCurToken() != '(')
+ return parseError<PrototypeAST>("(", "in prototype");
+ lexer.consume(Token('('));
+
+ std::vector<std::unique_ptr<VariableExprAST>> args;
+ if (lexer.getCurToken() != ')') {
+ do {
+ std::string name = lexer.getId();
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_identifier);
+ auto decl = llvm::make_unique<VariableExprAST>(std::move(loc), name);
+ args.push_back(std::move(decl));
+ if (lexer.getCurToken() != ',')
+ break;
+ lexer.consume(Token(','));
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>(
+ "identifier", "after ',' in function parameter list");
+ } while (true);
+ }
+ if (lexer.getCurToken() != ')')
+ return parseError<PrototypeAST>("}", "to end function prototype");
+
+ // success.
+ lexer.consume(Token(')'));
+ return llvm::make_unique<PrototypeAST>(std::move(loc), FnName,
+ std::move(args));
+ }
+
+ /// Parse a function definition, we expect a prototype initiated with the
+ /// `def` keyword, followed by a block containing a list of expressions.
+ ///
+ /// definition ::= prototype block
+ std::unique_ptr<FunctionAST> ParseDefinition() {
+ auto Proto = ParsePrototype();
+ if (!Proto)
+ return nullptr;
+
+ if (auto block = ParseBlock())
+ return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ return nullptr;
+ }
+
+ /// Get the precedence of the pending binary operator token.
+ int GetTokPrecedence() {
+ if (!isascii(lexer.getCurToken()))
+ return -1;
+
+ // 1 is lowest precedence.
+ switch (static_cast<char>(lexer.getCurToken())) {
+ case '-':
+ return 20;
+ case '+':
+ return 20;
+ case '*':
+ return 40;
+ default:
+ return -1;
+ }
+ }
+
+ /// Helper function to signal errors while parsing, it takes an argument
+ /// indicating the expected token and another argument giving more context.
+ /// Location is retrieved from the lexer to enrich the error message.
+ template <typename R, typename T, typename U = const char *>
+ std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
+ auto curToken = lexer.getCurToken();
+ llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
+ << lexer.getLastLocation().col << "): expected '" << expected
+ << "' " << context << " but has Token " << curToken;
+ if (isprint(curToken))
+ llvm::errs() << " '" << (char)curToken << "'";
+ llvm::errs() << "\n";
+ return nullptr;
+ }
+};
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_PARSER_H
diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
new file mode 100644
index 00000000000..464a206f7f1
--- /dev/null
+++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
@@ -0,0 +1,480 @@
+//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR generation targeting MLIR from a Module AST
+// for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/MLIRGen.h"
+#include "toy/AST.h"
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/StandardOps/Ops.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+using namespace toy;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::make_unique;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+/// Implementation of a simple MLIR emission from the Toy AST.
+///
+/// This will emit operations that are specific to the Toy language, preserving
+/// the semantics of the language and (hopefully) allow to perform accurate
+/// analysis and transformation based on these high level semantics.
+///
+/// At this point we take advantage of the "raw" MLIR APIs to create operations
+/// that haven't been registered in any way with MLIR. These operations are
+/// unknown to MLIR, custom passes could operate by string-matching the name of
+/// these operations, but no other type checking or semantic is associated with
+/// them natively by MLIR.
+class MLIRGenImpl {
+public:
+ MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
+
+ /// Public API: convert the AST for a Toy module (source file) to an MLIR
+ /// Module.
+ std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+ // We create an empty MLIR module and codegen functions one at a time and
+ // add them to the module.
+ theModule = make_unique<mlir::Module>(&context);
+
+ for (FunctionAST &F : moduleAST) {
+ auto func = mlirGen(F);
+ if (!func)
+ return nullptr;
+ theModule->getFunctions().push_back(func.release());
+ }
+
+ // FIXME: (in the next chapter...) without registering a dialect in MLIR,
+ // this won't do much, but it should at least check some structural
+ // properties.
+ if (failed(theModule->verify())) {
+ context.emitError(mlir::UnknownLoc::get(&context),
+ "Module verification error");
+ return nullptr;
+ }
+
+ return std::move(theModule);
+ }
+
+private:
+ /// In MLIR (like in LLVM) a "context" object holds the memory allocation and
+ /// the ownership of many internal structure of the IR and provide a level
+ /// of "uniquing" across multiple modules (types for instance).
+ mlir::MLIRContext &context;
+
+ /// A "module" matches a source file: it contains a list of functions.
+ std::unique_ptr<mlir::Module> theModule;
+
+ /// The builder is a helper class to create IR inside a function. It is
+ /// re-initialized every time we enter a function and kept around as a
+ /// convenience for emitting individual operations.
+ /// The builder is stateful, in particular it keeeps an "insertion point":
+ /// this is where the next operations will be introduced.
+ std::unique_ptr<mlir::FuncBuilder> builder;
+
+ /// The symbol table maps a variable name to a value in the current scope.
+ /// Entering a function creates a new scope, and the function arguments are
+ /// added to the mapping. When the processing of a function is terminated, the
+ /// scope is destroyed and the mappings created in this scope are dropped.
+ llvm::ScopedHashTable<StringRef, mlir::Value *> symbolTable;
+
+ /// Helper conversion for a Toy AST location to an MLIR location.
+ mlir::FileLineColLoc loc(Location loc) {
+ return mlir::FileLineColLoc::get(
+ mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col,
+ &context);
+ }
+
+ /// Declare a variable in the current scope, return true if the variable
+ /// wasn't declared yet.
+ bool declare(llvm::StringRef var, mlir::Value *value) {
+ if (symbolTable.count(var)) {
+ return false;
+ }
+ symbolTable.insert(var, value);
+ return true;
+ }
+
+ /// Create the prototype for an MLIR function with as many arguments as the
+ /// provided Toy AST prototype.
+ mlir::Function *mlirGen(PrototypeAST &proto) {
+ // This is a generic function, the return type will be inferred later.
+ llvm::SmallVector<mlir::Type, 4> ret_types;
+ // Arguments type is uniformly a generic array.
+ llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
+ getType(VarType{}));
+ auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
+ auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
+ func_type, /* attrs = */ {});
+
+ // Mark the function as generic: it'll require type specialization for every
+ // call site.
+ if (function->getNumArguments())
+ function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
+
+ return function;
+ }
+
+ /// Emit a new function and add it to the MLIR module.
+ std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
+ // Create a scope in the symbol table to hold variable declarations.
+ ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+
+ // Create an MLIR function for the given prototype.
+ std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
+ if (!function)
+ return nullptr;
+
+ // Let's start the body of the function now!
+ // In MLIR the entry block of the function is special: it must have the same
+ // argument list as the function itself.
+ function->addEntryBlock();
+
+ auto &entryBlock = function->front();
+ auto &protoArgs = funcAST.getProto()->getArgs();
+ // Declare all the function arguments in the symbol table.
+ for (const auto &name_value :
+ llvm::zip(protoArgs, entryBlock.getArguments())) {
+ declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
+ }
+
+ // Create a builder for the function, it will be used throughout the codegen
+ // to create operations in this function.
+ builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
+
+ // Emit the body of the function.
+ if (!mlirGen(*funcAST.getBody()))
+ return nullptr;
+
+ // Implicitly return void if no return statement was emited.
+ // FIXME: we may fix the parser instead to always return the last expression
+ // (this would possibly help the REPL case later)
+ if (function->getBlocks().back().back().getName().getStringRef() !=
+ "toy.return") {
+ ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
+ mlirGen(fakeRet);
+ }
+
+ return function;
+ }
+
+ /// Emit a binary operation
+ mlir::Value *mlirGen(BinaryExprAST &binop) {
+ // First emit the operations for each side of the operation before emitting
+ // the operation itself. For example if the expression is `a + foo(a)`
+ // 1) First it will visiting the LHS, which will return a reference to the
+ // value holding `a`. This value should have been emitted at declaration
+ // time and registered in the symbol table, so nothing would be
+ // codegen'd. If the value is not in the symbol table, an error has been
+ // emitted and nullptr is returned.
+ // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
+ // and the result value is returned. If an error occurs we get a nullptr
+ // and propagate.
+ //
+ mlir::Value *L = mlirGen(*binop.getLHS());
+ if (!L)
+ return nullptr;
+ mlir::Value *R = mlirGen(*binop.getRHS());
+ if (!R)
+ return nullptr;
+ auto location = loc(binop.loc());
+
+ // Derive the operation name from the binary operator. At the moment we only
+ // support '+' and '*'.
+ switch (binop.getOp()) {
+ case '+':
+ return builder->create<AddOp>(location, L, R).getResult();
+ break;
+ case '*':
+ return builder->create<MulOp>(location, L, R).getResult();
+ default:
+ context.emitError(loc(binop.loc()),
+ Twine("Error: invalid binary operator '") +
+ Twine(binop.getOp()) + "'");
+ return nullptr;
+ }
+ }
+
+ // This is a reference to a variable in an expression. The variable is
+ // expected to have been declared and so should have a value in the symbol
+ // table, otherwise emit an error and return nullptr.
+ mlir::Value *mlirGen(VariableExprAST &expr) {
+ if (symbolTable.count(expr.getName()))
+ return symbolTable.lookup(expr.getName());
+ context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") +
+ expr.getName() + "'");
+ return nullptr;
+ }
+
+ // Emit a return operation, return true on success.
+ bool mlirGen(ReturnExprAST &ret) {
+ auto location = loc(ret.loc());
+ // `return` takes an optional expression, we need to account for it here.
+ if (!ret.getExpr().hasValue()) {
+ builder->create<ReturnOp>(location);
+ return true;
+ }
+ auto *expr = mlirGen(*ret.getExpr().getValue());
+ if (!expr)
+ return false;
+ builder->create<ReturnOp>(location, expr);
+ return true;
+ }
+
+ // Emit a literal/constant array. It will be emitted as a flattened array of
+ // data in an Attribute attached to a `toy.constant` operation.
+ // See documentation on [Attributes](LangRef.md#attributes) for more details.
+ // Here is an excerpt:
+ //
+ // Attributes are the mechanism for specifying constant data in MLIR in
+ // places where a variable is never allowed [...]. They consist of a name
+ // and a [concrete attribute value](#attribute-values). It is possible to
+ // attach attributes to operations, functions, and function arguments. The
+ // set of expected attributes, their structure, and their interpretation
+ // are all contextually dependent on what they are attached to.
+ //
+ // Example, the source level statement:
+ // var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+ // will be converted to:
+ // %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
+ // [[1.000000e+00, 2.000000e+00, 3.000000e+00],
+ // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
+ //
+ mlir::Value *mlirGen(LiteralExprAST &lit) {
+ auto location = loc(lit.loc());
+ // The attribute is a vector with an attribute per element (number) in the
+ // array, see `collectData()` below for more details.
+ std::vector<mlir::Attribute> data;
+ data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
+ std::multiplies<int>()));
+ collectData(lit, data);
+
+ // FIXME: using a tensor type is a HACK here.
+ // Can we do differently without registering a dialect? Using a string blob?
+ mlir::Type elementType = mlir::FloatType::getF64(&context);
+ auto dataType = builder->getTensorType(lit.getDims(), elementType);
+
+ // This is the actual attribute that actually hold the list of values for
+ // this array literal.
+ auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
+ .cast<mlir::DenseElementsAttr>();
+
+ // Build the MLIR op `toy.constant`, only boilerplate below.
+ return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
+ .getResult();
+ }
+
+ // Recursive helper function to accumulate the data that compose an array
+ // literal. It flattens the nested structure in the supplied vector. For
+ // example with this array:
+ // [[1, 2], [3, 4]]
+ // we will generate:
+ // [ 1, 2, 3, 4 ]
+ // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
+ // Attributes are the way MLIR attaches constant to operations and functions.
+ void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
+ if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
+ for (auto &value : lit->getValues())
+ collectData(*value, data);
+ return;
+ }
+ assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
+ mlir::Type elementType = mlir::FloatType::getF64(&context);
+ auto attr = mlir::FloatAttr::getChecked(
+ elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
+ data.push_back(attr);
+ }
+
+ // Emit a call expression. It emits specific operations for the `transpose`
+ // builtin. Other identifiers are assumed to be user-defined functions.
+ mlir::Value *mlirGen(CallExprAST &call) {
+ auto location = loc(call.loc());
+ std::string callee = call.getCallee();
+ if (callee == "transpose") {
+ if (call.getArgs().size() != 1) {
+ context.emitError(
+ location, Twine("MLIR codegen encountered an error: toy.transpose "
+ "does not accept multiple arguments"));
+ return nullptr;
+ }
+ mlir::Value *arg = mlirGen(*call.getArgs()[0]);
+ return builder->create<TransposeOp>(location, arg).getResult();
+ }
+
+ // Codegen the operands first
+ SmallVector<mlir::Value *, 4> operands;
+ for (auto &expr : call.getArgs()) {
+ auto *arg = mlirGen(*expr);
+ if (!arg)
+ return nullptr;
+ operands.push_back(arg);
+ }
+ // Calls to user-defined function are mapped to a custom call that takes
+ // the callee name as an attribute.
+ return builder->create<GenericCallOp>(location, call.getCallee(), operands)
+ .getResult();
+ }
+
+ // Emit a call expression. It emits specific operations for two builtins:
+ // transpose(x) and print(x). Other identifiers are assumed to be user-defined
+ // functions. Return false on failure.
+ bool mlirGen(PrintExprAST &call) {
+ auto *arg = mlirGen(*call.getArg());
+ if (!arg)
+ return false;
+ auto location = loc(call.loc());
+ builder->create<PrintOp>(location, arg);
+ return true;
+ }
+
+ // Emit a constant for a single number (FIXME: semantic? broadcast?)
+ mlir::Value *mlirGen(NumberExprAST &num) {
+ auto location = loc(num.loc());
+ mlir::Type elementType = mlir::FloatType::getF64(&context);
+ auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
+ loc(num.loc()));
+ return builder->create<ConstantOp>(location, attr).getResult();
+ }
+
+ // Dispatch codegen for the right expression subclass using RTTI.
+ mlir::Value *mlirGen(ExprAST &expr) {
+ switch (expr.getKind()) {
+ case toy::ExprAST::Expr_BinOp:
+ return mlirGen(cast<BinaryExprAST>(expr));
+ case toy::ExprAST::Expr_Var:
+ return mlirGen(cast<VariableExprAST>(expr));
+ case toy::ExprAST::Expr_Literal:
+ return mlirGen(cast<LiteralExprAST>(expr));
+ case toy::ExprAST::Expr_Call:
+ return mlirGen(cast<CallExprAST>(expr));
+ case toy::ExprAST::Expr_Num:
+ return mlirGen(cast<NumberExprAST>(expr));
+ default:
+ context.emitError(
+ loc(expr.loc()),
+ Twine("MLIR codegen encountered an unhandled expr kind '") +
+ Twine(expr.getKind()) + "'");
+ return nullptr;
+ }
+ }
+
+ // Handle a variable declaration, we'll codegen the expression that forms the
+ // initializer and record the value in the symbol table before returning it.
+ // Future expressions will be able to reference this variable through symbol
+ // table lookup.
+ mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
+ mlir::Value *value = nullptr;
+ auto location = loc(vardecl.loc());
+ if (auto init = vardecl.getInitVal()) {
+ value = mlirGen(*init);
+ if (!value)
+ return nullptr;
+ // We have the initializer value, but in case the variable was declared
+ // with specific shape, we emit a "reshape" operation. It will get
+ // optimized out later as needed.
+ if (!vardecl.getType().shape.empty()) {
+ value = builder
+ ->create<ReshapeOp>(
+ location, value,
+ getType(vardecl.getType()).cast<ToyArrayType>())
+ .getResult();
+ }
+ } else {
+ context.emitError(loc(vardecl.loc()),
+ "Missing initializer in variable declaration");
+ return nullptr;
+ }
+ // Register the value in the symbol table
+ declare(vardecl.getName(), value);
+ return value;
+ }
+
+ /// Codegen a list of expression, return false if one of them hit an error.
+ bool mlirGen(ExprASTList &blockAST) {
+ ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+ for (auto &expr : blockAST) {
+ // Specific handling for variable declarations, return statement, and
+ // print. These can only appear in block list and not in nested
+ // expressions.
+ if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
+ if (!mlirGen(*vardecl))
+ return false;
+ continue;
+ }
+ if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
+ if (!mlirGen(*ret))
+ return false;
+ return true;
+ }
+ if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
+ if (!mlirGen(*print))
+ return false;
+ return true;
+ }
+ // Generic expression dispatch codegen.
+ if (!mlirGen(*expr))
+ return false;
+ }
+ return true;
+ }
+
+ /// Build a type from a list of shape dimensions. Types are `array` followed
+ /// by an optional dimension list, example: array<2, 2>
+ /// They are wrapped in a `toy` dialect (see next chapter) and get printed:
+ /// !toy<"array<2, 2>">
+ template <typename T> mlir::Type getType(T shape) {
+ SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
+ return ToyArrayType::get(&context, shape64);
+ }
+
+ /// Build an MLIR type from a Toy AST variable type
+ /// (forward to the generic getType(T) above).
+ mlir::Type getType(const VarType &type) { return getType(type.shape); }
+};
+
+} // namespace
+
+namespace toy {
+
+// The public API for codegen.
+std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST) {
+ return MLIRGenImpl(context).mlirGen(moduleAST);
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch3/mlir/ToyDialect.cpp b/mlir/examples/toy/Ch3/mlir/ToyDialect.cpp
new file mode 100644
index 00000000000..7910842a51b
--- /dev/null
+++ b/mlir/examples/toy/Ch3/mlir/ToyDialect.cpp
@@ -0,0 +1,393 @@
+//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the dialect for the Toy IR: custom type parsing and
+// operation verification.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/raw_ostream.h"
+
+using llvm::ArrayRef;
+using llvm::raw_ostream;
+using llvm::raw_string_ostream;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace toy {
+namespace detail {
+
+/// This class holds the implementation of the ToyArrayType.
+/// It is intended to be uniqued based on its content and owned by the context.
+struct ToyArrayTypeStorage : public mlir::TypeStorage {
+ /// This defines how we unique this type in the context: our key contains
+ /// only the shape, a more complex type would have multiple entries in the
+ /// tuple here.
+ /// The element of the tuples usually matches 1-1 the arguments from the
+ /// public `get()` method arguments from the facade.
+ using KeyTy = std::tuple<ArrayRef<int64_t>>;
+ static unsigned hashKey(const KeyTy &key) {
+ return llvm::hash_combine(std::get<0>(key));
+ }
+ /// When the key hash hits an existing type, we compare the shape themselves
+ /// to confirm we have the right type.
+ bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
+
+ /// This is a factory method to create our type storage. It is only
+ /// invoked after looking up the type in the context using the key and not
+ /// finding it.
+ static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ // Copy the shape array into the bumpptr allocator owned by the context.
+ ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
+
+ // Allocate the instance for the ToyArrayTypeStorage itself
+ auto *storage = allocator.allocate<ToyArrayTypeStorage>();
+ // Initialize the instance using placement new.
+ return new (storage) ToyArrayTypeStorage(shape);
+ }
+
+ ArrayRef<int64_t> getShape() const { return shape; }
+
+private:
+ ArrayRef<int64_t> shape;
+
+ /// Constructor is only invoked from the `construct()` method above.
+ ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
+};
+
+} // namespace detail
+
+mlir::Type ToyArrayType::getElementType() {
+ return mlir::FloatType::getF64(getContext());
+}
+
+ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
+ ArrayRef<int64_t> shape) {
+ return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
+}
+
+ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
+
+/// Dialect creation, the instance will be owned by the context. This is the
+/// point of registration of custom types and operations for the dialect.
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
+ MulOp, AddOp, ReturnOp>();
+ addTypes<ToyArrayType>();
+}
+
+/// Parse a type registered to this dialect, we expect only Toy arrays.
+mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
+ // Sanity check: we only support array or array<...>
+ if (!tyData.startswith("array")) {
+ getContext()->emitError(loc, "Invalid Toy type '" + tyData +
+ "', array expected");
+ return nullptr;
+ }
+ // Drop the "array" prefix from the type name, we expect either an empty
+ // string or just the shape.
+ tyData = tyData.drop_front(StringRef("array").size());
+ // This is the generic array case without shape, early return it.
+ if (tyData.empty())
+ return ToyArrayType::get(getContext());
+
+ // Use a regex to parse the shape (for efficient we should store this regex in
+ // the dialect itself).
+ SmallVector<StringRef, 4> matches;
+ auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
+ if (!shapeRegex.match(tyData, &matches)) {
+ getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'");
+ return nullptr;
+ }
+ SmallVector<int64_t, 4> shape;
+ // Iterate through the captures, skip the first one which is the full string.
+ for (auto dimStr :
+ llvm::make_range(std::next(matches.begin()), matches.end())) {
+ if (dimStr.startswith(","))
+ continue; // POSIX misses non-capturing groups.
+ if (dimStr.empty())
+ continue; // '*' makes it an optional group capture
+ // Convert the capture to an integer
+ unsigned long long dim;
+ if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
+ getContext()->emitError(
+ loc, "Couldn't parse dimension as integer, matched: " + dimStr);
+ return mlir::Type();
+ }
+ shape.push_back(dim);
+ }
+ // Finally we collected all the dimensions in the shape,
+ // create the array type.
+ return ToyArrayType::get(getContext(), shape);
+}
+
+/// Print a Toy array type, for example `array<2, 3, 4>`
+void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
+ auto arrayTy = type.dyn_cast<ToyArrayType>();
+ if (!arrayTy) {
+ os << "unknown toy type";
+ return;
+ }
+ os << "array";
+ if (!arrayTy.getShape().empty()) {
+ os << "<";
+ mlir::interleaveComma(arrayTy.getShape(), os);
+ os << ">";
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//////////////////// Custom Operations for the Dialect /////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+
+/// Helper to verify that the result of an operation is a Toy array type.
+template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
+ if (!op->getResult()->getType().template isa<ToyArrayType>()) {
+ std::string msg;
+ raw_string_ostream os(msg);
+ os << "expects a Toy Array for its argument, got "
+ << op->getResult()->getType();
+ return op->emitOpError(os.str());
+ }
+ return mlir::success();
+}
+
+/// Helper to verify that the two operands of a binary operation are Toy
+/// arrays..
+template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
+ if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
+ std::string msg;
+ raw_string_ostream os(msg);
+ os << "expects a Toy Array for its LHS, got "
+ << op->getOperand(0)->getType();
+ return op->emitOpError(os.str());
+ }
+ if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
+ std::string msg;
+ raw_string_ostream os(msg);
+ os << "expects a Toy Array for its LHS, got "
+ << op->getOperand(0)->getType();
+ return op->emitOpError(os.str());
+ }
+ return mlir::success();
+}
+
+/// Build a constant operation.
+/// The builder is passed as an argument, so is the state that this method is
+/// expected to fill in order to build the operation.
+void ConstantOp::build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
+ state->types.push_back(ToyArrayType::get(builder->getContext(), shape));
+ auto dataAttribute = builder->getNamedAttr("value", value);
+ state->attributes.push_back(dataAttribute);
+}
+
+/// Build a constant operation.
+/// The builder is passed as an argument, so is the state that this method is
+/// expected to fill in order to build the operation.
+void ConstantOp::build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::FloatAttr value) {
+ // Broadcast and forward to the other build factory
+ mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
+ auto dataType = builder->getTensorType({1}, elementType);
+ auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
+ .cast<mlir::DenseElementsAttr>();
+
+ ConstantOp::build(builder, state, {1}, dataAttribute);
+}
+
+/// Verifier for constant operation.
+mlir::LogicalResult ConstantOp::verify() {
+ // Ensure that the return type is a Toy array
+ if (failed(verifyToyReturnArray(this)))
+ return mlir::failure();
+
+ // We expect the constant itself to be stored as an attribute.
+ auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
+ if (!dataAttr) {
+ return emitOpError(
+ "missing valid `value` DenseElementsAttribute on toy.constant()");
+ }
+ auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
+ if (!attrType) {
+ return emitOpError(
+ "missing valid `value` DenseElementsAttribute on toy.constant()");
+ }
+
+ // If the return type of the constant is not a generic array, the shape must
+ // match the shape of the attribute holding the data.
+ auto resultType = getResult()->getType().cast<ToyArrayType>();
+ if (!resultType.isGeneric()) {
+ if (attrType.getRank() != resultType.getRank()) {
+ return emitOpError("The rank of the toy.constant return type must match "
+ "the one of the attached value attribute: " +
+ Twine(attrType.getRank()) +
+ " != " + Twine(resultType.getRank()));
+ }
+ for (int dim = 0; dim < attrType.getRank(); ++dim) {
+ if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
+ std::string msg;
+ raw_string_ostream os(msg);
+ return emitOpError(
+ "Shape mismatch between toy.constant return type and its "
+ "attribute at dimension " +
+ Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
+ " != " + Twine(resultType.getShape()[dim]));
+ }
+ }
+ }
+ return mlir::success();
+}
+
+void GenericCallOp::build(mlir::FuncBuilder *builder,
+ mlir::OperationState *state, StringRef callee,
+ ArrayRef<mlir::Value *> arguments) {
+ // Generic call always returns a generic ToyArray initially
+ state->types.push_back(ToyArrayType::get(builder->getContext()));
+ state->operands.assign(arguments.begin(), arguments.end());
+ auto calleeAttr = builder->getStringAttr(callee);
+ state->attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
+}
+
+mlir::LogicalResult GenericCallOp::verify() {
+ // Verify that every operand is a Toy Array
+ for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
+ if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
+ std::string msg;
+ raw_string_ostream os(msg);
+ os << "expects a Toy Array for its " << opId << " operand, got "
+ << getOperand(opId)->getType();
+ return emitOpError(os.str());
+ }
+ }
+ return mlir::success();
+}
+
+/// Return the name of the callee.
+StringRef GenericCallOp::getCalleeName() {
+ return getAttr("callee").cast<mlir::StringAttr>().getValue();
+}
+
+template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
+ if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
+ std::string msg;
+ raw_string_ostream os(msg);
+ os << "expects a Toy Array for its argument, got "
+ << op->getOperand()->getType();
+ return op->emitOpError(msg);
+ }
+ return mlir::success();
+}
+
+void ReturnOp::build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *value) {
+ // Return does not return any value and has an optional single argument
+ if (value)
+ state->operands.push_back(value);
+}
+
+mlir::LogicalResult ReturnOp::verify() {
+ if (getNumOperands() > 1) {
+ std::string msg;
+ raw_string_ostream os(msg);
+ os << "expects zero or one operand, got " << getNumOperands();
+ return emitOpError(os.str());
+ }
+ if (hasOperand() && failed(verifyToySingleOperand(this)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+void PrintOp::build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *value) {
+ // Print does not return any value and has a single argument
+ state->operands.push_back(value);
+}
+
+mlir::LogicalResult PrintOp::verify() {
+ if (failed(verifyToySingleOperand(this)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+void TransposeOp::build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *value) {
+ state->types.push_back(ToyArrayType::get(builder->getContext()));
+ state->operands.push_back(value);
+}
+
+mlir::LogicalResult TransposeOp::verify() {
+ if (failed(verifyToySingleOperand(this)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+void ReshapeOp::build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *value, ToyArrayType reshapedType) {
+ state->types.push_back(reshapedType);
+ state->operands.push_back(value);
+}
+
+mlir::LogicalResult ReshapeOp::verify() {
+ if (failed(verifyToySingleOperand(this)))
+ return mlir::failure();
+ auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
+ if (!retTy)
+ return emitOpError("toy.reshape is expected to produce a Toy array");
+ if (retTy.isGeneric())
+ return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
+ "got a generic one.");
+ return mlir::success();
+}
+
+void AddOp::build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ state->types.push_back(ToyArrayType::get(builder->getContext()));
+ state->operands.push_back(lhs);
+ state->operands.push_back(rhs);
+}
+
+mlir::LogicalResult AddOp::verify() {
+ if (failed(verifyToyBinOperands(this)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+void MulOp::build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ state->types.push_back(ToyArrayType::get(builder->getContext()));
+ state->operands.push_back(lhs);
+ state->operands.push_back(rhs);
+}
+
+mlir::LogicalResult MulOp::verify() {
+ if (failed(verifyToyBinOperands(this)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp
new file mode 100644
index 00000000000..869f2ef2013
--- /dev/null
+++ b/mlir/examples/toy/Ch3/parser/AST.cpp
@@ -0,0 +1,263 @@
+//===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST dump for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/AST.h"
+
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+
+namespace {
+
+// RAII helper to manage increasing/decreasing the indentation as we traverse
+// the AST
+struct Indent {
+ Indent(int &level) : level(level) { ++level; }
+ ~Indent() { --level; }
+ int &level;
+};
+
+/// Helper class that implement the AST tree traversal and print the nodes along
+/// the way. The only data member is the current indentation level.
+class ASTDumper {
+public:
+ void dump(ModuleAST *Node);
+
+private:
+ void dump(VarType &type);
+ void dump(VarDeclExprAST *varDecl);
+ void dump(ExprAST *expr);
+ void dump(ExprASTList *exprList);
+ void dump(NumberExprAST *num);
+ void dump(LiteralExprAST *Node);
+ void dump(VariableExprAST *Node);
+ void dump(ReturnExprAST *Node);
+ void dump(BinaryExprAST *Node);
+ void dump(CallExprAST *Node);
+ void dump(PrintExprAST *Node);
+ void dump(PrototypeAST *Node);
+ void dump(FunctionAST *Node);
+
+ // Actually print spaces matching the current indentation level
+ void indent() {
+ for (int i = 0; i < curIndent; i++)
+ llvm::errs() << " ";
+ }
+ int curIndent = 0;
+};
+
+} // namespace
+
+/// Return a formatted string for the location of any node
+template <typename T> static std::string loc(T *Node) {
+ const auto &loc = Node->loc();
+ return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
+ llvm::Twine(loc.col))
+ .str();
+}
+
+// Helper Macro to bump the indentation level and print the leading spaces for
+// the current indentations
+#define INDENT() \
+ Indent level_(curIndent); \
+ indent();
+
+/// Dispatch to a generic expressions to the appropriate subclass using RTTI
+void ASTDumper::dump(ExprAST *expr) {
+#define dispatch(CLASS) \
+ if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
+ return dump(node);
+ dispatch(VarDeclExprAST);
+ dispatch(LiteralExprAST);
+ dispatch(NumberExprAST);
+ dispatch(VariableExprAST);
+ dispatch(ReturnExprAST);
+ dispatch(BinaryExprAST);
+ dispatch(CallExprAST);
+ dispatch(PrintExprAST);
+ // No match, fallback to a generic message
+ INDENT();
+ llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
+}
+
+/// A variable declaration is printing the variable name, the type, and then
+/// recurse in the initializer value.
+void ASTDumper::dump(VarDeclExprAST *varDecl) {
+ INDENT();
+ llvm::errs() << "VarDecl " << varDecl->getName();
+ dump(varDecl->getType());
+ llvm::errs() << " " << loc(varDecl) << "\n";
+ dump(varDecl->getInitVal());
+}
+
+/// A "block", or a list of expression
+void ASTDumper::dump(ExprASTList *exprList) {
+ INDENT();
+ llvm::errs() << "Block {\n";
+ for (auto &expr : *exprList)
+ dump(expr.get());
+ indent();
+ llvm::errs() << "} // Block\n";
+}
+
+/// A literal number, just print the value.
+void ASTDumper::dump(NumberExprAST *num) {
+ INDENT();
+ llvm::errs() << num->getValue() << " " << loc(num) << "\n";
+}
+
+/// Helper to print recurisvely a literal. This handles nested array like:
+/// [ [ 1, 2 ], [ 3, 4 ] ]
+/// We print out such array with the dimensions spelled out at every level:
+/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
+void printLitHelper(ExprAST *lit_or_num) {
+ // Inside a literal expression we can have either a number or another literal
+ if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+ llvm::errs() << num->getValue();
+ return;
+ }
+ auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+
+ // Print the dimension for this literal first
+ llvm::errs() << "<";
+ {
+ const char *sep = "";
+ for (auto dim : literal->getDims()) {
+ llvm::errs() << sep << dim;
+ sep = ", ";
+ }
+ }
+ llvm::errs() << ">";
+
+ // Now print the content, recursing on every element of the list
+ llvm::errs() << "[ ";
+ const char *sep = "";
+ for (auto &elt : literal->getValues()) {
+ llvm::errs() << sep;
+ printLitHelper(elt.get());
+ sep = ", ";
+ }
+ llvm::errs() << "]";
+}
+
+/// Print a literal, see the recursive helper above for the implementation.
+void ASTDumper::dump(LiteralExprAST *Node) {
+ INDENT();
+ llvm::errs() << "Literal: ";
+ printLitHelper(Node);
+ llvm::errs() << " " << loc(Node) << "\n";
+}
+
+/// Print a variable reference (just a name).
+void ASTDumper::dump(VariableExprAST *Node) {
+ INDENT();
+ llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+}
+
+/// Return statement print the return and its (optional) argument.
+void ASTDumper::dump(ReturnExprAST *Node) {
+ INDENT();
+ llvm::errs() << "Return\n";
+ if (Node->getExpr().hasValue())
+ return dump(*Node->getExpr());
+ {
+ INDENT();
+ llvm::errs() << "(void)\n";
+ }
+}
+
+/// Print a binary operation, first the operator, then recurse into LHS and RHS.
+void ASTDumper::dump(BinaryExprAST *Node) {
+ INDENT();
+ llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
+ dump(Node->getLHS());
+ dump(Node->getRHS());
+}
+
+/// Print a call expression, first the callee name and the list of args by
+/// recursing into each individual argument.
+void ASTDumper::dump(CallExprAST *Node) {
+ INDENT();
+ llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
+ for (auto &arg : Node->getArgs())
+ dump(arg.get());
+ indent();
+ llvm::errs() << "]\n";
+}
+
+/// Print a builtin print call, first the builtin name and then the argument.
+void ASTDumper::dump(PrintExprAST *Node) {
+ INDENT();
+ llvm::errs() << "Print [ " << loc(Node) << "\n";
+ dump(Node->getArg());
+ indent();
+ llvm::errs() << "]\n";
+}
+
+/// Print type: only the shape is printed in between '<' and '>'
+void ASTDumper::dump(VarType &type) {
+ llvm::errs() << "<";
+ const char *sep = "";
+ for (auto shape : type.shape) {
+ llvm::errs() << sep << shape;
+ sep = ", ";
+ }
+ llvm::errs() << ">";
+}
+
+/// Print a function prototype, first the function name, and then the list of
+/// parameters names.
+void ASTDumper::dump(PrototypeAST *Node) {
+ INDENT();
+ llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+ indent();
+ llvm::errs() << "Params: [";
+ const char *sep = "";
+ for (auto &arg : Node->getArgs()) {
+ llvm::errs() << sep << arg->getName();
+ sep = ", ";
+ }
+ llvm::errs() << "]\n";
+}
+
+/// Print a function, first the prototype and then the body.
+void ASTDumper::dump(FunctionAST *Node) {
+ INDENT();
+ llvm::errs() << "Function \n";
+ dump(Node->getProto());
+ dump(Node->getBody());
+}
+
+/// Print a module, actually loop over the functions and print them in sequence.
+void ASTDumper::dump(ModuleAST *Node) {
+ INDENT();
+ llvm::errs() << "Module:\n";
+ for (auto &F : *Node)
+ dump(&F);
+}
+
+namespace toy {
+
+// Public API
+void dump(ModuleAST &module) { ASTDumper().dump(&module); }
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp
new file mode 100644
index 00000000000..3d18417dc8a
--- /dev/null
+++ b/mlir/examples/toy/Ch3/toyc.cpp
@@ -0,0 +1,139 @@
+//===- toyc.cpp - The Toy Compiler ----------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the entry point for the Toy compiler.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+#include "toy/MLIRGen.h"
+#include "toy/Parser.h"
+#include <memory>
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Parser.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+namespace cl = llvm::cl;
+
+static cl::opt<std::string> inputFilename(cl::Positional,
+ cl::desc("<input toy file>"),
+ cl::init("-"),
+ cl::value_desc("filename"));
+
+namespace {
+enum InputType { Toy, MLIR };
+}
+static cl::opt<enum InputType> inputType(
+ "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
+ cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
+ cl::values(clEnumValN(MLIR, "mlir",
+ "load the input file as an MLIR file")));
+
+namespace {
+enum Action { None, DumpAST, DumpMLIR };
+}
+static cl::opt<enum Action> emitAction(
+ "emit", cl::desc("Select the kind of output desired"),
+ cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
+ cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
+
+/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
+std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::MemoryBuffer::getFileOrSTDIN(filename);
+ if (std::error_code EC = FileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ return nullptr;
+ }
+ auto buffer = FileOrErr.get()->getBuffer();
+ LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
+ Parser parser(lexer);
+ return parser.ParseModule();
+}
+
+int dumpMLIR() {
+ // Register our Dialect with MLIR
+ mlir::registerDialect<ToyDialect>();
+
+ mlir::MLIRContext context;
+ std::unique_ptr<mlir::Module> module;
+ if (inputType == InputType::MLIR ||
+ llvm::StringRef(inputFilename).endswith(".mlir")) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
+ llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
+ if (std::error_code EC = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ return -1;
+ }
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
+ module.reset(mlir::parseSourceFile(sourceMgr, &context));
+ if (!module) {
+ llvm::errs() << "Error can't load file " << inputFilename << "\n";
+ return 3;
+ }
+ if (failed(module->verify())) {
+ llvm::errs() << "Error verifying MLIR module\n";
+ return 4;
+ }
+ } else {
+ auto moduleAST = parseInputFile(inputFilename);
+ module = mlirGen(context, *moduleAST);
+ }
+ if (!module)
+ return 1;
+ module->dump();
+ return 0;
+}
+
+int dumpAST() {
+ if (inputType == InputType::MLIR) {
+ llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
+ return 5;
+ }
+
+ auto moduleAST = parseInputFile(inputFilename);
+ if (!moduleAST)
+ return 1;
+
+ dump(*moduleAST);
+ return 0;
+}
+
+int main(int argc, char **argv) {
+ cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+
+ switch (emitAction) {
+ case Action::DumpAST:
+ return dumpAST();
+ case Action::DumpMLIR:
+ return dumpMLIR();
+ default:
+ llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+ }
+
+ return 0;
+}
diff --git a/mlir/g3doc/Tutorials/Toy/Ch-3.md b/mlir/g3doc/Tutorials/Toy/Ch-3.md
new file mode 100644
index 00000000000..b3d79d05eb8
--- /dev/null
+++ b/mlir/g3doc/Tutorials/Toy/Ch-3.md
@@ -0,0 +1,297 @@
+# Chapter 3: Defining and Registering a Dialect in MLIR
+
+In the previous chapter, we saw how to emit a custom IR for Toy in MLIR using
+opaque operations. In this chapter we will register our Dialect with MLIR to
+start making the Toy IR more robust and friendly to use.
+
+Dialects in MLIR allow for registering operations and types with an MLIRContext.
+They also must reserve a "namespace" to avoid collision with other registered
+dialects. These registered operations are no longer opaque to MLIR: for example
+we can teach the MLIR verifier to enforce some invariants on the IR.
+
+```c++
+/// This is the definition of the Toy dialect. A dialect inherits from
+/// mlir::Dialect and registers custom operations and types (in its constructor).
+/// It can also overridde general behavior of dialects exposed as virtual
+/// methods, for example regarding verification and parsing/printing.
+class ToyDialect : public mlir::Dialect {
+ public:
+ explicit ToyDialect(mlir::MLIRContext *ctx);
+
+ /// Parse a type registered to this dialect. Overridding this method is
+ /// required for dialects that have custom types.
+ /// Technically this is only needed to be able to round-trip to textual IR.
+ mlir::Type parseType(llvm::StringRef tyData,
+ mlir::Location loc) const override;
+
+ /// Print a type registered to this dialect. Overridding this method is
+ /// only required for dialects that have custom types.
+ /// Technically this is only needed to be able to round-trip to textual IR.
+ void printType(mlir::Type type, llvm::raw_ostream &os) const override;
+};
+```
+
+The dialect can now be registered in the global registry:
+
+```c++
+ mlir::registerDialect<ToyDialect>();
+```
+
+Any new `MLIRContext` created from now on will recognize the `toy` prefix when
+parsing new types and invoke our `parseType` method. We will see later how to
+enable custom operations, but first let's define a custom type to handle Toy
+arrays.
+
+# Custom Type Handling
+
+As you may have noticed in the previous chapter, dialect specific types in MLIR
+are serialized as strings. In the case of Toy, an example would be
+`!toy<"array<2, 3>">`. MLIR will find the ToyDialect from the `!toy` prefix but
+it is up to the dialect itself to translate the content of the string into a
+proper type.
+
+First we need to define the class representing our type. In MLIR, types are
+references to immutable and uniqued objects owned by the MLIRContext. As such,
+our `ToyArrayType` will only be a wrapper around a pointer to an uniqued
+instance of `ToyArrayTypeStorage` in the Context and provide the public facade
+API to interact with the type.
+
+```c++
+class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
+ detail::ToyArrayTypeStorage> {
+ public:
+ /// Returns the dimensions for this Toy array, or an empty range for a generic array.
+ llvm::ArrayRef<int64_t> getShape();
+
+ /// Predicate to test if this array is generic (shape haven't been inferred yet).
+ bool isGeneric() { return getShape().empty(); }
+
+ /// Return the rank of this array (0 if it is generic)
+ int getRank() { return getShape().size(); }
+
+ /// Get the unique instance of this Type from the context.
+ /// A ToyArrayType is only defined by the shape of the array.
+ static ToyArrayType get(mlir::MLIRContext *context,
+ llvm::ArrayRef<int64_t> shape = {});
+
+ /// Support method to enable LLVM-style RTTI type casting.
+ static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
+};
+```
+
+Implementing `getShape()` for example is just about retrieving the pointer to
+the uniqued instance and forwarding:
+
+```c++
+llvm::ArrayRef<int64_t> ToyArrayType::getShape() {
+ return getImpl()->getShape();
+}
+```
+
+The calls to `getImpl()` give access to the `ToyArrayTypeStorage` that holds the
+information for this type. For details about how the storage of the type works,
+we'll refer you to `Ch3/mlir/ToyDialect.cpp`.
+
+Finally, the Toy dialect can register the type with MLIR, and implement some
+custom parsing for our types:
+
+```c++
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ // note the `toy` prefix that we reserve here.
+ : mlir::Dialect("toy", ctx) {
+ // Register our custom type with MLIR.
+ addTypes<ToyArrayType>();
+}
+
+/// Parse a type registered to this dialect, we expect only Toy arrays.
+mlir::Type ToyDialect::parseType(StringRef tyData,
+ mlir::Location loc) const {
+ // Sanity check: we only support array or array<...>
+ if (!tyData.startswith("array")) {
+ getContext()->emitError(loc, "Invalid Toy type '" + tyData +
+ "', array expected");
+ return nullptr;
+ }
+ // Drop the "array" prefix from the type name, we expect either an empty
+ // string or just the shape.
+ tyData = tyData.drop_front(StringRef("array").size());
+ // This is the generic array case without shape, early return it.
+ if (tyData.empty())
+ return ToyArrayType::get(getContext());
+
+ // Use a regex to parse the shape (for efficient we should store this regex in
+ // the dialect itself).
+ SmallVector<StringRef, 4> matches;
+ auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
+ if (!shapeRegex.match(tyData, &matches)) {
+ getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'");
+ return nullptr;
+ }
+ SmallVector<int64_t, 4> shape;
+ // Iterate through the captures, skip the first one which is the full string.
+ for (auto dimStr :
+ llvm::make_range(std::next(matches.begin()), matches.end())) {
+ if (dimStr.startswith(","))
+ continue; // POSIX misses non-capturing groups.
+ if (dimStr.empty())
+ continue; // '*' makes it an optional group capture
+ // Convert the capture to an integer
+ unsigned long long dim;
+ if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
+ getContext()->emitError(loc, Twine("Couldn't parse dimension as integer, matched: ") + dimStr);
+ return mlir::Type();
+ }
+ shape.push_back(dim);
+ }
+ // Finally we collected all the dimensions in the shape,
+ // create the array type.
+ return ToyArrayType::get(getContext(), shape);
+}
+```
+
+And we also update our IR generation from the Toy AST to use our new type
+instead of an opaque one:
+
+```c++
+template <typename T> mlir::Type getType(T shape) {
+ SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
+ return ToyArrayType::get(&context, shape64);
+}
+```
+
+From now on, MLIR knows how to parse types that are wrapped in `!toy<...>` and
+these won't be opaque anymore. The first consequence is that bogus IR with
+respect to our type won't be loaded anymore:
+
+```bash(.sh)
+$ echo 'func @foo() -> !toy<"bla">' | toyc -emit=mlir -x mlir -
+loc("<stdin>":1:21): error: Invalid Toy type 'bla', array expected
+$ echo 'func @foo() -> !toy<"array<>">' | toyc -emit=mlir -x mlir -
+loc("<stdin>":1:21): error: Invalid toy array shape '<>'
+$ echo 'func @foo() -> !toy<"array<1, >">' | toyc -emit=mlir -x mlir -
+loc("<stdin>":1:21): error: Invalid toy array shape '<1, >'
+$ echo 'func @foo() -> !toy<"array<1, 2, 3>">' | toyc -emit=mlir -x mlir -
+func @foo() -> !toy<"array<1, 3>">
+```
+
+## Defining a C++ Class for an Operation
+
+After defining our custom type, we will register all the operations for the Toy
+language. Let's walk through the creation of the `toy.generic_call` operation:
+
+```MLIR(.mlir)
+ %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
+ : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+```
+
+This operation takes a variable number of operands, all of which are expected to
+be Toy arrays, and return a single result. An operation inherit from `mlir::Op`
+and add some optional *traits* to customize its behavior.
+
+```c++
+class GenericCallOp
+ : public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::OneResult> {
+
+ public:
+ /// MLIR will use this to register the operation with the parser/printer.
+ static llvm::StringRef getOperationName() { return "toy.generic_call"; }
+
+ /// Operations can add custom verification beyond the traits they define.
+ /// We will ensure that all the operands are Toy arrays.
+ bool verify();
+
+ /// Interface to the builder to allow:
+ /// mlir::FuncBuilder::create<GenericCallOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.generic_call` operation accepts a callee name and a list of
+ /// arguments for the call.
+ static void build(mlir::FuncBuilder *builder, mlir::OperationState *state,
+ llvm::StringRef callee,
+ llvm::ArrayRef<mlir::Value *> arguments);
+
+ /// Return the name of the callee by fetching it from the attribute.
+ llvm::StringRef getCalleeName();
+
+ private:
+ friend class mlir::Operation;
+ using Op::Op;
+};
+```
+
+and we register this operation in the `ToyDialect` constructor:
+
+```c++
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ addOperations<GenericCallOp>();
+ addTypes<ToyArrayType>();
+}
+```
+
+After creating classes for each of our operations, our dialect is ready and we
+have now better invariants enforced in our IR, and nicer API to implement
+analyses and transformations in the [next chapter](Ch-4.md).
+
+## Using TableGen
+
+FIXME: complete
+
+## Revisiting the Builder API
+
+We can now update `MLIRGen.cpp`, previously our use of the builder was very
+generic and creating a call operation looked like:
+
+```
+ // Calls to user-defined function are mapped to a custom call that takes
+ // the callee name as an attribute.
+ mlir::OperationState result(&context, location, "toy.generic_call");
+ result.types.push_back(getType(VarType{}));
+ result.operands = std::move(operands);
+ for (auto &expr : call.getArgs()) {
+ auto *arg = mlirGen(*expr);
+ if (!arg)
+ return nullptr;
+ result.operands.push_back(arg);
+ }
+ auto calleeAttr = builder->getStringAttr(call.getCallee());
+ result.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
+ return builder->createOperation(result)->getResult(0);
+```
+
+We replace it with this new version:
+
+```c++
+ for (auto &expr : call.getArgs()) {
+ auto *arg = mlirGen(*expr);
+ if (!arg)
+ return nullptr;
+ operands.push_back(arg);
+ }
+ return builder->create<GenericCallOp>(location, call.getCallee(), operands)->getResult();
+```
+
+This interface offers better type safety, with some invariant enforced at the
+API level. For instance the `GenericCallOp` exposes now a `getResult()` method
+that does not take any argument, while before MLIR assumed the general cases and
+left open the possibility to have multiple returned values. The API was
+`getResult(int resultNum)`.
+
+# Putting It All Together
+
+After writing a class for each of our operation and implementing custom
+verifier, we try again the same example of invalid IR from the previous chapter:
+
+```bash(.sh)
+$ cat test/invalid.mlir
+func @main() {
+ %0 = "toy.print"() : () -> !toy<"array<2, 3>">
+}
+$ toyc test/invalid.mlir -emit=mlir
+loc("test/invalid.mlir":2:8): error: 'toy.print' op requires a single operand
+```
+
+This time the IR is correctly rejected by the verifier!
+
+In the [next chapter](Ch-4.md) we will leverage our new dialect to implement
+some high-level language-specific analyses and transformations for the Toy
+language.
diff --git a/mlir/include/mlir/IR/DialectTypeRegistry.def b/mlir/include/mlir/IR/DialectTypeRegistry.def
index b212a572aa6..c664241ab5e 100644
--- a/mlir/include/mlir/IR/DialectTypeRegistry.def
+++ b/mlir/include/mlir/IR/DialectTypeRegistry.def
@@ -27,6 +27,7 @@ DEFINE_TYPE_KIND_RANGE(LLVM)
DEFINE_TYPE_KIND_RANGE(QUANTIZATION)
DEFINE_TYPE_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
DEFINE_TYPE_KIND_RANGE(LINALG) // Linear Algebra Dialect
+DEFINE_TYPE_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
// The following ranges are reserved for experimenting with MLIR dialects in a
// private context without having to register them here.
diff --git a/mlir/test/Examples/Toy/Ch2/codegen.toy b/mlir/test/Examples/Toy/Ch2/codegen.toy
index f2397e63ff0..e361a09528a 100644
--- a/mlir/test/Examples/Toy/Ch2/codegen.toy
+++ b/mlir/test/Examples/Toy/Ch2/codegen.toy
@@ -25,8 +25,8 @@ def main() {
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy<"array<2, 3>">) -> !toy<"array<2, 3>">
# CHECK-NEXT: %2 = "toy.constant"() {value: dense<tensor<6xf64>, [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]>} : () -> !toy<"array<6>">
# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy<"array<6>">) -> !toy<"array<2, 3>">
-# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3, %1, %3) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
-# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1, %3, %1) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
# CHECK-NEXT: "toy.print"(%5) : (!toy<"array">) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()
diff --git a/mlir/test/Examples/Toy/Ch2/invalid.mlir b/mlir/test/Examples/Toy/Ch2/invalid.mlir
index 324d4ca2717..fe8369be982 100644
--- a/mlir/test/Examples/Toy/Ch2/invalid.mlir
+++ b/mlir/test/Examples/Toy/Ch2/invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: toyc-ch2 %s -emit=mlir 2>&1
+// RUN: toyc-ch2 %s -emit=mlir 2>&1
// This IR is not "valid":
diff --git a/mlir/test/Examples/Toy/Ch3/ast.toy b/mlir/test/Examples/Toy/Ch3/ast.toy
new file mode 100644
index 00000000000..0c904216757
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch3/ast.toy
@@ -0,0 +1,73 @@
+# RUN: toyc-ch3 %s -emit=ast 2>&1 | FileCheck %s
+
+
+# User defined generic function that operates solely on
+def multiply_transpose(a, b) {
+ return a * transpose(b);
+}
+
+def main() {
+ # Define a variable `a` with shape <2, 3>, initialized with the literal value.
+ # The shape is inferred from the supplied literal.
+ var a = [[1, 2, 3], [4, 5, 6]];
+ # b is identical to a, the literal array is implicitely reshaped: defining new
+ # variables is the way to reshape arrays (element count must match).
+ var b<2, 3> = [1, 2, 3, 4, 5, 6];
+ # This call will specialize `multiply_transpose` with <2, 3> for both
+ # arguments and deduce a return type of <2, 2> in initialization of `c`.
+ var c = multiply_transpose(a, b);
+ # A second call to `multiply_transpose` with <2, 3> for both arguments will
+ # reuse the previously specialized and inferred version and return `<2, 2>`
+ var d = multiply_transpose(b, a);
+ # A new call with `<2, 2>` for both dimension will trigger another
+ # specialization of `multiply_transpose`.
+ var e = multiply_transpose(b, c);
+ # Finally, calling into `multiply_transpose` with incompatible shape will
+ # trigger a shape inference error.
+ var e = multiply_transpose(transpose(a), c);
+}
+
+
+# CHECK: Module:
+# CHECK-NEXT: Function
+# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}Toy/Ch3/ast.toy:5:1'
+# CHECK-NEXT: Params: [a, b]
+# CHECK-NEXT: Block {
+# CHECK-NEXT: Retur
+# CHECK-NEXT: BinOp: * @{{.*}}Toy/Ch3/ast.toy:6:14
+# CHECK-NEXT: var: a @{{.*}}Toy/Ch3/ast.toy:6:10
+# CHECK-NEXT: Call 'transpose' [ @{{.*}}Toy/Ch3/ast.toy:6:14
+# CHECK-NEXT: var: b @{{.*}}Toy/Ch3/ast.toy:6:24
+# CHECK-NEXT: ]
+# CHECK-NEXT: } // Block
+# CHECK-NEXT: Function
+# CHECK-NEXT: Proto 'main' @{{.*}}Toy/Ch3/ast.toy:9:1'
+# CHECK-NEXT: Params: []
+# CHECK-NEXT: Block {
+# CHECK-NEXT: VarDecl a<> @{{.*}}Toy/Ch3/ast.toy:12:3
+# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}Toy/Ch3/ast.toy:12:11
+# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}Toy/Ch3/ast.toy:15:3
+# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}Toy/Ch3/ast.toy:15:17
+# CHECK-NEXT: VarDecl c<> @{{.*}}Toy/Ch3/ast.toy:18:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch3/ast.toy:18:11
+# CHECK-NEXT: var: a @{{.*}}Toy/Ch3/ast.toy:18:30
+# CHECK-NEXT: var: b @{{.*}}Toy/Ch3/ast.toy:18:33
+# CHECK-NEXT: ]
+# CHECK-NEXT: VarDecl d<> @{{.*}}Toy/Ch3/ast.toy:21:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch3/ast.toy:21:11
+# CHECK-NEXT: var: b @{{.*}}Toy/Ch3/ast.toy:21:30
+# CHECK-NEXT: var: a @{{.*}}Toy/Ch3/ast.toy:21:33
+# CHECK-NEXT: ]
+# CHECK-NEXT: VarDecl e<> @{{.*}}Toy/Ch3/ast.toy:24:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch3/ast.toy:24:11
+# CHECK-NEXT: var: b @{{.*}}Toy/Ch3/ast.toy:24:30
+# CHECK-NEXT: var: c @{{.*}}Toy/Ch3/ast.toy:24:33
+# CHECK-NEXT: ]
+# CHECK-NEXT: VarDecl e<> @{{.*}}Toy/Ch3/ast.toy:27:3
+# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch3/ast.toy:27:11
+# CHECK-NEXT: Call 'transpose' [ @{{.*}}Toy/Ch3/ast.toy:27:30
+# CHECK-NEXT: var: a @{{.*}}Toy/Ch3/ast.toy:27:40
+# CHECK-NEXT: ]
+# CHECK-NEXT: var: c @{{.*}}Toy/Ch3/ast.toy:27:44
+# CHECK-NEXT: ]
+
diff --git a/mlir/test/Examples/Toy/Ch3/codegen.toy b/mlir/test/Examples/Toy/Ch3/codegen.toy
new file mode 100644
index 00000000000..a4d7058f8f3
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch3/codegen.toy
@@ -0,0 +1,32 @@
+# RUN: toyc-ch3 %s -emit=mlir 2>&1 | FileCheck %s
+
+# User defined generic function that operates on unknown shaped arguments
+def multiply_transpose(a, b) {
+ return a * transpose(b);
+}
+
+def main() {
+ var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+ var b<2, 3> = [1, 2, 3, 4, 5, 6];
+ var c = multiply_transpose(a, b);
+ var d = multiply_transpose(b, a);
+ print(d);
+}
+
+# CHECK-LABEL: func @multiply_transpose(%arg0: !toy<"array">, %arg1: !toy<"array">)
+# CHECK-NEXT: attributes {toy.generic: true} {
+# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy<"array">) -> !toy<"array">
+# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy<"array">, !toy<"array">) -> !toy<"array">
+# CHECK-NEXT: "toy.return"(%1) : (!toy<"array">) -> ()
+# CHECK-NEXT: }
+
+# CHECK-LABEL: func @main() {
+# CHECK-NEXT: %0 = "toy.constant"() {value: dense<tensor<2x3xf64>, {{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> !toy<"array<2, 3>">
+# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy<"array<2, 3>">) -> !toy<"array<2, 3>">
+# CHECK-NEXT: %2 = "toy.constant"() {value: dense<tensor<6xf64>, [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]>} : () -> !toy<"array<6>">
+# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy<"array<6>">) -> !toy<"array<2, 3>">
+# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+# CHECK-NEXT: "toy.print"(%5) : (!toy<"array">) -> ()
+# CHECK-NEXT: "toy.return"() : () -> ()
+
diff --git a/mlir/test/Examples/Toy/Ch3/invalid.mlir b/mlir/test/Examples/Toy/Ch3/invalid.mlir
new file mode 100644
index 00000000000..2dd22280e76
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch3/invalid.mlir
@@ -0,0 +1,11 @@
+// RUN: not toyc-ch3 %s -emit=mlir 2>&1
+
+
+// This IR is not "valid":
+// - toy.print should not return a value.
+// - toy.print should take an argument.
+// - There should be a block terminator.
+// This all round-trip since this is opaque for MLIR.
+func @main() {
+ %0 = "toy.print"() : () -> !toy<"array<2, 3>">
+}
diff --git a/mlir/test/Examples/Toy/Ch3/scalar.toy b/mlir/test/Examples/Toy/Ch3/scalar.toy
new file mode 100644
index 00000000000..41153ec92e1
--- /dev/null
+++ b/mlir/test/Examples/Toy/Ch3/scalar.toy
@@ -0,0 +1,12 @@
+def main() {
+ var a<2, 2> = 5.5;
+ print(a);
+}
+
+# CHECK-LABEL: func @main() {
+# CHECK-NEXT: %0 = "toy.constant"() {value: dense<tensor<1xf64>, [5.500000e+00]>} : () -> !toy<"array<1>">
+# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy<"array<1>">) -> !toy<"array<2, 2>">
+# CHECK-NEXT: "toy.print"(%1) : (!toy<"array<2, 2>">) -> ()
+# CHECK-NEXT: "toy.return"() : () -> ()
+# CHECK-NEXT: }
+
OpenPOWER on IntegriCloud