summaryrefslogtreecommitdiffstats
path: root/mlir/examples
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-11-06 18:20:24 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-06 18:21:03 -0800
commit22cfff7043daa10ca3e00afd49ff80b882bbb107 (patch)
treec080e8627df81d30895b59cd69483ee081c5783f /mlir/examples
parentf6188b5b07418dda04743a51f0ddcbca30c7a196 (diff)
downloadbcm5719-llvm-22cfff7043daa10ca3e00afd49ff80b882bbb107.tar.gz
bcm5719-llvm-22cfff7043daa10ca3e00afd49ff80b882bbb107.zip
NFC: Uniformize parser naming scheme in Toy tutorial to camelCase and tidy a bit of the implementation.
PiperOrigin-RevId: 278982817
Diffstat (limited to 'mlir/examples')
-rw-r--r--mlir/examples/toy/Ch1/CMakeLists.txt5
-rw-r--r--mlir/examples/toy/Ch1/include/toy/AST.h92
-rw-r--r--mlir/examples/toy/Ch1/include/toy/Lexer.h60
-rw-r--r--mlir/examples/toy/Ch1/include/toy/Parser.h138
-rw-r--r--mlir/examples/toy/Ch1/parser/AST.cpp114
-rw-r--r--mlir/examples/toy/Ch1/toyc.cpp14
-rw-r--r--mlir/examples/toy/Ch2/include/toy/AST.h92
-rw-r--r--mlir/examples/toy/Ch2/include/toy/Lexer.h60
-rw-r--r--mlir/examples/toy/Ch2/include/toy/Parser.h138
-rw-r--r--mlir/examples/toy/Ch2/mlir/MLIRGen.cpp2
-rw-r--r--mlir/examples/toy/Ch2/parser/AST.cpp114
-rw-r--r--mlir/examples/toy/Ch2/toyc.cpp10
-rw-r--r--mlir/examples/toy/Ch3/include/toy/AST.h92
-rw-r--r--mlir/examples/toy/Ch3/include/toy/Lexer.h60
-rw-r--r--mlir/examples/toy/Ch3/include/toy/Parser.h138
-rw-r--r--mlir/examples/toy/Ch3/mlir/MLIRGen.cpp2
-rw-r--r--mlir/examples/toy/Ch3/parser/AST.cpp116
-rw-r--r--mlir/examples/toy/Ch3/toyc.cpp14
-rw-r--r--mlir/examples/toy/Ch4/include/toy/AST.h92
-rw-r--r--mlir/examples/toy/Ch4/include/toy/Lexer.h60
-rw-r--r--mlir/examples/toy/Ch4/include/toy/Parser.h138
-rw-r--r--mlir/examples/toy/Ch4/mlir/MLIRGen.cpp4
-rw-r--r--mlir/examples/toy/Ch4/parser/AST.cpp114
-rw-r--r--mlir/examples/toy/Ch4/toyc.cpp14
-rw-r--r--mlir/examples/toy/Ch5/include/toy/AST.h92
-rw-r--r--mlir/examples/toy/Ch5/include/toy/Lexer.h60
-rw-r--r--mlir/examples/toy/Ch5/include/toy/Parser.h138
-rw-r--r--mlir/examples/toy/Ch5/mlir/MLIRGen.cpp2
-rw-r--r--mlir/examples/toy/Ch5/parser/AST.cpp114
-rw-r--r--mlir/examples/toy/Ch5/toyc.cpp16
-rw-r--r--mlir/examples/toy/Ch6/include/toy/AST.h92
-rw-r--r--mlir/examples/toy/Ch6/include/toy/Lexer.h60
-rw-r--r--mlir/examples/toy/Ch6/include/toy/Parser.h138
-rw-r--r--mlir/examples/toy/Ch6/mlir/MLIRGen.cpp2
-rw-r--r--mlir/examples/toy/Ch6/parser/AST.cpp114
-rw-r--r--mlir/examples/toy/Ch6/toyc.cpp20
36 files changed, 1225 insertions, 1306 deletions
diff --git a/mlir/examples/toy/Ch1/CMakeLists.txt b/mlir/examples/toy/Ch1/CMakeLists.txt
index dd26cf140be..f4e85556130 100644
--- a/mlir/examples/toy/Ch1/CMakeLists.txt
+++ b/mlir/examples/toy/Ch1/CMakeLists.txt
@@ -6,4 +6,7 @@ add_toy_chapter(toyc-ch1
toyc.cpp
parser/AST.cpp
)
-include_directories(include/) \ No newline at end of file
+include_directories(include/)
+target_link_libraries(toyc-ch1
+ PRIVATE
+ MLIRSupport)
diff --git a/mlir/examples/toy/Ch1/include/toy/AST.h b/mlir/examples/toy/Ch1/include/toy/AST.h
index 2ad3392c11a..901164b0f39 100644
--- a/mlir/examples/toy/Ch1/include/toy/AST.h
+++ b/mlir/examples/toy/Ch1/include/toy/AST.h
@@ -54,7 +54,6 @@ public:
ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}
-
virtual ~ExprAST() = default;
ExprASTKind getKind() const { return kind; }
@@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
double Val;
public:
- NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+ NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
double getValue() { return Val; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
};
/// Expression class for a literal value.
@@ -93,10 +92,11 @@ public:
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}
- std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
- std::vector<int64_t> &getDims() { return dims; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
+ llvm::ArrayRef<int64_t> getDims() { return dims; }
+
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
};
/// Expression class for referencing a variable, like "a".
@@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
std::string name;
public:
- VariableExprAST(Location loc, const std::string &name)
+ VariableExprAST(Location loc, llvm::StringRef name)
: ExprAST(Expr_Var, loc), name(name) {}
llvm::StringRef getName() { return name; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
};
/// Expression class for defining a variable.
@@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
std::unique_ptr<ExprAST> initVal;
public:
- VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}
llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
- VarType &getType() { return type; }
+ const VarType &getType() { return type; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
};
/// Expression class for a return operator.
@@ -144,61 +144,61 @@ public:
llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
- return llvm::NoneType();
+ return llvm::None;
}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
};
/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
- char Op;
- std::unique_ptr<ExprAST> LHS, RHS;
+ char op;
+ std::unique_ptr<ExprAST> lhs, rhs;
public:
- char getOp() { return Op; }
- ExprAST *getLHS() { return LHS.get(); }
- ExprAST *getRHS() { return RHS.get(); }
+ char getOp() { return op; }
+ ExprAST *getLHS() { return lhs.get(); }
+ ExprAST *getRHS() { return rhs.get(); }
- BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
- std::unique_ptr<ExprAST> RHS)
- : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
- RHS(std::move(RHS)) {}
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
+ std::unique_ptr<ExprAST> rhs)
+ : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
+ rhs(std::move(rhs)) {}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
};
/// Expression class for function calls.
class CallExprAST : public ExprAST {
- std::string Callee;
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::string callee;
+ std::vector<std::unique_ptr<ExprAST>> args;
public:
- CallExprAST(Location loc, const std::string &Callee,
- std::vector<std::unique_ptr<ExprAST>> Args)
- : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+ CallExprAST(Location loc, const std::string &callee,
+ std::vector<std::unique_ptr<ExprAST>> args)
+ : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
- llvm::StringRef getCallee() { return Callee; }
- llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+ llvm::StringRef getCallee() { return callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
};
/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
- std::unique_ptr<ExprAST> Arg;
+ std::unique_ptr<ExprAST> arg;
public:
- PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
- : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+ : ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
- ExprAST *getArg() { return Arg.get(); }
+ ExprAST *getArg() { return arg.get(); }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
};
/// This class represents the "prototype" for a function, which captures its
@@ -215,23 +215,21 @@ public:
: location(location), name(name), args(std::move(args)) {}
const Location &loc() { return location; }
- const std::string &getName() const { return name; }
- const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
- return args;
- }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
};
/// This class represents a function definition itself.
class FunctionAST {
- std::unique_ptr<PrototypeAST> Proto;
- std::unique_ptr<ExprASTList> Body;
+ std::unique_ptr<PrototypeAST> proto;
+ std::unique_ptr<ExprASTList> body;
public:
- FunctionAST(std::unique_ptr<PrototypeAST> Proto,
- std::unique_ptr<ExprASTList> Body)
- : Proto(std::move(Proto)), Body(std::move(Body)) {}
- PrototypeAST *getProto() { return Proto.get(); }
- ExprASTList *getBody() { return Body.get(); }
+ FunctionAST(std::unique_ptr<PrototypeAST> proto,
+ std::unique_ptr<ExprASTList> body)
+ : proto(std::move(proto)), body(std::move(body)) {}
+ PrototypeAST *getProto() { return proto.get(); }
+ ExprASTList *getBody() { return body.get(); }
};
/// This class represents a list of functions to be processed together
diff --git a/mlir/examples/toy/Ch1/include/toy/Lexer.h b/mlir/examples/toy/Ch1/include/toy/Lexer.h
index 21f92614912..144388c460c 100644
--- a/mlir/examples/toy/Ch1/include/toy/Lexer.h
+++ b/mlir/examples/toy/Ch1/include/toy/Lexer.h
@@ -89,13 +89,13 @@ public:
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
llvm::StringRef getId() {
assert(curTok == tok_identifier);
- return IdentifierStr;
+ return identifierStr;
}
/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
- return NumVal;
+ return numVal;
}
/// Return the location for the beginning of the current token.
@@ -135,56 +135,58 @@ private:
/// Return the next token from standard input.
Token getTok() {
// Skip any whitespace.
- while (isspace(LastChar))
- LastChar = Token(getNextChar());
+ while (isspace(lastChar))
+ lastChar = Token(getNextChar());
// Save the current location before reading the token characters.
lastLocation.line = curLineNum;
lastLocation.col = curCol;
- if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
- IdentifierStr = (char)LastChar;
- while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
- IdentifierStr += (char)LastChar;
+ // Identifier: [a-zA-Z][a-zA-Z0-9_]*
+ if (isalpha(lastChar)) {
+ identifierStr = (char)lastChar;
+ while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
+ identifierStr += (char)lastChar;
- if (IdentifierStr == "return")
+ if (identifierStr == "return")
return tok_return;
- if (IdentifierStr == "def")
+ if (identifierStr == "def")
return tok_def;
- if (IdentifierStr == "var")
+ if (identifierStr == "var")
return tok_var;
return tok_identifier;
}
- if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
- std::string NumStr;
+ // Number: [0-9.]+
+ if (isdigit(lastChar) || lastChar == '.') {
+ std::string numStr;
do {
- NumStr += LastChar;
- LastChar = Token(getNextChar());
- } while (isdigit(LastChar) || LastChar == '.');
+ numStr += lastChar;
+ lastChar = Token(getNextChar());
+ } while (isdigit(lastChar) || lastChar == '.');
- NumVal = strtod(NumStr.c_str(), nullptr);
+ numVal = strtod(numStr.c_str(), nullptr);
return tok_number;
}
- if (LastChar == '#') {
+ if (lastChar == '#') {
// Comment until end of line.
- do
- LastChar = Token(getNextChar());
- while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+ do {
+ lastChar = Token(getNextChar());
+ } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
- if (LastChar != EOF)
+ if (lastChar != EOF)
return getTok();
}
// Check for end of file. Don't eat the EOF.
- if (LastChar == EOF)
+ if (lastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
- Token ThisChar = Token(LastChar);
- LastChar = Token(getNextChar());
- return ThisChar;
+ Token thisChar = Token(lastChar);
+ lastChar = Token(getNextChar());
+ return thisChar;
}
/// The last token read from the input.
@@ -194,15 +196,15 @@ private:
Location lastLocation;
/// If the current Token is an identifier, this string contains the value.
- std::string IdentifierStr;
+ std::string identifierStr;
/// If the current Token is a number, this contains the value.
- double NumVal = 0;
+ double numVal = 0;
/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
/// we can't put it back in the stream after reading from it.
- Token LastChar = Token(' ');
+ Token lastChar = Token(' ');
/// Keep track of the current line number in the input stream
int curLineNum = 0;
diff --git a/mlir/examples/toy/Ch1/include/toy/Parser.h b/mlir/examples/toy/Ch1/include/toy/Parser.h
index ec3d7654a85..9e219e56551 100644
--- a/mlir/examples/toy/Ch1/include/toy/Parser.h
+++ b/mlir/examples/toy/Ch1/include/toy/Parser.h
@@ -48,13 +48,13 @@ public:
Parser(Lexer &lexer) : lexer(lexer) {}
/// Parse a full Module. A module is a list of function definitions.
- std::unique_ptr<ModuleAST> ParseModule() {
+ std::unique_ptr<ModuleAST> parseModule() {
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<FunctionAST> functions;
- while (auto F = ParseDefinition()) {
- functions.push_back(std::move(*F));
+ while (auto f = parseDefinition()) {
+ functions.push_back(std::move(*f));
if (lexer.getCurToken() == tok_eof)
break;
}
@@ -70,14 +70,14 @@ private:
/// Parse a return statement.
/// return :== return ; | return expr ;
- std::unique_ptr<ReturnExprAST> ParseReturn() {
+ std::unique_ptr<ReturnExprAST> parseReturn() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
llvm::Optional<std::unique_ptr<ExprAST>> expr;
if (lexer.getCurToken() != ';') {
- expr = ParseExpression();
+ expr = parseExpression();
if (!expr)
return nullptr;
}
@@ -86,18 +86,18 @@ private:
/// Parse a literal number.
/// numberexpr ::= number
- std::unique_ptr<ExprAST> ParseNumberExpr() {
+ std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
- auto Result =
+ auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
- return std::move(Result);
+ return std::move(result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
- std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
+ std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
@@ -108,13 +108,13 @@ private:
do {
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[') {
- values.push_back(ParseTensorLiteralExpr());
+ values.push_back(parseTensorLiteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
- values.push_back(ParseNumberExpr());
+ values.push_back(parseNumberExpr());
}
// End of this list on ']'
@@ -130,8 +130,10 @@ private:
if (values.empty())
return parseError<ExprAST>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
+
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
+
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
@@ -143,7 +145,7 @@ private:
"inside literal expression");
// Append the nested dimensions to the current level
- auto &firstDims = firstLiteral->getDims();
+ auto firstDims = firstLiteral->getDims();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
@@ -162,22 +164,22 @@ private:
}
/// parenexpr ::= '(' expression ')'
- std::unique_ptr<ExprAST> ParseParenExpr() {
+ std::unique_ptr<ExprAST> parseParenExpr() {
lexer.getNextToken(); // eat (.
- auto V = ParseExpression();
- if (!V)
+ auto v = parseExpression();
+ if (!v)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExprAST>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
- return V;
+ return v;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
- std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::unique_ptr<ExprAST> parseIdentifierExpr() {
std::string name = lexer.getId();
auto loc = lexer.getLastLocation();
@@ -188,11 +190,11 @@ private:
// This is a function call.
lexer.consume(Token('('));
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::vector<std::unique_ptr<ExprAST>> args;
if (lexer.getCurToken() != ')') {
while (true) {
- if (auto Arg = ParseExpression())
- Args.push_back(std::move(Arg));
+ if (auto arg = parseExpression())
+ args.push_back(std::move(arg));
else
return nullptr;
@@ -208,14 +210,14 @@ private:
// It can be a builtin call to print
if (name == "print") {
- if (Args.size() != 1)
+ if (args.size() != 1)
return parseError<ExprAST>("<single arg>", "as argument to print()");
- return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
}
// Call to a user-defined function
- return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
}
/// primary
@@ -223,20 +225,20 @@ private:
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
- std::unique_ptr<ExprAST> ParsePrimary() {
+ std::unique_ptr<ExprAST> parsePrimary() {
switch (lexer.getCurToken()) {
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
- return ParseIdentifierExpr();
+ return parseIdentifierExpr();
case tok_number:
- return ParseNumberExpr();
+ return parseNumberExpr();
case '(':
- return ParseParenExpr();
+ return parseParenExpr();
case '[':
- return ParseTensorLiteralExpr();
+ return parseTensorLiteralExpr();
case ';':
return nullptr;
case '}':
@@ -248,54 +250,54 @@ private:
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
- std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
- std::unique_ptr<ExprAST> LHS) {
+ std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+ std::unique_ptr<ExprAST> lhs) {
// If this is a binop, find its precedence.
while (true) {
- int TokPrec = GetTokPrecedence();
+ int tokPrec = getTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
- if (TokPrec < ExprPrec)
- return LHS;
+ if (tokPrec < exprPrec)
+ return lhs;
// Okay, we know this is a binop.
- int BinOp = lexer.getCurToken();
- lexer.consume(Token(BinOp));
+ int binOp = lexer.getCurToken();
+ lexer.consume(Token(binOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
- auto RHS = ParsePrimary();
- if (!RHS)
+ auto rhs = parsePrimary();
+ if (!rhs)
return parseError<ExprAST>("expression", "to complete binary operator");
- // If BinOp binds less tightly with RHS than the operator after RHS, let
- // the pending operator take RHS as its LHS.
- int NextPrec = GetTokPrecedence();
- if (TokPrec < NextPrec) {
- RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
- if (!RHS)
+ // If BinOp binds less tightly with rhs than the operator after rhs, let
+ // the pending operator take rhs as its lhs.
+ int nextPrec = getTokPrecedence();
+ if (tokPrec < nextPrec) {
+ rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+ if (!rhs)
return nullptr;
}
- // Merge LHS/RHS.
- LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
- std::move(LHS), std::move(RHS));
+ // Merge lhs/RHS.
+ lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+ std::move(lhs), std::move(rhs));
}
}
- /// expression::= primary binoprhs
- std::unique_ptr<ExprAST> ParseExpression() {
- auto LHS = ParsePrimary();
- if (!LHS)
+ /// expression::= primary binop rhs
+ std::unique_ptr<ExprAST> parseExpression() {
+ auto lhs = parsePrimary();
+ if (!lhs)
return nullptr;
- return ParseBinOpRHS(0, std::move(LHS));
+ return parseBinOpRHS(0, std::move(lhs));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
- std::unique_ptr<VarType> ParseType() {
+ std::unique_ptr<VarType> parseType() {
if (lexer.getCurToken() != '<')
return parseError<VarType>("<", "to begin type");
lexer.getNextToken(); // eat <
@@ -319,7 +321,7 @@ private:
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
- std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ std::unique_ptr<VarDeclExprAST> parseDeclaration() {
if (lexer.getCurToken() != tok_var)
return parseError<VarDeclExprAST>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
@@ -333,7 +335,7 @@ private:
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<') {
- type = ParseType();
+ type = parseType();
if (!type)
return nullptr;
}
@@ -341,7 +343,7 @@ private:
if (!type)
type = std::make_unique<VarType>();
lexer.consume(Token('='));
- auto expr = ParseExpression();
+ auto expr = parseExpression();
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
@@ -352,7 +354,7 @@ private:
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
- std::unique_ptr<ExprASTList> ParseBlock() {
+ std::unique_ptr<ExprASTList> parseBlock() {
if (lexer.getCurToken() != '{')
return parseError<ExprASTList>("{", "to begin block");
lexer.consume(Token('{'));
@@ -366,19 +368,19 @@ private:
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
if (lexer.getCurToken() == tok_var) {
// Variable declaration
- auto varDecl = ParseDeclaration();
+ auto varDecl = parseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
} else if (lexer.getCurToken() == tok_return) {
// Return statement
- auto ret = ParseReturn();
+ auto ret = parseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
} else {
// General expression
- auto expr = ParseExpression();
+ auto expr = parseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
@@ -401,13 +403,13 @@ private:
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
- std::unique_ptr<PrototypeAST> ParsePrototype() {
+ std::unique_ptr<PrototypeAST> parsePrototype() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>("function name", "in prototype");
- std::string FnName = lexer.getId();
+ std::string fnName = lexer.getId();
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
@@ -435,7 +437,7 @@ private:
// success.
lexer.consume(Token(')'));
- return std::make_unique<PrototypeAST>(std::move(loc), FnName,
+ return std::make_unique<PrototypeAST>(std::move(loc), fnName,
std::move(args));
}
@@ -443,18 +445,18 @@ private:
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
- std::unique_ptr<FunctionAST> ParseDefinition() {
- auto Proto = ParsePrototype();
- if (!Proto)
+ std::unique_ptr<FunctionAST> parseDefinition() {
+ auto proto = parsePrototype();
+ if (!proto)
return nullptr;
- if (auto block = ParseBlock())
- return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ if (auto block = parseBlock())
+ return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
- int GetTokPrecedence() {
+ int getTokPrecedence() {
if (!isascii(lexer.getCurToken()))
return -1;
diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp
index fde8b101e83..0c7735ec9a4 100644
--- a/mlir/examples/toy/Ch1/parser/AST.cpp
+++ b/mlir/examples/toy/Ch1/parser/AST.cpp
@@ -21,6 +21,7 @@
#include "toy/AST.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@@ -40,22 +41,22 @@ struct Indent {
/// the way. The only data member is the current indentation level.
class ASTDumper {
public:
- void dump(ModuleAST *Node);
+ void dump(ModuleAST *node);
private:
- void dump(VarType &type);
+ void dump(const VarType &type);
void dump(VarDeclExprAST *varDecl);
void dump(ExprAST *expr);
void dump(ExprASTList *exprList);
void dump(NumberExprAST *num);
- void dump(LiteralExprAST *Node);
- void dump(VariableExprAST *Node);
- void dump(ReturnExprAST *Node);
- void dump(BinaryExprAST *Node);
- void dump(CallExprAST *Node);
- void dump(PrintExprAST *Node);
- void dump(PrototypeAST *Node);
- void dump(FunctionAST *Node);
+ void dump(LiteralExprAST *node);
+ void dump(VariableExprAST *node);
+ void dump(ReturnExprAST *node);
+ void dump(BinaryExprAST *node);
+ void dump(CallExprAST *node);
+ void dump(PrintExprAST *node);
+ void dump(PrototypeAST *node);
+ void dump(FunctionAST *node);
// Actually print spaces matching the current indentation level
void indent() {
@@ -68,8 +69,8 @@ private:
} // namespace
/// Return a formatted string for the location of any node
-template <typename T> static std::string loc(T *Node) {
- const auto &loc = Node->loc();
+template <typename T> static std::string loc(T *node) {
+ const auto &loc = node->loc();
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
llvm::Twine(loc.col))
.str();
@@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *lit_or_num) {
+void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
- if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+ if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
return;
}
- auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+ auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
// Print the dimension for this literal first
llvm::errs() << "<";
- {
- const char *sep = "";
- for (auto dim : literal->getDims()) {
- llvm::errs() << sep << dim;
- sep = ", ";
- }
- }
+ mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
- const char *sep = "";
- for (auto &elt : literal->getValues()) {
- llvm::errs() << sep;
- printLitHelper(elt.get());
- sep = ", ";
- }
+ mlir::interleaveComma(literal->getValues(), llvm::errs(),
+ [&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
/// Print a literal, see the recursive helper above for the implementation.
-void ASTDumper::dump(LiteralExprAST *Node) {
+void ASTDumper::dump(LiteralExprAST *node) {
INDENT();
llvm::errs() << "Literal: ";
- printLitHelper(Node);
- llvm::errs() << " " << loc(Node) << "\n";
+ printLitHelper(node);
+ llvm::errs() << " " << loc(node) << "\n";
}
/// Print a variable reference (just a name).
-void ASTDumper::dump(VariableExprAST *Node) {
+void ASTDumper::dump(VariableExprAST *node) {
INDENT();
- llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+ llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
}
/// Return statement print the return and its (optional) argument.
-void ASTDumper::dump(ReturnExprAST *Node) {
+void ASTDumper::dump(ReturnExprAST *node) {
INDENT();
llvm::errs() << "Return\n";
- if (Node->getExpr().hasValue())
- return dump(*Node->getExpr());
+ if (node->getExpr().hasValue())
+ return dump(*node->getExpr());
{
INDENT();
llvm::errs() << "(void)\n";
@@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
}
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
-void ASTDumper::dump(BinaryExprAST *Node) {
+void ASTDumper::dump(BinaryExprAST *node) {
INDENT();
- llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
- dump(Node->getLHS());
- dump(Node->getRHS());
+ llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+ dump(node->getLHS());
+ dump(node->getRHS());
}
/// Print a call expression, first the callee name and the list of args by
/// recursing into each individual argument.
-void ASTDumper::dump(CallExprAST *Node) {
+void ASTDumper::dump(CallExprAST *node) {
INDENT();
- llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
- for (auto &arg : Node->getArgs())
+ llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+ for (auto &arg : node->getArgs())
dump(arg.get());
indent();
llvm::errs() << "]\n";
}
/// Print a builtin print call, first the builtin name and then the argument.
-void ASTDumper::dump(PrintExprAST *Node) {
+void ASTDumper::dump(PrintExprAST *node) {
INDENT();
- llvm::errs() << "Print [ " << loc(Node) << "\n";
- dump(Node->getArg());
+ llvm::errs() << "Print [ " << loc(node) << "\n";
+ dump(node->getArg());
indent();
llvm::errs() << "]\n";
}
/// Print type: only the shape is printed in between '<' and '>'
-void ASTDumper::dump(VarType &type) {
+void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
- const char *sep = "";
- for (auto shape : type.shape) {
- llvm::errs() << sep << shape;
- sep = ", ";
- }
+ mlir::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
/// Print a function prototype, first the function name, and then the list of
/// parameters names.
-void ASTDumper::dump(PrototypeAST *Node) {
+void ASTDumper::dump(PrototypeAST *node) {
INDENT();
- llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
- const char *sep = "";
- for (auto &arg : Node->getArgs()) {
- llvm::errs() << sep << arg->getName();
- sep = ", ";
- }
+ mlir::interleaveComma(node->getArgs(), llvm::errs(),
+ [](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}
/// Print a function, first the prototype and then the body.
-void ASTDumper::dump(FunctionAST *Node) {
+void ASTDumper::dump(FunctionAST *node) {
INDENT();
llvm::errs() << "Function \n";
- dump(Node->getProto());
- dump(Node->getBody());
+ dump(node->getProto());
+ dump(node->getBody());
}
/// Print a module, actually loop over the functions and print them in sequence.
-void ASTDumper::dump(ModuleAST *Node) {
+void ASTDumper::dump(ModuleAST *node) {
INDENT();
llvm::errs() << "Module:\n";
- for (auto &F : *Node)
- dump(&F);
+ for (auto &f : *node)
+ dump(&f);
}
namespace toy {
diff --git a/mlir/examples/toy/Ch1/toyc.cpp b/mlir/examples/toy/Ch1/toyc.cpp
index dd308caa24b..37794d5c4d9 100644
--- a/mlir/examples/toy/Ch1/toyc.cpp
+++ b/mlir/examples/toy/Ch1/toyc.cpp
@@ -30,7 +30,7 @@
using namespace toy;
namespace cl = llvm::cl;
-static cl::opt<std::string> InputFilename(cl::Positional,
+static cl::opt<std::string> inputFilename(cl::Positional,
cl::desc("<input toy file>"),
cl::init("-"),
cl::value_desc("filename"));
@@ -44,22 +44,22 @@ static cl::opt<enum Action>
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
- if (std::error_code EC = FileOrErr.getError()) {
- llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return nullptr;
}
- auto buffer = FileOrErr.get()->getBuffer();
+ auto buffer = fileOrErr.get()->getBuffer();
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
Parser parser(lexer);
- return parser.ParseModule();
+ return parser.parseModule();
}
int main(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
- auto moduleAST = parseInputFile(InputFilename);
+ auto moduleAST = parseInputFile(inputFilename);
if (!moduleAST)
return 1;
diff --git a/mlir/examples/toy/Ch2/include/toy/AST.h b/mlir/examples/toy/Ch2/include/toy/AST.h
index 2ad3392c11a..901164b0f39 100644
--- a/mlir/examples/toy/Ch2/include/toy/AST.h
+++ b/mlir/examples/toy/Ch2/include/toy/AST.h
@@ -54,7 +54,6 @@ public:
ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}
-
virtual ~ExprAST() = default;
ExprASTKind getKind() const { return kind; }
@@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
double Val;
public:
- NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+ NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
double getValue() { return Val; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
};
/// Expression class for a literal value.
@@ -93,10 +92,11 @@ public:
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}
- std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
- std::vector<int64_t> &getDims() { return dims; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
+ llvm::ArrayRef<int64_t> getDims() { return dims; }
+
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
};
/// Expression class for referencing a variable, like "a".
@@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
std::string name;
public:
- VariableExprAST(Location loc, const std::string &name)
+ VariableExprAST(Location loc, llvm::StringRef name)
: ExprAST(Expr_Var, loc), name(name) {}
llvm::StringRef getName() { return name; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
};
/// Expression class for defining a variable.
@@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
std::unique_ptr<ExprAST> initVal;
public:
- VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}
llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
- VarType &getType() { return type; }
+ const VarType &getType() { return type; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
};
/// Expression class for a return operator.
@@ -144,61 +144,61 @@ public:
llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
- return llvm::NoneType();
+ return llvm::None;
}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
};
/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
- char Op;
- std::unique_ptr<ExprAST> LHS, RHS;
+ char op;
+ std::unique_ptr<ExprAST> lhs, rhs;
public:
- char getOp() { return Op; }
- ExprAST *getLHS() { return LHS.get(); }
- ExprAST *getRHS() { return RHS.get(); }
+ char getOp() { return op; }
+ ExprAST *getLHS() { return lhs.get(); }
+ ExprAST *getRHS() { return rhs.get(); }
- BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
- std::unique_ptr<ExprAST> RHS)
- : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
- RHS(std::move(RHS)) {}
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
+ std::unique_ptr<ExprAST> rhs)
+ : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
+ rhs(std::move(rhs)) {}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
};
/// Expression class for function calls.
class CallExprAST : public ExprAST {
- std::string Callee;
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::string callee;
+ std::vector<std::unique_ptr<ExprAST>> args;
public:
- CallExprAST(Location loc, const std::string &Callee,
- std::vector<std::unique_ptr<ExprAST>> Args)
- : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+ CallExprAST(Location loc, const std::string &callee,
+ std::vector<std::unique_ptr<ExprAST>> args)
+ : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
- llvm::StringRef getCallee() { return Callee; }
- llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+ llvm::StringRef getCallee() { return callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
};
/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
- std::unique_ptr<ExprAST> Arg;
+ std::unique_ptr<ExprAST> arg;
public:
- PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
- : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+ : ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
- ExprAST *getArg() { return Arg.get(); }
+ ExprAST *getArg() { return arg.get(); }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
};
/// This class represents the "prototype" for a function, which captures its
@@ -215,23 +215,21 @@ public:
: location(location), name(name), args(std::move(args)) {}
const Location &loc() { return location; }
- const std::string &getName() const { return name; }
- const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
- return args;
- }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
};
/// This class represents a function definition itself.
class FunctionAST {
- std::unique_ptr<PrototypeAST> Proto;
- std::unique_ptr<ExprASTList> Body;
+ std::unique_ptr<PrototypeAST> proto;
+ std::unique_ptr<ExprASTList> body;
public:
- FunctionAST(std::unique_ptr<PrototypeAST> Proto,
- std::unique_ptr<ExprASTList> Body)
- : Proto(std::move(Proto)), Body(std::move(Body)) {}
- PrototypeAST *getProto() { return Proto.get(); }
- ExprASTList *getBody() { return Body.get(); }
+ FunctionAST(std::unique_ptr<PrototypeAST> proto,
+ std::unique_ptr<ExprASTList> body)
+ : proto(std::move(proto)), body(std::move(body)) {}
+ PrototypeAST *getProto() { return proto.get(); }
+ ExprASTList *getBody() { return body.get(); }
};
/// This class represents a list of functions to be processed together
diff --git a/mlir/examples/toy/Ch2/include/toy/Lexer.h b/mlir/examples/toy/Ch2/include/toy/Lexer.h
index 21f92614912..144388c460c 100644
--- a/mlir/examples/toy/Ch2/include/toy/Lexer.h
+++ b/mlir/examples/toy/Ch2/include/toy/Lexer.h
@@ -89,13 +89,13 @@ public:
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
llvm::StringRef getId() {
assert(curTok == tok_identifier);
- return IdentifierStr;
+ return identifierStr;
}
/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
- return NumVal;
+ return numVal;
}
/// Return the location for the beginning of the current token.
@@ -135,56 +135,58 @@ private:
/// Return the next token from standard input.
Token getTok() {
// Skip any whitespace.
- while (isspace(LastChar))
- LastChar = Token(getNextChar());
+ while (isspace(lastChar))
+ lastChar = Token(getNextChar());
// Save the current location before reading the token characters.
lastLocation.line = curLineNum;
lastLocation.col = curCol;
- if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
- IdentifierStr = (char)LastChar;
- while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
- IdentifierStr += (char)LastChar;
+ // Identifier: [a-zA-Z][a-zA-Z0-9_]*
+ if (isalpha(lastChar)) {
+ identifierStr = (char)lastChar;
+ while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
+ identifierStr += (char)lastChar;
- if (IdentifierStr == "return")
+ if (identifierStr == "return")
return tok_return;
- if (IdentifierStr == "def")
+ if (identifierStr == "def")
return tok_def;
- if (IdentifierStr == "var")
+ if (identifierStr == "var")
return tok_var;
return tok_identifier;
}
- if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
- std::string NumStr;
+ // Number: [0-9.]+
+ if (isdigit(lastChar) || lastChar == '.') {
+ std::string numStr;
do {
- NumStr += LastChar;
- LastChar = Token(getNextChar());
- } while (isdigit(LastChar) || LastChar == '.');
+ numStr += lastChar;
+ lastChar = Token(getNextChar());
+ } while (isdigit(lastChar) || lastChar == '.');
- NumVal = strtod(NumStr.c_str(), nullptr);
+ numVal = strtod(numStr.c_str(), nullptr);
return tok_number;
}
- if (LastChar == '#') {
+ if (lastChar == '#') {
// Comment until end of line.
- do
- LastChar = Token(getNextChar());
- while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+ do {
+ lastChar = Token(getNextChar());
+ } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
- if (LastChar != EOF)
+ if (lastChar != EOF)
return getTok();
}
// Check for end of file. Don't eat the EOF.
- if (LastChar == EOF)
+ if (lastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
- Token ThisChar = Token(LastChar);
- LastChar = Token(getNextChar());
- return ThisChar;
+ Token thisChar = Token(lastChar);
+ lastChar = Token(getNextChar());
+ return thisChar;
}
/// The last token read from the input.
@@ -194,15 +196,15 @@ private:
Location lastLocation;
/// If the current Token is an identifier, this string contains the value.
- std::string IdentifierStr;
+ std::string identifierStr;
/// If the current Token is a number, this contains the value.
- double NumVal = 0;
+ double numVal = 0;
/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
/// we can't put it back in the stream after reading from it.
- Token LastChar = Token(' ');
+ Token lastChar = Token(' ');
/// Keep track of the current line number in the input stream
int curLineNum = 0;
diff --git a/mlir/examples/toy/Ch2/include/toy/Parser.h b/mlir/examples/toy/Ch2/include/toy/Parser.h
index ec3d7654a85..9e219e56551 100644
--- a/mlir/examples/toy/Ch2/include/toy/Parser.h
+++ b/mlir/examples/toy/Ch2/include/toy/Parser.h
@@ -48,13 +48,13 @@ public:
Parser(Lexer &lexer) : lexer(lexer) {}
/// Parse a full Module. A module is a list of function definitions.
- std::unique_ptr<ModuleAST> ParseModule() {
+ std::unique_ptr<ModuleAST> parseModule() {
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<FunctionAST> functions;
- while (auto F = ParseDefinition()) {
- functions.push_back(std::move(*F));
+ while (auto f = parseDefinition()) {
+ functions.push_back(std::move(*f));
if (lexer.getCurToken() == tok_eof)
break;
}
@@ -70,14 +70,14 @@ private:
/// Parse a return statement.
/// return :== return ; | return expr ;
- std::unique_ptr<ReturnExprAST> ParseReturn() {
+ std::unique_ptr<ReturnExprAST> parseReturn() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
llvm::Optional<std::unique_ptr<ExprAST>> expr;
if (lexer.getCurToken() != ';') {
- expr = ParseExpression();
+ expr = parseExpression();
if (!expr)
return nullptr;
}
@@ -86,18 +86,18 @@ private:
/// Parse a literal number.
/// numberexpr ::= number
- std::unique_ptr<ExprAST> ParseNumberExpr() {
+ std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
- auto Result =
+ auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
- return std::move(Result);
+ return std::move(result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
- std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
+ std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
@@ -108,13 +108,13 @@ private:
do {
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[') {
- values.push_back(ParseTensorLiteralExpr());
+ values.push_back(parseTensorLiteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
- values.push_back(ParseNumberExpr());
+ values.push_back(parseNumberExpr());
}
// End of this list on ']'
@@ -130,8 +130,10 @@ private:
if (values.empty())
return parseError<ExprAST>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
+
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
+
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
@@ -143,7 +145,7 @@ private:
"inside literal expression");
// Append the nested dimensions to the current level
- auto &firstDims = firstLiteral->getDims();
+ auto firstDims = firstLiteral->getDims();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
@@ -162,22 +164,22 @@ private:
}
/// parenexpr ::= '(' expression ')'
- std::unique_ptr<ExprAST> ParseParenExpr() {
+ std::unique_ptr<ExprAST> parseParenExpr() {
lexer.getNextToken(); // eat (.
- auto V = ParseExpression();
- if (!V)
+ auto v = parseExpression();
+ if (!v)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExprAST>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
- return V;
+ return v;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
- std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::unique_ptr<ExprAST> parseIdentifierExpr() {
std::string name = lexer.getId();
auto loc = lexer.getLastLocation();
@@ -188,11 +190,11 @@ private:
// This is a function call.
lexer.consume(Token('('));
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::vector<std::unique_ptr<ExprAST>> args;
if (lexer.getCurToken() != ')') {
while (true) {
- if (auto Arg = ParseExpression())
- Args.push_back(std::move(Arg));
+ if (auto arg = parseExpression())
+ args.push_back(std::move(arg));
else
return nullptr;
@@ -208,14 +210,14 @@ private:
// It can be a builtin call to print
if (name == "print") {
- if (Args.size() != 1)
+ if (args.size() != 1)
return parseError<ExprAST>("<single arg>", "as argument to print()");
- return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
}
// Call to a user-defined function
- return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
}
/// primary
@@ -223,20 +225,20 @@ private:
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
- std::unique_ptr<ExprAST> ParsePrimary() {
+ std::unique_ptr<ExprAST> parsePrimary() {
switch (lexer.getCurToken()) {
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
- return ParseIdentifierExpr();
+ return parseIdentifierExpr();
case tok_number:
- return ParseNumberExpr();
+ return parseNumberExpr();
case '(':
- return ParseParenExpr();
+ return parseParenExpr();
case '[':
- return ParseTensorLiteralExpr();
+ return parseTensorLiteralExpr();
case ';':
return nullptr;
case '}':
@@ -248,54 +250,54 @@ private:
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
- std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
- std::unique_ptr<ExprAST> LHS) {
+ std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+ std::unique_ptr<ExprAST> lhs) {
// If this is a binop, find its precedence.
while (true) {
- int TokPrec = GetTokPrecedence();
+ int tokPrec = getTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
- if (TokPrec < ExprPrec)
- return LHS;
+ if (tokPrec < exprPrec)
+ return lhs;
// Okay, we know this is a binop.
- int BinOp = lexer.getCurToken();
- lexer.consume(Token(BinOp));
+ int binOp = lexer.getCurToken();
+ lexer.consume(Token(binOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
- auto RHS = ParsePrimary();
- if (!RHS)
+ auto rhs = parsePrimary();
+ if (!rhs)
return parseError<ExprAST>("expression", "to complete binary operator");
- // If BinOp binds less tightly with RHS than the operator after RHS, let
- // the pending operator take RHS as its LHS.
- int NextPrec = GetTokPrecedence();
- if (TokPrec < NextPrec) {
- RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
- if (!RHS)
+ // If BinOp binds less tightly with rhs than the operator after rhs, let
+ // the pending operator take rhs as its lhs.
+ int nextPrec = getTokPrecedence();
+ if (tokPrec < nextPrec) {
+ rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+ if (!rhs)
return nullptr;
}
- // Merge LHS/RHS.
- LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
- std::move(LHS), std::move(RHS));
+ // Merge lhs/RHS.
+ lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+ std::move(lhs), std::move(rhs));
}
}
- /// expression::= primary binoprhs
- std::unique_ptr<ExprAST> ParseExpression() {
- auto LHS = ParsePrimary();
- if (!LHS)
+ /// expression::= primary binop rhs
+ std::unique_ptr<ExprAST> parseExpression() {
+ auto lhs = parsePrimary();
+ if (!lhs)
return nullptr;
- return ParseBinOpRHS(0, std::move(LHS));
+ return parseBinOpRHS(0, std::move(lhs));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
- std::unique_ptr<VarType> ParseType() {
+ std::unique_ptr<VarType> parseType() {
if (lexer.getCurToken() != '<')
return parseError<VarType>("<", "to begin type");
lexer.getNextToken(); // eat <
@@ -319,7 +321,7 @@ private:
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
- std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ std::unique_ptr<VarDeclExprAST> parseDeclaration() {
if (lexer.getCurToken() != tok_var)
return parseError<VarDeclExprAST>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
@@ -333,7 +335,7 @@ private:
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<') {
- type = ParseType();
+ type = parseType();
if (!type)
return nullptr;
}
@@ -341,7 +343,7 @@ private:
if (!type)
type = std::make_unique<VarType>();
lexer.consume(Token('='));
- auto expr = ParseExpression();
+ auto expr = parseExpression();
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
@@ -352,7 +354,7 @@ private:
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
- std::unique_ptr<ExprASTList> ParseBlock() {
+ std::unique_ptr<ExprASTList> parseBlock() {
if (lexer.getCurToken() != '{')
return parseError<ExprASTList>("{", "to begin block");
lexer.consume(Token('{'));
@@ -366,19 +368,19 @@ private:
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
if (lexer.getCurToken() == tok_var) {
// Variable declaration
- auto varDecl = ParseDeclaration();
+ auto varDecl = parseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
} else if (lexer.getCurToken() == tok_return) {
// Return statement
- auto ret = ParseReturn();
+ auto ret = parseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
} else {
// General expression
- auto expr = ParseExpression();
+ auto expr = parseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
@@ -401,13 +403,13 @@ private:
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
- std::unique_ptr<PrototypeAST> ParsePrototype() {
+ std::unique_ptr<PrototypeAST> parsePrototype() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>("function name", "in prototype");
- std::string FnName = lexer.getId();
+ std::string fnName = lexer.getId();
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
@@ -435,7 +437,7 @@ private:
// success.
lexer.consume(Token(')'));
- return std::make_unique<PrototypeAST>(std::move(loc), FnName,
+ return std::make_unique<PrototypeAST>(std::move(loc), fnName,
std::move(args));
}
@@ -443,18 +445,18 @@ private:
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
- std::unique_ptr<FunctionAST> ParseDefinition() {
- auto Proto = ParsePrototype();
- if (!Proto)
+ std::unique_ptr<FunctionAST> parseDefinition() {
+ auto proto = parsePrototype();
+ if (!proto)
return nullptr;
- if (auto block = ParseBlock())
- return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ if (auto block = parseBlock())
+ return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
- int GetTokPrecedence() {
+ int getTokPrecedence() {
if (!isascii(lexer.getCurToken()))
return -1;
diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
index 8b434b139c7..da474e809b3 100644
--- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
@@ -143,7 +143,7 @@ private:
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
- auto &protoArgs = funcAST.getProto()->getArgs();
+ auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
diff --git a/mlir/examples/toy/Ch2/parser/AST.cpp b/mlir/examples/toy/Ch2/parser/AST.cpp
index fde8b101e83..0c7735ec9a4 100644
--- a/mlir/examples/toy/Ch2/parser/AST.cpp
+++ b/mlir/examples/toy/Ch2/parser/AST.cpp
@@ -21,6 +21,7 @@
#include "toy/AST.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@@ -40,22 +41,22 @@ struct Indent {
/// the way. The only data member is the current indentation level.
class ASTDumper {
public:
- void dump(ModuleAST *Node);
+ void dump(ModuleAST *node);
private:
- void dump(VarType &type);
+ void dump(const VarType &type);
void dump(VarDeclExprAST *varDecl);
void dump(ExprAST *expr);
void dump(ExprASTList *exprList);
void dump(NumberExprAST *num);
- void dump(LiteralExprAST *Node);
- void dump(VariableExprAST *Node);
- void dump(ReturnExprAST *Node);
- void dump(BinaryExprAST *Node);
- void dump(CallExprAST *Node);
- void dump(PrintExprAST *Node);
- void dump(PrototypeAST *Node);
- void dump(FunctionAST *Node);
+ void dump(LiteralExprAST *node);
+ void dump(VariableExprAST *node);
+ void dump(ReturnExprAST *node);
+ void dump(BinaryExprAST *node);
+ void dump(CallExprAST *node);
+ void dump(PrintExprAST *node);
+ void dump(PrototypeAST *node);
+ void dump(FunctionAST *node);
// Actually print spaces matching the current indentation level
void indent() {
@@ -68,8 +69,8 @@ private:
} // namespace
/// Return a formatted string for the location of any node
-template <typename T> static std::string loc(T *Node) {
- const auto &loc = Node->loc();
+template <typename T> static std::string loc(T *node) {
+ const auto &loc = node->loc();
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
llvm::Twine(loc.col))
.str();
@@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *lit_or_num) {
+void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
- if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+ if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
return;
}
- auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+ auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
// Print the dimension for this literal first
llvm::errs() << "<";
- {
- const char *sep = "";
- for (auto dim : literal->getDims()) {
- llvm::errs() << sep << dim;
- sep = ", ";
- }
- }
+ mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
- const char *sep = "";
- for (auto &elt : literal->getValues()) {
- llvm::errs() << sep;
- printLitHelper(elt.get());
- sep = ", ";
- }
+ mlir::interleaveComma(literal->getValues(), llvm::errs(),
+ [&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
/// Print a literal, see the recursive helper above for the implementation.
-void ASTDumper::dump(LiteralExprAST *Node) {
+void ASTDumper::dump(LiteralExprAST *node) {
INDENT();
llvm::errs() << "Literal: ";
- printLitHelper(Node);
- llvm::errs() << " " << loc(Node) << "\n";
+ printLitHelper(node);
+ llvm::errs() << " " << loc(node) << "\n";
}
/// Print a variable reference (just a name).
-void ASTDumper::dump(VariableExprAST *Node) {
+void ASTDumper::dump(VariableExprAST *node) {
INDENT();
- llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+ llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
}
/// Return statement print the return and its (optional) argument.
-void ASTDumper::dump(ReturnExprAST *Node) {
+void ASTDumper::dump(ReturnExprAST *node) {
INDENT();
llvm::errs() << "Return\n";
- if (Node->getExpr().hasValue())
- return dump(*Node->getExpr());
+ if (node->getExpr().hasValue())
+ return dump(*node->getExpr());
{
INDENT();
llvm::errs() << "(void)\n";
@@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
}
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
-void ASTDumper::dump(BinaryExprAST *Node) {
+void ASTDumper::dump(BinaryExprAST *node) {
INDENT();
- llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
- dump(Node->getLHS());
- dump(Node->getRHS());
+ llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+ dump(node->getLHS());
+ dump(node->getRHS());
}
/// Print a call expression, first the callee name and the list of args by
/// recursing into each individual argument.
-void ASTDumper::dump(CallExprAST *Node) {
+void ASTDumper::dump(CallExprAST *node) {
INDENT();
- llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
- for (auto &arg : Node->getArgs())
+ llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+ for (auto &arg : node->getArgs())
dump(arg.get());
indent();
llvm::errs() << "]\n";
}
/// Print a builtin print call, first the builtin name and then the argument.
-void ASTDumper::dump(PrintExprAST *Node) {
+void ASTDumper::dump(PrintExprAST *node) {
INDENT();
- llvm::errs() << "Print [ " << loc(Node) << "\n";
- dump(Node->getArg());
+ llvm::errs() << "Print [ " << loc(node) << "\n";
+ dump(node->getArg());
indent();
llvm::errs() << "]\n";
}
/// Print type: only the shape is printed in between '<' and '>'
-void ASTDumper::dump(VarType &type) {
+void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
- const char *sep = "";
- for (auto shape : type.shape) {
- llvm::errs() << sep << shape;
- sep = ", ";
- }
+ mlir::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
/// Print a function prototype, first the function name, and then the list of
/// parameters names.
-void ASTDumper::dump(PrototypeAST *Node) {
+void ASTDumper::dump(PrototypeAST *node) {
INDENT();
- llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
- const char *sep = "";
- for (auto &arg : Node->getArgs()) {
- llvm::errs() << sep << arg->getName();
- sep = ", ";
- }
+ mlir::interleaveComma(node->getArgs(), llvm::errs(),
+ [](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}
/// Print a function, first the prototype and then the body.
-void ASTDumper::dump(FunctionAST *Node) {
+void ASTDumper::dump(FunctionAST *node) {
INDENT();
llvm::errs() << "Function \n";
- dump(Node->getProto());
- dump(Node->getBody());
+ dump(node->getProto());
+ dump(node->getBody());
}
/// Print a module, actually loop over the functions and print them in sequence.
-void ASTDumper::dump(ModuleAST *Node) {
+void ASTDumper::dump(ModuleAST *node) {
INDENT();
llvm::errs() << "Module:\n";
- for (auto &F : *Node)
- dump(&F);
+ for (auto &f : *node)
+ dump(&f);
}
namespace toy {
diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp
index 547ac9e65b9..e6a69b92832 100644
--- a/mlir/examples/toy/Ch2/toyc.cpp
+++ b/mlir/examples/toy/Ch2/toyc.cpp
@@ -63,16 +63,16 @@ static cl::opt<enum Action> emitAction(
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
- if (std::error_code EC = FileOrErr.getError()) {
- llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return nullptr;
}
- auto buffer = FileOrErr.get()->getBuffer();
+ auto buffer = fileOrErr.get()->getBuffer();
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
Parser parser(lexer);
- return parser.ParseModule();
+ return parser.parseModule();
}
int dumpMLIR() {
diff --git a/mlir/examples/toy/Ch3/include/toy/AST.h b/mlir/examples/toy/Ch3/include/toy/AST.h
index 2ad3392c11a..901164b0f39 100644
--- a/mlir/examples/toy/Ch3/include/toy/AST.h
+++ b/mlir/examples/toy/Ch3/include/toy/AST.h
@@ -54,7 +54,6 @@ public:
ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}
-
virtual ~ExprAST() = default;
ExprASTKind getKind() const { return kind; }
@@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
double Val;
public:
- NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+ NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
double getValue() { return Val; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
};
/// Expression class for a literal value.
@@ -93,10 +92,11 @@ public:
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}
- std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
- std::vector<int64_t> &getDims() { return dims; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
+ llvm::ArrayRef<int64_t> getDims() { return dims; }
+
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
};
/// Expression class for referencing a variable, like "a".
@@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
std::string name;
public:
- VariableExprAST(Location loc, const std::string &name)
+ VariableExprAST(Location loc, llvm::StringRef name)
: ExprAST(Expr_Var, loc), name(name) {}
llvm::StringRef getName() { return name; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
};
/// Expression class for defining a variable.
@@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
std::unique_ptr<ExprAST> initVal;
public:
- VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}
llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
- VarType &getType() { return type; }
+ const VarType &getType() { return type; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
};
/// Expression class for a return operator.
@@ -144,61 +144,61 @@ public:
llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
- return llvm::NoneType();
+ return llvm::None;
}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
};
/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
- char Op;
- std::unique_ptr<ExprAST> LHS, RHS;
+ char op;
+ std::unique_ptr<ExprAST> lhs, rhs;
public:
- char getOp() { return Op; }
- ExprAST *getLHS() { return LHS.get(); }
- ExprAST *getRHS() { return RHS.get(); }
+ char getOp() { return op; }
+ ExprAST *getLHS() { return lhs.get(); }
+ ExprAST *getRHS() { return rhs.get(); }
- BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
- std::unique_ptr<ExprAST> RHS)
- : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
- RHS(std::move(RHS)) {}
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
+ std::unique_ptr<ExprAST> rhs)
+ : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
+ rhs(std::move(rhs)) {}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
};
/// Expression class for function calls.
class CallExprAST : public ExprAST {
- std::string Callee;
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::string callee;
+ std::vector<std::unique_ptr<ExprAST>> args;
public:
- CallExprAST(Location loc, const std::string &Callee,
- std::vector<std::unique_ptr<ExprAST>> Args)
- : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+ CallExprAST(Location loc, const std::string &callee,
+ std::vector<std::unique_ptr<ExprAST>> args)
+ : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
- llvm::StringRef getCallee() { return Callee; }
- llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+ llvm::StringRef getCallee() { return callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
};
/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
- std::unique_ptr<ExprAST> Arg;
+ std::unique_ptr<ExprAST> arg;
public:
- PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
- : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+ : ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
- ExprAST *getArg() { return Arg.get(); }
+ ExprAST *getArg() { return arg.get(); }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
};
/// This class represents the "prototype" for a function, which captures its
@@ -215,23 +215,21 @@ public:
: location(location), name(name), args(std::move(args)) {}
const Location &loc() { return location; }
- const std::string &getName() const { return name; }
- const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
- return args;
- }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
};
/// This class represents a function definition itself.
class FunctionAST {
- std::unique_ptr<PrototypeAST> Proto;
- std::unique_ptr<ExprASTList> Body;
+ std::unique_ptr<PrototypeAST> proto;
+ std::unique_ptr<ExprASTList> body;
public:
- FunctionAST(std::unique_ptr<PrototypeAST> Proto,
- std::unique_ptr<ExprASTList> Body)
- : Proto(std::move(Proto)), Body(std::move(Body)) {}
- PrototypeAST *getProto() { return Proto.get(); }
- ExprASTList *getBody() { return Body.get(); }
+ FunctionAST(std::unique_ptr<PrototypeAST> proto,
+ std::unique_ptr<ExprASTList> body)
+ : proto(std::move(proto)), body(std::move(body)) {}
+ PrototypeAST *getProto() { return proto.get(); }
+ ExprASTList *getBody() { return body.get(); }
};
/// This class represents a list of functions to be processed together
diff --git a/mlir/examples/toy/Ch3/include/toy/Lexer.h b/mlir/examples/toy/Ch3/include/toy/Lexer.h
index 21f92614912..144388c460c 100644
--- a/mlir/examples/toy/Ch3/include/toy/Lexer.h
+++ b/mlir/examples/toy/Ch3/include/toy/Lexer.h
@@ -89,13 +89,13 @@ public:
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
llvm::StringRef getId() {
assert(curTok == tok_identifier);
- return IdentifierStr;
+ return identifierStr;
}
/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
- return NumVal;
+ return numVal;
}
/// Return the location for the beginning of the current token.
@@ -135,56 +135,58 @@ private:
/// Return the next token from standard input.
Token getTok() {
// Skip any whitespace.
- while (isspace(LastChar))
- LastChar = Token(getNextChar());
+ while (isspace(lastChar))
+ lastChar = Token(getNextChar());
// Save the current location before reading the token characters.
lastLocation.line = curLineNum;
lastLocation.col = curCol;
- if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
- IdentifierStr = (char)LastChar;
- while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
- IdentifierStr += (char)LastChar;
+ // Identifier: [a-zA-Z][a-zA-Z0-9_]*
+ if (isalpha(lastChar)) {
+ identifierStr = (char)lastChar;
+ while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
+ identifierStr += (char)lastChar;
- if (IdentifierStr == "return")
+ if (identifierStr == "return")
return tok_return;
- if (IdentifierStr == "def")
+ if (identifierStr == "def")
return tok_def;
- if (IdentifierStr == "var")
+ if (identifierStr == "var")
return tok_var;
return tok_identifier;
}
- if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
- std::string NumStr;
+ // Number: [0-9.]+
+ if (isdigit(lastChar) || lastChar == '.') {
+ std::string numStr;
do {
- NumStr += LastChar;
- LastChar = Token(getNextChar());
- } while (isdigit(LastChar) || LastChar == '.');
+ numStr += lastChar;
+ lastChar = Token(getNextChar());
+ } while (isdigit(lastChar) || lastChar == '.');
- NumVal = strtod(NumStr.c_str(), nullptr);
+ numVal = strtod(numStr.c_str(), nullptr);
return tok_number;
}
- if (LastChar == '#') {
+ if (lastChar == '#') {
// Comment until end of line.
- do
- LastChar = Token(getNextChar());
- while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+ do {
+ lastChar = Token(getNextChar());
+ } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
- if (LastChar != EOF)
+ if (lastChar != EOF)
return getTok();
}
// Check for end of file. Don't eat the EOF.
- if (LastChar == EOF)
+ if (lastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
- Token ThisChar = Token(LastChar);
- LastChar = Token(getNextChar());
- return ThisChar;
+ Token thisChar = Token(lastChar);
+ lastChar = Token(getNextChar());
+ return thisChar;
}
/// The last token read from the input.
@@ -194,15 +196,15 @@ private:
Location lastLocation;
/// If the current Token is an identifier, this string contains the value.
- std::string IdentifierStr;
+ std::string identifierStr;
/// If the current Token is a number, this contains the value.
- double NumVal = 0;
+ double numVal = 0;
/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
/// we can't put it back in the stream after reading from it.
- Token LastChar = Token(' ');
+ Token lastChar = Token(' ');
/// Keep track of the current line number in the input stream
int curLineNum = 0;
diff --git a/mlir/examples/toy/Ch3/include/toy/Parser.h b/mlir/examples/toy/Ch3/include/toy/Parser.h
index ec3d7654a85..9e219e56551 100644
--- a/mlir/examples/toy/Ch3/include/toy/Parser.h
+++ b/mlir/examples/toy/Ch3/include/toy/Parser.h
@@ -48,13 +48,13 @@ public:
Parser(Lexer &lexer) : lexer(lexer) {}
/// Parse a full Module. A module is a list of function definitions.
- std::unique_ptr<ModuleAST> ParseModule() {
+ std::unique_ptr<ModuleAST> parseModule() {
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<FunctionAST> functions;
- while (auto F = ParseDefinition()) {
- functions.push_back(std::move(*F));
+ while (auto f = parseDefinition()) {
+ functions.push_back(std::move(*f));
if (lexer.getCurToken() == tok_eof)
break;
}
@@ -70,14 +70,14 @@ private:
/// Parse a return statement.
/// return :== return ; | return expr ;
- std::unique_ptr<ReturnExprAST> ParseReturn() {
+ std::unique_ptr<ReturnExprAST> parseReturn() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
llvm::Optional<std::unique_ptr<ExprAST>> expr;
if (lexer.getCurToken() != ';') {
- expr = ParseExpression();
+ expr = parseExpression();
if (!expr)
return nullptr;
}
@@ -86,18 +86,18 @@ private:
/// Parse a literal number.
/// numberexpr ::= number
- std::unique_ptr<ExprAST> ParseNumberExpr() {
+ std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
- auto Result =
+ auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
- return std::move(Result);
+ return std::move(result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
- std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
+ std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
@@ -108,13 +108,13 @@ private:
do {
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[') {
- values.push_back(ParseTensorLiteralExpr());
+ values.push_back(parseTensorLiteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
- values.push_back(ParseNumberExpr());
+ values.push_back(parseNumberExpr());
}
// End of this list on ']'
@@ -130,8 +130,10 @@ private:
if (values.empty())
return parseError<ExprAST>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
+
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
+
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
@@ -143,7 +145,7 @@ private:
"inside literal expression");
// Append the nested dimensions to the current level
- auto &firstDims = firstLiteral->getDims();
+ auto firstDims = firstLiteral->getDims();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
@@ -162,22 +164,22 @@ private:
}
/// parenexpr ::= '(' expression ')'
- std::unique_ptr<ExprAST> ParseParenExpr() {
+ std::unique_ptr<ExprAST> parseParenExpr() {
lexer.getNextToken(); // eat (.
- auto V = ParseExpression();
- if (!V)
+ auto v = parseExpression();
+ if (!v)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExprAST>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
- return V;
+ return v;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
- std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::unique_ptr<ExprAST> parseIdentifierExpr() {
std::string name = lexer.getId();
auto loc = lexer.getLastLocation();
@@ -188,11 +190,11 @@ private:
// This is a function call.
lexer.consume(Token('('));
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::vector<std::unique_ptr<ExprAST>> args;
if (lexer.getCurToken() != ')') {
while (true) {
- if (auto Arg = ParseExpression())
- Args.push_back(std::move(Arg));
+ if (auto arg = parseExpression())
+ args.push_back(std::move(arg));
else
return nullptr;
@@ -208,14 +210,14 @@ private:
// It can be a builtin call to print
if (name == "print") {
- if (Args.size() != 1)
+ if (args.size() != 1)
return parseError<ExprAST>("<single arg>", "as argument to print()");
- return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
}
// Call to a user-defined function
- return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
}
/// primary
@@ -223,20 +225,20 @@ private:
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
- std::unique_ptr<ExprAST> ParsePrimary() {
+ std::unique_ptr<ExprAST> parsePrimary() {
switch (lexer.getCurToken()) {
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
- return ParseIdentifierExpr();
+ return parseIdentifierExpr();
case tok_number:
- return ParseNumberExpr();
+ return parseNumberExpr();
case '(':
- return ParseParenExpr();
+ return parseParenExpr();
case '[':
- return ParseTensorLiteralExpr();
+ return parseTensorLiteralExpr();
case ';':
return nullptr;
case '}':
@@ -248,54 +250,54 @@ private:
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
- std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
- std::unique_ptr<ExprAST> LHS) {
+ std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+ std::unique_ptr<ExprAST> lhs) {
// If this is a binop, find its precedence.
while (true) {
- int TokPrec = GetTokPrecedence();
+ int tokPrec = getTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
- if (TokPrec < ExprPrec)
- return LHS;
+ if (tokPrec < exprPrec)
+ return lhs;
// Okay, we know this is a binop.
- int BinOp = lexer.getCurToken();
- lexer.consume(Token(BinOp));
+ int binOp = lexer.getCurToken();
+ lexer.consume(Token(binOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
- auto RHS = ParsePrimary();
- if (!RHS)
+ auto rhs = parsePrimary();
+ if (!rhs)
return parseError<ExprAST>("expression", "to complete binary operator");
- // If BinOp binds less tightly with RHS than the operator after RHS, let
- // the pending operator take RHS as its LHS.
- int NextPrec = GetTokPrecedence();
- if (TokPrec < NextPrec) {
- RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
- if (!RHS)
+ // If BinOp binds less tightly with rhs than the operator after rhs, let
+ // the pending operator take rhs as its lhs.
+ int nextPrec = getTokPrecedence();
+ if (tokPrec < nextPrec) {
+ rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+ if (!rhs)
return nullptr;
}
- // Merge LHS/RHS.
- LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
- std::move(LHS), std::move(RHS));
+ // Merge lhs/RHS.
+ lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+ std::move(lhs), std::move(rhs));
}
}
- /// expression::= primary binoprhs
- std::unique_ptr<ExprAST> ParseExpression() {
- auto LHS = ParsePrimary();
- if (!LHS)
+ /// expression::= primary binop rhs
+ std::unique_ptr<ExprAST> parseExpression() {
+ auto lhs = parsePrimary();
+ if (!lhs)
return nullptr;
- return ParseBinOpRHS(0, std::move(LHS));
+ return parseBinOpRHS(0, std::move(lhs));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
- std::unique_ptr<VarType> ParseType() {
+ std::unique_ptr<VarType> parseType() {
if (lexer.getCurToken() != '<')
return parseError<VarType>("<", "to begin type");
lexer.getNextToken(); // eat <
@@ -319,7 +321,7 @@ private:
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
- std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ std::unique_ptr<VarDeclExprAST> parseDeclaration() {
if (lexer.getCurToken() != tok_var)
return parseError<VarDeclExprAST>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
@@ -333,7 +335,7 @@ private:
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<') {
- type = ParseType();
+ type = parseType();
if (!type)
return nullptr;
}
@@ -341,7 +343,7 @@ private:
if (!type)
type = std::make_unique<VarType>();
lexer.consume(Token('='));
- auto expr = ParseExpression();
+ auto expr = parseExpression();
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
@@ -352,7 +354,7 @@ private:
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
- std::unique_ptr<ExprASTList> ParseBlock() {
+ std::unique_ptr<ExprASTList> parseBlock() {
if (lexer.getCurToken() != '{')
return parseError<ExprASTList>("{", "to begin block");
lexer.consume(Token('{'));
@@ -366,19 +368,19 @@ private:
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
if (lexer.getCurToken() == tok_var) {
// Variable declaration
- auto varDecl = ParseDeclaration();
+ auto varDecl = parseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
} else if (lexer.getCurToken() == tok_return) {
// Return statement
- auto ret = ParseReturn();
+ auto ret = parseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
} else {
// General expression
- auto expr = ParseExpression();
+ auto expr = parseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
@@ -401,13 +403,13 @@ private:
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
- std::unique_ptr<PrototypeAST> ParsePrototype() {
+ std::unique_ptr<PrototypeAST> parsePrototype() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>("function name", "in prototype");
- std::string FnName = lexer.getId();
+ std::string fnName = lexer.getId();
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
@@ -435,7 +437,7 @@ private:
// success.
lexer.consume(Token(')'));
- return std::make_unique<PrototypeAST>(std::move(loc), FnName,
+ return std::make_unique<PrototypeAST>(std::move(loc), fnName,
std::move(args));
}
@@ -443,18 +445,18 @@ private:
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
- std::unique_ptr<FunctionAST> ParseDefinition() {
- auto Proto = ParsePrototype();
- if (!Proto)
+ std::unique_ptr<FunctionAST> parseDefinition() {
+ auto proto = parsePrototype();
+ if (!proto)
return nullptr;
- if (auto block = ParseBlock())
- return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ if (auto block = parseBlock())
+ return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
- int GetTokPrecedence() {
+ int getTokPrecedence() {
if (!isascii(lexer.getCurToken()))
return -1;
diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
index 8b434b139c7..da474e809b3 100644
--- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
@@ -143,7 +143,7 @@ private:
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
- auto &protoArgs = funcAST.getProto()->getArgs();
+ auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp
index 869f2ef2013..0c7735ec9a4 100644
--- a/mlir/examples/toy/Ch3/parser/AST.cpp
+++ b/mlir/examples/toy/Ch3/parser/AST.cpp
@@ -21,6 +21,7 @@
#include "toy/AST.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@@ -40,22 +41,22 @@ struct Indent {
/// the way. The only data member is the current indentation level.
class ASTDumper {
public:
- void dump(ModuleAST *Node);
+ void dump(ModuleAST *node);
private:
- void dump(VarType &type);
+ void dump(const VarType &type);
void dump(VarDeclExprAST *varDecl);
void dump(ExprAST *expr);
void dump(ExprASTList *exprList);
void dump(NumberExprAST *num);
- void dump(LiteralExprAST *Node);
- void dump(VariableExprAST *Node);
- void dump(ReturnExprAST *Node);
- void dump(BinaryExprAST *Node);
- void dump(CallExprAST *Node);
- void dump(PrintExprAST *Node);
- void dump(PrototypeAST *Node);
- void dump(FunctionAST *Node);
+ void dump(LiteralExprAST *node);
+ void dump(VariableExprAST *node);
+ void dump(ReturnExprAST *node);
+ void dump(BinaryExprAST *node);
+ void dump(CallExprAST *node);
+ void dump(PrintExprAST *node);
+ void dump(PrototypeAST *node);
+ void dump(FunctionAST *node);
// Actually print spaces matching the current indentation level
void indent() {
@@ -68,8 +69,8 @@ private:
} // namespace
/// Return a formatted string for the location of any node
-template <typename T> static std::string loc(T *Node) {
- const auto &loc = Node->loc();
+template <typename T> static std::string loc(T *node) {
+ const auto &loc = node->loc();
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
llvm::Twine(loc.col))
.str();
@@ -125,60 +126,50 @@ void ASTDumper::dump(NumberExprAST *num) {
llvm::errs() << num->getValue() << " " << loc(num) << "\n";
}
-/// Helper to print recurisvely a literal. This handles nested array like:
+/// Helper to print recursively a literal. This handles nested array like:
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *lit_or_num) {
+void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
- if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+ if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
return;
}
- auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+ auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
// Print the dimension for this literal first
llvm::errs() << "<";
- {
- const char *sep = "";
- for (auto dim : literal->getDims()) {
- llvm::errs() << sep << dim;
- sep = ", ";
- }
- }
+ mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
- const char *sep = "";
- for (auto &elt : literal->getValues()) {
- llvm::errs() << sep;
- printLitHelper(elt.get());
- sep = ", ";
- }
+ mlir::interleaveComma(literal->getValues(), llvm::errs(),
+ [&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
/// Print a literal, see the recursive helper above for the implementation.
-void ASTDumper::dump(LiteralExprAST *Node) {
+void ASTDumper::dump(LiteralExprAST *node) {
INDENT();
llvm::errs() << "Literal: ";
- printLitHelper(Node);
- llvm::errs() << " " << loc(Node) << "\n";
+ printLitHelper(node);
+ llvm::errs() << " " << loc(node) << "\n";
}
/// Print a variable reference (just a name).
-void ASTDumper::dump(VariableExprAST *Node) {
+void ASTDumper::dump(VariableExprAST *node) {
INDENT();
- llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+ llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
}
/// Return statement print the return and its (optional) argument.
-void ASTDumper::dump(ReturnExprAST *Node) {
+void ASTDumper::dump(ReturnExprAST *node) {
INDENT();
llvm::errs() << "Return\n";
- if (Node->getExpr().hasValue())
- return dump(*Node->getExpr());
+ if (node->getExpr().hasValue())
+ return dump(*node->getExpr());
{
INDENT();
llvm::errs() << "(void)\n";
@@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
}
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
-void ASTDumper::dump(BinaryExprAST *Node) {
+void ASTDumper::dump(BinaryExprAST *node) {
INDENT();
- llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
- dump(Node->getLHS());
- dump(Node->getRHS());
+ llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+ dump(node->getLHS());
+ dump(node->getRHS());
}
/// Print a call expression, first the callee name and the list of args by
/// recursing into each individual argument.
-void ASTDumper::dump(CallExprAST *Node) {
+void ASTDumper::dump(CallExprAST *node) {
INDENT();
- llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
- for (auto &arg : Node->getArgs())
+ llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+ for (auto &arg : node->getArgs())
dump(arg.get());
indent();
llvm::errs() << "]\n";
}
/// Print a builtin print call, first the builtin name and then the argument.
-void ASTDumper::dump(PrintExprAST *Node) {
+void ASTDumper::dump(PrintExprAST *node) {
INDENT();
- llvm::errs() << "Print [ " << loc(Node) << "\n";
- dump(Node->getArg());
+ llvm::errs() << "Print [ " << loc(node) << "\n";
+ dump(node->getArg());
indent();
llvm::errs() << "]\n";
}
/// Print type: only the shape is printed in between '<' and '>'
-void ASTDumper::dump(VarType &type) {
+void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
- const char *sep = "";
- for (auto shape : type.shape) {
- llvm::errs() << sep << shape;
- sep = ", ";
- }
+ mlir::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
/// Print a function prototype, first the function name, and then the list of
/// parameters names.
-void ASTDumper::dump(PrototypeAST *Node) {
+void ASTDumper::dump(PrototypeAST *node) {
INDENT();
- llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
- const char *sep = "";
- for (auto &arg : Node->getArgs()) {
- llvm::errs() << sep << arg->getName();
- sep = ", ";
- }
+ mlir::interleaveComma(node->getArgs(), llvm::errs(),
+ [](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}
/// Print a function, first the prototype and then the body.
-void ASTDumper::dump(FunctionAST *Node) {
+void ASTDumper::dump(FunctionAST *node) {
INDENT();
llvm::errs() << "Function \n";
- dump(Node->getProto());
- dump(Node->getBody());
+ dump(node->getProto());
+ dump(node->getBody());
}
/// Print a module, actually loop over the functions and print them in sequence.
-void ASTDumper::dump(ModuleAST *Node) {
+void ASTDumper::dump(ModuleAST *node) {
INDENT();
llvm::errs() << "Module:\n";
- for (auto &F : *Node)
- dump(&F);
+ for (auto &f : *node)
+ dump(&f);
}
namespace toy {
diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp
index 7e62e13432c..d3bb5d1a160 100644
--- a/mlir/examples/toy/Ch3/toyc.cpp
+++ b/mlir/examples/toy/Ch3/toyc.cpp
@@ -63,20 +63,20 @@ static cl::opt<enum Action> emitAction(
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
-static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
+static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
- if (std::error_code EC = FileOrErr.getError()) {
- llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return nullptr;
}
- auto buffer = FileOrErr.get()->getBuffer();
+ auto buffer = fileOrErr.get()->getBuffer();
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
Parser parser(lexer);
- return parser.ParseModule();
+ return parser.parseModule();
}
int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
@@ -118,7 +118,7 @@ int dumpMLIR() {
if (int error = loadMLIR(sourceMgr, context, module))
return error;
- if (EnableOpt) {
+ if (enableOpt) {
mlir::PassManager pm(&context);
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
diff --git a/mlir/examples/toy/Ch4/include/toy/AST.h b/mlir/examples/toy/Ch4/include/toy/AST.h
index 2ad3392c11a..901164b0f39 100644
--- a/mlir/examples/toy/Ch4/include/toy/AST.h
+++ b/mlir/examples/toy/Ch4/include/toy/AST.h
@@ -54,7 +54,6 @@ public:
ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}
-
virtual ~ExprAST() = default;
ExprASTKind getKind() const { return kind; }
@@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
double Val;
public:
- NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+ NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
double getValue() { return Val; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
};
/// Expression class for a literal value.
@@ -93,10 +92,11 @@ public:
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}
- std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
- std::vector<int64_t> &getDims() { return dims; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
+ llvm::ArrayRef<int64_t> getDims() { return dims; }
+
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
};
/// Expression class for referencing a variable, like "a".
@@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
std::string name;
public:
- VariableExprAST(Location loc, const std::string &name)
+ VariableExprAST(Location loc, llvm::StringRef name)
: ExprAST(Expr_Var, loc), name(name) {}
llvm::StringRef getName() { return name; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
};
/// Expression class for defining a variable.
@@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
std::unique_ptr<ExprAST> initVal;
public:
- VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}
llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
- VarType &getType() { return type; }
+ const VarType &getType() { return type; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
};
/// Expression class for a return operator.
@@ -144,61 +144,61 @@ public:
llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
- return llvm::NoneType();
+ return llvm::None;
}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
};
/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
- char Op;
- std::unique_ptr<ExprAST> LHS, RHS;
+ char op;
+ std::unique_ptr<ExprAST> lhs, rhs;
public:
- char getOp() { return Op; }
- ExprAST *getLHS() { return LHS.get(); }
- ExprAST *getRHS() { return RHS.get(); }
+ char getOp() { return op; }
+ ExprAST *getLHS() { return lhs.get(); }
+ ExprAST *getRHS() { return rhs.get(); }
- BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
- std::unique_ptr<ExprAST> RHS)
- : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
- RHS(std::move(RHS)) {}
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
+ std::unique_ptr<ExprAST> rhs)
+ : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
+ rhs(std::move(rhs)) {}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
};
/// Expression class for function calls.
class CallExprAST : public ExprAST {
- std::string Callee;
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::string callee;
+ std::vector<std::unique_ptr<ExprAST>> args;
public:
- CallExprAST(Location loc, const std::string &Callee,
- std::vector<std::unique_ptr<ExprAST>> Args)
- : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+ CallExprAST(Location loc, const std::string &callee,
+ std::vector<std::unique_ptr<ExprAST>> args)
+ : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
- llvm::StringRef getCallee() { return Callee; }
- llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+ llvm::StringRef getCallee() { return callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
};
/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
- std::unique_ptr<ExprAST> Arg;
+ std::unique_ptr<ExprAST> arg;
public:
- PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
- : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+ : ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
- ExprAST *getArg() { return Arg.get(); }
+ ExprAST *getArg() { return arg.get(); }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
};
/// This class represents the "prototype" for a function, which captures its
@@ -215,23 +215,21 @@ public:
: location(location), name(name), args(std::move(args)) {}
const Location &loc() { return location; }
- const std::string &getName() const { return name; }
- const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
- return args;
- }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
};
/// This class represents a function definition itself.
class FunctionAST {
- std::unique_ptr<PrototypeAST> Proto;
- std::unique_ptr<ExprASTList> Body;
+ std::unique_ptr<PrototypeAST> proto;
+ std::unique_ptr<ExprASTList> body;
public:
- FunctionAST(std::unique_ptr<PrototypeAST> Proto,
- std::unique_ptr<ExprASTList> Body)
- : Proto(std::move(Proto)), Body(std::move(Body)) {}
- PrototypeAST *getProto() { return Proto.get(); }
- ExprASTList *getBody() { return Body.get(); }
+ FunctionAST(std::unique_ptr<PrototypeAST> proto,
+ std::unique_ptr<ExprASTList> body)
+ : proto(std::move(proto)), body(std::move(body)) {}
+ PrototypeAST *getProto() { return proto.get(); }
+ ExprASTList *getBody() { return body.get(); }
};
/// This class represents a list of functions to be processed together
diff --git a/mlir/examples/toy/Ch4/include/toy/Lexer.h b/mlir/examples/toy/Ch4/include/toy/Lexer.h
index 21f92614912..144388c460c 100644
--- a/mlir/examples/toy/Ch4/include/toy/Lexer.h
+++ b/mlir/examples/toy/Ch4/include/toy/Lexer.h
@@ -89,13 +89,13 @@ public:
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
llvm::StringRef getId() {
assert(curTok == tok_identifier);
- return IdentifierStr;
+ return identifierStr;
}
/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
- return NumVal;
+ return numVal;
}
/// Return the location for the beginning of the current token.
@@ -135,56 +135,58 @@ private:
/// Return the next token from standard input.
Token getTok() {
// Skip any whitespace.
- while (isspace(LastChar))
- LastChar = Token(getNextChar());
+ while (isspace(lastChar))
+ lastChar = Token(getNextChar());
// Save the current location before reading the token characters.
lastLocation.line = curLineNum;
lastLocation.col = curCol;
- if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
- IdentifierStr = (char)LastChar;
- while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
- IdentifierStr += (char)LastChar;
+ // Identifier: [a-zA-Z][a-zA-Z0-9_]*
+ if (isalpha(lastChar)) {
+ identifierStr = (char)lastChar;
+ while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
+ identifierStr += (char)lastChar;
- if (IdentifierStr == "return")
+ if (identifierStr == "return")
return tok_return;
- if (IdentifierStr == "def")
+ if (identifierStr == "def")
return tok_def;
- if (IdentifierStr == "var")
+ if (identifierStr == "var")
return tok_var;
return tok_identifier;
}
- if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
- std::string NumStr;
+ // Number: [0-9.]+
+ if (isdigit(lastChar) || lastChar == '.') {
+ std::string numStr;
do {
- NumStr += LastChar;
- LastChar = Token(getNextChar());
- } while (isdigit(LastChar) || LastChar == '.');
+ numStr += lastChar;
+ lastChar = Token(getNextChar());
+ } while (isdigit(lastChar) || lastChar == '.');
- NumVal = strtod(NumStr.c_str(), nullptr);
+ numVal = strtod(numStr.c_str(), nullptr);
return tok_number;
}
- if (LastChar == '#') {
+ if (lastChar == '#') {
// Comment until end of line.
- do
- LastChar = Token(getNextChar());
- while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+ do {
+ lastChar = Token(getNextChar());
+ } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
- if (LastChar != EOF)
+ if (lastChar != EOF)
return getTok();
}
// Check for end of file. Don't eat the EOF.
- if (LastChar == EOF)
+ if (lastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
- Token ThisChar = Token(LastChar);
- LastChar = Token(getNextChar());
- return ThisChar;
+ Token thisChar = Token(lastChar);
+ lastChar = Token(getNextChar());
+ return thisChar;
}
/// The last token read from the input.
@@ -194,15 +196,15 @@ private:
Location lastLocation;
/// If the current Token is an identifier, this string contains the value.
- std::string IdentifierStr;
+ std::string identifierStr;
/// If the current Token is a number, this contains the value.
- double NumVal = 0;
+ double numVal = 0;
/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
/// we can't put it back in the stream after reading from it.
- Token LastChar = Token(' ');
+ Token lastChar = Token(' ');
/// Keep track of the current line number in the input stream
int curLineNum = 0;
diff --git a/mlir/examples/toy/Ch4/include/toy/Parser.h b/mlir/examples/toy/Ch4/include/toy/Parser.h
index ec3d7654a85..9e219e56551 100644
--- a/mlir/examples/toy/Ch4/include/toy/Parser.h
+++ b/mlir/examples/toy/Ch4/include/toy/Parser.h
@@ -48,13 +48,13 @@ public:
Parser(Lexer &lexer) : lexer(lexer) {}
/// Parse a full Module. A module is a list of function definitions.
- std::unique_ptr<ModuleAST> ParseModule() {
+ std::unique_ptr<ModuleAST> parseModule() {
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<FunctionAST> functions;
- while (auto F = ParseDefinition()) {
- functions.push_back(std::move(*F));
+ while (auto f = parseDefinition()) {
+ functions.push_back(std::move(*f));
if (lexer.getCurToken() == tok_eof)
break;
}
@@ -70,14 +70,14 @@ private:
/// Parse a return statement.
/// return :== return ; | return expr ;
- std::unique_ptr<ReturnExprAST> ParseReturn() {
+ std::unique_ptr<ReturnExprAST> parseReturn() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
llvm::Optional<std::unique_ptr<ExprAST>> expr;
if (lexer.getCurToken() != ';') {
- expr = ParseExpression();
+ expr = parseExpression();
if (!expr)
return nullptr;
}
@@ -86,18 +86,18 @@ private:
/// Parse a literal number.
/// numberexpr ::= number
- std::unique_ptr<ExprAST> ParseNumberExpr() {
+ std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
- auto Result =
+ auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
- return std::move(Result);
+ return std::move(result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
- std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
+ std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
@@ -108,13 +108,13 @@ private:
do {
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[') {
- values.push_back(ParseTensorLiteralExpr());
+ values.push_back(parseTensorLiteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
- values.push_back(ParseNumberExpr());
+ values.push_back(parseNumberExpr());
}
// End of this list on ']'
@@ -130,8 +130,10 @@ private:
if (values.empty())
return parseError<ExprAST>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
+
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
+
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
@@ -143,7 +145,7 @@ private:
"inside literal expression");
// Append the nested dimensions to the current level
- auto &firstDims = firstLiteral->getDims();
+ auto firstDims = firstLiteral->getDims();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
@@ -162,22 +164,22 @@ private:
}
/// parenexpr ::= '(' expression ')'
- std::unique_ptr<ExprAST> ParseParenExpr() {
+ std::unique_ptr<ExprAST> parseParenExpr() {
lexer.getNextToken(); // eat (.
- auto V = ParseExpression();
- if (!V)
+ auto v = parseExpression();
+ if (!v)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExprAST>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
- return V;
+ return v;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
- std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::unique_ptr<ExprAST> parseIdentifierExpr() {
std::string name = lexer.getId();
auto loc = lexer.getLastLocation();
@@ -188,11 +190,11 @@ private:
// This is a function call.
lexer.consume(Token('('));
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::vector<std::unique_ptr<ExprAST>> args;
if (lexer.getCurToken() != ')') {
while (true) {
- if (auto Arg = ParseExpression())
- Args.push_back(std::move(Arg));
+ if (auto arg = parseExpression())
+ args.push_back(std::move(arg));
else
return nullptr;
@@ -208,14 +210,14 @@ private:
// It can be a builtin call to print
if (name == "print") {
- if (Args.size() != 1)
+ if (args.size() != 1)
return parseError<ExprAST>("<single arg>", "as argument to print()");
- return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
}
// Call to a user-defined function
- return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
}
/// primary
@@ -223,20 +225,20 @@ private:
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
- std::unique_ptr<ExprAST> ParsePrimary() {
+ std::unique_ptr<ExprAST> parsePrimary() {
switch (lexer.getCurToken()) {
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
- return ParseIdentifierExpr();
+ return parseIdentifierExpr();
case tok_number:
- return ParseNumberExpr();
+ return parseNumberExpr();
case '(':
- return ParseParenExpr();
+ return parseParenExpr();
case '[':
- return ParseTensorLiteralExpr();
+ return parseTensorLiteralExpr();
case ';':
return nullptr;
case '}':
@@ -248,54 +250,54 @@ private:
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
- std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
- std::unique_ptr<ExprAST> LHS) {
+ std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+ std::unique_ptr<ExprAST> lhs) {
// If this is a binop, find its precedence.
while (true) {
- int TokPrec = GetTokPrecedence();
+ int tokPrec = getTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
- if (TokPrec < ExprPrec)
- return LHS;
+ if (tokPrec < exprPrec)
+ return lhs;
// Okay, we know this is a binop.
- int BinOp = lexer.getCurToken();
- lexer.consume(Token(BinOp));
+ int binOp = lexer.getCurToken();
+ lexer.consume(Token(binOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
- auto RHS = ParsePrimary();
- if (!RHS)
+ auto rhs = parsePrimary();
+ if (!rhs)
return parseError<ExprAST>("expression", "to complete binary operator");
- // If BinOp binds less tightly with RHS than the operator after RHS, let
- // the pending operator take RHS as its LHS.
- int NextPrec = GetTokPrecedence();
- if (TokPrec < NextPrec) {
- RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
- if (!RHS)
+ // If BinOp binds less tightly with rhs than the operator after rhs, let
+ // the pending operator take rhs as its lhs.
+ int nextPrec = getTokPrecedence();
+ if (tokPrec < nextPrec) {
+ rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+ if (!rhs)
return nullptr;
}
- // Merge LHS/RHS.
- LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
- std::move(LHS), std::move(RHS));
+ // Merge lhs/RHS.
+ lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+ std::move(lhs), std::move(rhs));
}
}
- /// expression::= primary binoprhs
- std::unique_ptr<ExprAST> ParseExpression() {
- auto LHS = ParsePrimary();
- if (!LHS)
+ /// expression::= primary binop rhs
+ std::unique_ptr<ExprAST> parseExpression() {
+ auto lhs = parsePrimary();
+ if (!lhs)
return nullptr;
- return ParseBinOpRHS(0, std::move(LHS));
+ return parseBinOpRHS(0, std::move(lhs));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
- std::unique_ptr<VarType> ParseType() {
+ std::unique_ptr<VarType> parseType() {
if (lexer.getCurToken() != '<')
return parseError<VarType>("<", "to begin type");
lexer.getNextToken(); // eat <
@@ -319,7 +321,7 @@ private:
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
- std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ std::unique_ptr<VarDeclExprAST> parseDeclaration() {
if (lexer.getCurToken() != tok_var)
return parseError<VarDeclExprAST>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
@@ -333,7 +335,7 @@ private:
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<') {
- type = ParseType();
+ type = parseType();
if (!type)
return nullptr;
}
@@ -341,7 +343,7 @@ private:
if (!type)
type = std::make_unique<VarType>();
lexer.consume(Token('='));
- auto expr = ParseExpression();
+ auto expr = parseExpression();
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
@@ -352,7 +354,7 @@ private:
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
- std::unique_ptr<ExprASTList> ParseBlock() {
+ std::unique_ptr<ExprASTList> parseBlock() {
if (lexer.getCurToken() != '{')
return parseError<ExprASTList>("{", "to begin block");
lexer.consume(Token('{'));
@@ -366,19 +368,19 @@ private:
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
if (lexer.getCurToken() == tok_var) {
// Variable declaration
- auto varDecl = ParseDeclaration();
+ auto varDecl = parseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
} else if (lexer.getCurToken() == tok_return) {
// Return statement
- auto ret = ParseReturn();
+ auto ret = parseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
} else {
// General expression
- auto expr = ParseExpression();
+ auto expr = parseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
@@ -401,13 +403,13 @@ private:
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
- std::unique_ptr<PrototypeAST> ParsePrototype() {
+ std::unique_ptr<PrototypeAST> parsePrototype() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>("function name", "in prototype");
- std::string FnName = lexer.getId();
+ std::string fnName = lexer.getId();
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
@@ -435,7 +437,7 @@ private:
// success.
lexer.consume(Token(')'));
- return std::make_unique<PrototypeAST>(std::move(loc), FnName,
+ return std::make_unique<PrototypeAST>(std::move(loc), fnName,
std::move(args));
}
@@ -443,18 +445,18 @@ private:
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
- std::unique_ptr<FunctionAST> ParseDefinition() {
- auto Proto = ParsePrototype();
- if (!Proto)
+ std::unique_ptr<FunctionAST> parseDefinition() {
+ auto proto = parsePrototype();
+ if (!proto)
return nullptr;
- if (auto block = ParseBlock())
- return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ if (auto block = parseBlock())
+ return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
- int GetTokPrecedence() {
+ int getTokPrecedence() {
if (!isascii(lexer.getCurToken()))
return -1;
diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
index 5f9dd30a507..da474e809b3 100644
--- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
@@ -91,7 +91,7 @@ private:
mlir::ModuleOp theModule;
/// The builder is a helper class to create IR inside a function. The builder
- /// is stateful, in particular it keeeps an "insertion point": this is where
+ /// is stateful, in particular it keeps an "insertion point": this is where
/// the next operations will be introduced.
mlir::OpBuilder builder;
@@ -143,7 +143,7 @@ private:
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
- auto &protoArgs = funcAST.getProto()->getArgs();
+ auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
diff --git a/mlir/examples/toy/Ch4/parser/AST.cpp b/mlir/examples/toy/Ch4/parser/AST.cpp
index fde8b101e83..0c7735ec9a4 100644
--- a/mlir/examples/toy/Ch4/parser/AST.cpp
+++ b/mlir/examples/toy/Ch4/parser/AST.cpp
@@ -21,6 +21,7 @@
#include "toy/AST.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@@ -40,22 +41,22 @@ struct Indent {
/// the way. The only data member is the current indentation level.
class ASTDumper {
public:
- void dump(ModuleAST *Node);
+ void dump(ModuleAST *node);
private:
- void dump(VarType &type);
+ void dump(const VarType &type);
void dump(VarDeclExprAST *varDecl);
void dump(ExprAST *expr);
void dump(ExprASTList *exprList);
void dump(NumberExprAST *num);
- void dump(LiteralExprAST *Node);
- void dump(VariableExprAST *Node);
- void dump(ReturnExprAST *Node);
- void dump(BinaryExprAST *Node);
- void dump(CallExprAST *Node);
- void dump(PrintExprAST *Node);
- void dump(PrototypeAST *Node);
- void dump(FunctionAST *Node);
+ void dump(LiteralExprAST *node);
+ void dump(VariableExprAST *node);
+ void dump(ReturnExprAST *node);
+ void dump(BinaryExprAST *node);
+ void dump(CallExprAST *node);
+ void dump(PrintExprAST *node);
+ void dump(PrototypeAST *node);
+ void dump(FunctionAST *node);
// Actually print spaces matching the current indentation level
void indent() {
@@ -68,8 +69,8 @@ private:
} // namespace
/// Return a formatted string for the location of any node
-template <typename T> static std::string loc(T *Node) {
- const auto &loc = Node->loc();
+template <typename T> static std::string loc(T *node) {
+ const auto &loc = node->loc();
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
llvm::Twine(loc.col))
.str();
@@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *lit_or_num) {
+void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
- if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+ if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
return;
}
- auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+ auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
// Print the dimension for this literal first
llvm::errs() << "<";
- {
- const char *sep = "";
- for (auto dim : literal->getDims()) {
- llvm::errs() << sep << dim;
- sep = ", ";
- }
- }
+ mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
- const char *sep = "";
- for (auto &elt : literal->getValues()) {
- llvm::errs() << sep;
- printLitHelper(elt.get());
- sep = ", ";
- }
+ mlir::interleaveComma(literal->getValues(), llvm::errs(),
+ [&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
/// Print a literal, see the recursive helper above for the implementation.
-void ASTDumper::dump(LiteralExprAST *Node) {
+void ASTDumper::dump(LiteralExprAST *node) {
INDENT();
llvm::errs() << "Literal: ";
- printLitHelper(Node);
- llvm::errs() << " " << loc(Node) << "\n";
+ printLitHelper(node);
+ llvm::errs() << " " << loc(node) << "\n";
}
/// Print a variable reference (just a name).
-void ASTDumper::dump(VariableExprAST *Node) {
+void ASTDumper::dump(VariableExprAST *node) {
INDENT();
- llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+ llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
}
/// Return statement print the return and its (optional) argument.
-void ASTDumper::dump(ReturnExprAST *Node) {
+void ASTDumper::dump(ReturnExprAST *node) {
INDENT();
llvm::errs() << "Return\n";
- if (Node->getExpr().hasValue())
- return dump(*Node->getExpr());
+ if (node->getExpr().hasValue())
+ return dump(*node->getExpr());
{
INDENT();
llvm::errs() << "(void)\n";
@@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
}
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
-void ASTDumper::dump(BinaryExprAST *Node) {
+void ASTDumper::dump(BinaryExprAST *node) {
INDENT();
- llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
- dump(Node->getLHS());
- dump(Node->getRHS());
+ llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+ dump(node->getLHS());
+ dump(node->getRHS());
}
/// Print a call expression, first the callee name and the list of args by
/// recursing into each individual argument.
-void ASTDumper::dump(CallExprAST *Node) {
+void ASTDumper::dump(CallExprAST *node) {
INDENT();
- llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
- for (auto &arg : Node->getArgs())
+ llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+ for (auto &arg : node->getArgs())
dump(arg.get());
indent();
llvm::errs() << "]\n";
}
/// Print a builtin print call, first the builtin name and then the argument.
-void ASTDumper::dump(PrintExprAST *Node) {
+void ASTDumper::dump(PrintExprAST *node) {
INDENT();
- llvm::errs() << "Print [ " << loc(Node) << "\n";
- dump(Node->getArg());
+ llvm::errs() << "Print [ " << loc(node) << "\n";
+ dump(node->getArg());
indent();
llvm::errs() << "]\n";
}
/// Print type: only the shape is printed in between '<' and '>'
-void ASTDumper::dump(VarType &type) {
+void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
- const char *sep = "";
- for (auto shape : type.shape) {
- llvm::errs() << sep << shape;
- sep = ", ";
- }
+ mlir::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
/// Print a function prototype, first the function name, and then the list of
/// parameters names.
-void ASTDumper::dump(PrototypeAST *Node) {
+void ASTDumper::dump(PrototypeAST *node) {
INDENT();
- llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
- const char *sep = "";
- for (auto &arg : Node->getArgs()) {
- llvm::errs() << sep << arg->getName();
- sep = ", ";
- }
+ mlir::interleaveComma(node->getArgs(), llvm::errs(),
+ [](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}
/// Print a function, first the prototype and then the body.
-void ASTDumper::dump(FunctionAST *Node) {
+void ASTDumper::dump(FunctionAST *node) {
INDENT();
llvm::errs() << "Function \n";
- dump(Node->getProto());
- dump(Node->getBody());
+ dump(node->getProto());
+ dump(node->getBody());
}
/// Print a module, actually loop over the functions and print them in sequence.
-void ASTDumper::dump(ModuleAST *Node) {
+void ASTDumper::dump(ModuleAST *node) {
INDENT();
llvm::errs() << "Module:\n";
- for (auto &F : *Node)
- dump(&F);
+ for (auto &f : *node)
+ dump(&f);
}
namespace toy {
diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp
index d8e04d6ee89..a9f597dc032 100644
--- a/mlir/examples/toy/Ch4/toyc.cpp
+++ b/mlir/examples/toy/Ch4/toyc.cpp
@@ -64,20 +64,20 @@ static cl::opt<enum Action> emitAction(
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
-static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
+static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
- if (std::error_code EC = FileOrErr.getError()) {
- llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return nullptr;
}
- auto buffer = FileOrErr.get()->getBuffer();
+ auto buffer = fileOrErr.get()->getBuffer();
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
Parser parser(lexer);
- return parser.ParseModule();
+ return parser.parseModule();
}
int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
@@ -119,7 +119,7 @@ int dumpMLIR() {
if (int error = loadMLIR(sourceMgr, context, module))
return error;
- if (EnableOpt) {
+ if (enableOpt) {
mlir::PassManager pm(&context);
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
diff --git a/mlir/examples/toy/Ch5/include/toy/AST.h b/mlir/examples/toy/Ch5/include/toy/AST.h
index 2ad3392c11a..901164b0f39 100644
--- a/mlir/examples/toy/Ch5/include/toy/AST.h
+++ b/mlir/examples/toy/Ch5/include/toy/AST.h
@@ -54,7 +54,6 @@ public:
ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}
-
virtual ~ExprAST() = default;
ExprASTKind getKind() const { return kind; }
@@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
double Val;
public:
- NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+ NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
double getValue() { return Val; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
};
/// Expression class for a literal value.
@@ -93,10 +92,11 @@ public:
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}
- std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
- std::vector<int64_t> &getDims() { return dims; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
+ llvm::ArrayRef<int64_t> getDims() { return dims; }
+
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
};
/// Expression class for referencing a variable, like "a".
@@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
std::string name;
public:
- VariableExprAST(Location loc, const std::string &name)
+ VariableExprAST(Location loc, llvm::StringRef name)
: ExprAST(Expr_Var, loc), name(name) {}
llvm::StringRef getName() { return name; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
};
/// Expression class for defining a variable.
@@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
std::unique_ptr<ExprAST> initVal;
public:
- VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}
llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
- VarType &getType() { return type; }
+ const VarType &getType() { return type; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
};
/// Expression class for a return operator.
@@ -144,61 +144,61 @@ public:
llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
- return llvm::NoneType();
+ return llvm::None;
}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
};
/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
- char Op;
- std::unique_ptr<ExprAST> LHS, RHS;
+ char op;
+ std::unique_ptr<ExprAST> lhs, rhs;
public:
- char getOp() { return Op; }
- ExprAST *getLHS() { return LHS.get(); }
- ExprAST *getRHS() { return RHS.get(); }
+ char getOp() { return op; }
+ ExprAST *getLHS() { return lhs.get(); }
+ ExprAST *getRHS() { return rhs.get(); }
- BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
- std::unique_ptr<ExprAST> RHS)
- : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
- RHS(std::move(RHS)) {}
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
+ std::unique_ptr<ExprAST> rhs)
+ : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
+ rhs(std::move(rhs)) {}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
};
/// Expression class for function calls.
class CallExprAST : public ExprAST {
- std::string Callee;
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::string callee;
+ std::vector<std::unique_ptr<ExprAST>> args;
public:
- CallExprAST(Location loc, const std::string &Callee,
- std::vector<std::unique_ptr<ExprAST>> Args)
- : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+ CallExprAST(Location loc, const std::string &callee,
+ std::vector<std::unique_ptr<ExprAST>> args)
+ : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
- llvm::StringRef getCallee() { return Callee; }
- llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+ llvm::StringRef getCallee() { return callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
};
/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
- std::unique_ptr<ExprAST> Arg;
+ std::unique_ptr<ExprAST> arg;
public:
- PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
- : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+ : ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
- ExprAST *getArg() { return Arg.get(); }
+ ExprAST *getArg() { return arg.get(); }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
};
/// This class represents the "prototype" for a function, which captures its
@@ -215,23 +215,21 @@ public:
: location(location), name(name), args(std::move(args)) {}
const Location &loc() { return location; }
- const std::string &getName() const { return name; }
- const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
- return args;
- }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
};
/// This class represents a function definition itself.
class FunctionAST {
- std::unique_ptr<PrototypeAST> Proto;
- std::unique_ptr<ExprASTList> Body;
+ std::unique_ptr<PrototypeAST> proto;
+ std::unique_ptr<ExprASTList> body;
public:
- FunctionAST(std::unique_ptr<PrototypeAST> Proto,
- std::unique_ptr<ExprASTList> Body)
- : Proto(std::move(Proto)), Body(std::move(Body)) {}
- PrototypeAST *getProto() { return Proto.get(); }
- ExprASTList *getBody() { return Body.get(); }
+ FunctionAST(std::unique_ptr<PrototypeAST> proto,
+ std::unique_ptr<ExprASTList> body)
+ : proto(std::move(proto)), body(std::move(body)) {}
+ PrototypeAST *getProto() { return proto.get(); }
+ ExprASTList *getBody() { return body.get(); }
};
/// This class represents a list of functions to be processed together
diff --git a/mlir/examples/toy/Ch5/include/toy/Lexer.h b/mlir/examples/toy/Ch5/include/toy/Lexer.h
index 21f92614912..144388c460c 100644
--- a/mlir/examples/toy/Ch5/include/toy/Lexer.h
+++ b/mlir/examples/toy/Ch5/include/toy/Lexer.h
@@ -89,13 +89,13 @@ public:
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
llvm::StringRef getId() {
assert(curTok == tok_identifier);
- return IdentifierStr;
+ return identifierStr;
}
/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
- return NumVal;
+ return numVal;
}
/// Return the location for the beginning of the current token.
@@ -135,56 +135,58 @@ private:
/// Return the next token from standard input.
Token getTok() {
// Skip any whitespace.
- while (isspace(LastChar))
- LastChar = Token(getNextChar());
+ while (isspace(lastChar))
+ lastChar = Token(getNextChar());
// Save the current location before reading the token characters.
lastLocation.line = curLineNum;
lastLocation.col = curCol;
- if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
- IdentifierStr = (char)LastChar;
- while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
- IdentifierStr += (char)LastChar;
+ // Identifier: [a-zA-Z][a-zA-Z0-9_]*
+ if (isalpha(lastChar)) {
+ identifierStr = (char)lastChar;
+ while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
+ identifierStr += (char)lastChar;
- if (IdentifierStr == "return")
+ if (identifierStr == "return")
return tok_return;
- if (IdentifierStr == "def")
+ if (identifierStr == "def")
return tok_def;
- if (IdentifierStr == "var")
+ if (identifierStr == "var")
return tok_var;
return tok_identifier;
}
- if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
- std::string NumStr;
+ // Number: [0-9.]+
+ if (isdigit(lastChar) || lastChar == '.') {
+ std::string numStr;
do {
- NumStr += LastChar;
- LastChar = Token(getNextChar());
- } while (isdigit(LastChar) || LastChar == '.');
+ numStr += lastChar;
+ lastChar = Token(getNextChar());
+ } while (isdigit(lastChar) || lastChar == '.');
- NumVal = strtod(NumStr.c_str(), nullptr);
+ numVal = strtod(numStr.c_str(), nullptr);
return tok_number;
}
- if (LastChar == '#') {
+ if (lastChar == '#') {
// Comment until end of line.
- do
- LastChar = Token(getNextChar());
- while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+ do {
+ lastChar = Token(getNextChar());
+ } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
- if (LastChar != EOF)
+ if (lastChar != EOF)
return getTok();
}
// Check for end of file. Don't eat the EOF.
- if (LastChar == EOF)
+ if (lastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
- Token ThisChar = Token(LastChar);
- LastChar = Token(getNextChar());
- return ThisChar;
+ Token thisChar = Token(lastChar);
+ lastChar = Token(getNextChar());
+ return thisChar;
}
/// The last token read from the input.
@@ -194,15 +196,15 @@ private:
Location lastLocation;
/// If the current Token is an identifier, this string contains the value.
- std::string IdentifierStr;
+ std::string identifierStr;
/// If the current Token is a number, this contains the value.
- double NumVal = 0;
+ double numVal = 0;
/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
/// we can't put it back in the stream after reading from it.
- Token LastChar = Token(' ');
+ Token lastChar = Token(' ');
/// Keep track of the current line number in the input stream
int curLineNum = 0;
diff --git a/mlir/examples/toy/Ch5/include/toy/Parser.h b/mlir/examples/toy/Ch5/include/toy/Parser.h
index ec3d7654a85..9e219e56551 100644
--- a/mlir/examples/toy/Ch5/include/toy/Parser.h
+++ b/mlir/examples/toy/Ch5/include/toy/Parser.h
@@ -48,13 +48,13 @@ public:
Parser(Lexer &lexer) : lexer(lexer) {}
/// Parse a full Module. A module is a list of function definitions.
- std::unique_ptr<ModuleAST> ParseModule() {
+ std::unique_ptr<ModuleAST> parseModule() {
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<FunctionAST> functions;
- while (auto F = ParseDefinition()) {
- functions.push_back(std::move(*F));
+ while (auto f = parseDefinition()) {
+ functions.push_back(std::move(*f));
if (lexer.getCurToken() == tok_eof)
break;
}
@@ -70,14 +70,14 @@ private:
/// Parse a return statement.
/// return :== return ; | return expr ;
- std::unique_ptr<ReturnExprAST> ParseReturn() {
+ std::unique_ptr<ReturnExprAST> parseReturn() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
llvm::Optional<std::unique_ptr<ExprAST>> expr;
if (lexer.getCurToken() != ';') {
- expr = ParseExpression();
+ expr = parseExpression();
if (!expr)
return nullptr;
}
@@ -86,18 +86,18 @@ private:
/// Parse a literal number.
/// numberexpr ::= number
- std::unique_ptr<ExprAST> ParseNumberExpr() {
+ std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
- auto Result =
+ auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
- return std::move(Result);
+ return std::move(result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
- std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
+ std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
@@ -108,13 +108,13 @@ private:
do {
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[') {
- values.push_back(ParseTensorLiteralExpr());
+ values.push_back(parseTensorLiteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
- values.push_back(ParseNumberExpr());
+ values.push_back(parseNumberExpr());
}
// End of this list on ']'
@@ -130,8 +130,10 @@ private:
if (values.empty())
return parseError<ExprAST>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
+
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
+
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
@@ -143,7 +145,7 @@ private:
"inside literal expression");
// Append the nested dimensions to the current level
- auto &firstDims = firstLiteral->getDims();
+ auto firstDims = firstLiteral->getDims();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
@@ -162,22 +164,22 @@ private:
}
/// parenexpr ::= '(' expression ')'
- std::unique_ptr<ExprAST> ParseParenExpr() {
+ std::unique_ptr<ExprAST> parseParenExpr() {
lexer.getNextToken(); // eat (.
- auto V = ParseExpression();
- if (!V)
+ auto v = parseExpression();
+ if (!v)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExprAST>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
- return V;
+ return v;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
- std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::unique_ptr<ExprAST> parseIdentifierExpr() {
std::string name = lexer.getId();
auto loc = lexer.getLastLocation();
@@ -188,11 +190,11 @@ private:
// This is a function call.
lexer.consume(Token('('));
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::vector<std::unique_ptr<ExprAST>> args;
if (lexer.getCurToken() != ')') {
while (true) {
- if (auto Arg = ParseExpression())
- Args.push_back(std::move(Arg));
+ if (auto arg = parseExpression())
+ args.push_back(std::move(arg));
else
return nullptr;
@@ -208,14 +210,14 @@ private:
// It can be a builtin call to print
if (name == "print") {
- if (Args.size() != 1)
+ if (args.size() != 1)
return parseError<ExprAST>("<single arg>", "as argument to print()");
- return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
}
// Call to a user-defined function
- return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
}
/// primary
@@ -223,20 +225,20 @@ private:
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
- std::unique_ptr<ExprAST> ParsePrimary() {
+ std::unique_ptr<ExprAST> parsePrimary() {
switch (lexer.getCurToken()) {
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
- return ParseIdentifierExpr();
+ return parseIdentifierExpr();
case tok_number:
- return ParseNumberExpr();
+ return parseNumberExpr();
case '(':
- return ParseParenExpr();
+ return parseParenExpr();
case '[':
- return ParseTensorLiteralExpr();
+ return parseTensorLiteralExpr();
case ';':
return nullptr;
case '}':
@@ -248,54 +250,54 @@ private:
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
- std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
- std::unique_ptr<ExprAST> LHS) {
+ std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+ std::unique_ptr<ExprAST> lhs) {
// If this is a binop, find its precedence.
while (true) {
- int TokPrec = GetTokPrecedence();
+ int tokPrec = getTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
- if (TokPrec < ExprPrec)
- return LHS;
+ if (tokPrec < exprPrec)
+ return lhs;
// Okay, we know this is a binop.
- int BinOp = lexer.getCurToken();
- lexer.consume(Token(BinOp));
+ int binOp = lexer.getCurToken();
+ lexer.consume(Token(binOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
- auto RHS = ParsePrimary();
- if (!RHS)
+ auto rhs = parsePrimary();
+ if (!rhs)
return parseError<ExprAST>("expression", "to complete binary operator");
- // If BinOp binds less tightly with RHS than the operator after RHS, let
- // the pending operator take RHS as its LHS.
- int NextPrec = GetTokPrecedence();
- if (TokPrec < NextPrec) {
- RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
- if (!RHS)
+ // If BinOp binds less tightly with rhs than the operator after rhs, let
+ // the pending operator take rhs as its lhs.
+ int nextPrec = getTokPrecedence();
+ if (tokPrec < nextPrec) {
+ rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+ if (!rhs)
return nullptr;
}
- // Merge LHS/RHS.
- LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
- std::move(LHS), std::move(RHS));
+ // Merge lhs/RHS.
+ lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+ std::move(lhs), std::move(rhs));
}
}
- /// expression::= primary binoprhs
- std::unique_ptr<ExprAST> ParseExpression() {
- auto LHS = ParsePrimary();
- if (!LHS)
+ /// expression::= primary binop rhs
+ std::unique_ptr<ExprAST> parseExpression() {
+ auto lhs = parsePrimary();
+ if (!lhs)
return nullptr;
- return ParseBinOpRHS(0, std::move(LHS));
+ return parseBinOpRHS(0, std::move(lhs));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
- std::unique_ptr<VarType> ParseType() {
+ std::unique_ptr<VarType> parseType() {
if (lexer.getCurToken() != '<')
return parseError<VarType>("<", "to begin type");
lexer.getNextToken(); // eat <
@@ -319,7 +321,7 @@ private:
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
- std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ std::unique_ptr<VarDeclExprAST> parseDeclaration() {
if (lexer.getCurToken() != tok_var)
return parseError<VarDeclExprAST>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
@@ -333,7 +335,7 @@ private:
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<') {
- type = ParseType();
+ type = parseType();
if (!type)
return nullptr;
}
@@ -341,7 +343,7 @@ private:
if (!type)
type = std::make_unique<VarType>();
lexer.consume(Token('='));
- auto expr = ParseExpression();
+ auto expr = parseExpression();
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
@@ -352,7 +354,7 @@ private:
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
- std::unique_ptr<ExprASTList> ParseBlock() {
+ std::unique_ptr<ExprASTList> parseBlock() {
if (lexer.getCurToken() != '{')
return parseError<ExprASTList>("{", "to begin block");
lexer.consume(Token('{'));
@@ -366,19 +368,19 @@ private:
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
if (lexer.getCurToken() == tok_var) {
// Variable declaration
- auto varDecl = ParseDeclaration();
+ auto varDecl = parseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
} else if (lexer.getCurToken() == tok_return) {
// Return statement
- auto ret = ParseReturn();
+ auto ret = parseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
} else {
// General expression
- auto expr = ParseExpression();
+ auto expr = parseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
@@ -401,13 +403,13 @@ private:
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
- std::unique_ptr<PrototypeAST> ParsePrototype() {
+ std::unique_ptr<PrototypeAST> parsePrototype() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>("function name", "in prototype");
- std::string FnName = lexer.getId();
+ std::string fnName = lexer.getId();
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
@@ -435,7 +437,7 @@ private:
// success.
lexer.consume(Token(')'));
- return std::make_unique<PrototypeAST>(std::move(loc), FnName,
+ return std::make_unique<PrototypeAST>(std::move(loc), fnName,
std::move(args));
}
@@ -443,18 +445,18 @@ private:
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
- std::unique_ptr<FunctionAST> ParseDefinition() {
- auto Proto = ParsePrototype();
- if (!Proto)
+ std::unique_ptr<FunctionAST> parseDefinition() {
+ auto proto = parsePrototype();
+ if (!proto)
return nullptr;
- if (auto block = ParseBlock())
- return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ if (auto block = parseBlock())
+ return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
- int GetTokPrecedence() {
+ int getTokPrecedence() {
if (!isascii(lexer.getCurToken()))
return -1;
diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
index 8b434b139c7..da474e809b3 100644
--- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
@@ -143,7 +143,7 @@ private:
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
- auto &protoArgs = funcAST.getProto()->getArgs();
+ auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
diff --git a/mlir/examples/toy/Ch5/parser/AST.cpp b/mlir/examples/toy/Ch5/parser/AST.cpp
index fde8b101e83..0c7735ec9a4 100644
--- a/mlir/examples/toy/Ch5/parser/AST.cpp
+++ b/mlir/examples/toy/Ch5/parser/AST.cpp
@@ -21,6 +21,7 @@
#include "toy/AST.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@@ -40,22 +41,22 @@ struct Indent {
/// the way. The only data member is the current indentation level.
class ASTDumper {
public:
- void dump(ModuleAST *Node);
+ void dump(ModuleAST *node);
private:
- void dump(VarType &type);
+ void dump(const VarType &type);
void dump(VarDeclExprAST *varDecl);
void dump(ExprAST *expr);
void dump(ExprASTList *exprList);
void dump(NumberExprAST *num);
- void dump(LiteralExprAST *Node);
- void dump(VariableExprAST *Node);
- void dump(ReturnExprAST *Node);
- void dump(BinaryExprAST *Node);
- void dump(CallExprAST *Node);
- void dump(PrintExprAST *Node);
- void dump(PrototypeAST *Node);
- void dump(FunctionAST *Node);
+ void dump(LiteralExprAST *node);
+ void dump(VariableExprAST *node);
+ void dump(ReturnExprAST *node);
+ void dump(BinaryExprAST *node);
+ void dump(CallExprAST *node);
+ void dump(PrintExprAST *node);
+ void dump(PrototypeAST *node);
+ void dump(FunctionAST *node);
// Actually print spaces matching the current indentation level
void indent() {
@@ -68,8 +69,8 @@ private:
} // namespace
/// Return a formatted string for the location of any node
-template <typename T> static std::string loc(T *Node) {
- const auto &loc = Node->loc();
+template <typename T> static std::string loc(T *node) {
+ const auto &loc = node->loc();
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
llvm::Twine(loc.col))
.str();
@@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *lit_or_num) {
+void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
- if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+ if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
return;
}
- auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+ auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
// Print the dimension for this literal first
llvm::errs() << "<";
- {
- const char *sep = "";
- for (auto dim : literal->getDims()) {
- llvm::errs() << sep << dim;
- sep = ", ";
- }
- }
+ mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
- const char *sep = "";
- for (auto &elt : literal->getValues()) {
- llvm::errs() << sep;
- printLitHelper(elt.get());
- sep = ", ";
- }
+ mlir::interleaveComma(literal->getValues(), llvm::errs(),
+ [&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
/// Print a literal, see the recursive helper above for the implementation.
-void ASTDumper::dump(LiteralExprAST *Node) {
+void ASTDumper::dump(LiteralExprAST *node) {
INDENT();
llvm::errs() << "Literal: ";
- printLitHelper(Node);
- llvm::errs() << " " << loc(Node) << "\n";
+ printLitHelper(node);
+ llvm::errs() << " " << loc(node) << "\n";
}
/// Print a variable reference (just a name).
-void ASTDumper::dump(VariableExprAST *Node) {
+void ASTDumper::dump(VariableExprAST *node) {
INDENT();
- llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+ llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
}
/// Return statement print the return and its (optional) argument.
-void ASTDumper::dump(ReturnExprAST *Node) {
+void ASTDumper::dump(ReturnExprAST *node) {
INDENT();
llvm::errs() << "Return\n";
- if (Node->getExpr().hasValue())
- return dump(*Node->getExpr());
+ if (node->getExpr().hasValue())
+ return dump(*node->getExpr());
{
INDENT();
llvm::errs() << "(void)\n";
@@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
}
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
-void ASTDumper::dump(BinaryExprAST *Node) {
+void ASTDumper::dump(BinaryExprAST *node) {
INDENT();
- llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
- dump(Node->getLHS());
- dump(Node->getRHS());
+ llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+ dump(node->getLHS());
+ dump(node->getRHS());
}
/// Print a call expression, first the callee name and the list of args by
/// recursing into each individual argument.
-void ASTDumper::dump(CallExprAST *Node) {
+void ASTDumper::dump(CallExprAST *node) {
INDENT();
- llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
- for (auto &arg : Node->getArgs())
+ llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+ for (auto &arg : node->getArgs())
dump(arg.get());
indent();
llvm::errs() << "]\n";
}
/// Print a builtin print call, first the builtin name and then the argument.
-void ASTDumper::dump(PrintExprAST *Node) {
+void ASTDumper::dump(PrintExprAST *node) {
INDENT();
- llvm::errs() << "Print [ " << loc(Node) << "\n";
- dump(Node->getArg());
+ llvm::errs() << "Print [ " << loc(node) << "\n";
+ dump(node->getArg());
indent();
llvm::errs() << "]\n";
}
/// Print type: only the shape is printed in between '<' and '>'
-void ASTDumper::dump(VarType &type) {
+void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
- const char *sep = "";
- for (auto shape : type.shape) {
- llvm::errs() << sep << shape;
- sep = ", ";
- }
+ mlir::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
/// Print a function prototype, first the function name, and then the list of
/// parameters names.
-void ASTDumper::dump(PrototypeAST *Node) {
+void ASTDumper::dump(PrototypeAST *node) {
INDENT();
- llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
- const char *sep = "";
- for (auto &arg : Node->getArgs()) {
- llvm::errs() << sep << arg->getName();
- sep = ", ";
- }
+ mlir::interleaveComma(node->getArgs(), llvm::errs(),
+ [](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}
/// Print a function, first the prototype and then the body.
-void ASTDumper::dump(FunctionAST *Node) {
+void ASTDumper::dump(FunctionAST *node) {
INDENT();
llvm::errs() << "Function \n";
- dump(Node->getProto());
- dump(Node->getBody());
+ dump(node->getProto());
+ dump(node->getBody());
}
/// Print a module, actually loop over the functions and print them in sequence.
-void ASTDumper::dump(ModuleAST *Node) {
+void ASTDumper::dump(ModuleAST *node) {
INDENT();
llvm::errs() << "Module:\n";
- for (auto &F : *Node)
- dump(&F);
+ for (auto &f : *node)
+ dump(&f);
}
namespace toy {
diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp
index 54cbdf1405f..57fb1a95c94 100644
--- a/mlir/examples/toy/Ch5/toyc.cpp
+++ b/mlir/examples/toy/Ch5/toyc.cpp
@@ -66,20 +66,20 @@ static cl::opt<enum Action> emitAction(
cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine",
"output the MLIR dump after affine lowering")));
-static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
+static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
- if (std::error_code EC = FileOrErr.getError()) {
- llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return nullptr;
}
- auto buffer = FileOrErr.get()->getBuffer();
+ auto buffer = fileOrErr.get()->getBuffer();
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
Parser parser(lexer);
- return parser.ParseModule();
+ return parser.parseModule();
}
int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
@@ -128,7 +128,7 @@ int dumpMLIR() {
// Check to see what granularity of MLIR we are compiling to.
bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
- if (EnableOpt || isLoweringToAffine) {
+ if (enableOpt || isLoweringToAffine) {
// Inline all functions into main and then delete them.
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
@@ -150,7 +150,7 @@ int dumpMLIR() {
optPM.addPass(mlir::createCSEPass());
// Add optimizations if enabled.
- if (EnableOpt) {
+ if (enableOpt) {
optPM.addPass(mlir::createLoopFusionPass());
optPM.addPass(mlir::createMemRefDataFlowOptPass());
}
diff --git a/mlir/examples/toy/Ch6/include/toy/AST.h b/mlir/examples/toy/Ch6/include/toy/AST.h
index 2ad3392c11a..901164b0f39 100644
--- a/mlir/examples/toy/Ch6/include/toy/AST.h
+++ b/mlir/examples/toy/Ch6/include/toy/AST.h
@@ -54,7 +54,6 @@ public:
ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}
-
virtual ~ExprAST() = default;
ExprASTKind getKind() const { return kind; }
@@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
double Val;
public:
- NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+ NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
double getValue() { return Val; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
};
/// Expression class for a literal value.
@@ -93,10 +92,11 @@ public:
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}
- std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
- std::vector<int64_t> &getDims() { return dims; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
+ llvm::ArrayRef<int64_t> getDims() { return dims; }
+
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
};
/// Expression class for referencing a variable, like "a".
@@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
std::string name;
public:
- VariableExprAST(Location loc, const std::string &name)
+ VariableExprAST(Location loc, llvm::StringRef name)
: ExprAST(Expr_Var, loc), name(name) {}
llvm::StringRef getName() { return name; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
};
/// Expression class for defining a variable.
@@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
std::unique_ptr<ExprAST> initVal;
public:
- VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}
llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
- VarType &getType() { return type; }
+ const VarType &getType() { return type; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
};
/// Expression class for a return operator.
@@ -144,61 +144,61 @@ public:
llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
- return llvm::NoneType();
+ return llvm::None;
}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
};
/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
- char Op;
- std::unique_ptr<ExprAST> LHS, RHS;
+ char op;
+ std::unique_ptr<ExprAST> lhs, rhs;
public:
- char getOp() { return Op; }
- ExprAST *getLHS() { return LHS.get(); }
- ExprAST *getRHS() { return RHS.get(); }
+ char getOp() { return op; }
+ ExprAST *getLHS() { return lhs.get(); }
+ ExprAST *getRHS() { return rhs.get(); }
- BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
- std::unique_ptr<ExprAST> RHS)
- : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
- RHS(std::move(RHS)) {}
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
+ std::unique_ptr<ExprAST> rhs)
+ : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
+ rhs(std::move(rhs)) {}
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
};
/// Expression class for function calls.
class CallExprAST : public ExprAST {
- std::string Callee;
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::string callee;
+ std::vector<std::unique_ptr<ExprAST>> args;
public:
- CallExprAST(Location loc, const std::string &Callee,
- std::vector<std::unique_ptr<ExprAST>> Args)
- : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+ CallExprAST(Location loc, const std::string &callee,
+ std::vector<std::unique_ptr<ExprAST>> args)
+ : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
- llvm::StringRef getCallee() { return Callee; }
- llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+ llvm::StringRef getCallee() { return callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
};
/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
- std::unique_ptr<ExprAST> Arg;
+ std::unique_ptr<ExprAST> arg;
public:
- PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
- : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+ : ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
- ExprAST *getArg() { return Arg.get(); }
+ ExprAST *getArg() { return arg.get(); }
/// LLVM style RTTI
- static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
};
/// This class represents the "prototype" for a function, which captures its
@@ -215,23 +215,21 @@ public:
: location(location), name(name), args(std::move(args)) {}
const Location &loc() { return location; }
- const std::string &getName() const { return name; }
- const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
- return args;
- }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
};
/// This class represents a function definition itself.
class FunctionAST {
- std::unique_ptr<PrototypeAST> Proto;
- std::unique_ptr<ExprASTList> Body;
+ std::unique_ptr<PrototypeAST> proto;
+ std::unique_ptr<ExprASTList> body;
public:
- FunctionAST(std::unique_ptr<PrototypeAST> Proto,
- std::unique_ptr<ExprASTList> Body)
- : Proto(std::move(Proto)), Body(std::move(Body)) {}
- PrototypeAST *getProto() { return Proto.get(); }
- ExprASTList *getBody() { return Body.get(); }
+ FunctionAST(std::unique_ptr<PrototypeAST> proto,
+ std::unique_ptr<ExprASTList> body)
+ : proto(std::move(proto)), body(std::move(body)) {}
+ PrototypeAST *getProto() { return proto.get(); }
+ ExprASTList *getBody() { return body.get(); }
};
/// This class represents a list of functions to be processed together
diff --git a/mlir/examples/toy/Ch6/include/toy/Lexer.h b/mlir/examples/toy/Ch6/include/toy/Lexer.h
index 21f92614912..144388c460c 100644
--- a/mlir/examples/toy/Ch6/include/toy/Lexer.h
+++ b/mlir/examples/toy/Ch6/include/toy/Lexer.h
@@ -89,13 +89,13 @@ public:
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
llvm::StringRef getId() {
assert(curTok == tok_identifier);
- return IdentifierStr;
+ return identifierStr;
}
/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
- return NumVal;
+ return numVal;
}
/// Return the location for the beginning of the current token.
@@ -135,56 +135,58 @@ private:
/// Return the next token from standard input.
Token getTok() {
// Skip any whitespace.
- while (isspace(LastChar))
- LastChar = Token(getNextChar());
+ while (isspace(lastChar))
+ lastChar = Token(getNextChar());
// Save the current location before reading the token characters.
lastLocation.line = curLineNum;
lastLocation.col = curCol;
- if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
- IdentifierStr = (char)LastChar;
- while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
- IdentifierStr += (char)LastChar;
+ // Identifier: [a-zA-Z][a-zA-Z0-9_]*
+ if (isalpha(lastChar)) {
+ identifierStr = (char)lastChar;
+ while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
+ identifierStr += (char)lastChar;
- if (IdentifierStr == "return")
+ if (identifierStr == "return")
return tok_return;
- if (IdentifierStr == "def")
+ if (identifierStr == "def")
return tok_def;
- if (IdentifierStr == "var")
+ if (identifierStr == "var")
return tok_var;
return tok_identifier;
}
- if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
- std::string NumStr;
+ // Number: [0-9.]+
+ if (isdigit(lastChar) || lastChar == '.') {
+ std::string numStr;
do {
- NumStr += LastChar;
- LastChar = Token(getNextChar());
- } while (isdigit(LastChar) || LastChar == '.');
+ numStr += lastChar;
+ lastChar = Token(getNextChar());
+ } while (isdigit(lastChar) || lastChar == '.');
- NumVal = strtod(NumStr.c_str(), nullptr);
+ numVal = strtod(numStr.c_str(), nullptr);
return tok_number;
}
- if (LastChar == '#') {
+ if (lastChar == '#') {
// Comment until end of line.
- do
- LastChar = Token(getNextChar());
- while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+ do {
+ lastChar = Token(getNextChar());
+ } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
- if (LastChar != EOF)
+ if (lastChar != EOF)
return getTok();
}
// Check for end of file. Don't eat the EOF.
- if (LastChar == EOF)
+ if (lastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
- Token ThisChar = Token(LastChar);
- LastChar = Token(getNextChar());
- return ThisChar;
+ Token thisChar = Token(lastChar);
+ lastChar = Token(getNextChar());
+ return thisChar;
}
/// The last token read from the input.
@@ -194,15 +196,15 @@ private:
Location lastLocation;
/// If the current Token is an identifier, this string contains the value.
- std::string IdentifierStr;
+ std::string identifierStr;
/// If the current Token is a number, this contains the value.
- double NumVal = 0;
+ double numVal = 0;
/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
/// we can't put it back in the stream after reading from it.
- Token LastChar = Token(' ');
+ Token lastChar = Token(' ');
/// Keep track of the current line number in the input stream
int curLineNum = 0;
diff --git a/mlir/examples/toy/Ch6/include/toy/Parser.h b/mlir/examples/toy/Ch6/include/toy/Parser.h
index ec3d7654a85..9e219e56551 100644
--- a/mlir/examples/toy/Ch6/include/toy/Parser.h
+++ b/mlir/examples/toy/Ch6/include/toy/Parser.h
@@ -48,13 +48,13 @@ public:
Parser(Lexer &lexer) : lexer(lexer) {}
/// Parse a full Module. A module is a list of function definitions.
- std::unique_ptr<ModuleAST> ParseModule() {
+ std::unique_ptr<ModuleAST> parseModule() {
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<FunctionAST> functions;
- while (auto F = ParseDefinition()) {
- functions.push_back(std::move(*F));
+ while (auto f = parseDefinition()) {
+ functions.push_back(std::move(*f));
if (lexer.getCurToken() == tok_eof)
break;
}
@@ -70,14 +70,14 @@ private:
/// Parse a return statement.
/// return :== return ; | return expr ;
- std::unique_ptr<ReturnExprAST> ParseReturn() {
+ std::unique_ptr<ReturnExprAST> parseReturn() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
llvm::Optional<std::unique_ptr<ExprAST>> expr;
if (lexer.getCurToken() != ';') {
- expr = ParseExpression();
+ expr = parseExpression();
if (!expr)
return nullptr;
}
@@ -86,18 +86,18 @@ private:
/// Parse a literal number.
/// numberexpr ::= number
- std::unique_ptr<ExprAST> ParseNumberExpr() {
+ std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
- auto Result =
+ auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
- return std::move(Result);
+ return std::move(result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
- std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
+ std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
@@ -108,13 +108,13 @@ private:
do {
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[') {
- values.push_back(ParseTensorLiteralExpr());
+ values.push_back(parseTensorLiteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
- values.push_back(ParseNumberExpr());
+ values.push_back(parseNumberExpr());
}
// End of this list on ']'
@@ -130,8 +130,10 @@ private:
if (values.empty())
return parseError<ExprAST>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
+
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
+
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
@@ -143,7 +145,7 @@ private:
"inside literal expression");
// Append the nested dimensions to the current level
- auto &firstDims = firstLiteral->getDims();
+ auto firstDims = firstLiteral->getDims();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
@@ -162,22 +164,22 @@ private:
}
/// parenexpr ::= '(' expression ')'
- std::unique_ptr<ExprAST> ParseParenExpr() {
+ std::unique_ptr<ExprAST> parseParenExpr() {
lexer.getNextToken(); // eat (.
- auto V = ParseExpression();
- if (!V)
+ auto v = parseExpression();
+ if (!v)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExprAST>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
- return V;
+ return v;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
- std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::unique_ptr<ExprAST> parseIdentifierExpr() {
std::string name = lexer.getId();
auto loc = lexer.getLastLocation();
@@ -188,11 +190,11 @@ private:
// This is a function call.
lexer.consume(Token('('));
- std::vector<std::unique_ptr<ExprAST>> Args;
+ std::vector<std::unique_ptr<ExprAST>> args;
if (lexer.getCurToken() != ')') {
while (true) {
- if (auto Arg = ParseExpression())
- Args.push_back(std::move(Arg));
+ if (auto arg = parseExpression())
+ args.push_back(std::move(arg));
else
return nullptr;
@@ -208,14 +210,14 @@ private:
// It can be a builtin call to print
if (name == "print") {
- if (Args.size() != 1)
+ if (args.size() != 1)
return parseError<ExprAST>("<single arg>", "as argument to print()");
- return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
}
// Call to a user-defined function
- return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
}
/// primary
@@ -223,20 +225,20 @@ private:
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
- std::unique_ptr<ExprAST> ParsePrimary() {
+ std::unique_ptr<ExprAST> parsePrimary() {
switch (lexer.getCurToken()) {
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
- return ParseIdentifierExpr();
+ return parseIdentifierExpr();
case tok_number:
- return ParseNumberExpr();
+ return parseNumberExpr();
case '(':
- return ParseParenExpr();
+ return parseParenExpr();
case '[':
- return ParseTensorLiteralExpr();
+ return parseTensorLiteralExpr();
case ';':
return nullptr;
case '}':
@@ -248,54 +250,54 @@ private:
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
- std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
- std::unique_ptr<ExprAST> LHS) {
+ std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+ std::unique_ptr<ExprAST> lhs) {
// If this is a binop, find its precedence.
while (true) {
- int TokPrec = GetTokPrecedence();
+ int tokPrec = getTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
- if (TokPrec < ExprPrec)
- return LHS;
+ if (tokPrec < exprPrec)
+ return lhs;
// Okay, we know this is a binop.
- int BinOp = lexer.getCurToken();
- lexer.consume(Token(BinOp));
+ int binOp = lexer.getCurToken();
+ lexer.consume(Token(binOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
- auto RHS = ParsePrimary();
- if (!RHS)
+ auto rhs = parsePrimary();
+ if (!rhs)
return parseError<ExprAST>("expression", "to complete binary operator");
- // If BinOp binds less tightly with RHS than the operator after RHS, let
- // the pending operator take RHS as its LHS.
- int NextPrec = GetTokPrecedence();
- if (TokPrec < NextPrec) {
- RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
- if (!RHS)
+ // If BinOp binds less tightly with rhs than the operator after rhs, let
+ // the pending operator take rhs as its lhs.
+ int nextPrec = getTokPrecedence();
+ if (tokPrec < nextPrec) {
+ rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+ if (!rhs)
return nullptr;
}
- // Merge LHS/RHS.
- LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
- std::move(LHS), std::move(RHS));
+ // Merge lhs/RHS.
+ lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+ std::move(lhs), std::move(rhs));
}
}
- /// expression::= primary binoprhs
- std::unique_ptr<ExprAST> ParseExpression() {
- auto LHS = ParsePrimary();
- if (!LHS)
+ /// expression::= primary binop rhs
+ std::unique_ptr<ExprAST> parseExpression() {
+ auto lhs = parsePrimary();
+ if (!lhs)
return nullptr;
- return ParseBinOpRHS(0, std::move(LHS));
+ return parseBinOpRHS(0, std::move(lhs));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
- std::unique_ptr<VarType> ParseType() {
+ std::unique_ptr<VarType> parseType() {
if (lexer.getCurToken() != '<')
return parseError<VarType>("<", "to begin type");
lexer.getNextToken(); // eat <
@@ -319,7 +321,7 @@ private:
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
- std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ std::unique_ptr<VarDeclExprAST> parseDeclaration() {
if (lexer.getCurToken() != tok_var)
return parseError<VarDeclExprAST>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
@@ -333,7 +335,7 @@ private:
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<') {
- type = ParseType();
+ type = parseType();
if (!type)
return nullptr;
}
@@ -341,7 +343,7 @@ private:
if (!type)
type = std::make_unique<VarType>();
lexer.consume(Token('='));
- auto expr = ParseExpression();
+ auto expr = parseExpression();
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
@@ -352,7 +354,7 @@ private:
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
- std::unique_ptr<ExprASTList> ParseBlock() {
+ std::unique_ptr<ExprASTList> parseBlock() {
if (lexer.getCurToken() != '{')
return parseError<ExprASTList>("{", "to begin block");
lexer.consume(Token('{'));
@@ -366,19 +368,19 @@ private:
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
if (lexer.getCurToken() == tok_var) {
// Variable declaration
- auto varDecl = ParseDeclaration();
+ auto varDecl = parseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
} else if (lexer.getCurToken() == tok_return) {
// Return statement
- auto ret = ParseReturn();
+ auto ret = parseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
} else {
// General expression
- auto expr = ParseExpression();
+ auto expr = parseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
@@ -401,13 +403,13 @@ private:
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
- std::unique_ptr<PrototypeAST> ParsePrototype() {
+ std::unique_ptr<PrototypeAST> parsePrototype() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>("function name", "in prototype");
- std::string FnName = lexer.getId();
+ std::string fnName = lexer.getId();
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
@@ -435,7 +437,7 @@ private:
// success.
lexer.consume(Token(')'));
- return std::make_unique<PrototypeAST>(std::move(loc), FnName,
+ return std::make_unique<PrototypeAST>(std::move(loc), fnName,
std::move(args));
}
@@ -443,18 +445,18 @@ private:
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
- std::unique_ptr<FunctionAST> ParseDefinition() {
- auto Proto = ParsePrototype();
- if (!Proto)
+ std::unique_ptr<FunctionAST> parseDefinition() {
+ auto proto = parsePrototype();
+ if (!proto)
return nullptr;
- if (auto block = ParseBlock())
- return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ if (auto block = parseBlock())
+ return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
- int GetTokPrecedence() {
+ int getTokPrecedence() {
if (!isascii(lexer.getCurToken()))
return -1;
diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp
index 8b434b139c7..da474e809b3 100644
--- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp
@@ -143,7 +143,7 @@ private:
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
- auto &protoArgs = funcAST.getProto()->getArgs();
+ auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
diff --git a/mlir/examples/toy/Ch6/parser/AST.cpp b/mlir/examples/toy/Ch6/parser/AST.cpp
index fde8b101e83..0c7735ec9a4 100644
--- a/mlir/examples/toy/Ch6/parser/AST.cpp
+++ b/mlir/examples/toy/Ch6/parser/AST.cpp
@@ -21,6 +21,7 @@
#include "toy/AST.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@@ -40,22 +41,22 @@ struct Indent {
/// the way. The only data member is the current indentation level.
class ASTDumper {
public:
- void dump(ModuleAST *Node);
+ void dump(ModuleAST *node);
private:
- void dump(VarType &type);
+ void dump(const VarType &type);
void dump(VarDeclExprAST *varDecl);
void dump(ExprAST *expr);
void dump(ExprASTList *exprList);
void dump(NumberExprAST *num);
- void dump(LiteralExprAST *Node);
- void dump(VariableExprAST *Node);
- void dump(ReturnExprAST *Node);
- void dump(BinaryExprAST *Node);
- void dump(CallExprAST *Node);
- void dump(PrintExprAST *Node);
- void dump(PrototypeAST *Node);
- void dump(FunctionAST *Node);
+ void dump(LiteralExprAST *node);
+ void dump(VariableExprAST *node);
+ void dump(ReturnExprAST *node);
+ void dump(BinaryExprAST *node);
+ void dump(CallExprAST *node);
+ void dump(PrintExprAST *node);
+ void dump(PrototypeAST *node);
+ void dump(FunctionAST *node);
// Actually print spaces matching the current indentation level
void indent() {
@@ -68,8 +69,8 @@ private:
} // namespace
/// Return a formatted string for the location of any node
-template <typename T> static std::string loc(T *Node) {
- const auto &loc = Node->loc();
+template <typename T> static std::string loc(T *node) {
+ const auto &loc = node->loc();
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
llvm::Twine(loc.col))
.str();
@@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *lit_or_num) {
+void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
- if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+ if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
return;
}
- auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+ auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
// Print the dimension for this literal first
llvm::errs() << "<";
- {
- const char *sep = "";
- for (auto dim : literal->getDims()) {
- llvm::errs() << sep << dim;
- sep = ", ";
- }
- }
+ mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
- const char *sep = "";
- for (auto &elt : literal->getValues()) {
- llvm::errs() << sep;
- printLitHelper(elt.get());
- sep = ", ";
- }
+ mlir::interleaveComma(literal->getValues(), llvm::errs(),
+ [&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
/// Print a literal, see the recursive helper above for the implementation.
-void ASTDumper::dump(LiteralExprAST *Node) {
+void ASTDumper::dump(LiteralExprAST *node) {
INDENT();
llvm::errs() << "Literal: ";
- printLitHelper(Node);
- llvm::errs() << " " << loc(Node) << "\n";
+ printLitHelper(node);
+ llvm::errs() << " " << loc(node) << "\n";
}
/// Print a variable reference (just a name).
-void ASTDumper::dump(VariableExprAST *Node) {
+void ASTDumper::dump(VariableExprAST *node) {
INDENT();
- llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+ llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
}
/// Return statement print the return and its (optional) argument.
-void ASTDumper::dump(ReturnExprAST *Node) {
+void ASTDumper::dump(ReturnExprAST *node) {
INDENT();
llvm::errs() << "Return\n";
- if (Node->getExpr().hasValue())
- return dump(*Node->getExpr());
+ if (node->getExpr().hasValue())
+ return dump(*node->getExpr());
{
INDENT();
llvm::errs() << "(void)\n";
@@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
}
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
-void ASTDumper::dump(BinaryExprAST *Node) {
+void ASTDumper::dump(BinaryExprAST *node) {
INDENT();
- llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
- dump(Node->getLHS());
- dump(Node->getRHS());
+ llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+ dump(node->getLHS());
+ dump(node->getRHS());
}
/// Print a call expression, first the callee name and the list of args by
/// recursing into each individual argument.
-void ASTDumper::dump(CallExprAST *Node) {
+void ASTDumper::dump(CallExprAST *node) {
INDENT();
- llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
- for (auto &arg : Node->getArgs())
+ llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+ for (auto &arg : node->getArgs())
dump(arg.get());
indent();
llvm::errs() << "]\n";
}
/// Print a builtin print call, first the builtin name and then the argument.
-void ASTDumper::dump(PrintExprAST *Node) {
+void ASTDumper::dump(PrintExprAST *node) {
INDENT();
- llvm::errs() << "Print [ " << loc(Node) << "\n";
- dump(Node->getArg());
+ llvm::errs() << "Print [ " << loc(node) << "\n";
+ dump(node->getArg());
indent();
llvm::errs() << "]\n";
}
/// Print type: only the shape is printed in between '<' and '>'
-void ASTDumper::dump(VarType &type) {
+void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
- const char *sep = "";
- for (auto shape : type.shape) {
- llvm::errs() << sep << shape;
- sep = ", ";
- }
+ mlir::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
/// Print a function prototype, first the function name, and then the list of
/// parameters names.
-void ASTDumper::dump(PrototypeAST *Node) {
+void ASTDumper::dump(PrototypeAST *node) {
INDENT();
- llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
- const char *sep = "";
- for (auto &arg : Node->getArgs()) {
- llvm::errs() << sep << arg->getName();
- sep = ", ";
- }
+ mlir::interleaveComma(node->getArgs(), llvm::errs(),
+ [](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}
/// Print a function, first the prototype and then the body.
-void ASTDumper::dump(FunctionAST *Node) {
+void ASTDumper::dump(FunctionAST *node) {
INDENT();
llvm::errs() << "Function \n";
- dump(Node->getProto());
- dump(Node->getBody());
+ dump(node->getProto());
+ dump(node->getBody());
}
/// Print a module, actually loop over the functions and print them in sequence.
-void ASTDumper::dump(ModuleAST *Node) {
+void ASTDumper::dump(ModuleAST *node) {
INDENT();
llvm::errs() << "Module:\n";
- for (auto &F : *Node)
- dump(&F);
+ for (auto &f : *node)
+ dump(&f);
}
namespace toy {
diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp
index b7aed1b440c..fb7cd5493a2 100644
--- a/mlir/examples/toy/Ch6/toyc.cpp
+++ b/mlir/examples/toy/Ch6/toyc.cpp
@@ -85,20 +85,20 @@ static cl::opt<enum Action> emitAction(
clEnumValN(RunJIT, "jit",
"JIT the code and run it by invoking the main function")));
-static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
+static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
- if (std::error_code EC = FileOrErr.getError()) {
- llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return nullptr;
}
- auto buffer = FileOrErr.get()->getBuffer();
+ auto buffer = fileOrErr.get()->getBuffer();
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
Parser parser(lexer);
- return parser.ParseModule();
+ return parser.parseModule();
}
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
@@ -142,7 +142,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM;
- if (EnableOpt || isLoweringToAffine) {
+ if (enableOpt || isLoweringToAffine) {
// Inline all functions into main and then delete them.
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
@@ -164,7 +164,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
optPM.addPass(mlir::createCSEPass());
// Add optimizations if enabled.
- if (EnableOpt) {
+ if (enableOpt) {
optPM.addPass(mlir::createLoopFusionPass());
optPM.addPass(mlir::createMemRefDataFlowOptPass());
}
@@ -208,7 +208,7 @@ int dumpLLVMIR(mlir::ModuleOp module) {
/// Optionally run an optimization pipeline over the llvm module.
auto optPipeline = mlir::makeOptimizingTransformer(
- /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
+ /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
if (auto err = optPipeline(llvmModule.get())) {
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
@@ -225,7 +225,7 @@ int runJit(mlir::ModuleOp module) {
// An optimization pipeline to use within the execution engine.
auto optPipeline = mlir::makeOptimizingTransformer(
- /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
+ /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
OpenPOWER on IntegriCloud