diff options
| author | River Riddle <riverriddle@google.com> | 2019-11-06 18:20:24 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-06 18:21:03 -0800 |
| commit | 22cfff7043daa10ca3e00afd49ff80b882bbb107 (patch) | |
| tree | c080e8627df81d30895b59cd69483ee081c5783f /mlir/examples | |
| parent | f6188b5b07418dda04743a51f0ddcbca30c7a196 (diff) | |
| download | bcm5719-llvm-22cfff7043daa10ca3e00afd49ff80b882bbb107.tar.gz bcm5719-llvm-22cfff7043daa10ca3e00afd49ff80b882bbb107.zip | |
NFC: Uniformize parser naming scheme in Toy tutorial to camelCase and tidy a bit of the implementation.
PiperOrigin-RevId: 278982817
Diffstat (limited to 'mlir/examples')
36 files changed, 1225 insertions, 1306 deletions
diff --git a/mlir/examples/toy/Ch1/CMakeLists.txt b/mlir/examples/toy/Ch1/CMakeLists.txt index dd26cf140be..f4e85556130 100644 --- a/mlir/examples/toy/Ch1/CMakeLists.txt +++ b/mlir/examples/toy/Ch1/CMakeLists.txt @@ -6,4 +6,7 @@ add_toy_chapter(toyc-ch1 toyc.cpp parser/AST.cpp ) -include_directories(include/)
\ No newline at end of file +include_directories(include/) +target_link_libraries(toyc-ch1 + PRIVATE + MLIRSupport) diff --git a/mlir/examples/toy/Ch1/include/toy/AST.h b/mlir/examples/toy/Ch1/include/toy/AST.h index 2ad3392c11a..901164b0f39 100644 --- a/mlir/examples/toy/Ch1/include/toy/AST.h +++ b/mlir/examples/toy/Ch1/include/toy/AST.h @@ -54,7 +54,6 @@ public: ExprAST(ExprASTKind kind, Location location) : kind(kind), location(location) {} - virtual ~ExprAST() = default; ExprASTKind getKind() const { return kind; } @@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST { double Val; public: - NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} double getValue() { return Val; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } }; /// Expression class for a literal value. @@ -93,10 +92,11 @@ public: : ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) {} - std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; } - std::vector<int64_t> &getDims() { return dims; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } + llvm::ArrayRef<int64_t> getDims() { return dims; } + /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } }; /// Expression class for referencing a variable, like "a". @@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST { std::string name; public: - VariableExprAST(Location loc, const std::string &name) + VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, loc), name(name) {} llvm::StringRef getName() { return name; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } }; /// Expression class for defining a variable. @@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST { std::unique_ptr<ExprAST> initVal; public: - VarDeclExprAST(Location loc, const std::string &name, VarType type, + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST> initVal) : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } - VarType &getType() { return type; } + const VarType &getType() { return type; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } }; /// Expression class for a return operator. @@ -144,61 +144,61 @@ public: llvm::Optional<ExprAST *> getExpr() { if (expr.hasValue()) return expr->get(); - return llvm::NoneType(); + return llvm::None; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } }; /// Expression class for a binary operator. class BinaryExprAST : public ExprAST { - char Op; - std::unique_ptr<ExprAST> LHS, RHS; + char op; + std::unique_ptr<ExprAST> lhs, rhs; public: - char getOp() { return Op; } - ExprAST *getLHS() { return LHS.get(); } - ExprAST *getRHS() { return RHS.get(); } + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } - BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS, - std::unique_ptr<ExprAST> RHS) - : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), - RHS(std::move(RHS)) {} + BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs, + std::unique_ptr<ExprAST> rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } }; /// Expression class for function calls. class CallExprAST : public ExprAST { - std::string Callee; - std::vector<std::unique_ptr<ExprAST>> Args; + std::string callee; + std::vector<std::unique_ptr<ExprAST>> args; public: - CallExprAST(Location loc, const std::string &Callee, - std::vector<std::unique_ptr<ExprAST>> Args) - : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + CallExprAST(Location loc, const std::string &callee, + std::vector<std::unique_ptr<ExprAST>> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} - llvm::StringRef getCallee() { return Callee; } - llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; } + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } }; /// Expression class for builtin print calls. class PrintExprAST : public ExprAST { - std::unique_ptr<ExprAST> Arg; + std::unique_ptr<ExprAST> arg; public: - PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg) - : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} - ExprAST *getArg() { return Arg.get(); } + ExprAST *getArg() { return arg.get(); } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } }; /// This class represents the "prototype" for a function, which captures its @@ -215,23 +215,21 @@ public: : location(location), name(name), args(std::move(args)) {} const Location &loc() { return location; } - const std::string &getName() const { return name; } - const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() { - return args; - } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; } }; /// This class represents a function definition itself. class FunctionAST { - std::unique_ptr<PrototypeAST> Proto; - std::unique_ptr<ExprASTList> Body; + std::unique_ptr<PrototypeAST> proto; + std::unique_ptr<ExprASTList> body; public: - FunctionAST(std::unique_ptr<PrototypeAST> Proto, - std::unique_ptr<ExprASTList> Body) - : Proto(std::move(Proto)), Body(std::move(Body)) {} - PrototypeAST *getProto() { return Proto.get(); } - ExprASTList *getBody() { return Body.get(); } + FunctionAST(std::unique_ptr<PrototypeAST> proto, + std::unique_ptr<ExprASTList> body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } }; /// This class represents a list of functions to be processed together diff --git a/mlir/examples/toy/Ch1/include/toy/Lexer.h b/mlir/examples/toy/Ch1/include/toy/Lexer.h index 21f92614912..144388c460c 100644 --- a/mlir/examples/toy/Ch1/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch1/include/toy/Lexer.h @@ -89,13 +89,13 @@ public: /// Return the current identifier (prereq: getCurToken() == tok_identifier) llvm::StringRef getId() { assert(curTok == tok_identifier); - return IdentifierStr; + return identifierStr; } /// Return the current number (prereq: getCurToken() == tok_number) double getValue() { assert(curTok == tok_number); - return NumVal; + return numVal; } /// Return the location for the beginning of the current token. @@ -135,56 +135,58 @@ private: /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(LastChar)) - LastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; lastLocation.col = curCol; - if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* - IdentifierStr = (char)LastChar; - while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') - IdentifierStr += (char)LastChar; + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; - if (IdentifierStr == "return") + if (identifierStr == "return") return tok_return; - if (IdentifierStr == "def") + if (identifierStr == "def") return tok_def; - if (IdentifierStr == "var") + if (identifierStr == "var") return tok_var; return tok_identifier; } - if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ - std::string NumStr; + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; do { - NumStr += LastChar; - LastChar = Token(getNextChar()); - } while (isdigit(LastChar) || LastChar == '.'); + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); - NumVal = strtod(NumStr.c_str(), nullptr); + numVal = strtod(numStr.c_str(), nullptr); return tok_number; } - if (LastChar == '#') { + if (lastChar == '#') { // Comment until end of line. - do - LastChar = Token(getNextChar()); - while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (LastChar != EOF) + if (lastChar != EOF) return getTok(); } // Check for end of file. Don't eat the EOF. - if (LastChar == EOF) + if (lastChar == EOF) return tok_eof; // Otherwise, just return the character as its ascii value. - Token ThisChar = Token(LastChar); - LastChar = Token(getNextChar()); - return ThisChar; + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; } /// The last token read from the input. @@ -194,15 +196,15 @@ private: Location lastLocation; /// If the current Token is an identifier, this string contains the value. - std::string IdentifierStr; + std::string identifierStr; /// If the current Token is a number, this contains the value. - double NumVal = 0; + double numVal = 0; /// The last value returned by getNextChar(). We need to keep it around as we /// always need to read ahead one character to decide when to end a token and /// we can't put it back in the stream after reading from it. - Token LastChar = Token(' '); + Token lastChar = Token(' '); /// Keep track of the current line number in the input stream int curLineNum = 0; diff --git a/mlir/examples/toy/Ch1/include/toy/Parser.h b/mlir/examples/toy/Ch1/include/toy/Parser.h index ec3d7654a85..9e219e56551 100644 --- a/mlir/examples/toy/Ch1/include/toy/Parser.h +++ b/mlir/examples/toy/Ch1/include/toy/Parser.h @@ -48,13 +48,13 @@ public: Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. - std::unique_ptr<ModuleAST> ParseModule() { + std::unique_ptr<ModuleAST> parseModule() { lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector<FunctionAST> functions; - while (auto F = ParseDefinition()) { - functions.push_back(std::move(*F)); + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); if (lexer.getCurToken() == tok_eof) break; } @@ -70,14 +70,14 @@ private: /// Parse a return statement. /// return :== return ; | return expr ; - std::unique_ptr<ReturnExprAST> ParseReturn() { + std::unique_ptr<ReturnExprAST> parseReturn() { auto loc = lexer.getLastLocation(); lexer.consume(tok_return); // return takes an optional argument llvm::Optional<std::unique_ptr<ExprAST>> expr; if (lexer.getCurToken() != ';') { - expr = ParseExpression(); + expr = parseExpression(); if (!expr) return nullptr; } @@ -86,18 +86,18 @@ private: /// Parse a literal number. /// numberexpr ::= number - std::unique_ptr<ExprAST> ParseNumberExpr() { + std::unique_ptr<ExprAST> parseNumberExpr() { auto loc = lexer.getLastLocation(); - auto Result = + auto result = std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); lexer.consume(tok_number); - return std::move(Result); + return std::move(result); } /// Parse a literal array expression. /// tensorLiteral ::= [ literalList ] | number /// literalList ::= tensorLiteral | tensorLiteral, literalList - std::unique_ptr<ExprAST> ParseTensorLiteralExpr() { + std::unique_ptr<ExprAST> parseTensorLiteralExpr() { auto loc = lexer.getLastLocation(); lexer.consume(Token('[')); @@ -108,13 +108,13 @@ private: do { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { - values.push_back(ParseTensorLiteralExpr()); + values.push_back(parseTensorLiteralExpr()); if (!values.back()) return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError<ExprAST>("<num> or [", "in literal expression"); - values.push_back(ParseNumberExpr()); + values.push_back(parseNumberExpr()); } // End of this list on ']' @@ -130,8 +130,10 @@ private: if (values.empty()) return parseError<ExprAST>("<something>", "to fill literal expression"); lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that /// dimensions are uniform. if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { @@ -143,7 +145,7 @@ private: "inside literal expression"); // Append the nested dimensions to the current level - auto &firstDims = firstLiteral->getDims(); + auto firstDims = firstLiteral->getDims(); dims.insert(dims.end(), firstDims.begin(), firstDims.end()); // Sanity check that shape is uniform across all elements of the list. @@ -162,22 +164,22 @@ private: } /// parenexpr ::= '(' expression ')' - std::unique_ptr<ExprAST> ParseParenExpr() { + std::unique_ptr<ExprAST> parseParenExpr() { lexer.getNextToken(); // eat (. - auto V = ParseExpression(); - if (!V) + auto v = parseExpression(); + if (!v) return nullptr; if (lexer.getCurToken() != ')') return parseError<ExprAST>(")", "to close expression with parentheses"); lexer.consume(Token(')')); - return V; + return v; } /// identifierexpr /// ::= identifier /// ::= identifier '(' expression ')' - std::unique_ptr<ExprAST> ParseIdentifierExpr() { + std::unique_ptr<ExprAST> parseIdentifierExpr() { std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); @@ -188,11 +190,11 @@ private: // This is a function call. lexer.consume(Token('(')); - std::vector<std::unique_ptr<ExprAST>> Args; + std::vector<std::unique_ptr<ExprAST>> args; if (lexer.getCurToken() != ')') { while (true) { - if (auto Arg = ParseExpression()) - Args.push_back(std::move(Arg)); + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); else return nullptr; @@ -208,14 +210,14 @@ private: // It can be a builtin call to print if (name == "print") { - if (Args.size() != 1) + if (args.size() != 1) return parseError<ExprAST>("<single arg>", "as argument to print()"); - return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0])); + return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0])); } // Call to a user-defined function - return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args)); + return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args)); } /// primary @@ -223,20 +225,20 @@ private: /// ::= numberexpr /// ::= parenexpr /// ::= tensorliteral - std::unique_ptr<ExprAST> ParsePrimary() { + std::unique_ptr<ExprAST> parsePrimary() { switch (lexer.getCurToken()) { default: llvm::errs() << "unknown token '" << lexer.getCurToken() << "' when expecting an expression\n"; return nullptr; case tok_identifier: - return ParseIdentifierExpr(); + return parseIdentifierExpr(); case tok_number: - return ParseNumberExpr(); + return parseNumberExpr(); case '(': - return ParseParenExpr(); + return parseParenExpr(); case '[': - return ParseTensorLiteralExpr(); + return parseTensorLiteralExpr(); case ';': return nullptr; case '}': @@ -248,54 +250,54 @@ private: /// argument indicates the precedence of the current binary operator. /// /// binoprhs ::= ('+' primary)* - std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, - std::unique_ptr<ExprAST> LHS) { + std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec, + std::unique_ptr<ExprAST> lhs) { // If this is a binop, find its precedence. while (true) { - int TokPrec = GetTokPrecedence(); + int tokPrec = getTokPrecedence(); // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (TokPrec < ExprPrec) - return LHS; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. - int BinOp = lexer.getCurToken(); - lexer.consume(Token(BinOp)); + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); auto loc = lexer.getLastLocation(); // Parse the primary expression after the binary operator. - auto RHS = ParsePrimary(); - if (!RHS) + auto rhs = parsePrimary(); + if (!rhs) return parseError<ExprAST>("expression", "to complete binary operator"); - // If BinOp binds less tightly with RHS than the operator after RHS, let - // the pending operator take RHS as its LHS. - int NextPrec = GetTokPrecedence(); - if (TokPrec < NextPrec) { - RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); - if (!RHS) + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) return nullptr; } - // Merge LHS/RHS. - LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + // Merge lhs/RHS. + lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); } } - /// expression::= primary binoprhs - std::unique_ptr<ExprAST> ParseExpression() { - auto LHS = ParsePrimary(); - if (!LHS) + /// expression::= primary binop rhs + std::unique_ptr<ExprAST> parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) return nullptr; - return ParseBinOpRHS(0, std::move(LHS)); + return parseBinOpRHS(0, std::move(lhs)); } /// type ::= < shape_list > /// shape_list ::= num | num , shape_list - std::unique_ptr<VarType> ParseType() { + std::unique_ptr<VarType> parseType() { if (lexer.getCurToken() != '<') return parseError<VarType>("<", "to begin type"); lexer.getNextToken(); // eat < @@ -319,7 +321,7 @@ private: /// and identifier and an optional type (shape specification) before the /// initializer. /// decl ::= var identifier [ type ] = expr - std::unique_ptr<VarDeclExprAST> ParseDeclaration() { + std::unique_ptr<VarDeclExprAST> parseDeclaration() { if (lexer.getCurToken() != tok_var) return parseError<VarDeclExprAST>("var", "to begin declaration"); auto loc = lexer.getLastLocation(); @@ -333,7 +335,7 @@ private: std::unique_ptr<VarType> type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { - type = ParseType(); + type = parseType(); if (!type) return nullptr; } @@ -341,7 +343,7 @@ private: if (!type) type = std::make_unique<VarType>(); lexer.consume(Token('=')); - auto expr = ParseExpression(); + auto expr = parseExpression(); return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), std::move(*type), std::move(expr)); } @@ -352,7 +354,7 @@ private: /// block ::= { expression_list } /// expression_list ::= block_expr ; expression_list /// block_expr ::= decl | "return" | expr - std::unique_ptr<ExprASTList> ParseBlock() { + std::unique_ptr<ExprASTList> parseBlock() { if (lexer.getCurToken() != '{') return parseError<ExprASTList>("{", "to begin block"); lexer.consume(Token('{')); @@ -366,19 +368,19 @@ private: while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration - auto varDecl = ParseDeclaration(); + auto varDecl = parseDeclaration(); if (!varDecl) return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement - auto ret = ParseReturn(); + auto ret = parseReturn(); if (!ret) return nullptr; exprList->push_back(std::move(ret)); } else { // General expression - auto expr = ParseExpression(); + auto expr = parseExpression(); if (!expr) return nullptr; exprList->push_back(std::move(expr)); @@ -401,13 +403,13 @@ private: /// prototype ::= def id '(' decl_list ')' /// decl_list ::= identifier | identifier, decl_list - std::unique_ptr<PrototypeAST> ParsePrototype() { + std::unique_ptr<PrototypeAST> parsePrototype() { auto loc = lexer.getLastLocation(); lexer.consume(tok_def); if (lexer.getCurToken() != tok_identifier) return parseError<PrototypeAST>("function name", "in prototype"); - std::string FnName = lexer.getId(); + std::string fnName = lexer.getId(); lexer.consume(tok_identifier); if (lexer.getCurToken() != '(') @@ -435,7 +437,7 @@ private: // success. lexer.consume(Token(')')); - return std::make_unique<PrototypeAST>(std::move(loc), FnName, + return std::make_unique<PrototypeAST>(std::move(loc), fnName, std::move(args)); } @@ -443,18 +445,18 @@ private: /// `def` keyword, followed by a block containing a list of expressions. /// /// definition ::= prototype block - std::unique_ptr<FunctionAST> ParseDefinition() { - auto Proto = ParsePrototype(); - if (!Proto) + std::unique_ptr<FunctionAST> parseDefinition() { + auto proto = parsePrototype(); + if (!proto) return nullptr; - if (auto block = ParseBlock()) - return std::make_unique<FunctionAST>(std::move(Proto), std::move(block)); + if (auto block = parseBlock()) + return std::make_unique<FunctionAST>(std::move(proto), std::move(block)); return nullptr; } /// Get the precedence of the pending binary operator token. - int GetTokPrecedence() { + int getTokPrecedence() { if (!isascii(lexer.getCurToken())) return -1; diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp index fde8b101e83..0c7735ec9a4 100644 --- a/mlir/examples/toy/Ch1/parser/AST.cpp +++ b/mlir/examples/toy/Ch1/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -40,22 +41,22 @@ struct Indent { /// the way. The only data member is the current indentation level. class ASTDumper { public: - void dump(ModuleAST *Node); + void dump(ModuleAST *node); private: - void dump(VarType &type); + void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); void dump(ExprASTList *exprList); void dump(NumberExprAST *num); - void dump(LiteralExprAST *Node); - void dump(VariableExprAST *Node); - void dump(ReturnExprAST *Node); - void dump(BinaryExprAST *Node); - void dump(CallExprAST *Node); - void dump(PrintExprAST *Node); - void dump(PrototypeAST *Node); - void dump(FunctionAST *Node); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); // Actually print spaces matching the current indentation level void indent() { @@ -68,8 +69,8 @@ private: } // namespace /// Return a formatted string for the location of any node -template <typename T> static std::string loc(T *Node) { - const auto &loc = Node->loc(); +template <typename T> static std::string loc(T *node) { + const auto &loc = node->loc(); return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + llvm::Twine(loc.col)) .str(); @@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *lit_or_num) { +void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal - if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) { + if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { llvm::errs() << num->getValue(); return; } - auto *literal = llvm::cast<LiteralExprAST>(lit_or_num); + auto *literal = llvm::cast<LiteralExprAST>(litOrNum); // Print the dimension for this literal first llvm::errs() << "<"; - { - const char *sep = ""; - for (auto dim : literal->getDims()) { - llvm::errs() << sep << dim; - sep = ", "; - } - } + mlir::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - const char *sep = ""; - for (auto &elt : literal->getValues()) { - llvm::errs() << sep; - printLitHelper(elt.get()); - sep = ", "; - } + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } /// Print a literal, see the recursive helper above for the implementation. -void ASTDumper::dump(LiteralExprAST *Node) { +void ASTDumper::dump(LiteralExprAST *node) { INDENT(); llvm::errs() << "Literal: "; - printLitHelper(Node); - llvm::errs() << " " << loc(Node) << "\n"; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; } /// Print a variable reference (just a name). -void ASTDumper::dump(VariableExprAST *Node) { +void ASTDumper::dump(VariableExprAST *node) { INDENT(); - llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; } /// Return statement print the return and its (optional) argument. -void ASTDumper::dump(ReturnExprAST *Node) { +void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (Node->getExpr().hasValue()) - return dump(*Node->getExpr()); + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) { } /// Print a binary operation, first the operator, then recurse into LHS and RHS. -void ASTDumper::dump(BinaryExprAST *Node) { +void ASTDumper::dump(BinaryExprAST *node) { INDENT(); - llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; - dump(Node->getLHS()); - dump(Node->getRHS()); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); } /// Print a call expression, first the callee name and the list of args by /// recursing into each individual argument. -void ASTDumper::dump(CallExprAST *Node) { +void ASTDumper::dump(CallExprAST *node) { INDENT(); - llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; - for (auto &arg : Node->getArgs()) + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) dump(arg.get()); indent(); llvm::errs() << "]\n"; } /// Print a builtin print call, first the builtin name and then the argument. -void ASTDumper::dump(PrintExprAST *Node) { +void ASTDumper::dump(PrintExprAST *node) { INDENT(); - llvm::errs() << "Print [ " << loc(Node) << "\n"; - dump(Node->getArg()); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); indent(); llvm::errs() << "]\n"; } /// Print type: only the shape is printed in between '<' and '>' -void ASTDumper::dump(VarType &type) { +void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - const char *sep = ""; - for (auto shape : type.shape) { - llvm::errs() << sep << shape; - sep = ", "; - } + mlir::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } /// Print a function prototype, first the function name, and then the list of /// parameters names. -void ASTDumper::dump(PrototypeAST *Node) { +void ASTDumper::dump(PrototypeAST *node) { INDENT(); - llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - const char *sep = ""; - for (auto &arg : Node->getArgs()) { - llvm::errs() << sep << arg->getName(); - sep = ", "; - } + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } /// Print a function, first the prototype and then the body. -void ASTDumper::dump(FunctionAST *Node) { +void ASTDumper::dump(FunctionAST *node) { INDENT(); llvm::errs() << "Function \n"; - dump(Node->getProto()); - dump(Node->getBody()); + dump(node->getProto()); + dump(node->getBody()); } /// Print a module, actually loop over the functions and print them in sequence. -void ASTDumper::dump(ModuleAST *Node) { +void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &F : *Node) - dump(&F); + for (auto &f : *node) + dump(&f); } namespace toy { diff --git a/mlir/examples/toy/Ch1/toyc.cpp b/mlir/examples/toy/Ch1/toyc.cpp index dd308caa24b..37794d5c4d9 100644 --- a/mlir/examples/toy/Ch1/toyc.cpp +++ b/mlir/examples/toy/Ch1/toyc.cpp @@ -30,7 +30,7 @@ using namespace toy; namespace cl = llvm::cl; -static cl::opt<std::string> InputFilename(cl::Positional, +static cl::opt<std::string> inputFilename(cl::Positional, cl::desc("<input toy file>"), cl::init("-"), cl::value_desc("filename")); @@ -44,22 +44,22 @@ static cl::opt<enum Action> /// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { - llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr = + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code EC = FileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; } - auto buffer = FileOrErr.get()->getBuffer(); + auto buffer = fileOrErr.get()->getBuffer(); LexerBuffer lexer(buffer.begin(), buffer.end(), filename); Parser parser(lexer); - return parser.ParseModule(); + return parser.parseModule(); } int main(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); - auto moduleAST = parseInputFile(InputFilename); + auto moduleAST = parseInputFile(inputFilename); if (!moduleAST) return 1; diff --git a/mlir/examples/toy/Ch2/include/toy/AST.h b/mlir/examples/toy/Ch2/include/toy/AST.h index 2ad3392c11a..901164b0f39 100644 --- a/mlir/examples/toy/Ch2/include/toy/AST.h +++ b/mlir/examples/toy/Ch2/include/toy/AST.h @@ -54,7 +54,6 @@ public: ExprAST(ExprASTKind kind, Location location) : kind(kind), location(location) {} - virtual ~ExprAST() = default; ExprASTKind getKind() const { return kind; } @@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST { double Val; public: - NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} double getValue() { return Val; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } }; /// Expression class for a literal value. @@ -93,10 +92,11 @@ public: : ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) {} - std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; } - std::vector<int64_t> &getDims() { return dims; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } + llvm::ArrayRef<int64_t> getDims() { return dims; } + /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } }; /// Expression class for referencing a variable, like "a". @@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST { std::string name; public: - VariableExprAST(Location loc, const std::string &name) + VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, loc), name(name) {} llvm::StringRef getName() { return name; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } }; /// Expression class for defining a variable. @@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST { std::unique_ptr<ExprAST> initVal; public: - VarDeclExprAST(Location loc, const std::string &name, VarType type, + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST> initVal) : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } - VarType &getType() { return type; } + const VarType &getType() { return type; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } }; /// Expression class for a return operator. @@ -144,61 +144,61 @@ public: llvm::Optional<ExprAST *> getExpr() { if (expr.hasValue()) return expr->get(); - return llvm::NoneType(); + return llvm::None; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } }; /// Expression class for a binary operator. class BinaryExprAST : public ExprAST { - char Op; - std::unique_ptr<ExprAST> LHS, RHS; + char op; + std::unique_ptr<ExprAST> lhs, rhs; public: - char getOp() { return Op; } - ExprAST *getLHS() { return LHS.get(); } - ExprAST *getRHS() { return RHS.get(); } + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } - BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS, - std::unique_ptr<ExprAST> RHS) - : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), - RHS(std::move(RHS)) {} + BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs, + std::unique_ptr<ExprAST> rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } }; /// Expression class for function calls. class CallExprAST : public ExprAST { - std::string Callee; - std::vector<std::unique_ptr<ExprAST>> Args; + std::string callee; + std::vector<std::unique_ptr<ExprAST>> args; public: - CallExprAST(Location loc, const std::string &Callee, - std::vector<std::unique_ptr<ExprAST>> Args) - : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + CallExprAST(Location loc, const std::string &callee, + std::vector<std::unique_ptr<ExprAST>> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} - llvm::StringRef getCallee() { return Callee; } - llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; } + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } }; /// Expression class for builtin print calls. class PrintExprAST : public ExprAST { - std::unique_ptr<ExprAST> Arg; + std::unique_ptr<ExprAST> arg; public: - PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg) - : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} - ExprAST *getArg() { return Arg.get(); } + ExprAST *getArg() { return arg.get(); } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } }; /// This class represents the "prototype" for a function, which captures its @@ -215,23 +215,21 @@ public: : location(location), name(name), args(std::move(args)) {} const Location &loc() { return location; } - const std::string &getName() const { return name; } - const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() { - return args; - } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; } }; /// This class represents a function definition itself. class FunctionAST { - std::unique_ptr<PrototypeAST> Proto; - std::unique_ptr<ExprASTList> Body; + std::unique_ptr<PrototypeAST> proto; + std::unique_ptr<ExprASTList> body; public: - FunctionAST(std::unique_ptr<PrototypeAST> Proto, - std::unique_ptr<ExprASTList> Body) - : Proto(std::move(Proto)), Body(std::move(Body)) {} - PrototypeAST *getProto() { return Proto.get(); } - ExprASTList *getBody() { return Body.get(); } + FunctionAST(std::unique_ptr<PrototypeAST> proto, + std::unique_ptr<ExprASTList> body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } }; /// This class represents a list of functions to be processed together diff --git a/mlir/examples/toy/Ch2/include/toy/Lexer.h b/mlir/examples/toy/Ch2/include/toy/Lexer.h index 21f92614912..144388c460c 100644 --- a/mlir/examples/toy/Ch2/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch2/include/toy/Lexer.h @@ -89,13 +89,13 @@ public: /// Return the current identifier (prereq: getCurToken() == tok_identifier) llvm::StringRef getId() { assert(curTok == tok_identifier); - return IdentifierStr; + return identifierStr; } /// Return the current number (prereq: getCurToken() == tok_number) double getValue() { assert(curTok == tok_number); - return NumVal; + return numVal; } /// Return the location for the beginning of the current token. @@ -135,56 +135,58 @@ private: /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(LastChar)) - LastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; lastLocation.col = curCol; - if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* - IdentifierStr = (char)LastChar; - while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') - IdentifierStr += (char)LastChar; + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; - if (IdentifierStr == "return") + if (identifierStr == "return") return tok_return; - if (IdentifierStr == "def") + if (identifierStr == "def") return tok_def; - if (IdentifierStr == "var") + if (identifierStr == "var") return tok_var; return tok_identifier; } - if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ - std::string NumStr; + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; do { - NumStr += LastChar; - LastChar = Token(getNextChar()); - } while (isdigit(LastChar) || LastChar == '.'); + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); - NumVal = strtod(NumStr.c_str(), nullptr); + numVal = strtod(numStr.c_str(), nullptr); return tok_number; } - if (LastChar == '#') { + if (lastChar == '#') { // Comment until end of line. - do - LastChar = Token(getNextChar()); - while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (LastChar != EOF) + if (lastChar != EOF) return getTok(); } // Check for end of file. Don't eat the EOF. - if (LastChar == EOF) + if (lastChar == EOF) return tok_eof; // Otherwise, just return the character as its ascii value. - Token ThisChar = Token(LastChar); - LastChar = Token(getNextChar()); - return ThisChar; + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; } /// The last token read from the input. @@ -194,15 +196,15 @@ private: Location lastLocation; /// If the current Token is an identifier, this string contains the value. - std::string IdentifierStr; + std::string identifierStr; /// If the current Token is a number, this contains the value. - double NumVal = 0; + double numVal = 0; /// The last value returned by getNextChar(). We need to keep it around as we /// always need to read ahead one character to decide when to end a token and /// we can't put it back in the stream after reading from it. - Token LastChar = Token(' '); + Token lastChar = Token(' '); /// Keep track of the current line number in the input stream int curLineNum = 0; diff --git a/mlir/examples/toy/Ch2/include/toy/Parser.h b/mlir/examples/toy/Ch2/include/toy/Parser.h index ec3d7654a85..9e219e56551 100644 --- a/mlir/examples/toy/Ch2/include/toy/Parser.h +++ b/mlir/examples/toy/Ch2/include/toy/Parser.h @@ -48,13 +48,13 @@ public: Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. - std::unique_ptr<ModuleAST> ParseModule() { + std::unique_ptr<ModuleAST> parseModule() { lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector<FunctionAST> functions; - while (auto F = ParseDefinition()) { - functions.push_back(std::move(*F)); + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); if (lexer.getCurToken() == tok_eof) break; } @@ -70,14 +70,14 @@ private: /// Parse a return statement. /// return :== return ; | return expr ; - std::unique_ptr<ReturnExprAST> ParseReturn() { + std::unique_ptr<ReturnExprAST> parseReturn() { auto loc = lexer.getLastLocation(); lexer.consume(tok_return); // return takes an optional argument llvm::Optional<std::unique_ptr<ExprAST>> expr; if (lexer.getCurToken() != ';') { - expr = ParseExpression(); + expr = parseExpression(); if (!expr) return nullptr; } @@ -86,18 +86,18 @@ private: /// Parse a literal number. /// numberexpr ::= number - std::unique_ptr<ExprAST> ParseNumberExpr() { + std::unique_ptr<ExprAST> parseNumberExpr() { auto loc = lexer.getLastLocation(); - auto Result = + auto result = std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); lexer.consume(tok_number); - return std::move(Result); + return std::move(result); } /// Parse a literal array expression. /// tensorLiteral ::= [ literalList ] | number /// literalList ::= tensorLiteral | tensorLiteral, literalList - std::unique_ptr<ExprAST> ParseTensorLiteralExpr() { + std::unique_ptr<ExprAST> parseTensorLiteralExpr() { auto loc = lexer.getLastLocation(); lexer.consume(Token('[')); @@ -108,13 +108,13 @@ private: do { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { - values.push_back(ParseTensorLiteralExpr()); + values.push_back(parseTensorLiteralExpr()); if (!values.back()) return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError<ExprAST>("<num> or [", "in literal expression"); - values.push_back(ParseNumberExpr()); + values.push_back(parseNumberExpr()); } // End of this list on ']' @@ -130,8 +130,10 @@ private: if (values.empty()) return parseError<ExprAST>("<something>", "to fill literal expression"); lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that /// dimensions are uniform. if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { @@ -143,7 +145,7 @@ private: "inside literal expression"); // Append the nested dimensions to the current level - auto &firstDims = firstLiteral->getDims(); + auto firstDims = firstLiteral->getDims(); dims.insert(dims.end(), firstDims.begin(), firstDims.end()); // Sanity check that shape is uniform across all elements of the list. @@ -162,22 +164,22 @@ private: } /// parenexpr ::= '(' expression ')' - std::unique_ptr<ExprAST> ParseParenExpr() { + std::unique_ptr<ExprAST> parseParenExpr() { lexer.getNextToken(); // eat (. - auto V = ParseExpression(); - if (!V) + auto v = parseExpression(); + if (!v) return nullptr; if (lexer.getCurToken() != ')') return parseError<ExprAST>(")", "to close expression with parentheses"); lexer.consume(Token(')')); - return V; + return v; } /// identifierexpr /// ::= identifier /// ::= identifier '(' expression ')' - std::unique_ptr<ExprAST> ParseIdentifierExpr() { + std::unique_ptr<ExprAST> parseIdentifierExpr() { std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); @@ -188,11 +190,11 @@ private: // This is a function call. lexer.consume(Token('(')); - std::vector<std::unique_ptr<ExprAST>> Args; + std::vector<std::unique_ptr<ExprAST>> args; if (lexer.getCurToken() != ')') { while (true) { - if (auto Arg = ParseExpression()) - Args.push_back(std::move(Arg)); + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); else return nullptr; @@ -208,14 +210,14 @@ private: // It can be a builtin call to print if (name == "print") { - if (Args.size() != 1) + if (args.size() != 1) return parseError<ExprAST>("<single arg>", "as argument to print()"); - return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0])); + return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0])); } // Call to a user-defined function - return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args)); + return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args)); } /// primary @@ -223,20 +225,20 @@ private: /// ::= numberexpr /// ::= parenexpr /// ::= tensorliteral - std::unique_ptr<ExprAST> ParsePrimary() { + std::unique_ptr<ExprAST> parsePrimary() { switch (lexer.getCurToken()) { default: llvm::errs() << "unknown token '" << lexer.getCurToken() << "' when expecting an expression\n"; return nullptr; case tok_identifier: - return ParseIdentifierExpr(); + return parseIdentifierExpr(); case tok_number: - return ParseNumberExpr(); + return parseNumberExpr(); case '(': - return ParseParenExpr(); + return parseParenExpr(); case '[': - return ParseTensorLiteralExpr(); + return parseTensorLiteralExpr(); case ';': return nullptr; case '}': @@ -248,54 +250,54 @@ private: /// argument indicates the precedence of the current binary operator. /// /// binoprhs ::= ('+' primary)* - std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, - std::unique_ptr<ExprAST> LHS) { + std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec, + std::unique_ptr<ExprAST> lhs) { // If this is a binop, find its precedence. while (true) { - int TokPrec = GetTokPrecedence(); + int tokPrec = getTokPrecedence(); // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (TokPrec < ExprPrec) - return LHS; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. - int BinOp = lexer.getCurToken(); - lexer.consume(Token(BinOp)); + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); auto loc = lexer.getLastLocation(); // Parse the primary expression after the binary operator. - auto RHS = ParsePrimary(); - if (!RHS) + auto rhs = parsePrimary(); + if (!rhs) return parseError<ExprAST>("expression", "to complete binary operator"); - // If BinOp binds less tightly with RHS than the operator after RHS, let - // the pending operator take RHS as its LHS. - int NextPrec = GetTokPrecedence(); - if (TokPrec < NextPrec) { - RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); - if (!RHS) + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) return nullptr; } - // Merge LHS/RHS. - LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + // Merge lhs/RHS. + lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); } } - /// expression::= primary binoprhs - std::unique_ptr<ExprAST> ParseExpression() { - auto LHS = ParsePrimary(); - if (!LHS) + /// expression::= primary binop rhs + std::unique_ptr<ExprAST> parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) return nullptr; - return ParseBinOpRHS(0, std::move(LHS)); + return parseBinOpRHS(0, std::move(lhs)); } /// type ::= < shape_list > /// shape_list ::= num | num , shape_list - std::unique_ptr<VarType> ParseType() { + std::unique_ptr<VarType> parseType() { if (lexer.getCurToken() != '<') return parseError<VarType>("<", "to begin type"); lexer.getNextToken(); // eat < @@ -319,7 +321,7 @@ private: /// and identifier and an optional type (shape specification) before the /// initializer. /// decl ::= var identifier [ type ] = expr - std::unique_ptr<VarDeclExprAST> ParseDeclaration() { + std::unique_ptr<VarDeclExprAST> parseDeclaration() { if (lexer.getCurToken() != tok_var) return parseError<VarDeclExprAST>("var", "to begin declaration"); auto loc = lexer.getLastLocation(); @@ -333,7 +335,7 @@ private: std::unique_ptr<VarType> type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { - type = ParseType(); + type = parseType(); if (!type) return nullptr; } @@ -341,7 +343,7 @@ private: if (!type) type = std::make_unique<VarType>(); lexer.consume(Token('=')); - auto expr = ParseExpression(); + auto expr = parseExpression(); return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), std::move(*type), std::move(expr)); } @@ -352,7 +354,7 @@ private: /// block ::= { expression_list } /// expression_list ::= block_expr ; expression_list /// block_expr ::= decl | "return" | expr - std::unique_ptr<ExprASTList> ParseBlock() { + std::unique_ptr<ExprASTList> parseBlock() { if (lexer.getCurToken() != '{') return parseError<ExprASTList>("{", "to begin block"); lexer.consume(Token('{')); @@ -366,19 +368,19 @@ private: while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration - auto varDecl = ParseDeclaration(); + auto varDecl = parseDeclaration(); if (!varDecl) return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement - auto ret = ParseReturn(); + auto ret = parseReturn(); if (!ret) return nullptr; exprList->push_back(std::move(ret)); } else { // General expression - auto expr = ParseExpression(); + auto expr = parseExpression(); if (!expr) return nullptr; exprList->push_back(std::move(expr)); @@ -401,13 +403,13 @@ private: /// prototype ::= def id '(' decl_list ')' /// decl_list ::= identifier | identifier, decl_list - std::unique_ptr<PrototypeAST> ParsePrototype() { + std::unique_ptr<PrototypeAST> parsePrototype() { auto loc = lexer.getLastLocation(); lexer.consume(tok_def); if (lexer.getCurToken() != tok_identifier) return parseError<PrototypeAST>("function name", "in prototype"); - std::string FnName = lexer.getId(); + std::string fnName = lexer.getId(); lexer.consume(tok_identifier); if (lexer.getCurToken() != '(') @@ -435,7 +437,7 @@ private: // success. lexer.consume(Token(')')); - return std::make_unique<PrototypeAST>(std::move(loc), FnName, + return std::make_unique<PrototypeAST>(std::move(loc), fnName, std::move(args)); } @@ -443,18 +445,18 @@ private: /// `def` keyword, followed by a block containing a list of expressions. /// /// definition ::= prototype block - std::unique_ptr<FunctionAST> ParseDefinition() { - auto Proto = ParsePrototype(); - if (!Proto) + std::unique_ptr<FunctionAST> parseDefinition() { + auto proto = parsePrototype(); + if (!proto) return nullptr; - if (auto block = ParseBlock()) - return std::make_unique<FunctionAST>(std::move(Proto), std::move(block)); + if (auto block = parseBlock()) + return std::make_unique<FunctionAST>(std::move(proto), std::move(block)); return nullptr; } /// Get the precedence of the pending binary operator token. - int GetTokPrecedence() { + int getTokPrecedence() { if (!isascii(lexer.getCurToken())) return -1; diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 8b434b139c7..da474e809b3 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -143,7 +143,7 @@ private: // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. auto &entryBlock = *function.addEntryBlock(); - auto &protoArgs = funcAST.getProto()->getArgs(); + auto protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : diff --git a/mlir/examples/toy/Ch2/parser/AST.cpp b/mlir/examples/toy/Ch2/parser/AST.cpp index fde8b101e83..0c7735ec9a4 100644 --- a/mlir/examples/toy/Ch2/parser/AST.cpp +++ b/mlir/examples/toy/Ch2/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -40,22 +41,22 @@ struct Indent { /// the way. The only data member is the current indentation level. class ASTDumper { public: - void dump(ModuleAST *Node); + void dump(ModuleAST *node); private: - void dump(VarType &type); + void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); void dump(ExprASTList *exprList); void dump(NumberExprAST *num); - void dump(LiteralExprAST *Node); - void dump(VariableExprAST *Node); - void dump(ReturnExprAST *Node); - void dump(BinaryExprAST *Node); - void dump(CallExprAST *Node); - void dump(PrintExprAST *Node); - void dump(PrototypeAST *Node); - void dump(FunctionAST *Node); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); // Actually print spaces matching the current indentation level void indent() { @@ -68,8 +69,8 @@ private: } // namespace /// Return a formatted string for the location of any node -template <typename T> static std::string loc(T *Node) { - const auto &loc = Node->loc(); +template <typename T> static std::string loc(T *node) { + const auto &loc = node->loc(); return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + llvm::Twine(loc.col)) .str(); @@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *lit_or_num) { +void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal - if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) { + if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { llvm::errs() << num->getValue(); return; } - auto *literal = llvm::cast<LiteralExprAST>(lit_or_num); + auto *literal = llvm::cast<LiteralExprAST>(litOrNum); // Print the dimension for this literal first llvm::errs() << "<"; - { - const char *sep = ""; - for (auto dim : literal->getDims()) { - llvm::errs() << sep << dim; - sep = ", "; - } - } + mlir::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - const char *sep = ""; - for (auto &elt : literal->getValues()) { - llvm::errs() << sep; - printLitHelper(elt.get()); - sep = ", "; - } + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } /// Print a literal, see the recursive helper above for the implementation. -void ASTDumper::dump(LiteralExprAST *Node) { +void ASTDumper::dump(LiteralExprAST *node) { INDENT(); llvm::errs() << "Literal: "; - printLitHelper(Node); - llvm::errs() << " " << loc(Node) << "\n"; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; } /// Print a variable reference (just a name). -void ASTDumper::dump(VariableExprAST *Node) { +void ASTDumper::dump(VariableExprAST *node) { INDENT(); - llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; } /// Return statement print the return and its (optional) argument. -void ASTDumper::dump(ReturnExprAST *Node) { +void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (Node->getExpr().hasValue()) - return dump(*Node->getExpr()); + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) { } /// Print a binary operation, first the operator, then recurse into LHS and RHS. -void ASTDumper::dump(BinaryExprAST *Node) { +void ASTDumper::dump(BinaryExprAST *node) { INDENT(); - llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; - dump(Node->getLHS()); - dump(Node->getRHS()); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); } /// Print a call expression, first the callee name and the list of args by /// recursing into each individual argument. -void ASTDumper::dump(CallExprAST *Node) { +void ASTDumper::dump(CallExprAST *node) { INDENT(); - llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; - for (auto &arg : Node->getArgs()) + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) dump(arg.get()); indent(); llvm::errs() << "]\n"; } /// Print a builtin print call, first the builtin name and then the argument. -void ASTDumper::dump(PrintExprAST *Node) { +void ASTDumper::dump(PrintExprAST *node) { INDENT(); - llvm::errs() << "Print [ " << loc(Node) << "\n"; - dump(Node->getArg()); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); indent(); llvm::errs() << "]\n"; } /// Print type: only the shape is printed in between '<' and '>' -void ASTDumper::dump(VarType &type) { +void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - const char *sep = ""; - for (auto shape : type.shape) { - llvm::errs() << sep << shape; - sep = ", "; - } + mlir::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } /// Print a function prototype, first the function name, and then the list of /// parameters names. -void ASTDumper::dump(PrototypeAST *Node) { +void ASTDumper::dump(PrototypeAST *node) { INDENT(); - llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - const char *sep = ""; - for (auto &arg : Node->getArgs()) { - llvm::errs() << sep << arg->getName(); - sep = ", "; - } + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } /// Print a function, first the prototype and then the body. -void ASTDumper::dump(FunctionAST *Node) { +void ASTDumper::dump(FunctionAST *node) { INDENT(); llvm::errs() << "Function \n"; - dump(Node->getProto()); - dump(Node->getBody()); + dump(node->getProto()); + dump(node->getBody()); } /// Print a module, actually loop over the functions and print them in sequence. -void ASTDumper::dump(ModuleAST *Node) { +void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &F : *Node) - dump(&F); + for (auto &f : *node) + dump(&f); } namespace toy { diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp index 547ac9e65b9..e6a69b92832 100644 --- a/mlir/examples/toy/Ch2/toyc.cpp +++ b/mlir/examples/toy/Ch2/toyc.cpp @@ -63,16 +63,16 @@ static cl::opt<enum Action> emitAction( /// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { - llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr = + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code EC = FileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; } - auto buffer = FileOrErr.get()->getBuffer(); + auto buffer = fileOrErr.get()->getBuffer(); LexerBuffer lexer(buffer.begin(), buffer.end(), filename); Parser parser(lexer); - return parser.ParseModule(); + return parser.parseModule(); } int dumpMLIR() { diff --git a/mlir/examples/toy/Ch3/include/toy/AST.h b/mlir/examples/toy/Ch3/include/toy/AST.h index 2ad3392c11a..901164b0f39 100644 --- a/mlir/examples/toy/Ch3/include/toy/AST.h +++ b/mlir/examples/toy/Ch3/include/toy/AST.h @@ -54,7 +54,6 @@ public: ExprAST(ExprASTKind kind, Location location) : kind(kind), location(location) {} - virtual ~ExprAST() = default; ExprASTKind getKind() const { return kind; } @@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST { double Val; public: - NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} double getValue() { return Val; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } }; /// Expression class for a literal value. @@ -93,10 +92,11 @@ public: : ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) {} - std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; } - std::vector<int64_t> &getDims() { return dims; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } + llvm::ArrayRef<int64_t> getDims() { return dims; } + /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } }; /// Expression class for referencing a variable, like "a". @@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST { std::string name; public: - VariableExprAST(Location loc, const std::string &name) + VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, loc), name(name) {} llvm::StringRef getName() { return name; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } }; /// Expression class for defining a variable. @@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST { std::unique_ptr<ExprAST> initVal; public: - VarDeclExprAST(Location loc, const std::string &name, VarType type, + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST> initVal) : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } - VarType &getType() { return type; } + const VarType &getType() { return type; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } }; /// Expression class for a return operator. @@ -144,61 +144,61 @@ public: llvm::Optional<ExprAST *> getExpr() { if (expr.hasValue()) return expr->get(); - return llvm::NoneType(); + return llvm::None; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } }; /// Expression class for a binary operator. class BinaryExprAST : public ExprAST { - char Op; - std::unique_ptr<ExprAST> LHS, RHS; + char op; + std::unique_ptr<ExprAST> lhs, rhs; public: - char getOp() { return Op; } - ExprAST *getLHS() { return LHS.get(); } - ExprAST *getRHS() { return RHS.get(); } + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } - BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS, - std::unique_ptr<ExprAST> RHS) - : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), - RHS(std::move(RHS)) {} + BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs, + std::unique_ptr<ExprAST> rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } }; /// Expression class for function calls. class CallExprAST : public ExprAST { - std::string Callee; - std::vector<std::unique_ptr<ExprAST>> Args; + std::string callee; + std::vector<std::unique_ptr<ExprAST>> args; public: - CallExprAST(Location loc, const std::string &Callee, - std::vector<std::unique_ptr<ExprAST>> Args) - : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + CallExprAST(Location loc, const std::string &callee, + std::vector<std::unique_ptr<ExprAST>> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} - llvm::StringRef getCallee() { return Callee; } - llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; } + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } }; /// Expression class for builtin print calls. class PrintExprAST : public ExprAST { - std::unique_ptr<ExprAST> Arg; + std::unique_ptr<ExprAST> arg; public: - PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg) - : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} - ExprAST *getArg() { return Arg.get(); } + ExprAST *getArg() { return arg.get(); } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } }; /// This class represents the "prototype" for a function, which captures its @@ -215,23 +215,21 @@ public: : location(location), name(name), args(std::move(args)) {} const Location &loc() { return location; } - const std::string &getName() const { return name; } - const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() { - return args; - } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; } }; /// This class represents a function definition itself. class FunctionAST { - std::unique_ptr<PrototypeAST> Proto; - std::unique_ptr<ExprASTList> Body; + std::unique_ptr<PrototypeAST> proto; + std::unique_ptr<ExprASTList> body; public: - FunctionAST(std::unique_ptr<PrototypeAST> Proto, - std::unique_ptr<ExprASTList> Body) - : Proto(std::move(Proto)), Body(std::move(Body)) {} - PrototypeAST *getProto() { return Proto.get(); } - ExprASTList *getBody() { return Body.get(); } + FunctionAST(std::unique_ptr<PrototypeAST> proto, + std::unique_ptr<ExprASTList> body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } }; /// This class represents a list of functions to be processed together diff --git a/mlir/examples/toy/Ch3/include/toy/Lexer.h b/mlir/examples/toy/Ch3/include/toy/Lexer.h index 21f92614912..144388c460c 100644 --- a/mlir/examples/toy/Ch3/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch3/include/toy/Lexer.h @@ -89,13 +89,13 @@ public: /// Return the current identifier (prereq: getCurToken() == tok_identifier) llvm::StringRef getId() { assert(curTok == tok_identifier); - return IdentifierStr; + return identifierStr; } /// Return the current number (prereq: getCurToken() == tok_number) double getValue() { assert(curTok == tok_number); - return NumVal; + return numVal; } /// Return the location for the beginning of the current token. @@ -135,56 +135,58 @@ private: /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(LastChar)) - LastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; lastLocation.col = curCol; - if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* - IdentifierStr = (char)LastChar; - while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') - IdentifierStr += (char)LastChar; + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; - if (IdentifierStr == "return") + if (identifierStr == "return") return tok_return; - if (IdentifierStr == "def") + if (identifierStr == "def") return tok_def; - if (IdentifierStr == "var") + if (identifierStr == "var") return tok_var; return tok_identifier; } - if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ - std::string NumStr; + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; do { - NumStr += LastChar; - LastChar = Token(getNextChar()); - } while (isdigit(LastChar) || LastChar == '.'); + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); - NumVal = strtod(NumStr.c_str(), nullptr); + numVal = strtod(numStr.c_str(), nullptr); return tok_number; } - if (LastChar == '#') { + if (lastChar == '#') { // Comment until end of line. - do - LastChar = Token(getNextChar()); - while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (LastChar != EOF) + if (lastChar != EOF) return getTok(); } // Check for end of file. Don't eat the EOF. - if (LastChar == EOF) + if (lastChar == EOF) return tok_eof; // Otherwise, just return the character as its ascii value. - Token ThisChar = Token(LastChar); - LastChar = Token(getNextChar()); - return ThisChar; + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; } /// The last token read from the input. @@ -194,15 +196,15 @@ private: Location lastLocation; /// If the current Token is an identifier, this string contains the value. - std::string IdentifierStr; + std::string identifierStr; /// If the current Token is a number, this contains the value. - double NumVal = 0; + double numVal = 0; /// The last value returned by getNextChar(). We need to keep it around as we /// always need to read ahead one character to decide when to end a token and /// we can't put it back in the stream after reading from it. - Token LastChar = Token(' '); + Token lastChar = Token(' '); /// Keep track of the current line number in the input stream int curLineNum = 0; diff --git a/mlir/examples/toy/Ch3/include/toy/Parser.h b/mlir/examples/toy/Ch3/include/toy/Parser.h index ec3d7654a85..9e219e56551 100644 --- a/mlir/examples/toy/Ch3/include/toy/Parser.h +++ b/mlir/examples/toy/Ch3/include/toy/Parser.h @@ -48,13 +48,13 @@ public: Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. - std::unique_ptr<ModuleAST> ParseModule() { + std::unique_ptr<ModuleAST> parseModule() { lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector<FunctionAST> functions; - while (auto F = ParseDefinition()) { - functions.push_back(std::move(*F)); + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); if (lexer.getCurToken() == tok_eof) break; } @@ -70,14 +70,14 @@ private: /// Parse a return statement. /// return :== return ; | return expr ; - std::unique_ptr<ReturnExprAST> ParseReturn() { + std::unique_ptr<ReturnExprAST> parseReturn() { auto loc = lexer.getLastLocation(); lexer.consume(tok_return); // return takes an optional argument llvm::Optional<std::unique_ptr<ExprAST>> expr; if (lexer.getCurToken() != ';') { - expr = ParseExpression(); + expr = parseExpression(); if (!expr) return nullptr; } @@ -86,18 +86,18 @@ private: /// Parse a literal number. /// numberexpr ::= number - std::unique_ptr<ExprAST> ParseNumberExpr() { + std::unique_ptr<ExprAST> parseNumberExpr() { auto loc = lexer.getLastLocation(); - auto Result = + auto result = std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); lexer.consume(tok_number); - return std::move(Result); + return std::move(result); } /// Parse a literal array expression. /// tensorLiteral ::= [ literalList ] | number /// literalList ::= tensorLiteral | tensorLiteral, literalList - std::unique_ptr<ExprAST> ParseTensorLiteralExpr() { + std::unique_ptr<ExprAST> parseTensorLiteralExpr() { auto loc = lexer.getLastLocation(); lexer.consume(Token('[')); @@ -108,13 +108,13 @@ private: do { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { - values.push_back(ParseTensorLiteralExpr()); + values.push_back(parseTensorLiteralExpr()); if (!values.back()) return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError<ExprAST>("<num> or [", "in literal expression"); - values.push_back(ParseNumberExpr()); + values.push_back(parseNumberExpr()); } // End of this list on ']' @@ -130,8 +130,10 @@ private: if (values.empty()) return parseError<ExprAST>("<something>", "to fill literal expression"); lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that /// dimensions are uniform. if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { @@ -143,7 +145,7 @@ private: "inside literal expression"); // Append the nested dimensions to the current level - auto &firstDims = firstLiteral->getDims(); + auto firstDims = firstLiteral->getDims(); dims.insert(dims.end(), firstDims.begin(), firstDims.end()); // Sanity check that shape is uniform across all elements of the list. @@ -162,22 +164,22 @@ private: } /// parenexpr ::= '(' expression ')' - std::unique_ptr<ExprAST> ParseParenExpr() { + std::unique_ptr<ExprAST> parseParenExpr() { lexer.getNextToken(); // eat (. - auto V = ParseExpression(); - if (!V) + auto v = parseExpression(); + if (!v) return nullptr; if (lexer.getCurToken() != ')') return parseError<ExprAST>(")", "to close expression with parentheses"); lexer.consume(Token(')')); - return V; + return v; } /// identifierexpr /// ::= identifier /// ::= identifier '(' expression ')' - std::unique_ptr<ExprAST> ParseIdentifierExpr() { + std::unique_ptr<ExprAST> parseIdentifierExpr() { std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); @@ -188,11 +190,11 @@ private: // This is a function call. lexer.consume(Token('(')); - std::vector<std::unique_ptr<ExprAST>> Args; + std::vector<std::unique_ptr<ExprAST>> args; if (lexer.getCurToken() != ')') { while (true) { - if (auto Arg = ParseExpression()) - Args.push_back(std::move(Arg)); + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); else return nullptr; @@ -208,14 +210,14 @@ private: // It can be a builtin call to print if (name == "print") { - if (Args.size() != 1) + if (args.size() != 1) return parseError<ExprAST>("<single arg>", "as argument to print()"); - return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0])); + return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0])); } // Call to a user-defined function - return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args)); + return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args)); } /// primary @@ -223,20 +225,20 @@ private: /// ::= numberexpr /// ::= parenexpr /// ::= tensorliteral - std::unique_ptr<ExprAST> ParsePrimary() { + std::unique_ptr<ExprAST> parsePrimary() { switch (lexer.getCurToken()) { default: llvm::errs() << "unknown token '" << lexer.getCurToken() << "' when expecting an expression\n"; return nullptr; case tok_identifier: - return ParseIdentifierExpr(); + return parseIdentifierExpr(); case tok_number: - return ParseNumberExpr(); + return parseNumberExpr(); case '(': - return ParseParenExpr(); + return parseParenExpr(); case '[': - return ParseTensorLiteralExpr(); + return parseTensorLiteralExpr(); case ';': return nullptr; case '}': @@ -248,54 +250,54 @@ private: /// argument indicates the precedence of the current binary operator. /// /// binoprhs ::= ('+' primary)* - std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, - std::unique_ptr<ExprAST> LHS) { + std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec, + std::unique_ptr<ExprAST> lhs) { // If this is a binop, find its precedence. while (true) { - int TokPrec = GetTokPrecedence(); + int tokPrec = getTokPrecedence(); // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (TokPrec < ExprPrec) - return LHS; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. - int BinOp = lexer.getCurToken(); - lexer.consume(Token(BinOp)); + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); auto loc = lexer.getLastLocation(); // Parse the primary expression after the binary operator. - auto RHS = ParsePrimary(); - if (!RHS) + auto rhs = parsePrimary(); + if (!rhs) return parseError<ExprAST>("expression", "to complete binary operator"); - // If BinOp binds less tightly with RHS than the operator after RHS, let - // the pending operator take RHS as its LHS. - int NextPrec = GetTokPrecedence(); - if (TokPrec < NextPrec) { - RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); - if (!RHS) + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) return nullptr; } - // Merge LHS/RHS. - LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + // Merge lhs/RHS. + lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); } } - /// expression::= primary binoprhs - std::unique_ptr<ExprAST> ParseExpression() { - auto LHS = ParsePrimary(); - if (!LHS) + /// expression::= primary binop rhs + std::unique_ptr<ExprAST> parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) return nullptr; - return ParseBinOpRHS(0, std::move(LHS)); + return parseBinOpRHS(0, std::move(lhs)); } /// type ::= < shape_list > /// shape_list ::= num | num , shape_list - std::unique_ptr<VarType> ParseType() { + std::unique_ptr<VarType> parseType() { if (lexer.getCurToken() != '<') return parseError<VarType>("<", "to begin type"); lexer.getNextToken(); // eat < @@ -319,7 +321,7 @@ private: /// and identifier and an optional type (shape specification) before the /// initializer. /// decl ::= var identifier [ type ] = expr - std::unique_ptr<VarDeclExprAST> ParseDeclaration() { + std::unique_ptr<VarDeclExprAST> parseDeclaration() { if (lexer.getCurToken() != tok_var) return parseError<VarDeclExprAST>("var", "to begin declaration"); auto loc = lexer.getLastLocation(); @@ -333,7 +335,7 @@ private: std::unique_ptr<VarType> type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { - type = ParseType(); + type = parseType(); if (!type) return nullptr; } @@ -341,7 +343,7 @@ private: if (!type) type = std::make_unique<VarType>(); lexer.consume(Token('=')); - auto expr = ParseExpression(); + auto expr = parseExpression(); return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), std::move(*type), std::move(expr)); } @@ -352,7 +354,7 @@ private: /// block ::= { expression_list } /// expression_list ::= block_expr ; expression_list /// block_expr ::= decl | "return" | expr - std::unique_ptr<ExprASTList> ParseBlock() { + std::unique_ptr<ExprASTList> parseBlock() { if (lexer.getCurToken() != '{') return parseError<ExprASTList>("{", "to begin block"); lexer.consume(Token('{')); @@ -366,19 +368,19 @@ private: while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration - auto varDecl = ParseDeclaration(); + auto varDecl = parseDeclaration(); if (!varDecl) return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement - auto ret = ParseReturn(); + auto ret = parseReturn(); if (!ret) return nullptr; exprList->push_back(std::move(ret)); } else { // General expression - auto expr = ParseExpression(); + auto expr = parseExpression(); if (!expr) return nullptr; exprList->push_back(std::move(expr)); @@ -401,13 +403,13 @@ private: /// prototype ::= def id '(' decl_list ')' /// decl_list ::= identifier | identifier, decl_list - std::unique_ptr<PrototypeAST> ParsePrototype() { + std::unique_ptr<PrototypeAST> parsePrototype() { auto loc = lexer.getLastLocation(); lexer.consume(tok_def); if (lexer.getCurToken() != tok_identifier) return parseError<PrototypeAST>("function name", "in prototype"); - std::string FnName = lexer.getId(); + std::string fnName = lexer.getId(); lexer.consume(tok_identifier); if (lexer.getCurToken() != '(') @@ -435,7 +437,7 @@ private: // success. lexer.consume(Token(')')); - return std::make_unique<PrototypeAST>(std::move(loc), FnName, + return std::make_unique<PrototypeAST>(std::move(loc), fnName, std::move(args)); } @@ -443,18 +445,18 @@ private: /// `def` keyword, followed by a block containing a list of expressions. /// /// definition ::= prototype block - std::unique_ptr<FunctionAST> ParseDefinition() { - auto Proto = ParsePrototype(); - if (!Proto) + std::unique_ptr<FunctionAST> parseDefinition() { + auto proto = parsePrototype(); + if (!proto) return nullptr; - if (auto block = ParseBlock()) - return std::make_unique<FunctionAST>(std::move(Proto), std::move(block)); + if (auto block = parseBlock()) + return std::make_unique<FunctionAST>(std::move(proto), std::move(block)); return nullptr; } /// Get the precedence of the pending binary operator token. - int GetTokPrecedence() { + int getTokPrecedence() { if (!isascii(lexer.getCurToken())) return -1; diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index 8b434b139c7..da474e809b3 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -143,7 +143,7 @@ private: // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. auto &entryBlock = *function.addEntryBlock(); - auto &protoArgs = funcAST.getProto()->getArgs(); + auto protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp index 869f2ef2013..0c7735ec9a4 100644 --- a/mlir/examples/toy/Ch3/parser/AST.cpp +++ b/mlir/examples/toy/Ch3/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -40,22 +41,22 @@ struct Indent { /// the way. The only data member is the current indentation level. class ASTDumper { public: - void dump(ModuleAST *Node); + void dump(ModuleAST *node); private: - void dump(VarType &type); + void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); void dump(ExprASTList *exprList); void dump(NumberExprAST *num); - void dump(LiteralExprAST *Node); - void dump(VariableExprAST *Node); - void dump(ReturnExprAST *Node); - void dump(BinaryExprAST *Node); - void dump(CallExprAST *Node); - void dump(PrintExprAST *Node); - void dump(PrototypeAST *Node); - void dump(FunctionAST *Node); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); // Actually print spaces matching the current indentation level void indent() { @@ -68,8 +69,8 @@ private: } // namespace /// Return a formatted string for the location of any node -template <typename T> static std::string loc(T *Node) { - const auto &loc = Node->loc(); +template <typename T> static std::string loc(T *node) { + const auto &loc = node->loc(); return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + llvm::Twine(loc.col)) .str(); @@ -125,60 +126,50 @@ void ASTDumper::dump(NumberExprAST *num) { llvm::errs() << num->getValue() << " " << loc(num) << "\n"; } -/// Helper to print recurisvely a literal. This handles nested array like: +/// Helper to print recursively a literal. This handles nested array like: /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *lit_or_num) { +void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal - if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) { + if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { llvm::errs() << num->getValue(); return; } - auto *literal = llvm::cast<LiteralExprAST>(lit_or_num); + auto *literal = llvm::cast<LiteralExprAST>(litOrNum); // Print the dimension for this literal first llvm::errs() << "<"; - { - const char *sep = ""; - for (auto dim : literal->getDims()) { - llvm::errs() << sep << dim; - sep = ", "; - } - } + mlir::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - const char *sep = ""; - for (auto &elt : literal->getValues()) { - llvm::errs() << sep; - printLitHelper(elt.get()); - sep = ", "; - } + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } /// Print a literal, see the recursive helper above for the implementation. -void ASTDumper::dump(LiteralExprAST *Node) { +void ASTDumper::dump(LiteralExprAST *node) { INDENT(); llvm::errs() << "Literal: "; - printLitHelper(Node); - llvm::errs() << " " << loc(Node) << "\n"; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; } /// Print a variable reference (just a name). -void ASTDumper::dump(VariableExprAST *Node) { +void ASTDumper::dump(VariableExprAST *node) { INDENT(); - llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; } /// Return statement print the return and its (optional) argument. -void ASTDumper::dump(ReturnExprAST *Node) { +void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (Node->getExpr().hasValue()) - return dump(*Node->getExpr()); + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) { } /// Print a binary operation, first the operator, then recurse into LHS and RHS. -void ASTDumper::dump(BinaryExprAST *Node) { +void ASTDumper::dump(BinaryExprAST *node) { INDENT(); - llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; - dump(Node->getLHS()); - dump(Node->getRHS()); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); } /// Print a call expression, first the callee name and the list of args by /// recursing into each individual argument. -void ASTDumper::dump(CallExprAST *Node) { +void ASTDumper::dump(CallExprAST *node) { INDENT(); - llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; - for (auto &arg : Node->getArgs()) + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) dump(arg.get()); indent(); llvm::errs() << "]\n"; } /// Print a builtin print call, first the builtin name and then the argument. -void ASTDumper::dump(PrintExprAST *Node) { +void ASTDumper::dump(PrintExprAST *node) { INDENT(); - llvm::errs() << "Print [ " << loc(Node) << "\n"; - dump(Node->getArg()); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); indent(); llvm::errs() << "]\n"; } /// Print type: only the shape is printed in between '<' and '>' -void ASTDumper::dump(VarType &type) { +void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - const char *sep = ""; - for (auto shape : type.shape) { - llvm::errs() << sep << shape; - sep = ", "; - } + mlir::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } /// Print a function prototype, first the function name, and then the list of /// parameters names. -void ASTDumper::dump(PrototypeAST *Node) { +void ASTDumper::dump(PrototypeAST *node) { INDENT(); - llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - const char *sep = ""; - for (auto &arg : Node->getArgs()) { - llvm::errs() << sep << arg->getName(); - sep = ", "; - } + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } /// Print a function, first the prototype and then the body. -void ASTDumper::dump(FunctionAST *Node) { +void ASTDumper::dump(FunctionAST *node) { INDENT(); llvm::errs() << "Function \n"; - dump(Node->getProto()); - dump(Node->getBody()); + dump(node->getProto()); + dump(node->getBody()); } /// Print a module, actually loop over the functions and print them in sequence. -void ASTDumper::dump(ModuleAST *Node) { +void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &F : *Node) - dump(&F); + for (auto &f : *node) + dump(&f); } namespace toy { diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp index 7e62e13432c..d3bb5d1a160 100644 --- a/mlir/examples/toy/Ch3/toyc.cpp +++ b/mlir/examples/toy/Ch3/toyc.cpp @@ -63,20 +63,20 @@ static cl::opt<enum Action> emitAction( cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); -static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations")); +static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { - llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr = + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code EC = FileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; } - auto buffer = FileOrErr.get()->getBuffer(); + auto buffer = fileOrErr.get()->getBuffer(); LexerBuffer lexer(buffer.begin(), buffer.end(), filename); Parser parser(lexer); - return parser.ParseModule(); + return parser.parseModule(); } int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, @@ -118,7 +118,7 @@ int dumpMLIR() { if (int error = loadMLIR(sourceMgr, context, module)) return error; - if (EnableOpt) { + if (enableOpt) { mlir::PassManager pm(&context); // Apply any generic pass manager command line options and run the pipeline. applyPassManagerCLOptions(pm); diff --git a/mlir/examples/toy/Ch4/include/toy/AST.h b/mlir/examples/toy/Ch4/include/toy/AST.h index 2ad3392c11a..901164b0f39 100644 --- a/mlir/examples/toy/Ch4/include/toy/AST.h +++ b/mlir/examples/toy/Ch4/include/toy/AST.h @@ -54,7 +54,6 @@ public: ExprAST(ExprASTKind kind, Location location) : kind(kind), location(location) {} - virtual ~ExprAST() = default; ExprASTKind getKind() const { return kind; } @@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST { double Val; public: - NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} double getValue() { return Val; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } }; /// Expression class for a literal value. @@ -93,10 +92,11 @@ public: : ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) {} - std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; } - std::vector<int64_t> &getDims() { return dims; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } + llvm::ArrayRef<int64_t> getDims() { return dims; } + /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } }; /// Expression class for referencing a variable, like "a". @@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST { std::string name; public: - VariableExprAST(Location loc, const std::string &name) + VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, loc), name(name) {} llvm::StringRef getName() { return name; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } }; /// Expression class for defining a variable. @@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST { std::unique_ptr<ExprAST> initVal; public: - VarDeclExprAST(Location loc, const std::string &name, VarType type, + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST> initVal) : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } - VarType &getType() { return type; } + const VarType &getType() { return type; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } }; /// Expression class for a return operator. @@ -144,61 +144,61 @@ public: llvm::Optional<ExprAST *> getExpr() { if (expr.hasValue()) return expr->get(); - return llvm::NoneType(); + return llvm::None; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } }; /// Expression class for a binary operator. class BinaryExprAST : public ExprAST { - char Op; - std::unique_ptr<ExprAST> LHS, RHS; + char op; + std::unique_ptr<ExprAST> lhs, rhs; public: - char getOp() { return Op; } - ExprAST *getLHS() { return LHS.get(); } - ExprAST *getRHS() { return RHS.get(); } + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } - BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS, - std::unique_ptr<ExprAST> RHS) - : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), - RHS(std::move(RHS)) {} + BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs, + std::unique_ptr<ExprAST> rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } }; /// Expression class for function calls. class CallExprAST : public ExprAST { - std::string Callee; - std::vector<std::unique_ptr<ExprAST>> Args; + std::string callee; + std::vector<std::unique_ptr<ExprAST>> args; public: - CallExprAST(Location loc, const std::string &Callee, - std::vector<std::unique_ptr<ExprAST>> Args) - : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + CallExprAST(Location loc, const std::string &callee, + std::vector<std::unique_ptr<ExprAST>> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} - llvm::StringRef getCallee() { return Callee; } - llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; } + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } }; /// Expression class for builtin print calls. class PrintExprAST : public ExprAST { - std::unique_ptr<ExprAST> Arg; + std::unique_ptr<ExprAST> arg; public: - PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg) - : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} - ExprAST *getArg() { return Arg.get(); } + ExprAST *getArg() { return arg.get(); } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } }; /// This class represents the "prototype" for a function, which captures its @@ -215,23 +215,21 @@ public: : location(location), name(name), args(std::move(args)) {} const Location &loc() { return location; } - const std::string &getName() const { return name; } - const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() { - return args; - } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; } }; /// This class represents a function definition itself. class FunctionAST { - std::unique_ptr<PrototypeAST> Proto; - std::unique_ptr<ExprASTList> Body; + std::unique_ptr<PrototypeAST> proto; + std::unique_ptr<ExprASTList> body; public: - FunctionAST(std::unique_ptr<PrototypeAST> Proto, - std::unique_ptr<ExprASTList> Body) - : Proto(std::move(Proto)), Body(std::move(Body)) {} - PrototypeAST *getProto() { return Proto.get(); } - ExprASTList *getBody() { return Body.get(); } + FunctionAST(std::unique_ptr<PrototypeAST> proto, + std::unique_ptr<ExprASTList> body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } }; /// This class represents a list of functions to be processed together diff --git a/mlir/examples/toy/Ch4/include/toy/Lexer.h b/mlir/examples/toy/Ch4/include/toy/Lexer.h index 21f92614912..144388c460c 100644 --- a/mlir/examples/toy/Ch4/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch4/include/toy/Lexer.h @@ -89,13 +89,13 @@ public: /// Return the current identifier (prereq: getCurToken() == tok_identifier) llvm::StringRef getId() { assert(curTok == tok_identifier); - return IdentifierStr; + return identifierStr; } /// Return the current number (prereq: getCurToken() == tok_number) double getValue() { assert(curTok == tok_number); - return NumVal; + return numVal; } /// Return the location for the beginning of the current token. @@ -135,56 +135,58 @@ private: /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(LastChar)) - LastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; lastLocation.col = curCol; - if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* - IdentifierStr = (char)LastChar; - while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') - IdentifierStr += (char)LastChar; + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; - if (IdentifierStr == "return") + if (identifierStr == "return") return tok_return; - if (IdentifierStr == "def") + if (identifierStr == "def") return tok_def; - if (IdentifierStr == "var") + if (identifierStr == "var") return tok_var; return tok_identifier; } - if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ - std::string NumStr; + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; do { - NumStr += LastChar; - LastChar = Token(getNextChar()); - } while (isdigit(LastChar) || LastChar == '.'); + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); - NumVal = strtod(NumStr.c_str(), nullptr); + numVal = strtod(numStr.c_str(), nullptr); return tok_number; } - if (LastChar == '#') { + if (lastChar == '#') { // Comment until end of line. - do - LastChar = Token(getNextChar()); - while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (LastChar != EOF) + if (lastChar != EOF) return getTok(); } // Check for end of file. Don't eat the EOF. - if (LastChar == EOF) + if (lastChar == EOF) return tok_eof; // Otherwise, just return the character as its ascii value. - Token ThisChar = Token(LastChar); - LastChar = Token(getNextChar()); - return ThisChar; + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; } /// The last token read from the input. @@ -194,15 +196,15 @@ private: Location lastLocation; /// If the current Token is an identifier, this string contains the value. - std::string IdentifierStr; + std::string identifierStr; /// If the current Token is a number, this contains the value. - double NumVal = 0; + double numVal = 0; /// The last value returned by getNextChar(). We need to keep it around as we /// always need to read ahead one character to decide when to end a token and /// we can't put it back in the stream after reading from it. - Token LastChar = Token(' '); + Token lastChar = Token(' '); /// Keep track of the current line number in the input stream int curLineNum = 0; diff --git a/mlir/examples/toy/Ch4/include/toy/Parser.h b/mlir/examples/toy/Ch4/include/toy/Parser.h index ec3d7654a85..9e219e56551 100644 --- a/mlir/examples/toy/Ch4/include/toy/Parser.h +++ b/mlir/examples/toy/Ch4/include/toy/Parser.h @@ -48,13 +48,13 @@ public: Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. - std::unique_ptr<ModuleAST> ParseModule() { + std::unique_ptr<ModuleAST> parseModule() { lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector<FunctionAST> functions; - while (auto F = ParseDefinition()) { - functions.push_back(std::move(*F)); + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); if (lexer.getCurToken() == tok_eof) break; } @@ -70,14 +70,14 @@ private: /// Parse a return statement. /// return :== return ; | return expr ; - std::unique_ptr<ReturnExprAST> ParseReturn() { + std::unique_ptr<ReturnExprAST> parseReturn() { auto loc = lexer.getLastLocation(); lexer.consume(tok_return); // return takes an optional argument llvm::Optional<std::unique_ptr<ExprAST>> expr; if (lexer.getCurToken() != ';') { - expr = ParseExpression(); + expr = parseExpression(); if (!expr) return nullptr; } @@ -86,18 +86,18 @@ private: /// Parse a literal number. /// numberexpr ::= number - std::unique_ptr<ExprAST> ParseNumberExpr() { + std::unique_ptr<ExprAST> parseNumberExpr() { auto loc = lexer.getLastLocation(); - auto Result = + auto result = std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); lexer.consume(tok_number); - return std::move(Result); + return std::move(result); } /// Parse a literal array expression. /// tensorLiteral ::= [ literalList ] | number /// literalList ::= tensorLiteral | tensorLiteral, literalList - std::unique_ptr<ExprAST> ParseTensorLiteralExpr() { + std::unique_ptr<ExprAST> parseTensorLiteralExpr() { auto loc = lexer.getLastLocation(); lexer.consume(Token('[')); @@ -108,13 +108,13 @@ private: do { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { - values.push_back(ParseTensorLiteralExpr()); + values.push_back(parseTensorLiteralExpr()); if (!values.back()) return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError<ExprAST>("<num> or [", "in literal expression"); - values.push_back(ParseNumberExpr()); + values.push_back(parseNumberExpr()); } // End of this list on ']' @@ -130,8 +130,10 @@ private: if (values.empty()) return parseError<ExprAST>("<something>", "to fill literal expression"); lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that /// dimensions are uniform. if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { @@ -143,7 +145,7 @@ private: "inside literal expression"); // Append the nested dimensions to the current level - auto &firstDims = firstLiteral->getDims(); + auto firstDims = firstLiteral->getDims(); dims.insert(dims.end(), firstDims.begin(), firstDims.end()); // Sanity check that shape is uniform across all elements of the list. @@ -162,22 +164,22 @@ private: } /// parenexpr ::= '(' expression ')' - std::unique_ptr<ExprAST> ParseParenExpr() { + std::unique_ptr<ExprAST> parseParenExpr() { lexer.getNextToken(); // eat (. - auto V = ParseExpression(); - if (!V) + auto v = parseExpression(); + if (!v) return nullptr; if (lexer.getCurToken() != ')') return parseError<ExprAST>(")", "to close expression with parentheses"); lexer.consume(Token(')')); - return V; + return v; } /// identifierexpr /// ::= identifier /// ::= identifier '(' expression ')' - std::unique_ptr<ExprAST> ParseIdentifierExpr() { + std::unique_ptr<ExprAST> parseIdentifierExpr() { std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); @@ -188,11 +190,11 @@ private: // This is a function call. lexer.consume(Token('(')); - std::vector<std::unique_ptr<ExprAST>> Args; + std::vector<std::unique_ptr<ExprAST>> args; if (lexer.getCurToken() != ')') { while (true) { - if (auto Arg = ParseExpression()) - Args.push_back(std::move(Arg)); + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); else return nullptr; @@ -208,14 +210,14 @@ private: // It can be a builtin call to print if (name == "print") { - if (Args.size() != 1) + if (args.size() != 1) return parseError<ExprAST>("<single arg>", "as argument to print()"); - return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0])); + return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0])); } // Call to a user-defined function - return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args)); + return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args)); } /// primary @@ -223,20 +225,20 @@ private: /// ::= numberexpr /// ::= parenexpr /// ::= tensorliteral - std::unique_ptr<ExprAST> ParsePrimary() { + std::unique_ptr<ExprAST> parsePrimary() { switch (lexer.getCurToken()) { default: llvm::errs() << "unknown token '" << lexer.getCurToken() << "' when expecting an expression\n"; return nullptr; case tok_identifier: - return ParseIdentifierExpr(); + return parseIdentifierExpr(); case tok_number: - return ParseNumberExpr(); + return parseNumberExpr(); case '(': - return ParseParenExpr(); + return parseParenExpr(); case '[': - return ParseTensorLiteralExpr(); + return parseTensorLiteralExpr(); case ';': return nullptr; case '}': @@ -248,54 +250,54 @@ private: /// argument indicates the precedence of the current binary operator. /// /// binoprhs ::= ('+' primary)* - std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, - std::unique_ptr<ExprAST> LHS) { + std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec, + std::unique_ptr<ExprAST> lhs) { // If this is a binop, find its precedence. while (true) { - int TokPrec = GetTokPrecedence(); + int tokPrec = getTokPrecedence(); // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (TokPrec < ExprPrec) - return LHS; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. - int BinOp = lexer.getCurToken(); - lexer.consume(Token(BinOp)); + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); auto loc = lexer.getLastLocation(); // Parse the primary expression after the binary operator. - auto RHS = ParsePrimary(); - if (!RHS) + auto rhs = parsePrimary(); + if (!rhs) return parseError<ExprAST>("expression", "to complete binary operator"); - // If BinOp binds less tightly with RHS than the operator after RHS, let - // the pending operator take RHS as its LHS. - int NextPrec = GetTokPrecedence(); - if (TokPrec < NextPrec) { - RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); - if (!RHS) + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) return nullptr; } - // Merge LHS/RHS. - LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + // Merge lhs/RHS. + lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); } } - /// expression::= primary binoprhs - std::unique_ptr<ExprAST> ParseExpression() { - auto LHS = ParsePrimary(); - if (!LHS) + /// expression::= primary binop rhs + std::unique_ptr<ExprAST> parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) return nullptr; - return ParseBinOpRHS(0, std::move(LHS)); + return parseBinOpRHS(0, std::move(lhs)); } /// type ::= < shape_list > /// shape_list ::= num | num , shape_list - std::unique_ptr<VarType> ParseType() { + std::unique_ptr<VarType> parseType() { if (lexer.getCurToken() != '<') return parseError<VarType>("<", "to begin type"); lexer.getNextToken(); // eat < @@ -319,7 +321,7 @@ private: /// and identifier and an optional type (shape specification) before the /// initializer. /// decl ::= var identifier [ type ] = expr - std::unique_ptr<VarDeclExprAST> ParseDeclaration() { + std::unique_ptr<VarDeclExprAST> parseDeclaration() { if (lexer.getCurToken() != tok_var) return parseError<VarDeclExprAST>("var", "to begin declaration"); auto loc = lexer.getLastLocation(); @@ -333,7 +335,7 @@ private: std::unique_ptr<VarType> type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { - type = ParseType(); + type = parseType(); if (!type) return nullptr; } @@ -341,7 +343,7 @@ private: if (!type) type = std::make_unique<VarType>(); lexer.consume(Token('=')); - auto expr = ParseExpression(); + auto expr = parseExpression(); return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), std::move(*type), std::move(expr)); } @@ -352,7 +354,7 @@ private: /// block ::= { expression_list } /// expression_list ::= block_expr ; expression_list /// block_expr ::= decl | "return" | expr - std::unique_ptr<ExprASTList> ParseBlock() { + std::unique_ptr<ExprASTList> parseBlock() { if (lexer.getCurToken() != '{') return parseError<ExprASTList>("{", "to begin block"); lexer.consume(Token('{')); @@ -366,19 +368,19 @@ private: while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration - auto varDecl = ParseDeclaration(); + auto varDecl = parseDeclaration(); if (!varDecl) return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement - auto ret = ParseReturn(); + auto ret = parseReturn(); if (!ret) return nullptr; exprList->push_back(std::move(ret)); } else { // General expression - auto expr = ParseExpression(); + auto expr = parseExpression(); if (!expr) return nullptr; exprList->push_back(std::move(expr)); @@ -401,13 +403,13 @@ private: /// prototype ::= def id '(' decl_list ')' /// decl_list ::= identifier | identifier, decl_list - std::unique_ptr<PrototypeAST> ParsePrototype() { + std::unique_ptr<PrototypeAST> parsePrototype() { auto loc = lexer.getLastLocation(); lexer.consume(tok_def); if (lexer.getCurToken() != tok_identifier) return parseError<PrototypeAST>("function name", "in prototype"); - std::string FnName = lexer.getId(); + std::string fnName = lexer.getId(); lexer.consume(tok_identifier); if (lexer.getCurToken() != '(') @@ -435,7 +437,7 @@ private: // success. lexer.consume(Token(')')); - return std::make_unique<PrototypeAST>(std::move(loc), FnName, + return std::make_unique<PrototypeAST>(std::move(loc), fnName, std::move(args)); } @@ -443,18 +445,18 @@ private: /// `def` keyword, followed by a block containing a list of expressions. /// /// definition ::= prototype block - std::unique_ptr<FunctionAST> ParseDefinition() { - auto Proto = ParsePrototype(); - if (!Proto) + std::unique_ptr<FunctionAST> parseDefinition() { + auto proto = parsePrototype(); + if (!proto) return nullptr; - if (auto block = ParseBlock()) - return std::make_unique<FunctionAST>(std::move(Proto), std::move(block)); + if (auto block = parseBlock()) + return std::make_unique<FunctionAST>(std::move(proto), std::move(block)); return nullptr; } /// Get the precedence of the pending binary operator token. - int GetTokPrecedence() { + int getTokPrecedence() { if (!isascii(lexer.getCurToken())) return -1; diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 5f9dd30a507..da474e809b3 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -91,7 +91,7 @@ private: mlir::ModuleOp theModule; /// The builder is a helper class to create IR inside a function. The builder - /// is stateful, in particular it keeeps an "insertion point": this is where + /// is stateful, in particular it keeps an "insertion point": this is where /// the next operations will be introduced. mlir::OpBuilder builder; @@ -143,7 +143,7 @@ private: // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. auto &entryBlock = *function.addEntryBlock(); - auto &protoArgs = funcAST.getProto()->getArgs(); + auto protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : diff --git a/mlir/examples/toy/Ch4/parser/AST.cpp b/mlir/examples/toy/Ch4/parser/AST.cpp index fde8b101e83..0c7735ec9a4 100644 --- a/mlir/examples/toy/Ch4/parser/AST.cpp +++ b/mlir/examples/toy/Ch4/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -40,22 +41,22 @@ struct Indent { /// the way. The only data member is the current indentation level. class ASTDumper { public: - void dump(ModuleAST *Node); + void dump(ModuleAST *node); private: - void dump(VarType &type); + void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); void dump(ExprASTList *exprList); void dump(NumberExprAST *num); - void dump(LiteralExprAST *Node); - void dump(VariableExprAST *Node); - void dump(ReturnExprAST *Node); - void dump(BinaryExprAST *Node); - void dump(CallExprAST *Node); - void dump(PrintExprAST *Node); - void dump(PrototypeAST *Node); - void dump(FunctionAST *Node); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); // Actually print spaces matching the current indentation level void indent() { @@ -68,8 +69,8 @@ private: } // namespace /// Return a formatted string for the location of any node -template <typename T> static std::string loc(T *Node) { - const auto &loc = Node->loc(); +template <typename T> static std::string loc(T *node) { + const auto &loc = node->loc(); return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + llvm::Twine(loc.col)) .str(); @@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *lit_or_num) { +void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal - if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) { + if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { llvm::errs() << num->getValue(); return; } - auto *literal = llvm::cast<LiteralExprAST>(lit_or_num); + auto *literal = llvm::cast<LiteralExprAST>(litOrNum); // Print the dimension for this literal first llvm::errs() << "<"; - { - const char *sep = ""; - for (auto dim : literal->getDims()) { - llvm::errs() << sep << dim; - sep = ", "; - } - } + mlir::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - const char *sep = ""; - for (auto &elt : literal->getValues()) { - llvm::errs() << sep; - printLitHelper(elt.get()); - sep = ", "; - } + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } /// Print a literal, see the recursive helper above for the implementation. -void ASTDumper::dump(LiteralExprAST *Node) { +void ASTDumper::dump(LiteralExprAST *node) { INDENT(); llvm::errs() << "Literal: "; - printLitHelper(Node); - llvm::errs() << " " << loc(Node) << "\n"; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; } /// Print a variable reference (just a name). -void ASTDumper::dump(VariableExprAST *Node) { +void ASTDumper::dump(VariableExprAST *node) { INDENT(); - llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; } /// Return statement print the return and its (optional) argument. -void ASTDumper::dump(ReturnExprAST *Node) { +void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (Node->getExpr().hasValue()) - return dump(*Node->getExpr()); + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) { } /// Print a binary operation, first the operator, then recurse into LHS and RHS. -void ASTDumper::dump(BinaryExprAST *Node) { +void ASTDumper::dump(BinaryExprAST *node) { INDENT(); - llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; - dump(Node->getLHS()); - dump(Node->getRHS()); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); } /// Print a call expression, first the callee name and the list of args by /// recursing into each individual argument. -void ASTDumper::dump(CallExprAST *Node) { +void ASTDumper::dump(CallExprAST *node) { INDENT(); - llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; - for (auto &arg : Node->getArgs()) + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) dump(arg.get()); indent(); llvm::errs() << "]\n"; } /// Print a builtin print call, first the builtin name and then the argument. -void ASTDumper::dump(PrintExprAST *Node) { +void ASTDumper::dump(PrintExprAST *node) { INDENT(); - llvm::errs() << "Print [ " << loc(Node) << "\n"; - dump(Node->getArg()); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); indent(); llvm::errs() << "]\n"; } /// Print type: only the shape is printed in between '<' and '>' -void ASTDumper::dump(VarType &type) { +void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - const char *sep = ""; - for (auto shape : type.shape) { - llvm::errs() << sep << shape; - sep = ", "; - } + mlir::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } /// Print a function prototype, first the function name, and then the list of /// parameters names. -void ASTDumper::dump(PrototypeAST *Node) { +void ASTDumper::dump(PrototypeAST *node) { INDENT(); - llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - const char *sep = ""; - for (auto &arg : Node->getArgs()) { - llvm::errs() << sep << arg->getName(); - sep = ", "; - } + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } /// Print a function, first the prototype and then the body. -void ASTDumper::dump(FunctionAST *Node) { +void ASTDumper::dump(FunctionAST *node) { INDENT(); llvm::errs() << "Function \n"; - dump(Node->getProto()); - dump(Node->getBody()); + dump(node->getProto()); + dump(node->getBody()); } /// Print a module, actually loop over the functions and print them in sequence. -void ASTDumper::dump(ModuleAST *Node) { +void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &F : *Node) - dump(&F); + for (auto &f : *node) + dump(&f); } namespace toy { diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp index d8e04d6ee89..a9f597dc032 100644 --- a/mlir/examples/toy/Ch4/toyc.cpp +++ b/mlir/examples/toy/Ch4/toyc.cpp @@ -64,20 +64,20 @@ static cl::opt<enum Action> emitAction( cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); -static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations")); +static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { - llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr = + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code EC = FileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; } - auto buffer = FileOrErr.get()->getBuffer(); + auto buffer = fileOrErr.get()->getBuffer(); LexerBuffer lexer(buffer.begin(), buffer.end(), filename); Parser parser(lexer); - return parser.ParseModule(); + return parser.parseModule(); } int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, @@ -119,7 +119,7 @@ int dumpMLIR() { if (int error = loadMLIR(sourceMgr, context, module)) return error; - if (EnableOpt) { + if (enableOpt) { mlir::PassManager pm(&context); // Apply any generic pass manager command line options and run the pipeline. applyPassManagerCLOptions(pm); diff --git a/mlir/examples/toy/Ch5/include/toy/AST.h b/mlir/examples/toy/Ch5/include/toy/AST.h index 2ad3392c11a..901164b0f39 100644 --- a/mlir/examples/toy/Ch5/include/toy/AST.h +++ b/mlir/examples/toy/Ch5/include/toy/AST.h @@ -54,7 +54,6 @@ public: ExprAST(ExprASTKind kind, Location location) : kind(kind), location(location) {} - virtual ~ExprAST() = default; ExprASTKind getKind() const { return kind; } @@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST { double Val; public: - NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} double getValue() { return Val; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } }; /// Expression class for a literal value. @@ -93,10 +92,11 @@ public: : ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) {} - std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; } - std::vector<int64_t> &getDims() { return dims; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } + llvm::ArrayRef<int64_t> getDims() { return dims; } + /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } }; /// Expression class for referencing a variable, like "a". @@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST { std::string name; public: - VariableExprAST(Location loc, const std::string &name) + VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, loc), name(name) {} llvm::StringRef getName() { return name; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } }; /// Expression class for defining a variable. @@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST { std::unique_ptr<ExprAST> initVal; public: - VarDeclExprAST(Location loc, const std::string &name, VarType type, + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST> initVal) : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } - VarType &getType() { return type; } + const VarType &getType() { return type; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } }; /// Expression class for a return operator. @@ -144,61 +144,61 @@ public: llvm::Optional<ExprAST *> getExpr() { if (expr.hasValue()) return expr->get(); - return llvm::NoneType(); + return llvm::None; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } }; /// Expression class for a binary operator. class BinaryExprAST : public ExprAST { - char Op; - std::unique_ptr<ExprAST> LHS, RHS; + char op; + std::unique_ptr<ExprAST> lhs, rhs; public: - char getOp() { return Op; } - ExprAST *getLHS() { return LHS.get(); } - ExprAST *getRHS() { return RHS.get(); } + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } - BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS, - std::unique_ptr<ExprAST> RHS) - : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), - RHS(std::move(RHS)) {} + BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs, + std::unique_ptr<ExprAST> rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } }; /// Expression class for function calls. class CallExprAST : public ExprAST { - std::string Callee; - std::vector<std::unique_ptr<ExprAST>> Args; + std::string callee; + std::vector<std::unique_ptr<ExprAST>> args; public: - CallExprAST(Location loc, const std::string &Callee, - std::vector<std::unique_ptr<ExprAST>> Args) - : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + CallExprAST(Location loc, const std::string &callee, + std::vector<std::unique_ptr<ExprAST>> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} - llvm::StringRef getCallee() { return Callee; } - llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; } + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } }; /// Expression class for builtin print calls. class PrintExprAST : public ExprAST { - std::unique_ptr<ExprAST> Arg; + std::unique_ptr<ExprAST> arg; public: - PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg) - : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} - ExprAST *getArg() { return Arg.get(); } + ExprAST *getArg() { return arg.get(); } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } }; /// This class represents the "prototype" for a function, which captures its @@ -215,23 +215,21 @@ public: : location(location), name(name), args(std::move(args)) {} const Location &loc() { return location; } - const std::string &getName() const { return name; } - const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() { - return args; - } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; } }; /// This class represents a function definition itself. class FunctionAST { - std::unique_ptr<PrototypeAST> Proto; - std::unique_ptr<ExprASTList> Body; + std::unique_ptr<PrototypeAST> proto; + std::unique_ptr<ExprASTList> body; public: - FunctionAST(std::unique_ptr<PrototypeAST> Proto, - std::unique_ptr<ExprASTList> Body) - : Proto(std::move(Proto)), Body(std::move(Body)) {} - PrototypeAST *getProto() { return Proto.get(); } - ExprASTList *getBody() { return Body.get(); } + FunctionAST(std::unique_ptr<PrototypeAST> proto, + std::unique_ptr<ExprASTList> body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } }; /// This class represents a list of functions to be processed together diff --git a/mlir/examples/toy/Ch5/include/toy/Lexer.h b/mlir/examples/toy/Ch5/include/toy/Lexer.h index 21f92614912..144388c460c 100644 --- a/mlir/examples/toy/Ch5/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch5/include/toy/Lexer.h @@ -89,13 +89,13 @@ public: /// Return the current identifier (prereq: getCurToken() == tok_identifier) llvm::StringRef getId() { assert(curTok == tok_identifier); - return IdentifierStr; + return identifierStr; } /// Return the current number (prereq: getCurToken() == tok_number) double getValue() { assert(curTok == tok_number); - return NumVal; + return numVal; } /// Return the location for the beginning of the current token. @@ -135,56 +135,58 @@ private: /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(LastChar)) - LastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; lastLocation.col = curCol; - if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* - IdentifierStr = (char)LastChar; - while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') - IdentifierStr += (char)LastChar; + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; - if (IdentifierStr == "return") + if (identifierStr == "return") return tok_return; - if (IdentifierStr == "def") + if (identifierStr == "def") return tok_def; - if (IdentifierStr == "var") + if (identifierStr == "var") return tok_var; return tok_identifier; } - if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ - std::string NumStr; + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; do { - NumStr += LastChar; - LastChar = Token(getNextChar()); - } while (isdigit(LastChar) || LastChar == '.'); + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); - NumVal = strtod(NumStr.c_str(), nullptr); + numVal = strtod(numStr.c_str(), nullptr); return tok_number; } - if (LastChar == '#') { + if (lastChar == '#') { // Comment until end of line. - do - LastChar = Token(getNextChar()); - while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (LastChar != EOF) + if (lastChar != EOF) return getTok(); } // Check for end of file. Don't eat the EOF. - if (LastChar == EOF) + if (lastChar == EOF) return tok_eof; // Otherwise, just return the character as its ascii value. - Token ThisChar = Token(LastChar); - LastChar = Token(getNextChar()); - return ThisChar; + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; } /// The last token read from the input. @@ -194,15 +196,15 @@ private: Location lastLocation; /// If the current Token is an identifier, this string contains the value. - std::string IdentifierStr; + std::string identifierStr; /// If the current Token is a number, this contains the value. - double NumVal = 0; + double numVal = 0; /// The last value returned by getNextChar(). We need to keep it around as we /// always need to read ahead one character to decide when to end a token and /// we can't put it back in the stream after reading from it. - Token LastChar = Token(' '); + Token lastChar = Token(' '); /// Keep track of the current line number in the input stream int curLineNum = 0; diff --git a/mlir/examples/toy/Ch5/include/toy/Parser.h b/mlir/examples/toy/Ch5/include/toy/Parser.h index ec3d7654a85..9e219e56551 100644 --- a/mlir/examples/toy/Ch5/include/toy/Parser.h +++ b/mlir/examples/toy/Ch5/include/toy/Parser.h @@ -48,13 +48,13 @@ public: Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. - std::unique_ptr<ModuleAST> ParseModule() { + std::unique_ptr<ModuleAST> parseModule() { lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector<FunctionAST> functions; - while (auto F = ParseDefinition()) { - functions.push_back(std::move(*F)); + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); if (lexer.getCurToken() == tok_eof) break; } @@ -70,14 +70,14 @@ private: /// Parse a return statement. /// return :== return ; | return expr ; - std::unique_ptr<ReturnExprAST> ParseReturn() { + std::unique_ptr<ReturnExprAST> parseReturn() { auto loc = lexer.getLastLocation(); lexer.consume(tok_return); // return takes an optional argument llvm::Optional<std::unique_ptr<ExprAST>> expr; if (lexer.getCurToken() != ';') { - expr = ParseExpression(); + expr = parseExpression(); if (!expr) return nullptr; } @@ -86,18 +86,18 @@ private: /// Parse a literal number. /// numberexpr ::= number - std::unique_ptr<ExprAST> ParseNumberExpr() { + std::unique_ptr<ExprAST> parseNumberExpr() { auto loc = lexer.getLastLocation(); - auto Result = + auto result = std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); lexer.consume(tok_number); - return std::move(Result); + return std::move(result); } /// Parse a literal array expression. /// tensorLiteral ::= [ literalList ] | number /// literalList ::= tensorLiteral | tensorLiteral, literalList - std::unique_ptr<ExprAST> ParseTensorLiteralExpr() { + std::unique_ptr<ExprAST> parseTensorLiteralExpr() { auto loc = lexer.getLastLocation(); lexer.consume(Token('[')); @@ -108,13 +108,13 @@ private: do { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { - values.push_back(ParseTensorLiteralExpr()); + values.push_back(parseTensorLiteralExpr()); if (!values.back()) return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError<ExprAST>("<num> or [", "in literal expression"); - values.push_back(ParseNumberExpr()); + values.push_back(parseNumberExpr()); } // End of this list on ']' @@ -130,8 +130,10 @@ private: if (values.empty()) return parseError<ExprAST>("<something>", "to fill literal expression"); lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that /// dimensions are uniform. if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { @@ -143,7 +145,7 @@ private: "inside literal expression"); // Append the nested dimensions to the current level - auto &firstDims = firstLiteral->getDims(); + auto firstDims = firstLiteral->getDims(); dims.insert(dims.end(), firstDims.begin(), firstDims.end()); // Sanity check that shape is uniform across all elements of the list. @@ -162,22 +164,22 @@ private: } /// parenexpr ::= '(' expression ')' - std::unique_ptr<ExprAST> ParseParenExpr() { + std::unique_ptr<ExprAST> parseParenExpr() { lexer.getNextToken(); // eat (. - auto V = ParseExpression(); - if (!V) + auto v = parseExpression(); + if (!v) return nullptr; if (lexer.getCurToken() != ')') return parseError<ExprAST>(")", "to close expression with parentheses"); lexer.consume(Token(')')); - return V; + return v; } /// identifierexpr /// ::= identifier /// ::= identifier '(' expression ')' - std::unique_ptr<ExprAST> ParseIdentifierExpr() { + std::unique_ptr<ExprAST> parseIdentifierExpr() { std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); @@ -188,11 +190,11 @@ private: // This is a function call. lexer.consume(Token('(')); - std::vector<std::unique_ptr<ExprAST>> Args; + std::vector<std::unique_ptr<ExprAST>> args; if (lexer.getCurToken() != ')') { while (true) { - if (auto Arg = ParseExpression()) - Args.push_back(std::move(Arg)); + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); else return nullptr; @@ -208,14 +210,14 @@ private: // It can be a builtin call to print if (name == "print") { - if (Args.size() != 1) + if (args.size() != 1) return parseError<ExprAST>("<single arg>", "as argument to print()"); - return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0])); + return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0])); } // Call to a user-defined function - return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args)); + return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args)); } /// primary @@ -223,20 +225,20 @@ private: /// ::= numberexpr /// ::= parenexpr /// ::= tensorliteral - std::unique_ptr<ExprAST> ParsePrimary() { + std::unique_ptr<ExprAST> parsePrimary() { switch (lexer.getCurToken()) { default: llvm::errs() << "unknown token '" << lexer.getCurToken() << "' when expecting an expression\n"; return nullptr; case tok_identifier: - return ParseIdentifierExpr(); + return parseIdentifierExpr(); case tok_number: - return ParseNumberExpr(); + return parseNumberExpr(); case '(': - return ParseParenExpr(); + return parseParenExpr(); case '[': - return ParseTensorLiteralExpr(); + return parseTensorLiteralExpr(); case ';': return nullptr; case '}': @@ -248,54 +250,54 @@ private: /// argument indicates the precedence of the current binary operator. /// /// binoprhs ::= ('+' primary)* - std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, - std::unique_ptr<ExprAST> LHS) { + std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec, + std::unique_ptr<ExprAST> lhs) { // If this is a binop, find its precedence. while (true) { - int TokPrec = GetTokPrecedence(); + int tokPrec = getTokPrecedence(); // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (TokPrec < ExprPrec) - return LHS; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. - int BinOp = lexer.getCurToken(); - lexer.consume(Token(BinOp)); + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); auto loc = lexer.getLastLocation(); // Parse the primary expression after the binary operator. - auto RHS = ParsePrimary(); - if (!RHS) + auto rhs = parsePrimary(); + if (!rhs) return parseError<ExprAST>("expression", "to complete binary operator"); - // If BinOp binds less tightly with RHS than the operator after RHS, let - // the pending operator take RHS as its LHS. - int NextPrec = GetTokPrecedence(); - if (TokPrec < NextPrec) { - RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); - if (!RHS) + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) return nullptr; } - // Merge LHS/RHS. - LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + // Merge lhs/RHS. + lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); } } - /// expression::= primary binoprhs - std::unique_ptr<ExprAST> ParseExpression() { - auto LHS = ParsePrimary(); - if (!LHS) + /// expression::= primary binop rhs + std::unique_ptr<ExprAST> parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) return nullptr; - return ParseBinOpRHS(0, std::move(LHS)); + return parseBinOpRHS(0, std::move(lhs)); } /// type ::= < shape_list > /// shape_list ::= num | num , shape_list - std::unique_ptr<VarType> ParseType() { + std::unique_ptr<VarType> parseType() { if (lexer.getCurToken() != '<') return parseError<VarType>("<", "to begin type"); lexer.getNextToken(); // eat < @@ -319,7 +321,7 @@ private: /// and identifier and an optional type (shape specification) before the /// initializer. /// decl ::= var identifier [ type ] = expr - std::unique_ptr<VarDeclExprAST> ParseDeclaration() { + std::unique_ptr<VarDeclExprAST> parseDeclaration() { if (lexer.getCurToken() != tok_var) return parseError<VarDeclExprAST>("var", "to begin declaration"); auto loc = lexer.getLastLocation(); @@ -333,7 +335,7 @@ private: std::unique_ptr<VarType> type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { - type = ParseType(); + type = parseType(); if (!type) return nullptr; } @@ -341,7 +343,7 @@ private: if (!type) type = std::make_unique<VarType>(); lexer.consume(Token('=')); - auto expr = ParseExpression(); + auto expr = parseExpression(); return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), std::move(*type), std::move(expr)); } @@ -352,7 +354,7 @@ private: /// block ::= { expression_list } /// expression_list ::= block_expr ; expression_list /// block_expr ::= decl | "return" | expr - std::unique_ptr<ExprASTList> ParseBlock() { + std::unique_ptr<ExprASTList> parseBlock() { if (lexer.getCurToken() != '{') return parseError<ExprASTList>("{", "to begin block"); lexer.consume(Token('{')); @@ -366,19 +368,19 @@ private: while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration - auto varDecl = ParseDeclaration(); + auto varDecl = parseDeclaration(); if (!varDecl) return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement - auto ret = ParseReturn(); + auto ret = parseReturn(); if (!ret) return nullptr; exprList->push_back(std::move(ret)); } else { // General expression - auto expr = ParseExpression(); + auto expr = parseExpression(); if (!expr) return nullptr; exprList->push_back(std::move(expr)); @@ -401,13 +403,13 @@ private: /// prototype ::= def id '(' decl_list ')' /// decl_list ::= identifier | identifier, decl_list - std::unique_ptr<PrototypeAST> ParsePrototype() { + std::unique_ptr<PrototypeAST> parsePrototype() { auto loc = lexer.getLastLocation(); lexer.consume(tok_def); if (lexer.getCurToken() != tok_identifier) return parseError<PrototypeAST>("function name", "in prototype"); - std::string FnName = lexer.getId(); + std::string fnName = lexer.getId(); lexer.consume(tok_identifier); if (lexer.getCurToken() != '(') @@ -435,7 +437,7 @@ private: // success. lexer.consume(Token(')')); - return std::make_unique<PrototypeAST>(std::move(loc), FnName, + return std::make_unique<PrototypeAST>(std::move(loc), fnName, std::move(args)); } @@ -443,18 +445,18 @@ private: /// `def` keyword, followed by a block containing a list of expressions. /// /// definition ::= prototype block - std::unique_ptr<FunctionAST> ParseDefinition() { - auto Proto = ParsePrototype(); - if (!Proto) + std::unique_ptr<FunctionAST> parseDefinition() { + auto proto = parsePrototype(); + if (!proto) return nullptr; - if (auto block = ParseBlock()) - return std::make_unique<FunctionAST>(std::move(Proto), std::move(block)); + if (auto block = parseBlock()) + return std::make_unique<FunctionAST>(std::move(proto), std::move(block)); return nullptr; } /// Get the precedence of the pending binary operator token. - int GetTokPrecedence() { + int getTokPrecedence() { if (!isascii(lexer.getCurToken())) return -1; diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 8b434b139c7..da474e809b3 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -143,7 +143,7 @@ private: // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. auto &entryBlock = *function.addEntryBlock(); - auto &protoArgs = funcAST.getProto()->getArgs(); + auto protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : diff --git a/mlir/examples/toy/Ch5/parser/AST.cpp b/mlir/examples/toy/Ch5/parser/AST.cpp index fde8b101e83..0c7735ec9a4 100644 --- a/mlir/examples/toy/Ch5/parser/AST.cpp +++ b/mlir/examples/toy/Ch5/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -40,22 +41,22 @@ struct Indent { /// the way. The only data member is the current indentation level. class ASTDumper { public: - void dump(ModuleAST *Node); + void dump(ModuleAST *node); private: - void dump(VarType &type); + void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); void dump(ExprASTList *exprList); void dump(NumberExprAST *num); - void dump(LiteralExprAST *Node); - void dump(VariableExprAST *Node); - void dump(ReturnExprAST *Node); - void dump(BinaryExprAST *Node); - void dump(CallExprAST *Node); - void dump(PrintExprAST *Node); - void dump(PrototypeAST *Node); - void dump(FunctionAST *Node); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); // Actually print spaces matching the current indentation level void indent() { @@ -68,8 +69,8 @@ private: } // namespace /// Return a formatted string for the location of any node -template <typename T> static std::string loc(T *Node) { - const auto &loc = Node->loc(); +template <typename T> static std::string loc(T *node) { + const auto &loc = node->loc(); return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + llvm::Twine(loc.col)) .str(); @@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *lit_or_num) { +void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal - if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) { + if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { llvm::errs() << num->getValue(); return; } - auto *literal = llvm::cast<LiteralExprAST>(lit_or_num); + auto *literal = llvm::cast<LiteralExprAST>(litOrNum); // Print the dimension for this literal first llvm::errs() << "<"; - { - const char *sep = ""; - for (auto dim : literal->getDims()) { - llvm::errs() << sep << dim; - sep = ", "; - } - } + mlir::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - const char *sep = ""; - for (auto &elt : literal->getValues()) { - llvm::errs() << sep; - printLitHelper(elt.get()); - sep = ", "; - } + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } /// Print a literal, see the recursive helper above for the implementation. -void ASTDumper::dump(LiteralExprAST *Node) { +void ASTDumper::dump(LiteralExprAST *node) { INDENT(); llvm::errs() << "Literal: "; - printLitHelper(Node); - llvm::errs() << " " << loc(Node) << "\n"; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; } /// Print a variable reference (just a name). -void ASTDumper::dump(VariableExprAST *Node) { +void ASTDumper::dump(VariableExprAST *node) { INDENT(); - llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; } /// Return statement print the return and its (optional) argument. -void ASTDumper::dump(ReturnExprAST *Node) { +void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (Node->getExpr().hasValue()) - return dump(*Node->getExpr()); + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) { } /// Print a binary operation, first the operator, then recurse into LHS and RHS. -void ASTDumper::dump(BinaryExprAST *Node) { +void ASTDumper::dump(BinaryExprAST *node) { INDENT(); - llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; - dump(Node->getLHS()); - dump(Node->getRHS()); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); } /// Print a call expression, first the callee name and the list of args by /// recursing into each individual argument. -void ASTDumper::dump(CallExprAST *Node) { +void ASTDumper::dump(CallExprAST *node) { INDENT(); - llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; - for (auto &arg : Node->getArgs()) + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) dump(arg.get()); indent(); llvm::errs() << "]\n"; } /// Print a builtin print call, first the builtin name and then the argument. -void ASTDumper::dump(PrintExprAST *Node) { +void ASTDumper::dump(PrintExprAST *node) { INDENT(); - llvm::errs() << "Print [ " << loc(Node) << "\n"; - dump(Node->getArg()); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); indent(); llvm::errs() << "]\n"; } /// Print type: only the shape is printed in between '<' and '>' -void ASTDumper::dump(VarType &type) { +void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - const char *sep = ""; - for (auto shape : type.shape) { - llvm::errs() << sep << shape; - sep = ", "; - } + mlir::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } /// Print a function prototype, first the function name, and then the list of /// parameters names. -void ASTDumper::dump(PrototypeAST *Node) { +void ASTDumper::dump(PrototypeAST *node) { INDENT(); - llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - const char *sep = ""; - for (auto &arg : Node->getArgs()) { - llvm::errs() << sep << arg->getName(); - sep = ", "; - } + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } /// Print a function, first the prototype and then the body. -void ASTDumper::dump(FunctionAST *Node) { +void ASTDumper::dump(FunctionAST *node) { INDENT(); llvm::errs() << "Function \n"; - dump(Node->getProto()); - dump(Node->getBody()); + dump(node->getProto()); + dump(node->getBody()); } /// Print a module, actually loop over the functions and print them in sequence. -void ASTDumper::dump(ModuleAST *Node) { +void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &F : *Node) - dump(&F); + for (auto &f : *node) + dump(&f); } namespace toy { diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp index 54cbdf1405f..57fb1a95c94 100644 --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -66,20 +66,20 @@ static cl::opt<enum Action> emitAction( cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", "output the MLIR dump after affine lowering"))); -static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations")); +static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { - llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr = + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code EC = FileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; } - auto buffer = FileOrErr.get()->getBuffer(); + auto buffer = fileOrErr.get()->getBuffer(); LexerBuffer lexer(buffer.begin(), buffer.end(), filename); Parser parser(lexer); - return parser.ParseModule(); + return parser.parseModule(); } int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, @@ -128,7 +128,7 @@ int dumpMLIR() { // Check to see what granularity of MLIR we are compiling to. bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; - if (EnableOpt || isLoweringToAffine) { + if (enableOpt || isLoweringToAffine) { // Inline all functions into main and then delete them. pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); @@ -150,7 +150,7 @@ int dumpMLIR() { optPM.addPass(mlir::createCSEPass()); // Add optimizations if enabled. - if (EnableOpt) { + if (enableOpt) { optPM.addPass(mlir::createLoopFusionPass()); optPM.addPass(mlir::createMemRefDataFlowOptPass()); } diff --git a/mlir/examples/toy/Ch6/include/toy/AST.h b/mlir/examples/toy/Ch6/include/toy/AST.h index 2ad3392c11a..901164b0f39 100644 --- a/mlir/examples/toy/Ch6/include/toy/AST.h +++ b/mlir/examples/toy/Ch6/include/toy/AST.h @@ -54,7 +54,6 @@ public: ExprAST(ExprASTKind kind, Location location) : kind(kind), location(location) {} - virtual ~ExprAST() = default; ExprASTKind getKind() const { return kind; } @@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST { double Val; public: - NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} double getValue() { return Val; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } }; /// Expression class for a literal value. @@ -93,10 +92,11 @@ public: : ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) {} - std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; } - std::vector<int64_t> &getDims() { return dims; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } + llvm::ArrayRef<int64_t> getDims() { return dims; } + /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } }; /// Expression class for referencing a variable, like "a". @@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST { std::string name; public: - VariableExprAST(Location loc, const std::string &name) + VariableExprAST(Location loc, llvm::StringRef name) : ExprAST(Expr_Var, loc), name(name) {} llvm::StringRef getName() { return name; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } }; /// Expression class for defining a variable. @@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST { std::unique_ptr<ExprAST> initVal; public: - VarDeclExprAST(Location loc, const std::string &name, VarType type, + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST> initVal) : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), initVal(std::move(initVal)) {} llvm::StringRef getName() { return name; } ExprAST *getInitVal() { return initVal.get(); } - VarType &getType() { return type; } + const VarType &getType() { return type; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } }; /// Expression class for a return operator. @@ -144,61 +144,61 @@ public: llvm::Optional<ExprAST *> getExpr() { if (expr.hasValue()) return expr->get(); - return llvm::NoneType(); + return llvm::None; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } }; /// Expression class for a binary operator. class BinaryExprAST : public ExprAST { - char Op; - std::unique_ptr<ExprAST> LHS, RHS; + char op; + std::unique_ptr<ExprAST> lhs, rhs; public: - char getOp() { return Op; } - ExprAST *getLHS() { return LHS.get(); } - ExprAST *getRHS() { return RHS.get(); } + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } - BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS, - std::unique_ptr<ExprAST> RHS) - : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), - RHS(std::move(RHS)) {} + BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs, + std::unique_ptr<ExprAST> rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } }; /// Expression class for function calls. class CallExprAST : public ExprAST { - std::string Callee; - std::vector<std::unique_ptr<ExprAST>> Args; + std::string callee; + std::vector<std::unique_ptr<ExprAST>> args; public: - CallExprAST(Location loc, const std::string &Callee, - std::vector<std::unique_ptr<ExprAST>> Args) - : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + CallExprAST(Location loc, const std::string &callee, + std::vector<std::unique_ptr<ExprAST>> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} - llvm::StringRef getCallee() { return Callee; } - llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; } + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } }; /// Expression class for builtin print calls. class PrintExprAST : public ExprAST { - std::unique_ptr<ExprAST> Arg; + std::unique_ptr<ExprAST> arg; public: - PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg) - : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} - ExprAST *getArg() { return Arg.get(); } + ExprAST *getArg() { return arg.get(); } /// LLVM style RTTI - static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } }; /// This class represents the "prototype" for a function, which captures its @@ -215,23 +215,21 @@ public: : location(location), name(name), args(std::move(args)) {} const Location &loc() { return location; } - const std::string &getName() const { return name; } - const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() { - return args; - } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; } }; /// This class represents a function definition itself. class FunctionAST { - std::unique_ptr<PrototypeAST> Proto; - std::unique_ptr<ExprASTList> Body; + std::unique_ptr<PrototypeAST> proto; + std::unique_ptr<ExprASTList> body; public: - FunctionAST(std::unique_ptr<PrototypeAST> Proto, - std::unique_ptr<ExprASTList> Body) - : Proto(std::move(Proto)), Body(std::move(Body)) {} - PrototypeAST *getProto() { return Proto.get(); } - ExprASTList *getBody() { return Body.get(); } + FunctionAST(std::unique_ptr<PrototypeAST> proto, + std::unique_ptr<ExprASTList> body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } }; /// This class represents a list of functions to be processed together diff --git a/mlir/examples/toy/Ch6/include/toy/Lexer.h b/mlir/examples/toy/Ch6/include/toy/Lexer.h index 21f92614912..144388c460c 100644 --- a/mlir/examples/toy/Ch6/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch6/include/toy/Lexer.h @@ -89,13 +89,13 @@ public: /// Return the current identifier (prereq: getCurToken() == tok_identifier) llvm::StringRef getId() { assert(curTok == tok_identifier); - return IdentifierStr; + return identifierStr; } /// Return the current number (prereq: getCurToken() == tok_number) double getValue() { assert(curTok == tok_number); - return NumVal; + return numVal; } /// Return the location for the beginning of the current token. @@ -135,56 +135,58 @@ private: /// Return the next token from standard input. Token getTok() { // Skip any whitespace. - while (isspace(LastChar)) - LastChar = Token(getNextChar()); + while (isspace(lastChar)) + lastChar = Token(getNextChar()); // Save the current location before reading the token characters. lastLocation.line = curLineNum; lastLocation.col = curCol; - if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* - IdentifierStr = (char)LastChar; - while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') - IdentifierStr += (char)LastChar; + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; - if (IdentifierStr == "return") + if (identifierStr == "return") return tok_return; - if (IdentifierStr == "def") + if (identifierStr == "def") return tok_def; - if (IdentifierStr == "var") + if (identifierStr == "var") return tok_var; return tok_identifier; } - if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ - std::string NumStr; + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; do { - NumStr += LastChar; - LastChar = Token(getNextChar()); - } while (isdigit(LastChar) || LastChar == '.'); + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); - NumVal = strtod(NumStr.c_str(), nullptr); + numVal = strtod(numStr.c_str(), nullptr); return tok_number; } - if (LastChar == '#') { + if (lastChar == '#') { // Comment until end of line. - do - LastChar = Token(getNextChar()); - while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - if (LastChar != EOF) + if (lastChar != EOF) return getTok(); } // Check for end of file. Don't eat the EOF. - if (LastChar == EOF) + if (lastChar == EOF) return tok_eof; // Otherwise, just return the character as its ascii value. - Token ThisChar = Token(LastChar); - LastChar = Token(getNextChar()); - return ThisChar; + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; } /// The last token read from the input. @@ -194,15 +196,15 @@ private: Location lastLocation; /// If the current Token is an identifier, this string contains the value. - std::string IdentifierStr; + std::string identifierStr; /// If the current Token is a number, this contains the value. - double NumVal = 0; + double numVal = 0; /// The last value returned by getNextChar(). We need to keep it around as we /// always need to read ahead one character to decide when to end a token and /// we can't put it back in the stream after reading from it. - Token LastChar = Token(' '); + Token lastChar = Token(' '); /// Keep track of the current line number in the input stream int curLineNum = 0; diff --git a/mlir/examples/toy/Ch6/include/toy/Parser.h b/mlir/examples/toy/Ch6/include/toy/Parser.h index ec3d7654a85..9e219e56551 100644 --- a/mlir/examples/toy/Ch6/include/toy/Parser.h +++ b/mlir/examples/toy/Ch6/include/toy/Parser.h @@ -48,13 +48,13 @@ public: Parser(Lexer &lexer) : lexer(lexer) {} /// Parse a full Module. A module is a list of function definitions. - std::unique_ptr<ModuleAST> ParseModule() { + std::unique_ptr<ModuleAST> parseModule() { lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector<FunctionAST> functions; - while (auto F = ParseDefinition()) { - functions.push_back(std::move(*F)); + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); if (lexer.getCurToken() == tok_eof) break; } @@ -70,14 +70,14 @@ private: /// Parse a return statement. /// return :== return ; | return expr ; - std::unique_ptr<ReturnExprAST> ParseReturn() { + std::unique_ptr<ReturnExprAST> parseReturn() { auto loc = lexer.getLastLocation(); lexer.consume(tok_return); // return takes an optional argument llvm::Optional<std::unique_ptr<ExprAST>> expr; if (lexer.getCurToken() != ';') { - expr = ParseExpression(); + expr = parseExpression(); if (!expr) return nullptr; } @@ -86,18 +86,18 @@ private: /// Parse a literal number. /// numberexpr ::= number - std::unique_ptr<ExprAST> ParseNumberExpr() { + std::unique_ptr<ExprAST> parseNumberExpr() { auto loc = lexer.getLastLocation(); - auto Result = + auto result = std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); lexer.consume(tok_number); - return std::move(Result); + return std::move(result); } /// Parse a literal array expression. /// tensorLiteral ::= [ literalList ] | number /// literalList ::= tensorLiteral | tensorLiteral, literalList - std::unique_ptr<ExprAST> ParseTensorLiteralExpr() { + std::unique_ptr<ExprAST> parseTensorLiteralExpr() { auto loc = lexer.getLastLocation(); lexer.consume(Token('[')); @@ -108,13 +108,13 @@ private: do { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { - values.push_back(ParseTensorLiteralExpr()); + values.push_back(parseTensorLiteralExpr()); if (!values.back()) return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError<ExprAST>("<num> or [", "in literal expression"); - values.push_back(ParseNumberExpr()); + values.push_back(parseNumberExpr()); } // End of this list on ']' @@ -130,8 +130,10 @@ private: if (values.empty()) return parseError<ExprAST>("<something>", "to fill literal expression"); lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that /// dimensions are uniform. if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { @@ -143,7 +145,7 @@ private: "inside literal expression"); // Append the nested dimensions to the current level - auto &firstDims = firstLiteral->getDims(); + auto firstDims = firstLiteral->getDims(); dims.insert(dims.end(), firstDims.begin(), firstDims.end()); // Sanity check that shape is uniform across all elements of the list. @@ -162,22 +164,22 @@ private: } /// parenexpr ::= '(' expression ')' - std::unique_ptr<ExprAST> ParseParenExpr() { + std::unique_ptr<ExprAST> parseParenExpr() { lexer.getNextToken(); // eat (. - auto V = ParseExpression(); - if (!V) + auto v = parseExpression(); + if (!v) return nullptr; if (lexer.getCurToken() != ')') return parseError<ExprAST>(")", "to close expression with parentheses"); lexer.consume(Token(')')); - return V; + return v; } /// identifierexpr /// ::= identifier /// ::= identifier '(' expression ')' - std::unique_ptr<ExprAST> ParseIdentifierExpr() { + std::unique_ptr<ExprAST> parseIdentifierExpr() { std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); @@ -188,11 +190,11 @@ private: // This is a function call. lexer.consume(Token('(')); - std::vector<std::unique_ptr<ExprAST>> Args; + std::vector<std::unique_ptr<ExprAST>> args; if (lexer.getCurToken() != ')') { while (true) { - if (auto Arg = ParseExpression()) - Args.push_back(std::move(Arg)); + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); else return nullptr; @@ -208,14 +210,14 @@ private: // It can be a builtin call to print if (name == "print") { - if (Args.size() != 1) + if (args.size() != 1) return parseError<ExprAST>("<single arg>", "as argument to print()"); - return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0])); + return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0])); } // Call to a user-defined function - return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args)); + return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args)); } /// primary @@ -223,20 +225,20 @@ private: /// ::= numberexpr /// ::= parenexpr /// ::= tensorliteral - std::unique_ptr<ExprAST> ParsePrimary() { + std::unique_ptr<ExprAST> parsePrimary() { switch (lexer.getCurToken()) { default: llvm::errs() << "unknown token '" << lexer.getCurToken() << "' when expecting an expression\n"; return nullptr; case tok_identifier: - return ParseIdentifierExpr(); + return parseIdentifierExpr(); case tok_number: - return ParseNumberExpr(); + return parseNumberExpr(); case '(': - return ParseParenExpr(); + return parseParenExpr(); case '[': - return ParseTensorLiteralExpr(); + return parseTensorLiteralExpr(); case ';': return nullptr; case '}': @@ -248,54 +250,54 @@ private: /// argument indicates the precedence of the current binary operator. /// /// binoprhs ::= ('+' primary)* - std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, - std::unique_ptr<ExprAST> LHS) { + std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec, + std::unique_ptr<ExprAST> lhs) { // If this is a binop, find its precedence. while (true) { - int TokPrec = GetTokPrecedence(); + int tokPrec = getTokPrecedence(); // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. - if (TokPrec < ExprPrec) - return LHS; + if (tokPrec < exprPrec) + return lhs; // Okay, we know this is a binop. - int BinOp = lexer.getCurToken(); - lexer.consume(Token(BinOp)); + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); auto loc = lexer.getLastLocation(); // Parse the primary expression after the binary operator. - auto RHS = ParsePrimary(); - if (!RHS) + auto rhs = parsePrimary(); + if (!rhs) return parseError<ExprAST>("expression", "to complete binary operator"); - // If BinOp binds less tightly with RHS than the operator after RHS, let - // the pending operator take RHS as its LHS. - int NextPrec = GetTokPrecedence(); - if (TokPrec < NextPrec) { - RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); - if (!RHS) + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) return nullptr; } - // Merge LHS/RHS. - LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + // Merge lhs/RHS. + lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); } } - /// expression::= primary binoprhs - std::unique_ptr<ExprAST> ParseExpression() { - auto LHS = ParsePrimary(); - if (!LHS) + /// expression::= primary binop rhs + std::unique_ptr<ExprAST> parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) return nullptr; - return ParseBinOpRHS(0, std::move(LHS)); + return parseBinOpRHS(0, std::move(lhs)); } /// type ::= < shape_list > /// shape_list ::= num | num , shape_list - std::unique_ptr<VarType> ParseType() { + std::unique_ptr<VarType> parseType() { if (lexer.getCurToken() != '<') return parseError<VarType>("<", "to begin type"); lexer.getNextToken(); // eat < @@ -319,7 +321,7 @@ private: /// and identifier and an optional type (shape specification) before the /// initializer. /// decl ::= var identifier [ type ] = expr - std::unique_ptr<VarDeclExprAST> ParseDeclaration() { + std::unique_ptr<VarDeclExprAST> parseDeclaration() { if (lexer.getCurToken() != tok_var) return parseError<VarDeclExprAST>("var", "to begin declaration"); auto loc = lexer.getLastLocation(); @@ -333,7 +335,7 @@ private: std::unique_ptr<VarType> type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { - type = ParseType(); + type = parseType(); if (!type) return nullptr; } @@ -341,7 +343,7 @@ private: if (!type) type = std::make_unique<VarType>(); lexer.consume(Token('=')); - auto expr = ParseExpression(); + auto expr = parseExpression(); return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), std::move(*type), std::move(expr)); } @@ -352,7 +354,7 @@ private: /// block ::= { expression_list } /// expression_list ::= block_expr ; expression_list /// block_expr ::= decl | "return" | expr - std::unique_ptr<ExprASTList> ParseBlock() { + std::unique_ptr<ExprASTList> parseBlock() { if (lexer.getCurToken() != '{') return parseError<ExprASTList>("{", "to begin block"); lexer.consume(Token('{')); @@ -366,19 +368,19 @@ private: while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration - auto varDecl = ParseDeclaration(); + auto varDecl = parseDeclaration(); if (!varDecl) return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement - auto ret = ParseReturn(); + auto ret = parseReturn(); if (!ret) return nullptr; exprList->push_back(std::move(ret)); } else { // General expression - auto expr = ParseExpression(); + auto expr = parseExpression(); if (!expr) return nullptr; exprList->push_back(std::move(expr)); @@ -401,13 +403,13 @@ private: /// prototype ::= def id '(' decl_list ')' /// decl_list ::= identifier | identifier, decl_list - std::unique_ptr<PrototypeAST> ParsePrototype() { + std::unique_ptr<PrototypeAST> parsePrototype() { auto loc = lexer.getLastLocation(); lexer.consume(tok_def); if (lexer.getCurToken() != tok_identifier) return parseError<PrototypeAST>("function name", "in prototype"); - std::string FnName = lexer.getId(); + std::string fnName = lexer.getId(); lexer.consume(tok_identifier); if (lexer.getCurToken() != '(') @@ -435,7 +437,7 @@ private: // success. lexer.consume(Token(')')); - return std::make_unique<PrototypeAST>(std::move(loc), FnName, + return std::make_unique<PrototypeAST>(std::move(loc), fnName, std::move(args)); } @@ -443,18 +445,18 @@ private: /// `def` keyword, followed by a block containing a list of expressions. /// /// definition ::= prototype block - std::unique_ptr<FunctionAST> ParseDefinition() { - auto Proto = ParsePrototype(); - if (!Proto) + std::unique_ptr<FunctionAST> parseDefinition() { + auto proto = parsePrototype(); + if (!proto) return nullptr; - if (auto block = ParseBlock()) - return std::make_unique<FunctionAST>(std::move(Proto), std::move(block)); + if (auto block = parseBlock()) + return std::make_unique<FunctionAST>(std::move(proto), std::move(block)); return nullptr; } /// Get the precedence of the pending binary operator token. - int GetTokPrecedence() { + int getTokPrecedence() { if (!isascii(lexer.getCurToken())) return -1; diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp index 8b434b139c7..da474e809b3 100644 --- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -143,7 +143,7 @@ private: // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. auto &entryBlock = *function.addEntryBlock(); - auto &protoArgs = funcAST.getProto()->getArgs(); + auto protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : diff --git a/mlir/examples/toy/Ch6/parser/AST.cpp b/mlir/examples/toy/Ch6/parser/AST.cpp index fde8b101e83..0c7735ec9a4 100644 --- a/mlir/examples/toy/Ch6/parser/AST.cpp +++ b/mlir/examples/toy/Ch6/parser/AST.cpp @@ -21,6 +21,7 @@ #include "toy/AST.h" +#include "mlir/Support/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -40,22 +41,22 @@ struct Indent { /// the way. The only data member is the current indentation level. class ASTDumper { public: - void dump(ModuleAST *Node); + void dump(ModuleAST *node); private: - void dump(VarType &type); + void dump(const VarType &type); void dump(VarDeclExprAST *varDecl); void dump(ExprAST *expr); void dump(ExprASTList *exprList); void dump(NumberExprAST *num); - void dump(LiteralExprAST *Node); - void dump(VariableExprAST *Node); - void dump(ReturnExprAST *Node); - void dump(BinaryExprAST *Node); - void dump(CallExprAST *Node); - void dump(PrintExprAST *Node); - void dump(PrototypeAST *Node); - void dump(FunctionAST *Node); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); // Actually print spaces matching the current indentation level void indent() { @@ -68,8 +69,8 @@ private: } // namespace /// Return a formatted string for the location of any node -template <typename T> static std::string loc(T *Node) { - const auto &loc = Node->loc(); +template <typename T> static std::string loc(T *node) { + const auto &loc = node->loc(); return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + llvm::Twine(loc.col)) .str(); @@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *lit_or_num) { +void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal - if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) { + if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { llvm::errs() << num->getValue(); return; } - auto *literal = llvm::cast<LiteralExprAST>(lit_or_num); + auto *literal = llvm::cast<LiteralExprAST>(litOrNum); // Print the dimension for this literal first llvm::errs() << "<"; - { - const char *sep = ""; - for (auto dim : literal->getDims()) { - llvm::errs() << sep << dim; - sep = ", "; - } - } + mlir::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - const char *sep = ""; - for (auto &elt : literal->getValues()) { - llvm::errs() << sep; - printLitHelper(elt.get()); - sep = ", "; - } + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } /// Print a literal, see the recursive helper above for the implementation. -void ASTDumper::dump(LiteralExprAST *Node) { +void ASTDumper::dump(LiteralExprAST *node) { INDENT(); llvm::errs() << "Literal: "; - printLitHelper(Node); - llvm::errs() << " " << loc(Node) << "\n"; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; } /// Print a variable reference (just a name). -void ASTDumper::dump(VariableExprAST *Node) { +void ASTDumper::dump(VariableExprAST *node) { INDENT(); - llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; } /// Return statement print the return and its (optional) argument. -void ASTDumper::dump(ReturnExprAST *Node) { +void ASTDumper::dump(ReturnExprAST *node) { INDENT(); llvm::errs() << "Return\n"; - if (Node->getExpr().hasValue()) - return dump(*Node->getExpr()); + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); { INDENT(); llvm::errs() << "(void)\n"; @@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) { } /// Print a binary operation, first the operator, then recurse into LHS and RHS. -void ASTDumper::dump(BinaryExprAST *Node) { +void ASTDumper::dump(BinaryExprAST *node) { INDENT(); - llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; - dump(Node->getLHS()); - dump(Node->getRHS()); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); } /// Print a call expression, first the callee name and the list of args by /// recursing into each individual argument. -void ASTDumper::dump(CallExprAST *Node) { +void ASTDumper::dump(CallExprAST *node) { INDENT(); - llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; - for (auto &arg : Node->getArgs()) + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) dump(arg.get()); indent(); llvm::errs() << "]\n"; } /// Print a builtin print call, first the builtin name and then the argument. -void ASTDumper::dump(PrintExprAST *Node) { +void ASTDumper::dump(PrintExprAST *node) { INDENT(); - llvm::errs() << "Print [ " << loc(Node) << "\n"; - dump(Node->getArg()); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); indent(); llvm::errs() << "]\n"; } /// Print type: only the shape is printed in between '<' and '>' -void ASTDumper::dump(VarType &type) { +void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - const char *sep = ""; - for (auto shape : type.shape) { - llvm::errs() << sep << shape; - sep = ", "; - } + mlir::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } /// Print a function prototype, first the function name, and then the list of /// parameters names. -void ASTDumper::dump(PrototypeAST *Node) { +void ASTDumper::dump(PrototypeAST *node) { INDENT(); - llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - const char *sep = ""; - for (auto &arg : Node->getArgs()) { - llvm::errs() << sep << arg->getName(); - sep = ", "; - } + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } /// Print a function, first the prototype and then the body. -void ASTDumper::dump(FunctionAST *Node) { +void ASTDumper::dump(FunctionAST *node) { INDENT(); llvm::errs() << "Function \n"; - dump(Node->getProto()); - dump(Node->getBody()); + dump(node->getProto()); + dump(node->getBody()); } /// Print a module, actually loop over the functions and print them in sequence. -void ASTDumper::dump(ModuleAST *Node) { +void ASTDumper::dump(ModuleAST *node) { INDENT(); llvm::errs() << "Module:\n"; - for (auto &F : *Node) - dump(&F); + for (auto &f : *node) + dump(&f); } namespace toy { diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp index b7aed1b440c..fb7cd5493a2 100644 --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -85,20 +85,20 @@ static cl::opt<enum Action> emitAction( clEnumValN(RunJIT, "jit", "JIT the code and run it by invoking the main function"))); -static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations")); +static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { - llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr = + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code EC = FileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; } - auto buffer = FileOrErr.get()->getBuffer(); + auto buffer = fileOrErr.get()->getBuffer(); LexerBuffer lexer(buffer.begin(), buffer.end(), filename); Parser parser(lexer); - return parser.ParseModule(); + return parser.parseModule(); } int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) { @@ -142,7 +142,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; - if (EnableOpt || isLoweringToAffine) { + if (enableOpt || isLoweringToAffine) { // Inline all functions into main and then delete them. pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); @@ -164,7 +164,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, optPM.addPass(mlir::createCSEPass()); // Add optimizations if enabled. - if (EnableOpt) { + if (enableOpt) { optPM.addPass(mlir::createLoopFusionPass()); optPM.addPass(mlir::createMemRefDataFlowOptPass()); } @@ -208,7 +208,7 @@ int dumpLLVMIR(mlir::ModuleOp module) { /// Optionally run an optimization pipeline over the llvm module. auto optPipeline = mlir::makeOptimizingTransformer( - /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, /*targetMachine=*/nullptr); if (auto err = optPipeline(llvmModule.get())) { llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; @@ -225,7 +225,7 @@ int runJit(mlir::ModuleOp module) { // An optimization pipeline to use within the execution engine. auto optPipeline = mlir::makeOptimizingTransformer( - /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, /*targetMachine=*/nullptr); // Create an MLIR execution engine. The execution engine eagerly JIT-compiles |

