//===- AST.h - Node definition for the Toy AST ----------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the AST for the Toy language. It is optimized for // simplicity, not efficiency. The AST forms a tree structure where each node // references its children using std::unique_ptr<>. // //===----------------------------------------------------------------------===// #ifndef MLIR_TUTORIAL_TOY_AST_H_ #define MLIR_TUTORIAL_TOY_AST_H_ #include "toy/Lexer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include namespace toy { /// A variable type with shape information. struct VarType { std::vector shape; }; /// Base class for all expression nodes. class ExprAST { public: enum ExprASTKind { Expr_VarDecl, Expr_Return, Expr_Num, Expr_Literal, Expr_Var, Expr_BinOp, Expr_Call, Expr_Print, }; ExprAST(ExprASTKind kind, Location location) : kind(kind), location(location) {} virtual ~ExprAST() = default; ExprASTKind getKind() const { return kind; } const Location &loc() { return location; } private: const ExprASTKind kind; Location location; }; /// A block-list of expressions. using ExprASTList = std::vector>; /// Expression class for numeric literals like "1.0". class NumberExprAST : public ExprAST { double Val; public: NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} double getValue() { return Val; } /// LLVM style RTTI static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } }; /// Expression class for a literal value. class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; public: LiteralExprAST(Location loc, std::vector> values, std::vector dims) : ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) {} llvm::ArrayRef> getValues() { return values; } llvm::ArrayRef getDims() { return dims; } /// LLVM style RTTI static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } }; /// Expression class for referencing a variable, like "a". class VariableExprAST : public ExprAST { std::string name; public: VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, loc), name(name) {} llvm::StringRef getName() { return name; } /// LLVM style RTTI static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } }; /// Expression class for defining a variable. class VarDeclExprAST : public ExprAST { std::string name; VarType type; std::unique_ptr initVal; public: VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr initVal) : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } const VarType &getType() { return type; } /// LLVM style RTTI static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } }; /// Expression class for a return operator. class ReturnExprAST : public ExprAST { llvm::Optional> expr; public: ReturnExprAST(Location loc, llvm::Optional> expr) : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} llvm::Optional getExpr() { if (expr.hasValue()) return expr->get(); return llvm::None; } /// LLVM style RTTI static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } }; /// Expression class for a binary operator. class BinaryExprAST : public ExprAST { char op; std::unique_ptr lhs, rhs; public: char getOp() { return op; } ExprAST *getLHS() { return lhs.get(); } ExprAST *getRHS() { return rhs.get(); } BinaryExprAST(Location loc, char Op, std::unique_ptr lhs, std::unique_ptr rhs) : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), rhs(std::move(rhs)) {} /// LLVM style RTTI static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } }; /// Expression class for function calls. class CallExprAST : public ExprAST { std::string callee; std::vector> args; public: CallExprAST(Location loc, const std::string &callee, std::vector> args) : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} llvm::StringRef getCallee() { return callee; } llvm::ArrayRef> getArgs() { return args; } /// LLVM style RTTI static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } }; /// Expression class for builtin print calls. class PrintExprAST : public ExprAST { std::unique_ptr arg; public: PrintExprAST(Location loc, std::unique_ptr arg) : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} ExprAST *getArg() { return arg.get(); } /// LLVM style RTTI static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } }; /// This class represents the "prototype" for a function, which captures its /// name, and its argument names (thus implicitly the number of arguments the /// function takes). class PrototypeAST { Location location; std::string name; std::vector> args; public: PrototypeAST(Location location, const std::string &name, std::vector> args) : location(location), name(name), args(std::move(args)) {} const Location &loc() { return location; } llvm::StringRef getName() const { return name; } llvm::ArrayRef> getArgs() { return args; } }; /// This class represents a function definition itself. class FunctionAST { std::unique_ptr proto; std::unique_ptr body; public: FunctionAST(std::unique_ptr proto, std::unique_ptr body) : proto(std::move(proto)), body(std::move(body)) {} PrototypeAST *getProto() { return proto.get(); } ExprASTList *getBody() { return body.get(); } }; /// This class represents a list of functions to be processed together class ModuleAST { std::vector functions; public: ModuleAST(std::vector functions) : functions(std::move(functions)) {} auto begin() -> decltype(functions.begin()) { return functions.begin(); } auto end() -> decltype(functions.end()) { return functions.end(); } }; void dump(ModuleAST &); } // namespace toy #endif // MLIR_TUTORIAL_TOY_AST_H_