summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Parser
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Parser')
-rw-r--r--mlir/lib/Parser/CMakeLists.txt10
-rw-r--r--mlir/lib/Parser/Lexer.cpp394
-rw-r--r--mlir/lib/Parser/Lexer.h73
-rw-r--r--mlir/lib/Parser/Parser.cpp4825
-rw-r--r--mlir/lib/Parser/Token.cpp155
-rw-r--r--mlir/lib/Parser/Token.h107
-rw-r--r--mlir/lib/Parser/TokenKinds.def124
7 files changed, 5688 insertions, 0 deletions
diff --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt
new file mode 100644
index 00000000000..9fd29ae7879
--- /dev/null
+++ b/mlir/lib/Parser/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRParser
+ Lexer.cpp
+ Parser.cpp
+ Token.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Parser
+ )
+add_dependencies(MLIRParser MLIRIR MLIRAnalysis)
+target_link_libraries(MLIRParser MLIRIR MLIRAnalysis)
diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp
new file mode 100644
index 00000000000..7d8337a9cb3
--- /dev/null
+++ b/mlir/lib/Parser/Lexer.cpp
@@ -0,0 +1,394 @@
+//===- Lexer.cpp - MLIR Lexer Implementation ------------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lexer for the MLIR textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Lexer.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/Support/SourceMgr.h"
+using namespace mlir;
+
+using llvm::SMLoc;
+using llvm::SourceMgr;
+
+// Returns true if 'c' is an allowable punctuation character: [$._-]
+// Returns false otherwise.
+static bool isPunct(char c) {
+ return c == '$' || c == '.' || c == '_' || c == '-';
+}
+
+Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context)
+ : sourceMgr(sourceMgr), context(context) {
+ auto bufferID = sourceMgr.getMainFileID();
+ curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
+ curPtr = curBuffer.begin();
+}
+
+/// Encode the specified source location information into an attribute for
+/// attachment to the IR.
+Location Lexer::getEncodedSourceLocation(llvm::SMLoc loc) {
+ auto &sourceMgr = getSourceMgr();
+ unsigned mainFileID = sourceMgr.getMainFileID();
+ auto lineAndColumn = sourceMgr.getLineAndColumn(loc, mainFileID);
+ auto *buffer = sourceMgr.getMemoryBuffer(mainFileID);
+
+ return FileLineColLoc::get(buffer->getBufferIdentifier(), lineAndColumn.first,
+ lineAndColumn.second, context);
+}
+
+/// emitError - Emit an error message and return an Token::error token.
+Token Lexer::emitError(const char *loc, const Twine &message) {
+ mlir::emitError(getEncodedSourceLocation(SMLoc::getFromPointer(loc)),
+ message);
+ return formToken(Token::error, loc);
+}
+
+Token Lexer::lexToken() {
+ while (true) {
+ const char *tokStart = curPtr;
+ switch (*curPtr++) {
+ default:
+ // Handle bare identifiers.
+ if (isalpha(curPtr[-1]))
+ return lexBareIdentifierOrKeyword(tokStart);
+
+ // Unknown character, emit an error.
+ return emitError(tokStart, "unexpected character");
+
+ case ' ':
+ case '\t':
+ case '\n':
+ case '\r':
+ // Handle whitespace.
+ continue;
+
+ case '_':
+ // Handle bare identifiers.
+ return lexBareIdentifierOrKeyword(tokStart);
+
+ case 0:
+ // This may either be a nul character in the source file or may be the EOF
+ // marker that llvm::MemoryBuffer guarantees will be there.
+ if (curPtr - 1 == curBuffer.end())
+ return formToken(Token::eof, tokStart);
+
+ LLVM_FALLTHROUGH;
+ case ':':
+ return formToken(Token::colon, tokStart);
+ case ',':
+ return formToken(Token::comma, tokStart);
+ case '.':
+ return lexEllipsis(tokStart);
+ case '(':
+ return formToken(Token::l_paren, tokStart);
+ case ')':
+ return formToken(Token::r_paren, tokStart);
+ case '{':
+ return formToken(Token::l_brace, tokStart);
+ case '}':
+ return formToken(Token::r_brace, tokStart);
+ case '[':
+ return formToken(Token::l_square, tokStart);
+ case ']':
+ return formToken(Token::r_square, tokStart);
+ case '<':
+ return formToken(Token::less, tokStart);
+ case '>':
+ return formToken(Token::greater, tokStart);
+ case '=':
+ return formToken(Token::equal, tokStart);
+
+ case '+':
+ return formToken(Token::plus, tokStart);
+ case '*':
+ return formToken(Token::star, tokStart);
+ case '-':
+ if (*curPtr == '>') {
+ ++curPtr;
+ return formToken(Token::arrow, tokStart);
+ }
+ return formToken(Token::minus, tokStart);
+
+ case '?':
+ return formToken(Token::question, tokStart);
+
+ case '/':
+ if (*curPtr == '/') {
+ skipComment();
+ continue;
+ }
+ return emitError(tokStart, "unexpected character");
+
+ case '@':
+ return lexAtIdentifier(tokStart);
+
+ case '!':
+ LLVM_FALLTHROUGH;
+ case '^':
+ LLVM_FALLTHROUGH;
+ case '#':
+ LLVM_FALLTHROUGH;
+ case '%':
+ return lexPrefixedIdentifier(tokStart);
+ case '"':
+ return lexString(tokStart);
+
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ return lexNumber(tokStart);
+ }
+ }
+}
+
+/// Lex an '@foo' identifier.
+///
+/// symbol-ref-id ::= `@` (bare-id | string-literal)
+///
+Token Lexer::lexAtIdentifier(const char *tokStart) {
+ char cur = *curPtr++;
+
+ // Try to parse a string literal, if present.
+ if (cur == '"') {
+ Token stringIdentifier = lexString(curPtr);
+ if (stringIdentifier.is(Token::error))
+ return stringIdentifier;
+ return formToken(Token::at_identifier, tokStart);
+ }
+
+ // Otherwise, these always start with a letter or underscore.
+ if (!isalpha(cur) && cur != '_')
+ return emitError(curPtr - 1,
+ "@ identifier expected to start with letter or '_'");
+
+ while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
+ *curPtr == '$' || *curPtr == '.')
+ ++curPtr;
+ return formToken(Token::at_identifier, tokStart);
+}
+
+/// Lex a bare identifier or keyword that starts with a letter.
+///
+/// bare-id ::= (letter|[_]) (letter|digit|[_$.])*
+/// integer-type ::= `i[1-9][0-9]*`
+///
+Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
+ // Match the rest of the identifier regex: [0-9a-zA-Z_.$]*
+ while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
+ *curPtr == '$' || *curPtr == '.')
+ ++curPtr;
+
+ // Check to see if this identifier is a keyword.
+ StringRef spelling(tokStart, curPtr - tokStart);
+
+ // Check for i123.
+ if (tokStart[0] == 'i') {
+ bool allDigits = true;
+ for (auto c : spelling.drop_front())
+ allDigits &= isdigit(c) != 0;
+ if (allDigits && spelling.size() != 1)
+ return Token(Token::inttype, spelling);
+ }
+
+ Token::Kind kind = llvm::StringSwitch<Token::Kind>(spelling)
+#define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING)
+#include "TokenKinds.def"
+ .Default(Token::bare_identifier);
+
+ return Token(kind, spelling);
+}
+
+/// Skip a comment line, starting with a '//'.
+///
+/// TODO: add a regex for comments here and to the spec.
+///
+void Lexer::skipComment() {
+ // Advance over the second '/' in a '//' comment.
+ assert(*curPtr == '/');
+ ++curPtr;
+
+ while (true) {
+ switch (*curPtr++) {
+ case '\n':
+ case '\r':
+ // Newline is end of comment.
+ return;
+ case 0:
+ // If this is the end of the buffer, end the comment.
+ if (curPtr - 1 == curBuffer.end()) {
+ --curPtr;
+ return;
+ }
+ LLVM_FALLTHROUGH;
+ default:
+ // Skip over other characters.
+ break;
+ }
+ }
+}
+
+/// Lex an ellipsis.
+///
+/// ellipsis ::= '...'
+///
+Token Lexer::lexEllipsis(const char *tokStart) {
+ assert(curPtr[-1] == '.');
+
+ if (curPtr == curBuffer.end() || *curPtr != '.' || *(curPtr + 1) != '.')
+ return emitError(curPtr, "expected three consecutive dots for an ellipsis");
+
+ curPtr += 2;
+ return formToken(Token::ellipsis, tokStart);
+}
+
+/// Lex a number literal.
+///
+/// integer-literal ::= digit+ | `0x` hex_digit+
+/// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
+///
+Token Lexer::lexNumber(const char *tokStart) {
+ assert(isdigit(curPtr[-1]));
+
+ // Handle the hexadecimal case.
+ if (curPtr[-1] == '0' && *curPtr == 'x') {
+ // If we see stuff like 0xi32, this is a literal `0` followed by an
+ // identifier `xi32`, stop after `0`.
+ if (!isxdigit(curPtr[1]))
+ return formToken(Token::integer, tokStart);
+
+ curPtr += 2;
+ while (isxdigit(*curPtr))
+ ++curPtr;
+
+ return formToken(Token::integer, tokStart);
+ }
+
+ // Handle the normal decimal case.
+ while (isdigit(*curPtr))
+ ++curPtr;
+
+ if (*curPtr != '.')
+ return formToken(Token::integer, tokStart);
+ ++curPtr;
+
+ // Skip over [0-9]*([eE][-+]?[0-9]+)?
+ while (isdigit(*curPtr))
+ ++curPtr;
+
+ if (*curPtr == 'e' || *curPtr == 'E') {
+ if (isdigit(static_cast<unsigned char>(curPtr[1])) ||
+ ((curPtr[1] == '-' || curPtr[1] == '+') &&
+ isdigit(static_cast<unsigned char>(curPtr[2])))) {
+ curPtr += 2;
+ while (isdigit(*curPtr))
+ ++curPtr;
+ }
+ }
+ return formToken(Token::floatliteral, tokStart);
+}
+
+/// Lex an identifier that starts with a prefix followed by suffix-id.
+///
+/// attribute-id ::= `#` suffix-id
+/// ssa-id ::= '%' suffix-id
+/// block-id ::= '^' suffix-id
+/// type-id ::= '!' suffix-id
+/// suffix-id ::= digit+ | (letter|id-punct) (letter|id-punct|digit)*
+/// id-punct ::= `$` | `.` | `_` | `-`
+///
+Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
+ Token::Kind kind;
+ StringRef errorKind;
+ switch (*tokStart) {
+ case '#':
+ kind = Token::hash_identifier;
+ errorKind = "invalid attribute name";
+ break;
+ case '%':
+ kind = Token::percent_identifier;
+ errorKind = "invalid SSA name";
+ break;
+ case '^':
+ kind = Token::caret_identifier;
+ errorKind = "invalid block name";
+ break;
+ case '!':
+ kind = Token::exclamation_identifier;
+ errorKind = "invalid type identifier";
+ break;
+ default:
+ llvm_unreachable("invalid caller");
+ }
+
+ // Parse suffix-id.
+ if (isdigit(*curPtr)) {
+ // If suffix-id starts with a digit, the rest must be digits.
+ while (isdigit(*curPtr)) {
+ ++curPtr;
+ }
+ } else if (isalpha(*curPtr) || isPunct(*curPtr)) {
+ do {
+ ++curPtr;
+ } while (isalpha(*curPtr) || isdigit(*curPtr) || isPunct(*curPtr));
+ } else {
+ return emitError(curPtr - 1, errorKind);
+ }
+
+ return formToken(kind, tokStart);
+}
+
+/// Lex a string literal.
+///
+/// string-literal ::= '"' [^"\n\f\v\r]* '"'
+///
+/// TODO: define escaping rules.
+Token Lexer::lexString(const char *tokStart) {
+ assert(curPtr[-1] == '"');
+
+ while (true) {
+ switch (*curPtr++) {
+ case '"':
+ return formToken(Token::string, tokStart);
+ case 0:
+ // If this is a random nul character in the middle of a string, just
+ // include it. If it is the end of file, then it is an error.
+ if (curPtr - 1 != curBuffer.end())
+ continue;
+ LLVM_FALLTHROUGH;
+ case '\n':
+ case '\v':
+ case '\f':
+ return emitError(curPtr - 1, "expected '\"' in string literal");
+ case '\\':
+ // Handle explicitly a few escapes.
+ if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
+ ++curPtr;
+ else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
+ // Support \xx for two hex digits.
+ curPtr += 2;
+ else
+ return emitError(curPtr - 1, "unknown escape in string literal");
+ continue;
+
+ default:
+ continue;
+ }
+ }
+}
diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h
new file mode 100644
index 00000000000..a760dca9396
--- /dev/null
+++ b/mlir/lib/Parser/Lexer.h
@@ -0,0 +1,73 @@
+//===- Lexer.h - MLIR Lexer Interface ---------------------------*- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the MLIR Lexer class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LIB_PARSER_LEXER_H
+#define MLIR_LIB_PARSER_LEXER_H
+
+#include "Token.h"
+#include "mlir/Parser.h"
+
+namespace mlir {
+class Location;
+
+/// This class breaks up the current file into a token stream.
+class Lexer {
+public:
+ explicit Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
+
+ const llvm::SourceMgr &getSourceMgr() { return sourceMgr; }
+
+ Token lexToken();
+
+ /// Encode the specified source location information into a Location object
+ /// for attachment to the IR or error reporting.
+ Location getEncodedSourceLocation(llvm::SMLoc loc);
+
+ /// Change the position of the lexer cursor. The next token we lex will start
+ /// at the designated point in the input.
+ void resetPointer(const char *newPointer) { curPtr = newPointer; }
+
+ /// Returns the start of the buffer.
+ const char *getBufferBegin() { return curBuffer.data(); }
+
+private:
+ // Helpers.
+ Token formToken(Token::Kind kind, const char *tokStart) {
+ return Token(kind, StringRef(tokStart, curPtr - tokStart));
+ }
+
+ Token emitError(const char *loc, const Twine &message);
+
+ // Lexer implementation methods.
+ Token lexAtIdentifier(const char *tokStart);
+ Token lexBareIdentifierOrKeyword(const char *tokStart);
+ Token lexEllipsis(const char *tokStart);
+ Token lexNumber(const char *tokStart);
+ Token lexPrefixedIdentifier(const char *tokStart);
+ Token lexString(const char *tokStart);
+
+ /// Skip a comment line, starting with a '//'.
+ void skipComment();
+
+ const llvm::SourceMgr &sourceMgr;
+ MLIRContext *context;
+
+ StringRef curBuffer;
+ const char *curPtr;
+
+ Lexer(const Lexer &) = delete;
+ void operator=(const Lexer &) = delete;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_LIB_PARSER_LEXER_H
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
new file mode 100644
index 00000000000..0198a45172b
--- /dev/null
+++ b/mlir/lib/Parser/Parser.cpp
@@ -0,0 +1,4825 @@
+//===- Parser.cpp - MLIR Parser Implementation ----------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the parser for the MLIR textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Parser.h"
+#include "Lexer.h"
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/bit.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/SourceMgr.h"
+#include <algorithm>
+using namespace mlir;
+using llvm::MemoryBuffer;
+using llvm::SMLoc;
+using llvm::SourceMgr;
+
+namespace {
+class Parser;
+
+//===----------------------------------------------------------------------===//
+// SymbolState
+//===----------------------------------------------------------------------===//
+
+/// This class contains record of any parsed top-level symbols.
+struct SymbolState {
+ // A map from attribute alias identifier to Attribute.
+ llvm::StringMap<Attribute> attributeAliasDefinitions;
+
+ // A map from type alias identifier to Type.
+ llvm::StringMap<Type> typeAliasDefinitions;
+
+ /// A set of locations into the main parser memory buffer for each of the
+ /// active nested parsers. Given that some nested parsers, i.e. custom dialect
+ /// parsers, operate on a temporary memory buffer, this provides an anchor
+ /// point for emitting diagnostics.
+ SmallVector<llvm::SMLoc, 1> nestedParserLocs;
+
+ /// The top-level lexer that contains the original memory buffer provided by
+ /// the user. This is used by nested parsers to get a properly encoded source
+ /// location.
+ Lexer *topLevelLexer = nullptr;
+};
+
+//===----------------------------------------------------------------------===//
+// ParserState
+//===----------------------------------------------------------------------===//
+
+/// This class refers to all of the state maintained globally by the parser,
+/// such as the current lexer position etc.
+struct ParserState {
+ ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
+ SymbolState &symbols)
+ : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
+ symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) {
+ // Set the top level lexer for the symbol state if one doesn't exist.
+ if (!symbols.topLevelLexer)
+ symbols.topLevelLexer = &lex;
+ }
+ ~ParserState() {
+ // Reset the top level lexer if it refers the lexer in our state.
+ if (symbols.topLevelLexer == &lex)
+ symbols.topLevelLexer = nullptr;
+ }
+ ParserState(const ParserState &) = delete;
+ void operator=(const ParserState &) = delete;
+
+ /// The context we're parsing into.
+ MLIRContext *const context;
+
+ /// The lexer for the source file we're parsing.
+ Lexer lex;
+
+ /// This is the next token that hasn't been consumed yet.
+ Token curToken;
+
+ /// The current state for symbol parsing.
+ SymbolState &symbols;
+
+ /// The depth of this parser in the nested parsing stack.
+ size_t parserDepth;
+};
+
+//===----------------------------------------------------------------------===//
+// Parser
+//===----------------------------------------------------------------------===//
+
+/// This class implement support for parsing global entities like types and
+/// shared entities like SSA names. It is intended to be subclassed by
+/// specialized subparsers that include state, e.g. when a local symbol table.
+class Parser {
+public:
+ Builder builder;
+
+ Parser(ParserState &state) : builder(state.context), state(state) {}
+
+ // Helper methods to get stuff from the parser-global state.
+ ParserState &getState() const { return state; }
+ MLIRContext *getContext() const { return state.context; }
+ const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
+
+ /// Parse a comma-separated list of elements up until the specified end token.
+ ParseResult
+ parseCommaSeparatedListUntil(Token::Kind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList = true);
+
+ /// Parse a comma separated list of elements that must have at least one entry
+ /// in it.
+ ParseResult
+ parseCommaSeparatedList(const std::function<ParseResult()> &parseElement);
+
+ ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
+
+ // We have two forms of parsing methods - those that return a non-null
+ // pointer on success, and those that return a ParseResult to indicate whether
+ // they returned a failure. The second class fills in by-reference arguments
+ // as the results of their action.
+
+ //===--------------------------------------------------------------------===//
+ // Error Handling
+ //===--------------------------------------------------------------------===//
+
+ /// Emit an error and return failure.
+ InFlightDiagnostic emitError(const Twine &message = {}) {
+ return emitError(state.curToken.getLoc(), message);
+ }
+ InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {});
+
+ /// Encode the specified source location information into an attribute for
+ /// attachment to the IR.
+ Location getEncodedSourceLocation(llvm::SMLoc loc) {
+ // If there are no active nested parsers, we can get the encoded source
+ // location directly.
+ if (state.parserDepth == 0)
+ return state.lex.getEncodedSourceLocation(loc);
+ // Otherwise, we need to re-encode it to point to the top level buffer.
+ return state.symbols.topLevelLexer->getEncodedSourceLocation(
+ remapLocationToTopLevelBuffer(loc));
+ }
+
+ /// Remaps the given SMLoc to the top level lexer of the parser. This is used
+ /// to adjust locations of potentially nested parsers to ensure that they can
+ /// be emitted properly as diagnostics.
+ llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) {
+ // If there are no active nested parsers, we can return location directly.
+ SymbolState &symbols = state.symbols;
+ if (state.parserDepth == 0)
+ return loc;
+ assert(symbols.topLevelLexer && "expected valid top-level lexer");
+
+ // Otherwise, we need to remap the location to the main parser. This is
+ // simply offseting the location onto the location of the last nested
+ // parser.
+ size_t offset = loc.getPointer() - state.lex.getBufferBegin();
+ auto *rawLoc =
+ symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset;
+ return llvm::SMLoc::getFromPointer(rawLoc);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Token Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Return the current token the parser is inspecting.
+ const Token &getToken() const { return state.curToken; }
+ StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
+
+ /// If the current token has the specified kind, consume it and return true.
+ /// If not, return false.
+ bool consumeIf(Token::Kind kind) {
+ if (state.curToken.isNot(kind))
+ return false;
+ consumeToken(kind);
+ return true;
+ }
+
+ /// Advance the current lexer onto the next token.
+ void consumeToken() {
+ assert(state.curToken.isNot(Token::eof, Token::error) &&
+ "shouldn't advance past EOF or errors");
+ state.curToken = state.lex.lexToken();
+ }
+
+ /// Advance the current lexer onto the next token, asserting what the expected
+ /// current token is. This is preferred to the above method because it leads
+ /// to more self-documenting code with better checking.
+ void consumeToken(Token::Kind kind) {
+ assert(state.curToken.is(kind) && "consumed an unexpected token");
+ consumeToken();
+ }
+
+ /// Consume the specified token if present and return success. On failure,
+ /// output a diagnostic and return failure.
+ ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
+
+ //===--------------------------------------------------------------------===//
+ // Type Parsing
+ //===--------------------------------------------------------------------===//
+
+ ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
+ ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
+ ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
+
+ /// Parse an arbitrary type.
+ Type parseType();
+
+ /// Parse a complex type.
+ Type parseComplexType();
+
+ /// Parse an extended type.
+ Type parseExtendedType();
+
+ /// Parse a function type.
+ Type parseFunctionType();
+
+ /// Parse a memref type.
+ Type parseMemRefType();
+
+ /// Parse a non function type.
+ Type parseNonFunctionType();
+
+ /// Parse a tensor type.
+ Type parseTensorType();
+
+ /// Parse a tuple type.
+ Type parseTupleType();
+
+ /// Parse a vector type.
+ VectorType parseVectorType();
+ ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic = true);
+ ParseResult parseXInDimensionList();
+
+ /// Parse strided layout specification.
+ ParseResult parseStridedLayout(int64_t &offset,
+ SmallVectorImpl<int64_t> &strides);
+
+ // Parse a brace-delimiter list of comma-separated integers with `?` as an
+ // unknown marker.
+ ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions);
+
+ //===--------------------------------------------------------------------===//
+ // Attribute Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an arbitrary attribute with an optional type.
+ Attribute parseAttribute(Type type = {});
+
+ /// Parse an attribute dictionary.
+ ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
+
+ /// Parse an extended attribute.
+ Attribute parseExtendedAttr(Type type);
+
+ /// Parse a float attribute.
+ Attribute parseFloatAttr(Type type, bool isNegative);
+
+ /// Parse a decimal or a hexadecimal literal, which can be either an integer
+ /// or a float attribute.
+ Attribute parseDecOrHexAttr(Type type, bool isNegative);
+
+ /// Parse an opaque elements attribute.
+ Attribute parseOpaqueElementsAttr();
+
+ /// Parse a dense elements attribute.
+ Attribute parseDenseElementsAttr();
+ ShapedType parseElementsLiteralType();
+
+ /// Parse a sparse elements attribute.
+ Attribute parseSparseElementsAttr();
+
+ //===--------------------------------------------------------------------===//
+ // Location Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an inline location.
+ ParseResult parseLocation(LocationAttr &loc);
+
+ /// Parse a raw location instance.
+ ParseResult parseLocationInstance(LocationAttr &loc);
+
+ /// Parse a callsite location instance.
+ ParseResult parseCallSiteLocation(LocationAttr &loc);
+
+ /// Parse a fused location instance.
+ ParseResult parseFusedLocation(LocationAttr &loc);
+
+ /// Parse a name or FileLineCol location instance.
+ ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
+
+ /// Parse an optional trailing location.
+ ///
+ /// trailing-location ::= (`loc` `(` location `)`)?
+ ///
+ ParseResult parseOptionalTrailingLocation(Location &loc) {
+ // If there is a 'loc' we parse a trailing location.
+ if (!getToken().is(Token::kw_loc))
+ return success();
+
+ // Parse the location.
+ LocationAttr directLoc;
+ if (parseLocation(directLoc))
+ return failure();
+ loc = directLoc;
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Affine Parsing
+ //===--------------------------------------------------------------------===//
+
+ ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
+ IntegerSet &set);
+
+ /// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+ ParseResult
+ parseAffineMapOfSSAIds(AffineMap &map,
+ function_ref<ParseResult(bool)> parseElement);
+
+private:
+ /// The Parser is subclassed and reinstantiated. Do not add additional
+ /// non-trivial state here, add it to the ParserState class.
+ ParserState &state;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Parse a comma separated list of elements that must have at least one entry
+/// in it.
+ParseResult Parser::parseCommaSeparatedList(
+ const std::function<ParseResult()> &parseElement) {
+ // Non-empty case starts with an element.
+ if (parseElement())
+ return failure();
+
+ // Otherwise we have a list of comma separated elements.
+ while (consumeIf(Token::comma)) {
+ if (parseElement())
+ return failure();
+ }
+ return success();
+}
+
+/// Parse a comma-separated list of elements, terminated with an arbitrary
+/// token. This allows empty lists if allowEmptyList is true.
+///
+/// abstract-list ::= rightToken // if allowEmptyList == true
+/// abstract-list ::= element (',' element)* rightToken
+///
+ParseResult Parser::parseCommaSeparatedListUntil(
+ Token::Kind rightToken, const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList) {
+ // Handle the empty case.
+ if (getToken().is(rightToken)) {
+ if (!allowEmptyList)
+ return emitError("expected list element");
+ consumeToken(rightToken);
+ return success();
+ }
+
+ if (parseCommaSeparatedList(parseElement) ||
+ parseToken(rightToken, "expected ',' or '" +
+ Token::getTokenSpelling(rightToken) + "'"))
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// DialectAsmParser
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides the main implementation of the DialectAsmParser that
+/// allows for dialects to parse attributes and types. This allows for dialect
+/// hooking into the main MLIR parsing logic.
+class CustomDialectAsmParser : public DialectAsmParser {
+public:
+ CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
+ : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()),
+ parser(parser) {}
+ ~CustomDialectAsmParser() override {}
+
+ /// Emit a diagnostic at the specified location and return failure.
+ InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
+ return parser.emitError(loc, message);
+ }
+
+ /// Return a builder which provides useful access to MLIRContext, global
+ /// objects like types and attributes.
+ Builder &getBuilder() const override { return parser.builder; }
+
+ /// Get the location of the next token and store it into the argument. This
+ /// always succeeds.
+ llvm::SMLoc getCurrentLocation() override {
+ return parser.getToken().getLoc();
+ }
+
+ /// Return the location of the original name token.
+ llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+ /// Re-encode the given source location as an MLIR location and return it.
+ Location getEncodedSourceLoc(llvm::SMLoc loc) override {
+ return parser.getEncodedSourceLocation(loc);
+ }
+
+ /// Returns the full specification of the symbol being parsed. This allows
+ /// for using a separate parser if necessary.
+ StringRef getFullSymbolSpec() const override { return fullSpec; }
+
+ /// Parse a floating point value from the stream.
+ ParseResult parseFloat(double &result) override {
+ bool negative = parser.consumeIf(Token::minus);
+ Token curTok = parser.getToken();
+
+ // Check for a floating point value.
+ if (curTok.is(Token::floatliteral)) {
+ auto val = curTok.getFloatingPointValue();
+ if (!val.hasValue())
+ return emitError(curTok.getLoc(), "floating point value too large");
+ parser.consumeToken(Token::floatliteral);
+ result = negative ? -*val : *val;
+ return success();
+ }
+
+ // TODO(riverriddle) support hex floating point values.
+ return emitError(getCurrentLocation(), "expected floating point literal");
+ }
+
+ /// Parse an optional integer value from the stream.
+ OptionalParseResult parseOptionalInteger(uint64_t &result) override {
+ Token curToken = parser.getToken();
+ if (curToken.isNot(Token::integer, Token::minus))
+ return llvm::None;
+
+ bool negative = parser.consumeIf(Token::minus);
+ Token curTok = parser.getToken();
+ if (parser.parseToken(Token::integer, "expected integer value"))
+ return failure();
+
+ auto val = curTok.getUInt64IntegerValue();
+ if (!val)
+ return emitError(curTok.getLoc(), "integer value too large");
+ result = negative ? -*val : *val;
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Token Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a `->` token.
+ ParseResult parseArrow() override {
+ return parser.parseToken(Token::arrow, "expected '->'");
+ }
+
+ /// Parses a `->` if present.
+ ParseResult parseOptionalArrow() override {
+ return success(parser.consumeIf(Token::arrow));
+ }
+
+ /// Parse a '{' token.
+ ParseResult parseLBrace() override {
+ return parser.parseToken(Token::l_brace, "expected '{'");
+ }
+
+ /// Parse a '{' token if present
+ ParseResult parseOptionalLBrace() override {
+ return success(parser.consumeIf(Token::l_brace));
+ }
+
+ /// Parse a `}` token.
+ ParseResult parseRBrace() override {
+ return parser.parseToken(Token::r_brace, "expected '}'");
+ }
+
+ /// Parse a `}` token if present
+ ParseResult parseOptionalRBrace() override {
+ return success(parser.consumeIf(Token::r_brace));
+ }
+
+ /// Parse a `:` token.
+ ParseResult parseColon() override {
+ return parser.parseToken(Token::colon, "expected ':'");
+ }
+
+ /// Parse a `:` token if present.
+ ParseResult parseOptionalColon() override {
+ return success(parser.consumeIf(Token::colon));
+ }
+
+ /// Parse a `,` token.
+ ParseResult parseComma() override {
+ return parser.parseToken(Token::comma, "expected ','");
+ }
+
+ /// Parse a `,` token if present.
+ ParseResult parseOptionalComma() override {
+ return success(parser.consumeIf(Token::comma));
+ }
+
+ /// Parses a `...` if present.
+ ParseResult parseOptionalEllipsis() override {
+ return success(parser.consumeIf(Token::ellipsis));
+ }
+
+ /// Parse a `=` token.
+ ParseResult parseEqual() override {
+ return parser.parseToken(Token::equal, "expected '='");
+ }
+
+ /// Parse a '<' token.
+ ParseResult parseLess() override {
+ return parser.parseToken(Token::less, "expected '<'");
+ }
+
+ /// Parse a `<` token if present.
+ ParseResult parseOptionalLess() override {
+ return success(parser.consumeIf(Token::less));
+ }
+
+ /// Parse a '>' token.
+ ParseResult parseGreater() override {
+ return parser.parseToken(Token::greater, "expected '>'");
+ }
+
+ /// Parse a `>` token if present.
+ ParseResult parseOptionalGreater() override {
+ return success(parser.consumeIf(Token::greater));
+ }
+
+ /// Parse a `(` token.
+ ParseResult parseLParen() override {
+ return parser.parseToken(Token::l_paren, "expected '('");
+ }
+
+ /// Parses a '(' if present.
+ ParseResult parseOptionalLParen() override {
+ return success(parser.consumeIf(Token::l_paren));
+ }
+
+ /// Parse a `)` token.
+ ParseResult parseRParen() override {
+ return parser.parseToken(Token::r_paren, "expected ')'");
+ }
+
+ /// Parses a ')' if present.
+ ParseResult parseOptionalRParen() override {
+ return success(parser.consumeIf(Token::r_paren));
+ }
+
+ /// Parse a `[` token.
+ ParseResult parseLSquare() override {
+ return parser.parseToken(Token::l_square, "expected '['");
+ }
+
+ /// Parses a '[' if present.
+ ParseResult parseOptionalLSquare() override {
+ return success(parser.consumeIf(Token::l_square));
+ }
+
+ /// Parse a `]` token.
+ ParseResult parseRSquare() override {
+ return parser.parseToken(Token::r_square, "expected ']'");
+ }
+
+ /// Parses a ']' if present.
+ ParseResult parseOptionalRSquare() override {
+ return success(parser.consumeIf(Token::r_square));
+ }
+
+ /// Parses a '?' if present.
+ ParseResult parseOptionalQuestion() override {
+ return success(parser.consumeIf(Token::question));
+ }
+
+ /// Parses a '*' if present.
+ ParseResult parseOptionalStar() override {
+ return success(parser.consumeIf(Token::star));
+ }
+
+ /// Returns if the current token corresponds to a keyword.
+ bool isCurrentTokenAKeyword() const {
+ return parser.getToken().is(Token::bare_identifier) ||
+ parser.getToken().isKeyword();
+ }
+
+ /// Parse the given keyword if present.
+ ParseResult parseOptionalKeyword(StringRef keyword) override {
+ // Check that the current token has the same spelling.
+ if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
+ return failure();
+ parser.consumeToken();
+ return success();
+ }
+
+ /// Parse a keyword, if present, into 'keyword'.
+ ParseResult parseOptionalKeyword(StringRef *keyword) override {
+ // Check that the current token is a keyword.
+ if (!isCurrentTokenAKeyword())
+ return failure();
+
+ *keyword = parser.getTokenSpelling();
+ parser.consumeToken();
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Attribute Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an arbitrary attribute and return it in result.
+ ParseResult parseAttribute(Attribute &result, Type type) override {
+ result = parser.parseAttribute(type);
+ return success(static_cast<bool>(result));
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Type Parsing
+ //===--------------------------------------------------------------------===//
+
+ ParseResult parseType(Type &result) override {
+ result = parser.parseType();
+ return success(static_cast<bool>(result));
+ }
+
+ ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic) override {
+ return parser.parseDimensionListRanked(dimensions, allowDynamic);
+ }
+
+private:
+ /// The full symbol specification.
+ StringRef fullSpec;
+
+ /// The source location of the dialect symbol.
+ SMLoc nameLoc;
+
+ /// The main parser.
+ Parser &parser;
+};
+} // namespace
+
+/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
+/// and may be recursive. Return with the 'prettyName' StringRef encompassing
+/// the entire pretty name.
+///
+/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
+/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
+/// | '(' pretty-dialect-sym-contents+ ')'
+/// | '[' pretty-dialect-sym-contents+ ']'
+/// | '{' pretty-dialect-sym-contents+ '}'
+/// | '[^[<({>\])}\0]+'
+///
+ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
+ // Pretty symbol names are a relatively unstructured format that contains a
+ // series of properly nested punctuation, with anything else in the middle.
+ // Scan ahead to find it and consume it if successful, otherwise emit an
+ // error.
+ auto *curPtr = getTokenSpelling().data();
+
+ SmallVector<char, 8> nestedPunctuation;
+
+ // Scan over the nested punctuation, bailing out on error and consuming until
+ // we find the end. We know that we're currently looking at the '<', so we
+ // can go until we find the matching '>' character.
+ assert(*curPtr == '<');
+ do {
+ char c = *curPtr++;
+ switch (c) {
+ case '\0':
+ // This also handles the EOF case.
+ return emitError("unexpected nul or EOF in pretty dialect name");
+ case '<':
+ case '[':
+ case '(':
+ case '{':
+ nestedPunctuation.push_back(c);
+ continue;
+
+ case '-':
+ // The sequence `->` is treated as special token.
+ if (*curPtr == '>')
+ ++curPtr;
+ continue;
+
+ case '>':
+ if (nestedPunctuation.pop_back_val() != '<')
+ return emitError("unbalanced '>' character in pretty dialect name");
+ break;
+ case ']':
+ if (nestedPunctuation.pop_back_val() != '[')
+ return emitError("unbalanced ']' character in pretty dialect name");
+ break;
+ case ')':
+ if (nestedPunctuation.pop_back_val() != '(')
+ return emitError("unbalanced ')' character in pretty dialect name");
+ break;
+ case '}':
+ if (nestedPunctuation.pop_back_val() != '{')
+ return emitError("unbalanced '}' character in pretty dialect name");
+ break;
+
+ default:
+ continue;
+ }
+ } while (!nestedPunctuation.empty());
+
+ // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
+ // consuming all this stuff, and return.
+ state.lex.resetPointer(curPtr);
+
+ unsigned length = curPtr - prettyName.begin();
+ prettyName = StringRef(prettyName.begin(), length);
+ consumeToken();
+ return success();
+}
+
+/// Parse an extended dialect symbol.
+template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
+static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
+ SymbolAliasMap &aliases,
+ CreateFn &&createSymbol) {
+ // Parse the dialect namespace.
+ StringRef identifier = p.getTokenSpelling().drop_front();
+ auto loc = p.getToken().getLoc();
+ p.consumeToken(identifierTok);
+
+ // If there is no '<' token following this, and if the typename contains no
+ // dot, then we are parsing a symbol alias.
+ if (p.getToken().isNot(Token::less) && !identifier.contains('.')) {
+ // Check for an alias for this type.
+ auto aliasIt = aliases.find(identifier);
+ if (aliasIt == aliases.end())
+ return (p.emitError("undefined symbol alias id '" + identifier + "'"),
+ nullptr);
+ return aliasIt->second;
+ }
+
+ // Otherwise, we are parsing a dialect-specific symbol. If the name contains
+ // a dot, then this is the "pretty" form. If not, it is the verbose form that
+ // looks like <"...">.
+ std::string symbolData;
+ auto dialectName = identifier;
+
+ // Handle the verbose form, where "identifier" is a simple dialect name.
+ if (!identifier.contains('.')) {
+ // Consume the '<'.
+ if (p.parseToken(Token::less, "expected '<' in dialect type"))
+ return nullptr;
+
+ // Parse the symbol specific data.
+ if (p.getToken().isNot(Token::string))
+ return (p.emitError("expected string literal data in dialect symbol"),
+ nullptr);
+ symbolData = p.getToken().getStringValue();
+ loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1);
+ p.consumeToken(Token::string);
+
+ // Consume the '>'.
+ if (p.parseToken(Token::greater, "expected '>' in dialect symbol"))
+ return nullptr;
+ } else {
+ // Ok, the dialect name is the part of the identifier before the dot, the
+ // part after the dot is the dialect's symbol, or the start thereof.
+ auto dotHalves = identifier.split('.');
+ dialectName = dotHalves.first;
+ auto prettyName = dotHalves.second;
+ loc = llvm::SMLoc::getFromPointer(prettyName.data());
+
+ // If the dialect's symbol is followed immediately by a <, then lex the body
+ // of it into prettyName.
+ if (p.getToken().is(Token::less) &&
+ prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) {
+ if (p.parsePrettyDialectSymbolName(prettyName))
+ return nullptr;
+ }
+
+ symbolData = prettyName.str();
+ }
+
+ // Record the name location of the type remapped to the top level buffer.
+ llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc);
+ p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer);
+
+ // Call into the provided symbol construction function.
+ Symbol sym = createSymbol(dialectName, symbolData, loc);
+
+ // Pop the last parser location.
+ p.getState().symbols.nestedParserLocs.pop_back();
+ return sym;
+}
+
+/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
+/// parsing failed, nullptr is returned. The number of bytes read from the input
+/// string is returned in 'numRead'.
+template <typename T, typename ParserFn>
+static T parseSymbol(StringRef inputStr, MLIRContext *context,
+ SymbolState &symbolState, ParserFn &&parserFn,
+ size_t *numRead = nullptr) {
+ SourceMgr sourceMgr;
+ auto memBuffer = MemoryBuffer::getMemBuffer(
+ inputStr, /*BufferName=*/"<mlir_parser_buffer>",
+ /*RequiresNullTerminator=*/false);
+ sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+ ParserState state(sourceMgr, context, symbolState);
+ Parser parser(state);
+
+ Token startTok = parser.getToken();
+ T symbol = parserFn(parser);
+ if (!symbol)
+ return T();
+
+ // If 'numRead' is valid, then provide the number of bytes that were read.
+ Token endTok = parser.getToken();
+ if (numRead) {
+ *numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
+ startTok.getLoc().getPointer());
+
+ // Otherwise, ensure that all of the tokens were parsed.
+ } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
+ parser.emitError(endTok.getLoc(), "encountered unexpected token");
+ return T();
+ }
+ return symbol;
+}
+
+//===----------------------------------------------------------------------===//
+// Error Handling
+//===----------------------------------------------------------------------===//
+
+InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
+ auto diag = mlir::emitError(getEncodedSourceLocation(loc), message);
+
+ // If we hit a parse error in response to a lexer error, then the lexer
+ // already reported the error.
+ if (getToken().is(Token::error))
+ diag.abandon();
+ return diag;
+}
+
+//===----------------------------------------------------------------------===//
+// Token Parsing
+//===----------------------------------------------------------------------===//
+
+/// Consume the specified token if present and return success. On failure,
+/// output a diagnostic and return failure.
+ParseResult Parser::parseToken(Token::Kind expectedToken,
+ const Twine &message) {
+ if (consumeIf(expectedToken))
+ return success();
+ return emitError(message);
+}
+
+//===----------------------------------------------------------------------===//
+// Type Parsing
+//===----------------------------------------------------------------------===//
+
+/// Parse an arbitrary type.
+///
+/// type ::= function-type
+/// | non-function-type
+///
+Type Parser::parseType() {
+ if (getToken().is(Token::l_paren))
+ return parseFunctionType();
+ return parseNonFunctionType();
+}
+
+/// Parse a function result type.
+///
+/// function-result-type ::= type-list-parens
+/// | non-function-type
+///
+ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
+ if (getToken().is(Token::l_paren))
+ return parseTypeListParens(elements);
+
+ Type t = parseNonFunctionType();
+ if (!t)
+ return failure();
+ elements.push_back(t);
+ return success();
+}
+
+/// Parse a list of types without an enclosing parenthesis. The list must have
+/// at least one member.
+///
+/// type-list-no-parens ::= type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseType();
+ elements.push_back(elt);
+ return elt ? success() : failure();
+ };
+
+ return parseCommaSeparatedList(parseElt);
+}
+
+/// Parse a parenthesized list of types.
+///
+/// type-list-parens ::= `(` `)`
+/// | `(` type-list-no-parens `)`
+///
+ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
+ if (parseToken(Token::l_paren, "expected '('"))
+ return failure();
+
+ // Handle empty lists.
+ if (getToken().is(Token::r_paren))
+ return consumeToken(), success();
+
+ if (parseTypeListNoParens(elements) ||
+ parseToken(Token::r_paren, "expected ')'"))
+ return failure();
+ return success();
+}
+
+/// Parse a complex type.
+///
+/// complex-type ::= `complex` `<` type `>`
+///
+Type Parser::parseComplexType() {
+ consumeToken(Token::kw_complex);
+
+ // Parse the '<'.
+ if (parseToken(Token::less, "expected '<' in complex type"))
+ return nullptr;
+
+ auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
+ auto elementType = parseType();
+ if (!elementType ||
+ parseToken(Token::greater, "expected '>' in complex type"))
+ return nullptr;
+
+ return ComplexType::getChecked(elementType, typeLocation);
+}
+
+/// Parse an extended type.
+///
+/// extended-type ::= (dialect-type | type-alias)
+/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
+/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
+/// type-alias ::= `!` alias-name
+///
+Type Parser::parseExtendedType() {
+ return parseExtendedSymbol<Type>(
+ *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
+ [&](StringRef dialectName, StringRef symbolData,
+ llvm::SMLoc loc) -> Type {
+ // If we found a registered dialect, then ask it to parse the type.
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+ return parseSymbol<Type>(
+ symbolData, state.context, state.symbols, [&](Parser &parser) {
+ CustomDialectAsmParser customParser(symbolData, parser);
+ return dialect->parseType(customParser);
+ });
+ }
+
+ // Otherwise, form a new opaque type.
+ return OpaqueType::getChecked(
+ Identifier::get(dialectName, state.context), symbolData,
+ state.context, getEncodedSourceLocation(loc));
+ });
+}
+
+/// Parse a function type.
+///
+/// function-type ::= type-list-parens `->` function-result-type
+///
+Type Parser::parseFunctionType() {
+ assert(getToken().is(Token::l_paren));
+
+ SmallVector<Type, 4> arguments, results;
+ if (parseTypeListParens(arguments) ||
+ parseToken(Token::arrow, "expected '->' in function type") ||
+ parseFunctionResultTypes(results))
+ return nullptr;
+
+ return builder.getFunctionType(arguments, results);
+}
+
+/// Parse the offset and strides from a strided layout specification.
+///
+/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
+///
+ParseResult Parser::parseStridedLayout(int64_t &offset,
+ SmallVectorImpl<int64_t> &strides) {
+ // Parse offset.
+ consumeToken(Token::kw_offset);
+ if (!consumeIf(Token::colon))
+ return emitError("expected colon after `offset` keyword");
+ auto maybeOffset = getToken().getUnsignedIntegerValue();
+ bool question = getToken().is(Token::question);
+ if (!maybeOffset && !question)
+ return emitError("invalid offset");
+ offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
+ : MemRefType::getDynamicStrideOrOffset();
+ consumeToken();
+
+ if (!consumeIf(Token::comma))
+ return emitError("expected comma after offset value");
+
+ // Parse stride list.
+ if (!consumeIf(Token::kw_strides))
+ return emitError("expected `strides` keyword after offset specification");
+ if (!consumeIf(Token::colon))
+ return emitError("expected colon after `strides` keyword");
+ if (failed(parseStrideList(strides)))
+ return emitError("invalid braces-enclosed stride list");
+ if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
+ return emitError("invalid memref stride");
+
+ return success();
+}
+
+/// Parse a memref type.
+///
+/// memref-type ::= ranked-memref-type | unranked-memref-type
+///
+/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
+/// (`,` semi-affine-map-composition)? (`,`
+/// memory-space)? `>`
+///
+/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
+///
+/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
+/// memory-space ::= integer-literal /* | TODO: address-space-id */
+///
+Type Parser::parseMemRefType() {
+ consumeToken(Token::kw_memref);
+
+ if (parseToken(Token::less, "expected '<' in memref type"))
+ return nullptr;
+
+ bool isUnranked;
+ SmallVector<int64_t, 4> dimensions;
+
+ if (consumeIf(Token::star)) {
+ // This is an unranked memref type.
+ isUnranked = true;
+ if (parseXInDimensionList())
+ return nullptr;
+
+ } else {
+ isUnranked = false;
+ if (parseDimensionListRanked(dimensions))
+ return nullptr;
+ }
+
+ // Parse the element type.
+ auto typeLoc = getToken().getLoc();
+ auto elementType = parseType();
+ if (!elementType)
+ return nullptr;
+
+ // Parse semi-affine-map-composition.
+ SmallVector<AffineMap, 2> affineMapComposition;
+ unsigned memorySpace = 0;
+ bool parsedMemorySpace = false;
+
+ auto parseElt = [&]() -> ParseResult {
+ if (getToken().is(Token::integer)) {
+ // Parse memory space.
+ if (parsedMemorySpace)
+ return emitError("multiple memory spaces specified in memref type");
+ auto v = getToken().getUnsignedIntegerValue();
+ if (!v.hasValue())
+ return emitError("invalid memory space in memref type");
+ memorySpace = v.getValue();
+ consumeToken(Token::integer);
+ parsedMemorySpace = true;
+ } else {
+ if (isUnranked)
+ return emitError("cannot have affine map for unranked memref type");
+ if (parsedMemorySpace)
+ return emitError("expected memory space to be last in memref type");
+ if (getToken().is(Token::kw_offset)) {
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(parseStridedLayout(offset, strides)))
+ return failure();
+ // Construct strided affine map.
+ auto map = makeStridedLinearLayoutMap(strides, offset,
+ elementType.getContext());
+ affineMapComposition.push_back(map);
+ } else {
+ // Parse affine map.
+ auto affineMap = parseAttribute();
+ if (!affineMap)
+ return failure();
+ // Verify that the parsed attribute is an affine map.
+ if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
+ affineMapComposition.push_back(affineMapAttr.getValue());
+ else
+ return emitError("expected affine map in memref type");
+ }
+ }
+ return success();
+ };
+
+ // Parse a list of mappings and address space if present.
+ if (consumeIf(Token::comma)) {
+ // Parse comma separated list of affine maps, followed by memory space.
+ if (parseCommaSeparatedListUntil(Token::greater, parseElt,
+ /*allowEmptyList=*/false)) {
+ return nullptr;
+ }
+ } else {
+ if (parseToken(Token::greater, "expected ',' or '>' in memref type"))
+ return nullptr;
+ }
+
+ if (isUnranked)
+ return UnrankedMemRefType::getChecked(elementType, memorySpace,
+ getEncodedSourceLocation(typeLoc));
+
+ return MemRefType::getChecked(dimensions, elementType, affineMapComposition,
+ memorySpace, getEncodedSourceLocation(typeLoc));
+}
+
+/// Parse any type except the function type.
+///
+/// non-function-type ::= integer-type
+/// | index-type
+/// | float-type
+/// | extended-type
+/// | vector-type
+/// | tensor-type
+/// | memref-type
+/// | complex-type
+/// | tuple-type
+/// | none-type
+///
+/// index-type ::= `index`
+/// float-type ::= `f16` | `bf16` | `f32` | `f64`
+/// none-type ::= `none`
+///
+Type Parser::parseNonFunctionType() {
+ switch (getToken().getKind()) {
+ default:
+ return (emitError("expected non-function type"), nullptr);
+ case Token::kw_memref:
+ return parseMemRefType();
+ case Token::kw_tensor:
+ return parseTensorType();
+ case Token::kw_complex:
+ return parseComplexType();
+ case Token::kw_tuple:
+ return parseTupleType();
+ case Token::kw_vector:
+ return parseVectorType();
+ // integer-type
+ case Token::inttype: {
+ auto width = getToken().getIntTypeBitwidth();
+ if (!width.hasValue())
+ return (emitError("invalid integer width"), nullptr);
+ auto loc = getEncodedSourceLocation(getToken().getLoc());
+ consumeToken(Token::inttype);
+ return IntegerType::getChecked(width.getValue(), builder.getContext(), loc);
+ }
+
+ // float-type
+ case Token::kw_bf16:
+ consumeToken(Token::kw_bf16);
+ return builder.getBF16Type();
+ case Token::kw_f16:
+ consumeToken(Token::kw_f16);
+ return builder.getF16Type();
+ case Token::kw_f32:
+ consumeToken(Token::kw_f32);
+ return builder.getF32Type();
+ case Token::kw_f64:
+ consumeToken(Token::kw_f64);
+ return builder.getF64Type();
+
+ // index-type
+ case Token::kw_index:
+ consumeToken(Token::kw_index);
+ return builder.getIndexType();
+
+ // none-type
+ case Token::kw_none:
+ consumeToken(Token::kw_none);
+ return builder.getNoneType();
+
+ // extended type
+ case Token::exclamation_identifier:
+ return parseExtendedType();
+ }
+}
+
+/// Parse a tensor type.
+///
+/// tensor-type ::= `tensor` `<` dimension-list type `>`
+/// dimension-list ::= dimension-list-ranked | `*x`
+///
+Type Parser::parseTensorType() {
+ consumeToken(Token::kw_tensor);
+
+ if (parseToken(Token::less, "expected '<' in tensor type"))
+ return nullptr;
+
+ bool isUnranked;
+ SmallVector<int64_t, 4> dimensions;
+
+ if (consumeIf(Token::star)) {
+ // This is an unranked tensor type.
+ isUnranked = true;
+
+ if (parseXInDimensionList())
+ return nullptr;
+
+ } else {
+ isUnranked = false;
+ if (parseDimensionListRanked(dimensions))
+ return nullptr;
+ }
+
+ // Parse the element type.
+ auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
+ auto elementType = parseType();
+ if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
+ return nullptr;
+
+ if (isUnranked)
+ return UnrankedTensorType::getChecked(elementType, typeLocation);
+ return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
+}
+
+/// Parse a tuple type.
+///
+/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
+///
+Type Parser::parseTupleType() {
+ consumeToken(Token::kw_tuple);
+
+ // Parse the '<'.
+ if (parseToken(Token::less, "expected '<' in tuple type"))
+ return nullptr;
+
+ // Check for an empty tuple by directly parsing '>'.
+ if (consumeIf(Token::greater))
+ return TupleType::get(getContext());
+
+ // Parse the element types and the '>'.
+ SmallVector<Type, 4> types;
+ if (parseTypeListNoParens(types) ||
+ parseToken(Token::greater, "expected '>' in tuple type"))
+ return nullptr;
+
+ return TupleType::get(types, getContext());
+}
+
+/// Parse a vector type.
+///
+/// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
+/// non-empty-static-dimension-list ::= decimal-literal `x`
+/// static-dimension-list
+/// static-dimension-list ::= (decimal-literal `x`)*
+///
+VectorType Parser::parseVectorType() {
+ consumeToken(Token::kw_vector);
+
+ if (parseToken(Token::less, "expected '<' in vector type"))
+ return nullptr;
+
+ SmallVector<int64_t, 4> dimensions;
+ if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
+ return nullptr;
+ if (dimensions.empty())
+ return (emitError("expected dimension size in vector type"), nullptr);
+
+ // Parse the element type.
+ auto typeLoc = getToken().getLoc();
+ auto elementType = parseType();
+ if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
+ return nullptr;
+
+ return VectorType::getChecked(dimensions, elementType,
+ getEncodedSourceLocation(typeLoc));
+}
+
+/// Parse a dimension list of a tensor or memref type. This populates the
+/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
+/// errors out on `?` otherwise.
+///
+/// dimension-list-ranked ::= (dimension `x`)*
+/// dimension ::= `?` | decimal-literal
+///
+/// When `allowDynamic` is not set, this is used to parse:
+///
+/// static-dimension-list ::= (decimal-literal `x`)*
+ParseResult
+Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic) {
+ while (getToken().isAny(Token::integer, Token::question)) {
+ if (consumeIf(Token::question)) {
+ if (!allowDynamic)
+ return emitError("expected static shape");
+ dimensions.push_back(-1);
+ } else {
+ // Hexadecimal integer literals (starting with `0x`) are not allowed in
+ // aggregate type declarations. Therefore, `0xf32` should be processed as
+ // a sequence of separate elements `0`, `x`, `f32`.
+ if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
+ // We can get here only if the token is an integer literal. Hexadecimal
+ // integer literals can only start with `0x` (`1x` wouldn't lex as a
+ // literal, just `1` would, at which point we don't get into this
+ // branch).
+ assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
+ dimensions.push_back(0);
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
+ consumeToken();
+ } else {
+ // Make sure this integer value is in bound and valid.
+ auto dimension = getToken().getUnsignedIntegerValue();
+ if (!dimension.hasValue())
+ return emitError("invalid dimension");
+ dimensions.push_back((int64_t)dimension.getValue());
+ consumeToken(Token::integer);
+ }
+ }
+
+ // Make sure we have an 'x' or something like 'xbf32'.
+ if (parseXInDimensionList())
+ return failure();
+ }
+
+ return success();
+}
+
+/// Parse an 'x' token in a dimension list, handling the case where the x is
+/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
+/// token.
+ParseResult Parser::parseXInDimensionList() {
+ if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
+ return emitError("expected 'x' in dimension list");
+
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+ if (getTokenSpelling().size() != 1)
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
+
+ // Consume the 'x'.
+ consumeToken(Token::bare_identifier);
+
+ return success();
+}
+
+// Parse a comma-separated list of dimensions, possibly empty:
+// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
+ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
+ if (!consumeIf(Token::l_square))
+ return failure();
+ // Empty list early exit.
+ if (consumeIf(Token::r_square))
+ return success();
+ while (true) {
+ if (consumeIf(Token::question)) {
+ dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
+ } else {
+ // This must be an integer value.
+ int64_t val;
+ if (getToken().getSpelling().getAsInteger(10, val))
+ return emitError("invalid integer value: ") << getToken().getSpelling();
+ // Make sure it is not the one value for `?`.
+ if (ShapedType::isDynamic(val))
+ return emitError("invalid integer value: ")
+ << getToken().getSpelling()
+ << ", use `?` to specify a dynamic dimension";
+ dimensions.push_back(val);
+ consumeToken(Token::integer);
+ }
+ if (!consumeIf(Token::comma))
+ break;
+ }
+ if (!consumeIf(Token::r_square))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute parsing.
+//===----------------------------------------------------------------------===//
+
+/// Return the symbol reference referred to by the given token, that is known to
+/// be an @-identifier.
+static std::string extractSymbolReference(Token tok) {
+ assert(tok.is(Token::at_identifier) && "expected valid @-identifier");
+ StringRef nameStr = tok.getSpelling().drop_front();
+
+ // Check to see if the reference is a string literal, or a bare identifier.
+ if (nameStr.front() == '"')
+ return tok.getStringValue();
+ return nameStr;
+}
+
+/// Parse an arbitrary attribute.
+///
+/// attribute-value ::= `unit`
+/// | bool-literal
+/// | integer-literal (`:` (index-type | integer-type))?
+/// | float-literal (`:` float-type)?
+/// | string-literal (`:` type)?
+/// | type
+/// | `[` (attribute-value (`,` attribute-value)*)? `]`
+/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
+/// | symbol-ref-id (`::` symbol-ref-id)*
+/// | `dense` `<` attribute-value `>` `:`
+/// (tensor-type | vector-type)
+/// | `sparse` `<` attribute-value `,` attribute-value `>`
+/// `:` (tensor-type | vector-type)
+/// | `opaque` `<` dialect-namespace `,` hex-string-literal
+/// `>` `:` (tensor-type | vector-type)
+/// | extended-attribute
+///
+Attribute Parser::parseAttribute(Type type) {
+ switch (getToken().getKind()) {
+ // Parse an AffineMap or IntegerSet attribute.
+ case Token::l_paren: {
+ // Try to parse an affine map or an integer set reference.
+ AffineMap map;
+ IntegerSet set;
+ if (parseAffineMapOrIntegerSetReference(map, set))
+ return nullptr;
+ if (map)
+ return AffineMapAttr::get(map);
+ assert(set);
+ return IntegerSetAttr::get(set);
+ }
+
+ // Parse an array attribute.
+ case Token::l_square: {
+ consumeToken(Token::l_square);
+
+ SmallVector<Attribute, 4> elements;
+ auto parseElt = [&]() -> ParseResult {
+ elements.push_back(parseAttribute());
+ return elements.back() ? success() : failure();
+ };
+
+ if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
+ return nullptr;
+ return builder.getArrayAttr(elements);
+ }
+
+ // Parse a boolean attribute.
+ case Token::kw_false:
+ consumeToken(Token::kw_false);
+ return builder.getBoolAttr(false);
+ case Token::kw_true:
+ consumeToken(Token::kw_true);
+ return builder.getBoolAttr(true);
+
+ // Parse a dense elements attribute.
+ case Token::kw_dense:
+ return parseDenseElementsAttr();
+
+ // Parse a dictionary attribute.
+ case Token::l_brace: {
+ SmallVector<NamedAttribute, 4> elements;
+ if (parseAttributeDict(elements))
+ return nullptr;
+ return builder.getDictionaryAttr(elements);
+ }
+
+ // Parse an extended attribute, i.e. alias or dialect attribute.
+ case Token::hash_identifier:
+ return parseExtendedAttr(type);
+
+ // Parse floating point and integer attributes.
+ case Token::floatliteral:
+ return parseFloatAttr(type, /*isNegative=*/false);
+ case Token::integer:
+ return parseDecOrHexAttr(type, /*isNegative=*/false);
+ case Token::minus: {
+ consumeToken(Token::minus);
+ if (getToken().is(Token::integer))
+ return parseDecOrHexAttr(type, /*isNegative=*/true);
+ if (getToken().is(Token::floatliteral))
+ return parseFloatAttr(type, /*isNegative=*/true);
+
+ return (emitError("expected constant integer or floating point value"),
+ nullptr);
+ }
+
+ // Parse a location attribute.
+ case Token::kw_loc: {
+ LocationAttr attr;
+ return failed(parseLocation(attr)) ? Attribute() : attr;
+ }
+
+ // Parse an opaque elements attribute.
+ case Token::kw_opaque:
+ return parseOpaqueElementsAttr();
+
+ // Parse a sparse elements attribute.
+ case Token::kw_sparse:
+ return parseSparseElementsAttr();
+
+ // Parse a string attribute.
+ case Token::string: {
+ auto val = getToken().getStringValue();
+ consumeToken(Token::string);
+ // Parse the optional trailing colon type if one wasn't explicitly provided.
+ if (!type && consumeIf(Token::colon) && !(type = parseType()))
+ return Attribute();
+
+ return type ? StringAttr::get(val, type)
+ : StringAttr::get(val, getContext());
+ }
+
+ // Parse a symbol reference attribute.
+ case Token::at_identifier: {
+ std::string nameStr = extractSymbolReference(getToken());
+ consumeToken(Token::at_identifier);
+
+ // Parse any nested references.
+ std::vector<FlatSymbolRefAttr> nestedRefs;
+ while (getToken().is(Token::colon)) {
+ // Check for the '::' prefix.
+ const char *curPointer = getToken().getLoc().getPointer();
+ consumeToken(Token::colon);
+ if (!consumeIf(Token::colon)) {
+ state.lex.resetPointer(curPointer);
+ consumeToken();
+ break;
+ }
+ // Parse the reference itself.
+ auto curLoc = getToken().getLoc();
+ if (getToken().isNot(Token::at_identifier)) {
+ emitError(curLoc, "expected nested symbol reference identifier");
+ return Attribute();
+ }
+
+ std::string nameStr = extractSymbolReference(getToken());
+ consumeToken(Token::at_identifier);
+ nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+ }
+
+ return builder.getSymbolRefAttr(nameStr, nestedRefs);
+ }
+
+ // Parse a 'unit' attribute.
+ case Token::kw_unit:
+ consumeToken(Token::kw_unit);
+ return builder.getUnitAttr();
+
+ default:
+ // Parse a type attribute.
+ if (Type type = parseType())
+ return TypeAttr::get(type);
+ return nullptr;
+ }
+}
+
+/// Attribute dictionary.
+///
+/// attribute-dict ::= `{` `}`
+/// | `{` attribute-entry (`,` attribute-entry)* `}`
+/// attribute-entry ::= bare-id `=` attribute-value
+///
+ParseResult
+Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
+ if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
+ return failure();
+
+ auto parseElt = [&]() -> ParseResult {
+ // We allow keywords as attribute names.
+ if (getToken().isNot(Token::bare_identifier, Token::inttype) &&
+ !getToken().isKeyword())
+ return emitError("expected attribute name");
+ Identifier nameId = builder.getIdentifier(getTokenSpelling());
+ consumeToken();
+
+ // Try to parse the '=' for the attribute value.
+ if (!consumeIf(Token::equal)) {
+ // If there is no '=', we treat this as a unit attribute.
+ attributes.push_back({nameId, builder.getUnitAttr()});
+ return success();
+ }
+
+ auto attr = parseAttribute();
+ if (!attr)
+ return failure();
+
+ attributes.push_back({nameId, attr});
+ return success();
+ };
+
+ if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
+ return failure();
+
+ return success();
+}
+
+/// Parse an extended attribute.
+///
+/// extended-attribute ::= (dialect-attribute | attribute-alias)
+/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
+/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
+/// attribute-alias ::= `#` alias-name
+///
+Attribute Parser::parseExtendedAttr(Type type) {
+ Attribute attr = parseExtendedSymbol<Attribute>(
+ *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
+ [&](StringRef dialectName, StringRef symbolData,
+ llvm::SMLoc loc) -> Attribute {
+ // Parse an optional trailing colon type.
+ Type attrType = type;
+ if (consumeIf(Token::colon) && !(attrType = parseType()))
+ return Attribute();
+
+ // If we found a registered dialect, then ask it to parse the attribute.
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+ return parseSymbol<Attribute>(
+ symbolData, state.context, state.symbols, [&](Parser &parser) {
+ CustomDialectAsmParser customParser(symbolData, parser);
+ return dialect->parseAttribute(customParser, attrType);
+ });
+ }
+
+ // Otherwise, form a new opaque attribute.
+ return OpaqueAttr::getChecked(
+ Identifier::get(dialectName, state.context), symbolData,
+ attrType ? attrType : NoneType::get(state.context),
+ getEncodedSourceLocation(loc));
+ });
+
+ // Ensure that the attribute has the same type as requested.
+ if (attr && type && attr.getType() != type) {
+ emitError("attribute type different than expected: expected ")
+ << type << ", but got " << attr.getType();
+ return nullptr;
+ }
+ return attr;
+}
+
+/// Parse a float attribute.
+Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
+ auto val = getToken().getFloatingPointValue();
+ if (!val.hasValue())
+ return (emitError("floating point value too large for attribute"), nullptr);
+ consumeToken(Token::floatliteral);
+ if (!type) {
+ // Default to F64 when no type is specified.
+ if (!consumeIf(Token::colon))
+ type = builder.getF64Type();
+ else if (!(type = parseType()))
+ return nullptr;
+ }
+ if (!type.isa<FloatType>())
+ return (emitError("floating point value not valid for specified type"),
+ nullptr);
+ return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
+}
+
+/// Construct a float attribute bitwise equivalent to the integer literal.
+static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type,
+ uint64_t value) {
+ int width = type.getIntOrFloatBitWidth();
+ APInt apInt(width, value);
+ if (apInt != value) {
+ p->emitError("hexadecimal float constant out of range for type");
+ return nullptr;
+ }
+ APFloat apFloat(type.getFloatSemantics(), apInt);
+ return p->builder.getFloatAttr(type, apFloat);
+}
+
+/// Parse a decimal or a hexadecimal literal, which can be either an integer
+/// or a float attribute.
+Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
+ auto val = getToken().getUInt64IntegerValue();
+ if (!val.hasValue())
+ return (emitError("integer constant out of range for attribute"), nullptr);
+
+ // Remember if the literal is hexadecimal.
+ StringRef spelling = getToken().getSpelling();
+ auto loc = state.curToken.getLoc();
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+ consumeToken(Token::integer);
+ if (!type) {
+ // Default to i64 if not type is specified.
+ if (!consumeIf(Token::colon))
+ type = builder.getIntegerType(64);
+ else if (!(type = parseType()))
+ return nullptr;
+ }
+
+ if (auto floatType = type.dyn_cast<FloatType>()) {
+ // TODO(zinenko): Update once hex format for bfloat16 is supported.
+ if (type.isBF16())
+ return emitError(loc,
+ "hexadecimal float literal not supported for bfloat16"),
+ nullptr;
+ if (isNegative)
+ return emitError(
+ loc,
+ "hexadecimal float literal should not have a leading minus"),
+ nullptr;
+ if (!isHex) {
+ emitError(loc, "unexpected decimal integer literal for a float attribute")
+ .attachNote()
+ << "add a trailing dot to make the literal a float";
+ return nullptr;
+ }
+
+ // Construct a float attribute bitwise equivalent to the integer literal.
+ return buildHexadecimalFloatLiteral(this, floatType, *val);
+ }
+
+ if (!type.isIntOrIndex())
+ return emitError(loc, "integer literal not valid for specified type"),
+ nullptr;
+
+ // Parse the integer literal.
+ int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
+ APInt apInt(width, *val, isNegative);
+ if (apInt != *val)
+ return emitError(loc, "integer constant out of range for attribute"),
+ nullptr;
+
+ // Otherwise construct an integer attribute.
+ if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0)
+ return emitError(loc, "integer constant out of range for attribute"),
+ nullptr;
+
+ return builder.getIntegerAttr(type, isNegative ? -apInt : apInt);
+}
+
+/// Parse an opaque elements attribute.
+Attribute Parser::parseOpaqueElementsAttr() {
+ consumeToken(Token::kw_opaque);
+ if (parseToken(Token::less, "expected '<' after 'opaque'"))
+ return nullptr;
+
+ if (getToken().isNot(Token::string))
+ return (emitError("expected dialect namespace"), nullptr);
+
+ auto name = getToken().getStringValue();
+ auto *dialect = builder.getContext()->getRegisteredDialect(name);
+ // TODO(shpeisman): Allow for having an unknown dialect on an opaque
+ // attribute. Otherwise, it can't be roundtripped without having the dialect
+ // registered.
+ if (!dialect)
+ return (emitError("no registered dialect with namespace '" + name + "'"),
+ nullptr);
+
+ consumeToken(Token::string);
+ if (parseToken(Token::comma, "expected ','"))
+ return nullptr;
+
+ if (getToken().getKind() != Token::string)
+ return (emitError("opaque string should start with '0x'"), nullptr);
+
+ auto val = getToken().getStringValue();
+ if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
+ return (emitError("opaque string should start with '0x'"), nullptr);
+
+ val = val.substr(2);
+ if (!llvm::all_of(val, llvm::isHexDigit))
+ return (emitError("opaque string only contains hex digits"), nullptr);
+
+ consumeToken(Token::string);
+ if (parseToken(Token::greater, "expected '>'") ||
+ parseToken(Token::colon, "expected ':'"))
+ return nullptr;
+
+ auto type = parseElementsLiteralType();
+ if (!type)
+ return nullptr;
+
+ return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val));
+}
+
+namespace {
+class TensorLiteralParser {
+public:
+ TensorLiteralParser(Parser &p) : p(p) {}
+
+ ParseResult parse() {
+ if (p.getToken().is(Token::l_square))
+ return parseList(shape);
+ return parseElement();
+ }
+
+ /// Build a dense attribute instance with the parsed elements and the given
+ /// shaped type.
+ DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
+
+ ArrayRef<int64_t> getShape() const { return shape; }
+
+private:
+ enum class ElementKind { Boolean, Integer, Float };
+
+ /// Return a string to represent the given element kind.
+ const char *getElementKindStr(ElementKind kind) {
+ switch (kind) {
+ case ElementKind::Boolean:
+ return "'boolean'";
+ case ElementKind::Integer:
+ return "'integer'";
+ case ElementKind::Float:
+ return "'float'";
+ }
+ llvm_unreachable("unknown element kind");
+ }
+
+ /// Build a Dense Integer attribute for the given type.
+ DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type,
+ IntegerType eltTy);
+
+ /// Build a Dense Float attribute for the given type.
+ DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
+ FloatType eltTy);
+
+ /// Parse a single element, returning failure if it isn't a valid element
+ /// literal. For example:
+ /// parseElement(1) -> Success, 1
+ /// parseElement([1]) -> Failure
+ ParseResult parseElement();
+
+ /// Parse a list of either lists or elements, returning the dimensions of the
+ /// parsed sub-tensors in dims. For example:
+ /// parseList([1, 2, 3]) -> Success, [3]
+ /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
+ /// parseList([[1, 2], 3]) -> Failure
+ /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+ ParseResult parseList(SmallVectorImpl<int64_t> &dims);
+
+ Parser &p;
+
+ /// The shape inferred from the parsed elements.
+ SmallVector<int64_t, 4> shape;
+
+ /// Storage used when parsing elements, this is a pair of <is_negated, token>.
+ std::vector<std::pair<bool, Token>> storage;
+
+ /// A flag that indicates the type of elements that have been parsed.
+ Optional<ElementKind> knownEltKind;
+};
+} // namespace
+
+/// Build a dense attribute instance with the parsed elements and the given
+/// shaped type.
+DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
+ ShapedType type) {
+ // Check that the parsed storage size has the same number of elements to the
+ // type, or is a known splat.
+ if (!shape.empty() && getShape() != type.getShape()) {
+ p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
+ << "]) does not match type ([" << type.getShape() << "])";
+ return nullptr;
+ }
+
+ // If the type is an integer, build a set of APInt values from the storage
+ // with the correct bitwidth.
+ if (auto intTy = type.getElementType().dyn_cast<IntegerType>())
+ return getIntAttr(loc, type, intTy);
+
+ // Otherwise, this must be a floating point type.
+ auto floatTy = type.getElementType().dyn_cast<FloatType>();
+ if (!floatTy) {
+ p.emitError(loc) << "expected floating-point or integer element type, got "
+ << type.getElementType();
+ return nullptr;
+ }
+ return getFloatAttr(loc, type, floatTy);
+}
+
+/// Build a Dense Integer attribute for the given type.
+DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
+ ShapedType type,
+ IntegerType eltTy) {
+ std::vector<APInt> intElements;
+ intElements.reserve(storage.size());
+ for (const auto &signAndToken : storage) {
+ bool isNegative = signAndToken.first;
+ const Token &token = signAndToken.second;
+
+ // Check to see if floating point values were parsed.
+ if (token.is(Token::floatliteral)) {
+ p.emitError() << "expected integer elements, but parsed floating-point";
+ return nullptr;
+ }
+
+ assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
+ "unexpected token type");
+ if (token.isAny(Token::kw_true, Token::kw_false)) {
+ if (!eltTy.isInteger(1))
+ p.emitError() << "expected i1 type for 'true' or 'false' values";
+ APInt apInt(eltTy.getWidth(), token.is(Token::kw_true),
+ /*isSigned=*/false);
+ intElements.push_back(apInt);
+ continue;
+ }
+
+ // Create APInt values for each element with the correct bitwidth.
+ auto val = token.getUInt64IntegerValue();
+ if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0
+ : (int64_t)val.getValue() < 0)) {
+ p.emitError(token.getLoc(),
+ "integer constant out of range for attribute");
+ return nullptr;
+ }
+ APInt apInt(eltTy.getWidth(), val.getValue(), isNegative);
+ if (apInt != val.getValue())
+ return (p.emitError("integer constant out of range for type"), nullptr);
+ intElements.push_back(isNegative ? -apInt : apInt);
+ }
+
+ return DenseElementsAttr::get(type, intElements);
+}
+
+/// Build a Dense Float attribute for the given type.
+DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
+ ShapedType type,
+ FloatType eltTy) {
+ std::vector<Attribute> floatValues;
+ floatValues.reserve(storage.size());
+ for (const auto &signAndToken : storage) {
+ bool isNegative = signAndToken.first;
+ const Token &token = signAndToken.second;
+
+ // Handle hexadecimal float literals.
+ if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
+ if (isNegative) {
+ p.emitError(token.getLoc())
+ << "hexadecimal float literal should not have a leading minus";
+ return nullptr;
+ }
+ auto val = token.getUInt64IntegerValue();
+ if (!val.hasValue()) {
+ p.emitError("hexadecimal float constant out of range for attribute");
+ return nullptr;
+ }
+ FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val);
+ if (!attr)
+ return nullptr;
+ floatValues.push_back(attr);
+ continue;
+ }
+
+ // Check to see if any decimal integers or booleans were parsed.
+ if (!token.is(Token::floatliteral)) {
+ p.emitError() << "expected floating-point elements, but parsed integer";
+ return nullptr;
+ }
+
+ // Build the float values from tokens.
+ auto val = token.getFloatingPointValue();
+ if (!val.hasValue()) {
+ p.emitError("floating point value too large for attribute");
+ return nullptr;
+ }
+ floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val));
+ }
+
+ return DenseElementsAttr::get(type, floatValues);
+}
+
+ParseResult TensorLiteralParser::parseElement() {
+ switch (p.getToken().getKind()) {
+ // Parse a boolean element.
+ case Token::kw_true:
+ case Token::kw_false:
+ case Token::floatliteral:
+ case Token::integer:
+ storage.emplace_back(/*isNegative=*/false, p.getToken());
+ p.consumeToken();
+ break;
+
+ // Parse a signed integer or a negative floating-point element.
+ case Token::minus:
+ p.consumeToken(Token::minus);
+ if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+ return p.emitError("expected integer or floating point literal");
+ storage.emplace_back(/*isNegative=*/true, p.getToken());
+ p.consumeToken();
+ break;
+
+ default:
+ return p.emitError("expected element literal of primitive type");
+ }
+
+ return success();
+}
+
+/// Parse a list of either lists or elements, returning the dimensions of the
+/// parsed sub-tensors in dims. For example:
+/// parseList([1, 2, 3]) -> Success, [3]
+/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
+/// parseList([[1, 2], 3]) -> Failure
+/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
+ p.consumeToken(Token::l_square);
+
+ auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
+ const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
+ if (prevDims == newDims)
+ return success();
+ return p.emitError("tensor literal is invalid; ranks are not consistent "
+ "between elements");
+ };
+
+ bool first = true;
+ SmallVector<int64_t, 4> newDims;
+ unsigned size = 0;
+ auto parseCommaSeparatedList = [&]() -> ParseResult {
+ SmallVector<int64_t, 4> thisDims;
+ if (p.getToken().getKind() == Token::l_square) {
+ if (parseList(thisDims))
+ return failure();
+ } else if (parseElement()) {
+ return failure();
+ }
+ ++size;
+ if (!first)
+ return checkDims(newDims, thisDims);
+ newDims = thisDims;
+ first = false;
+ return success();
+ };
+ if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
+ return failure();
+
+ // Return the sublists' dimensions with 'size' prepended.
+ dims.clear();
+ dims.push_back(size);
+ dims.append(newDims.begin(), newDims.end());
+ return success();
+}
+
+/// Parse a dense elements attribute.
+Attribute Parser::parseDenseElementsAttr() {
+ consumeToken(Token::kw_dense);
+ if (parseToken(Token::less, "expected '<' after 'dense'"))
+ return nullptr;
+
+ // Parse the literal data.
+ TensorLiteralParser literalParser(*this);
+ if (literalParser.parse())
+ return nullptr;
+
+ if (parseToken(Token::greater, "expected '>'") ||
+ parseToken(Token::colon, "expected ':'"))
+ return nullptr;
+
+ auto typeLoc = getToken().getLoc();
+ auto type = parseElementsLiteralType();
+ if (!type)
+ return nullptr;
+ return literalParser.getAttr(typeLoc, type);
+}
+
+/// Shaped type for elements attribute.
+///
+/// elements-literal-type ::= vector-type | ranked-tensor-type
+///
+/// This method also checks the type has static shape.
+ShapedType Parser::parseElementsLiteralType() {
+ auto type = parseType();
+ if (!type)
+ return nullptr;
+
+ if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
+ emitError("elements literal must be a ranked tensor or vector type");
+ return nullptr;
+ }
+
+ auto sType = type.cast<ShapedType>();
+ if (!sType.hasStaticShape())
+ return (emitError("elements literal type must have static shape"), nullptr);
+
+ return sType;
+}
+
+/// Parse a sparse elements attribute.
+Attribute Parser::parseSparseElementsAttr() {
+ consumeToken(Token::kw_sparse);
+ if (parseToken(Token::less, "Expected '<' after 'sparse'"))
+ return nullptr;
+
+ /// Parse indices
+ auto indicesLoc = getToken().getLoc();
+ TensorLiteralParser indiceParser(*this);
+ if (indiceParser.parse())
+ return nullptr;
+
+ if (parseToken(Token::comma, "expected ','"))
+ return nullptr;
+
+ /// Parse values.
+ auto valuesLoc = getToken().getLoc();
+ TensorLiteralParser valuesParser(*this);
+ if (valuesParser.parse())
+ return nullptr;
+
+ if (parseToken(Token::greater, "expected '>'") ||
+ parseToken(Token::colon, "expected ':'"))
+ return nullptr;
+
+ auto type = parseElementsLiteralType();
+ if (!type)
+ return nullptr;
+
+ // If the indices are a splat, i.e. the literal parser parsed an element and
+ // not a list, we set the shape explicitly. The indices are represented by a
+ // 2-dimensional shape where the second dimension is the rank of the type.
+ // Given that the parsed indices is a splat, we know that we only have one
+ // indice and thus one for the first dimension.
+ auto indiceEltType = builder.getIntegerType(64);
+ ShapedType indicesType;
+ if (indiceParser.getShape().empty()) {
+ indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
+ } else {
+ // Otherwise, set the shape to the one parsed by the literal parser.
+ indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
+ }
+ auto indices = indiceParser.getAttr(indicesLoc, indicesType);
+
+ // If the values are a splat, set the shape explicitly based on the number of
+ // indices. The number of indices is encoded in the first dimension of the
+ // indice shape type.
+ auto valuesEltType = type.getElementType();
+ ShapedType valuesType =
+ valuesParser.getShape().empty()
+ ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
+ : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
+ auto values = valuesParser.getAttr(valuesLoc, valuesType);
+
+ /// Sanity check.
+ if (valuesType.getRank() != 1)
+ return (emitError("expected 1-d tensor for values"), nullptr);
+
+ auto sameShape = (indicesType.getRank() == 1) ||
+ (type.getRank() == indicesType.getDimSize(1));
+ auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
+ if (!sameShape || !sameElementNum) {
+ emitError() << "expected shape ([" << type.getShape()
+ << "]); inferred shape of indices literal (["
+ << indicesType.getShape()
+ << "]); inferred shape of values literal (["
+ << valuesType.getShape() << "])";
+ return nullptr;
+ }
+
+ // Build the sparse elements attribute by the indices and values.
+ return SparseElementsAttr::get(type, indices, values);
+}
+
+//===----------------------------------------------------------------------===//
+// Location parsing.
+//===----------------------------------------------------------------------===//
+
+/// Parse a location.
+///
+/// location ::= `loc` inline-location
+/// inline-location ::= '(' location-inst ')'
+///
+ParseResult Parser::parseLocation(LocationAttr &loc) {
+ // Check for 'loc' identifier.
+ if (parseToken(Token::kw_loc, "expected 'loc' keyword"))
+ return emitError();
+
+ // Parse the inline-location.
+ if (parseToken(Token::l_paren, "expected '(' in inline location") ||
+ parseLocationInstance(loc) ||
+ parseToken(Token::r_paren, "expected ')' in inline location"))
+ return failure();
+ return success();
+}
+
+/// Specific location instances.
+///
+/// location-inst ::= filelinecol-location |
+/// name-location |
+/// callsite-location |
+/// fused-location |
+/// unknown-location
+/// filelinecol-location ::= string-literal ':' integer-literal
+/// ':' integer-literal
+/// name-location ::= string-literal
+/// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')'
+/// fused-location ::= fused ('<' attribute-value '>')?
+/// '[' location-inst (location-inst ',')* ']'
+/// unknown-location ::= 'unknown'
+///
+ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) {
+ consumeToken(Token::bare_identifier);
+
+ // Parse the '('.
+ if (parseToken(Token::l_paren, "expected '(' in callsite location"))
+ return failure();
+
+ // Parse the callee location.
+ LocationAttr calleeLoc;
+ if (parseLocationInstance(calleeLoc))
+ return failure();
+
+ // Parse the 'at'.
+ if (getToken().isNot(Token::bare_identifier) ||
+ getToken().getSpelling() != "at")
+ return emitError("expected 'at' in callsite location");
+ consumeToken(Token::bare_identifier);
+
+ // Parse the caller location.
+ LocationAttr callerLoc;
+ if (parseLocationInstance(callerLoc))
+ return failure();
+
+ // Parse the ')'.
+ if (parseToken(Token::r_paren, "expected ')' in callsite location"))
+ return failure();
+
+ // Return the callsite location.
+ loc = CallSiteLoc::get(calleeLoc, callerLoc);
+ return success();
+}
+
+ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
+ consumeToken(Token::bare_identifier);
+
+ // Try to parse the optional metadata.
+ Attribute metadata;
+ if (consumeIf(Token::less)) {
+ metadata = parseAttribute();
+ if (!metadata)
+ return emitError("expected valid attribute metadata");
+ // Parse the '>' token.
+ if (parseToken(Token::greater,
+ "expected '>' after fused location metadata"))
+ return failure();
+ }
+
+ SmallVector<Location, 4> locations;
+ auto parseElt = [&] {
+ LocationAttr newLoc;
+ if (parseLocationInstance(newLoc))
+ return failure();
+ locations.push_back(newLoc);
+ return success();
+ };
+
+ if (parseToken(Token::l_square, "expected '[' in fused location") ||
+ parseCommaSeparatedList(parseElt) ||
+ parseToken(Token::r_square, "expected ']' in fused location"))
+ return failure();
+
+ // Return the fused location.
+ loc = FusedLoc::get(locations, metadata, getContext());
+ return success();
+}
+
+ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
+ auto *ctx = getContext();
+ auto str = getToken().getStringValue();
+ consumeToken(Token::string);
+
+ // If the next token is ':' this is a filelinecol location.
+ if (consumeIf(Token::colon)) {
+ // Parse the line number.
+ if (getToken().isNot(Token::integer))
+ return emitError("expected integer line number in FileLineColLoc");
+ auto line = getToken().getUnsignedIntegerValue();
+ if (!line.hasValue())
+ return emitError("expected integer line number in FileLineColLoc");
+ consumeToken(Token::integer);
+
+ // Parse the ':'.
+ if (parseToken(Token::colon, "expected ':' in FileLineColLoc"))
+ return failure();
+
+ // Parse the column number.
+ if (getToken().isNot(Token::integer))
+ return emitError("expected integer column number in FileLineColLoc");
+ auto column = getToken().getUnsignedIntegerValue();
+ if (!column.hasValue())
+ return emitError("expected integer column number in FileLineColLoc");
+ consumeToken(Token::integer);
+
+ loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx);
+ return success();
+ }
+
+ // Otherwise, this is a NameLoc.
+
+ // Check for a child location.
+ if (consumeIf(Token::l_paren)) {
+ auto childSourceLoc = getToken().getLoc();
+
+ // Parse the child location.
+ LocationAttr childLoc;
+ if (parseLocationInstance(childLoc))
+ return failure();
+
+ // The child must not be another NameLoc.
+ if (childLoc.isa<NameLoc>())
+ return emitError(childSourceLoc,
+ "child of NameLoc cannot be another NameLoc");
+ loc = NameLoc::get(Identifier::get(str, ctx), childLoc);
+
+ // Parse the closing ')'.
+ if (parseToken(Token::r_paren,
+ "expected ')' after child location of NameLoc"))
+ return failure();
+ } else {
+ loc = NameLoc::get(Identifier::get(str, ctx), ctx);
+ }
+
+ return success();
+}
+
+ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
+ // Handle either name or filelinecol locations.
+ if (getToken().is(Token::string))
+ return parseNameOrFileLineColLocation(loc);
+
+ // Bare tokens required for other cases.
+ if (!getToken().is(Token::bare_identifier))
+ return emitError("expected location instance");
+
+ // Check for the 'callsite' signifying a callsite location.
+ if (getToken().getSpelling() == "callsite")
+ return parseCallSiteLocation(loc);
+
+ // If the token is 'fused', then this is a fused location.
+ if (getToken().getSpelling() == "fused")
+ return parseFusedLocation(loc);
+
+ // Check for a 'unknown' for an unknown location.
+ if (getToken().getSpelling() == "unknown") {
+ consumeToken(Token::bare_identifier);
+ loc = UnknownLoc::get(getContext());
+ return success();
+ }
+
+ return emitError("expected location instance");
+}
+
+//===----------------------------------------------------------------------===//
+// Affine parsing.
+//===----------------------------------------------------------------------===//
+
+/// Lower precedence ops (all at the same precedence level). LNoOp is false in
+/// the boolean sense.
+enum AffineLowPrecOp {
+ /// Null value.
+ LNoOp,
+ Add,
+ Sub
+};
+
+/// Higher precedence ops - all at the same precedence level. HNoOp is false
+/// in the boolean sense.
+enum AffineHighPrecOp {
+ /// Null value.
+ HNoOp,
+ Mul,
+ FloorDiv,
+ CeilDiv,
+ Mod
+};
+
+namespace {
+/// This is a specialized parser for affine structures (affine maps, affine
+/// expressions, and integer sets), maintaining the state transient to their
+/// bodies.
+class AffineParser : public Parser {
+public:
+ AffineParser(ParserState &state, bool allowParsingSSAIds = false,
+ function_ref<ParseResult(bool)> parseElement = nullptr)
+ : Parser(state), allowParsingSSAIds(allowParsingSSAIds),
+ parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {}
+
+ AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols);
+ ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set);
+ IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
+ ParseResult parseAffineMapOfSSAIds(AffineMap &map);
+ void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
+ unsigned &numDims);
+
+private:
+ // Binary affine op parsing.
+ AffineLowPrecOp consumeIfLowPrecOp();
+ AffineHighPrecOp consumeIfHighPrecOp();
+
+ // Identifier lists for polyhedral structures.
+ ParseResult parseDimIdList(unsigned &numDims);
+ ParseResult parseSymbolIdList(unsigned &numSymbols);
+ ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims,
+ unsigned &numSymbols);
+ ParseResult parseIdentifierDefinition(AffineExpr idExpr);
+
+ AffineExpr parseAffineExpr();
+ AffineExpr parseParentheticalExpr();
+ AffineExpr parseNegateExpression(AffineExpr lhs);
+ AffineExpr parseIntegerExpr();
+ AffineExpr parseBareIdExpr();
+ AffineExpr parseSSAIdExpr(bool isSymbol);
+ AffineExpr parseSymbolSSAIdExpr();
+
+ AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
+ AffineExpr rhs, SMLoc opLoc);
+ AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
+ AffineExpr rhs);
+ AffineExpr parseAffineOperandExpr(AffineExpr lhs);
+ AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
+ AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc);
+ AffineExpr parseAffineConstraint(bool *isEq);
+
+private:
+ bool allowParsingSSAIds;
+ function_ref<ParseResult(bool)> parseElement;
+ unsigned numDimOperands;
+ unsigned numSymbolOperands;
+ SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
+};
+} // end anonymous namespace
+
+/// Create an affine binary high precedence op expression (mul's, div's, mod).
+/// opLoc is the location of the op token to be used to report errors
+/// for non-conforming expressions.
+AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
+ AffineExpr lhs, AffineExpr rhs,
+ SMLoc opLoc) {
+ // TODO: make the error location info accurate.
+ switch (op) {
+ case Mul:
+ if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
+ emitError(opLoc, "non-affine expression: at least one of the multiply "
+ "operands has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs * rhs;
+ case FloorDiv:
+ if (!rhs.isSymbolicOrConstant()) {
+ emitError(opLoc, "non-affine expression: right operand of floordiv "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs.floorDiv(rhs);
+ case CeilDiv:
+ if (!rhs.isSymbolicOrConstant()) {
+ emitError(opLoc, "non-affine expression: right operand of ceildiv "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs.ceilDiv(rhs);
+ case Mod:
+ if (!rhs.isSymbolicOrConstant()) {
+ emitError(opLoc, "non-affine expression: right operand of mod "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs % rhs;
+ case HNoOp:
+ llvm_unreachable("can't create affine expression for null high prec op");
+ return nullptr;
+ }
+ llvm_unreachable("Unknown AffineHighPrecOp");
+}
+
+/// Create an affine binary low precedence op expression (add, sub).
+AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
+ AffineExpr lhs, AffineExpr rhs) {
+ switch (op) {
+ case AffineLowPrecOp::Add:
+ return lhs + rhs;
+ case AffineLowPrecOp::Sub:
+ return lhs - rhs;
+ case AffineLowPrecOp::LNoOp:
+ llvm_unreachable("can't create affine expression for null low prec op");
+ return nullptr;
+ }
+ llvm_unreachable("Unknown AffineLowPrecOp");
+}
+
+/// Consume this token if it is a lower precedence affine op (there are only
+/// two precedence levels).
+AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
+ switch (getToken().getKind()) {
+ case Token::plus:
+ consumeToken(Token::plus);
+ return AffineLowPrecOp::Add;
+ case Token::minus:
+ consumeToken(Token::minus);
+ return AffineLowPrecOp::Sub;
+ default:
+ return AffineLowPrecOp::LNoOp;
+ }
+}
+
+/// Consume this token if it is a higher precedence affine op (there are only
+/// two precedence levels)
+AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
+ switch (getToken().getKind()) {
+ case Token::star:
+ consumeToken(Token::star);
+ return Mul;
+ case Token::kw_floordiv:
+ consumeToken(Token::kw_floordiv);
+ return FloorDiv;
+ case Token::kw_ceildiv:
+ consumeToken(Token::kw_ceildiv);
+ return CeilDiv;
+ case Token::kw_mod:
+ consumeToken(Token::kw_mod);
+ return Mod;
+ default:
+ return HNoOp;
+ }
+}
+
+/// Parse a high precedence op expression list: mul, div, and mod are high
+/// precedence binary ops, i.e., parse a
+/// expr_1 op_1 expr_2 op_2 ... expr_n
+/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
+/// All affine binary ops are left associative.
+/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
+/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
+/// null. llhsOpLoc is the location of the llhsOp token that will be used to
+/// report an error for non-conforming expressions.
+AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
+ AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc) {
+ AffineExpr lhs = parseAffineOperandExpr(llhs);
+ if (!lhs)
+ return nullptr;
+
+ // Found an LHS. Parse the remaining expression.
+ auto opLoc = getToken().getLoc();
+ if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
+ if (llhs) {
+ AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
+ if (!expr)
+ return nullptr;
+ return parseAffineHighPrecOpExpr(expr, op, opLoc);
+ }
+ // No LLHS, get RHS
+ return parseAffineHighPrecOpExpr(lhs, op, opLoc);
+ }
+
+ // This is the last operand in this expression.
+ if (llhs)
+ return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
+
+ // No llhs, 'lhs' itself is the expression.
+ return lhs;
+}
+
+/// Parse an affine expression inside parentheses.
+///
+/// affine-expr ::= `(` affine-expr `)`
+AffineExpr AffineParser::parseParentheticalExpr() {
+ if (parseToken(Token::l_paren, "expected '('"))
+ return nullptr;
+ if (getToken().is(Token::r_paren))
+ return (emitError("no expression inside parentheses"), nullptr);
+
+ auto expr = parseAffineExpr();
+ if (!expr)
+ return nullptr;
+ if (parseToken(Token::r_paren, "expected ')'"))
+ return nullptr;
+
+ return expr;
+}
+
+/// Parse the negation expression.
+///
+/// affine-expr ::= `-` affine-expr
+AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
+ if (parseToken(Token::minus, "expected '-'"))
+ return nullptr;
+
+ AffineExpr operand = parseAffineOperandExpr(lhs);
+ // Since negation has the highest precedence of all ops (including high
+ // precedence ops) but lower than parentheses, we are only going to use
+ // parseAffineOperandExpr instead of parseAffineExpr here.
+ if (!operand)
+ // Extra error message although parseAffineOperandExpr would have
+ // complained. Leads to a better diagnostic.
+ return (emitError("missing operand of negation"), nullptr);
+ return (-1) * operand;
+}
+
+/// Parse a bare id that may appear in an affine expression.
+///
+/// affine-expr ::= bare-id
+AffineExpr AffineParser::parseBareIdExpr() {
+ if (getToken().isNot(Token::bare_identifier))
+ return (emitError("expected bare identifier"), nullptr);
+
+ StringRef sRef = getTokenSpelling();
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first == sRef) {
+ consumeToken(Token::bare_identifier);
+ return entry.second;
+ }
+ }
+
+ return (emitError("use of undeclared identifier"), nullptr);
+}
+
+/// Parse an SSA id which may appear in an affine expression.
+AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) {
+ if (!allowParsingSSAIds)
+ return (emitError("unexpected ssa identifier"), nullptr);
+ if (getToken().isNot(Token::percent_identifier))
+ return (emitError("expected ssa identifier"), nullptr);
+ auto name = getTokenSpelling();
+ // Check if we already parsed this SSA id.
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first == name) {
+ consumeToken(Token::percent_identifier);
+ return entry.second;
+ }
+ }
+ // Parse the SSA id and add an AffineDim/SymbolExpr to represent it.
+ if (parseElement(isSymbol))
+ return (emitError("failed to parse ssa identifier"), nullptr);
+ auto idExpr = isSymbol
+ ? getAffineSymbolExpr(numSymbolOperands++, getContext())
+ : getAffineDimExpr(numDimOperands++, getContext());
+ dimsAndSymbols.push_back({name, idExpr});
+ return idExpr;
+}
+
+AffineExpr AffineParser::parseSymbolSSAIdExpr() {
+ if (parseToken(Token::kw_symbol, "expected symbol keyword") ||
+ parseToken(Token::l_paren, "expected '(' at start of SSA symbol"))
+ return nullptr;
+ AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true);
+ if (!symbolExpr)
+ return nullptr;
+ if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol"))
+ return nullptr;
+ return symbolExpr;
+}
+
+/// Parse a positive integral constant appearing in an affine expression.
+///
+/// affine-expr ::= integer-literal
+AffineExpr AffineParser::parseIntegerExpr() {
+ auto val = getToken().getUInt64IntegerValue();
+ if (!val.hasValue() || (int64_t)val.getValue() < 0)
+ return (emitError("constant too large for index"), nullptr);
+
+ consumeToken(Token::integer);
+ return builder.getAffineConstantExpr((int64_t)val.getValue());
+}
+
+/// Parses an expression that can be a valid operand of an affine expression.
+/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
+/// operator, the rhs of which is being parsed. This is used to determine
+/// whether an error should be emitted for a missing right operand.
+// Eg: for an expression without parentheses (like i + j + k + l), each
+// of the four identifiers is an operand. For i + j*k + l, j*k is not an
+// operand expression, it's an op expression and will be parsed via
+// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
+// -l are valid operands that will be parsed by this function.
+AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
+ switch (getToken().getKind()) {
+ case Token::bare_identifier:
+ return parseBareIdExpr();
+ case Token::kw_symbol:
+ return parseSymbolSSAIdExpr();
+ case Token::percent_identifier:
+ return parseSSAIdExpr(/*isSymbol=*/false);
+ case Token::integer:
+ return parseIntegerExpr();
+ case Token::l_paren:
+ return parseParentheticalExpr();
+ case Token::minus:
+ return parseNegateExpression(lhs);
+ case Token::kw_ceildiv:
+ case Token::kw_floordiv:
+ case Token::kw_mod:
+ case Token::plus:
+ case Token::star:
+ if (lhs)
+ emitError("missing right operand of binary operator");
+ else
+ emitError("missing left operand of binary operator");
+ return nullptr;
+ default:
+ if (lhs)
+ emitError("missing right operand of binary operator");
+ else
+ emitError("expected affine expression");
+ return nullptr;
+ }
+}
+
+/// Parse affine expressions that are bare-id's, integer constants,
+/// parenthetical affine expressions, and affine op expressions that are a
+/// composition of those.
+///
+/// All binary op's associate from left to right.
+///
+/// {add, sub} have lower precedence than {mul, div, and mod}.
+///
+/// Add, sub'are themselves at the same precedence level. Mul, floordiv,
+/// ceildiv, and mod are at the same higher precedence level. Negation has
+/// higher precedence than any binary op.
+///
+/// llhs: the affine expression appearing on the left of the one being parsed.
+/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
+/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
+/// if llhs is non-null; otherwise lhs is returned. This is to deal with left
+/// associativity.
+///
+/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
+/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
+/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
+AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
+ AffineLowPrecOp llhsOp) {
+ AffineExpr lhs;
+ if (!(lhs = parseAffineOperandExpr(llhs)))
+ return nullptr;
+
+ // Found an LHS. Deal with the ops.
+ if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
+ if (llhs) {
+ AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
+ return parseAffineLowPrecOpExpr(sum, lOp);
+ }
+ // No LLHS, get RHS and form the expression.
+ return parseAffineLowPrecOpExpr(lhs, lOp);
+ }
+ auto opLoc = getToken().getLoc();
+ if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
+ // We have a higher precedence op here. Get the rhs operand for the llhs
+ // through parseAffineHighPrecOpExpr.
+ AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
+ if (!highRes)
+ return nullptr;
+
+ // If llhs is null, the product forms the first operand of the yet to be
+ // found expression. If non-null, the op to associate with llhs is llhsOp.
+ AffineExpr expr =
+ llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
+
+ // Recurse for subsequent low prec op's after the affine high prec op
+ // expression.
+ if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
+ return parseAffineLowPrecOpExpr(expr, nextOp);
+ return expr;
+ }
+ // Last operand in the expression list.
+ if (llhs)
+ return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
+ // No llhs, 'lhs' itself is the expression.
+ return lhs;
+}
+
+/// Parse an affine expression.
+/// affine-expr ::= `(` affine-expr `)`
+/// | `-` affine-expr
+/// | affine-expr `+` affine-expr
+/// | affine-expr `-` affine-expr
+/// | affine-expr `*` affine-expr
+/// | affine-expr `floordiv` affine-expr
+/// | affine-expr `ceildiv` affine-expr
+/// | affine-expr `mod` affine-expr
+/// | bare-id
+/// | integer-literal
+///
+/// Additional conditions are checked depending on the production. For eg.,
+/// one of the operands for `*` has to be either constant/symbolic; the second
+/// operand for floordiv, ceildiv, and mod has to be a positive integer.
+AffineExpr AffineParser::parseAffineExpr() {
+ return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
+}
+
+/// Parse a dim or symbol from the lists appearing before the actual
+/// expressions of the affine map. Update our state to store the
+/// dimensional/symbolic identifier.
+ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
+ if (getToken().isNot(Token::bare_identifier))
+ return emitError("expected bare identifier");
+
+ auto name = getTokenSpelling();
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first == name)
+ return emitError("redefinition of identifier '" + name + "'");
+ }
+ consumeToken(Token::bare_identifier);
+
+ dimsAndSymbols.push_back({name, idExpr});
+ return success();
+}
+
+/// Parse the list of dimensional identifiers to an affine map.
+ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
+ if (parseToken(Token::l_paren,
+ "expected '(' at start of dimensional identifiers list")) {
+ return failure();
+ }
+
+ auto parseElt = [&]() -> ParseResult {
+ auto dimension = getAffineDimExpr(numDims++, getContext());
+ return parseIdentifierDefinition(dimension);
+ };
+ return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
+}
+
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
+ consumeToken(Token::l_square);
+ auto parseElt = [&]() -> ParseResult {
+ auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
+ return parseIdentifierDefinition(symbol);
+ };
+ return parseCommaSeparatedListUntil(Token::r_square, parseElt);
+}
+
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult
+AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims,
+ unsigned &numSymbols) {
+ if (parseDimIdList(numDims)) {
+ return failure();
+ }
+ if (!getToken().is(Token::l_square)) {
+ numSymbols = 0;
+ return success();
+ }
+ return parseSymbolIdList(numSymbols);
+}
+
+/// Parses an ambiguous affine map or integer set definition inline.
+ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
+ IntegerSet &set) {
+ unsigned numDims = 0, numSymbols = 0;
+
+ // List of dimensional and optional symbol identifiers.
+ if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) {
+ return failure();
+ }
+
+ // This is needed for parsing attributes as we wouldn't know whether we would
+ // be parsing an integer set attribute or an affine map attribute.
+ bool isArrow = getToken().is(Token::arrow);
+ bool isColon = getToken().is(Token::colon);
+ if (!isArrow && !isColon) {
+ return emitError("expected '->' or ':'");
+ } else if (isArrow) {
+ parseToken(Token::arrow, "expected '->' or '['");
+ map = parseAffineMapRange(numDims, numSymbols);
+ return map ? success() : failure();
+ } else if (parseToken(Token::colon, "expected ':' or '['")) {
+ return failure();
+ }
+
+ if ((set = parseIntegerSetConstraints(numDims, numSymbols)))
+ return success();
+
+ return failure();
+}
+
+/// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) {
+ if (parseToken(Token::l_square, "expected '['"))
+ return failure();
+
+ SmallVector<AffineExpr, 4> exprs;
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseAffineExpr();
+ exprs.push_back(elt);
+ return elt ? success() : failure();
+ };
+
+ // Parse a multi-dimensional affine expression (a comma-separated list of
+ // 1-d affine expressions); the list cannot be empty. Grammar:
+ // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+ if (parseCommaSeparatedListUntil(Token::r_square, parseElt,
+ /*allowEmptyList=*/true))
+ return failure();
+ // Parsed a valid affine map.
+ if (exprs.empty())
+ map = AffineMap::get(getContext());
+ else
+ map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
+ exprs);
+ return success();
+}
+
+/// Parse the range and sizes affine map definition inline.
+///
+/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
+///
+/// multi-dim-affine-expr ::= `(` `)`
+/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
+AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
+ unsigned numSymbols) {
+ parseToken(Token::l_paren, "expected '(' at start of affine map range");
+
+ SmallVector<AffineExpr, 4> exprs;
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseAffineExpr();
+ ParseResult res = elt ? success() : failure();
+ exprs.push_back(elt);
+ return res;
+ };
+
+ // Parse a multi-dimensional affine expression (a comma-separated list of
+ // 1-d affine expressions); the list cannot be empty. Grammar:
+ // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+ return AffineMap();
+
+ if (exprs.empty())
+ return AffineMap::get(getContext());
+
+ // Parsed a valid affine map.
+ return AffineMap::get(numDims, numSymbols, exprs);
+}
+
+/// Parse an affine constraint.
+/// affine-constraint ::= affine-expr `>=` `0`
+/// | affine-expr `==` `0`
+///
+/// isEq is set to true if the parsed constraint is an equality, false if it
+/// is an inequality (greater than or equal).
+///
+AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
+ AffineExpr expr = parseAffineExpr();
+ if (!expr)
+ return nullptr;
+
+ if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
+ getToken().is(Token::integer)) {
+ auto dim = getToken().getUnsignedIntegerValue();
+ if (dim.hasValue() && dim.getValue() == 0) {
+ consumeToken(Token::integer);
+ *isEq = false;
+ return expr;
+ }
+ return (emitError("expected '0' after '>='"), nullptr);
+ }
+
+ if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
+ getToken().is(Token::integer)) {
+ auto dim = getToken().getUnsignedIntegerValue();
+ if (dim.hasValue() && dim.getValue() == 0) {
+ consumeToken(Token::integer);
+ *isEq = true;
+ return expr;
+ }
+ return (emitError("expected '0' after '=='"), nullptr);
+ }
+
+ return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
+ nullptr);
+}
+
+/// Parse the constraints that are part of an integer set definition.
+/// integer-set-inline
+/// ::= dim-and-symbol-id-lists `:`
+/// '(' affine-constraint-conjunction? ')'
+/// affine-constraint-conjunction ::= affine-constraint (`,`
+/// affine-constraint)*
+///
+IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
+ unsigned numSymbols) {
+ if (parseToken(Token::l_paren,
+ "expected '(' at start of integer set constraint list"))
+ return IntegerSet();
+
+ SmallVector<AffineExpr, 4> constraints;
+ SmallVector<bool, 4> isEqs;
+ auto parseElt = [&]() -> ParseResult {
+ bool isEq;
+ auto elt = parseAffineConstraint(&isEq);
+ ParseResult res = elt ? success() : failure();
+ if (elt) {
+ constraints.push_back(elt);
+ isEqs.push_back(isEq);
+ }
+ return res;
+ };
+
+ // Parse a list of affine constraints (comma-separated).
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+ return IntegerSet();
+
+ // If no constraints were parsed, then treat this as a degenerate 'true' case.
+ if (constraints.empty()) {
+ /* 0 == 0 */
+ auto zero = getAffineConstantExpr(0, getContext());
+ return IntegerSet::get(numDims, numSymbols, zero, true);
+ }
+
+ // Parsed a valid integer set.
+ return IntegerSet::get(numDims, numSymbols, constraints, isEqs);
+}
+
+/// Parse an ambiguous reference to either and affine map or an integer set.
+ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map,
+ IntegerSet &set) {
+ return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set);
+}
+
+/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to
+/// parse SSA value uses encountered while parsing affine expressions.
+ParseResult
+Parser::parseAffineMapOfSSAIds(AffineMap &map,
+ function_ref<ParseResult(bool)> parseElement) {
+ return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
+ .parseAffineMapOfSSAIds(map);
+}
+
+//===----------------------------------------------------------------------===//
+// OperationParser
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides support for parsing operations and regions of
+/// operations.
+class OperationParser : public Parser {
+public:
+ OperationParser(ParserState &state, ModuleOp moduleOp)
+ : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) {
+ }
+
+ ~OperationParser();
+
+ /// After parsing is finished, this function must be called to see if there
+ /// are any remaining issues.
+ ParseResult finalize();
+
+ //===--------------------------------------------------------------------===//
+ // SSA Value Handling
+ //===--------------------------------------------------------------------===//
+
+ /// This represents a use of an SSA value in the program. The first two
+ /// entries in the tuple are the name and result number of a reference. The
+ /// third is the location of the reference, which is used in case this ends
+ /// up being a use of an undefined value.
+ struct SSAUseInfo {
+ StringRef name; // Value name, e.g. %42 or %abc
+ unsigned number; // Number, specified with #12
+ SMLoc loc; // Location of first definition or use.
+ };
+
+ /// Push a new SSA name scope to the parser.
+ void pushSSANameScope(bool isIsolated);
+
+ /// Pop the last SSA name scope from the parser.
+ ParseResult popSSANameScope();
+
+ /// Register a definition of a value with the symbol table.
+ ParseResult addDefinition(SSAUseInfo useInfo, Value value);
+
+ /// Parse an optional list of SSA uses into 'results'.
+ ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
+
+ /// Parse a single SSA use into 'result'.
+ ParseResult parseSSAUse(SSAUseInfo &result);
+
+ /// Given a reference to an SSA value and its type, return a reference. This
+ /// returns null on failure.
+ Value resolveSSAUse(SSAUseInfo useInfo, Type type);
+
+ ParseResult parseSSADefOrUseAndType(
+ const std::function<ParseResult(SSAUseInfo, Type)> &action);
+
+ ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results);
+
+ /// Return the location of the value identified by its name and number if it
+ /// has been already reference.
+ Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) {
+ auto &values = isolatedNameScopes.back().values;
+ if (!values.count(name) || number >= values[name].size())
+ return {};
+ if (values[name][number].first)
+ return values[name][number].second;
+ return {};
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Operation Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an operation instance.
+ ParseResult parseOperation();
+
+ /// Parse a single operation successor and its operand list.
+ ParseResult parseSuccessorAndUseList(Block *&dest,
+ SmallVectorImpl<Value> &operands);
+
+ /// Parse a comma-separated list of operation successors in brackets.
+ ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations,
+ SmallVectorImpl<SmallVector<Value, 4>> &operands);
+
+ /// Parse an operation instance that is in the generic form.
+ Operation *parseGenericOperation();
+
+ /// Parse an operation instance that is in the generic form and insert it at
+ /// the provided insertion point.
+ Operation *parseGenericOperation(Block *insertBlock,
+ Block::iterator insertPt);
+
+ /// Parse an operation instance that is in the op-defined custom form.
+ Operation *parseCustomOperation();
+
+ //===--------------------------------------------------------------------===//
+ // Region Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a region into 'region' with the provided entry block arguments.
+ /// 'isIsolatedNameScope' indicates if the naming scope of this region is
+ /// isolated from those above.
+ ParseResult parseRegion(Region &region,
+ ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments,
+ bool isIsolatedNameScope = false);
+
+ /// Parse a region body into 'region'.
+ ParseResult parseRegionBody(Region &region);
+
+ //===--------------------------------------------------------------------===//
+ // Block Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a new block into 'block'.
+ ParseResult parseBlock(Block *&block);
+
+ /// Parse a list of operations into 'block'.
+ ParseResult parseBlockBody(Block *block);
+
+ /// Parse a (possibly empty) list of block arguments.
+ ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results,
+ Block *owner);
+
+ /// Get the block with the specified name, creating it if it doesn't
+ /// already exist. The location specified is the point of use, which allows
+ /// us to diagnose references to blocks that are not defined precisely.
+ Block *getBlockNamed(StringRef name, SMLoc loc);
+
+ /// Define the block with the specified name. Returns the Block* or nullptr in
+ /// the case of redefinition.
+ Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing);
+
+private:
+ /// Returns the info for a block at the current scope for the given name.
+ std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
+ return blocksByName.back()[name];
+ }
+
+ /// Insert a new forward reference to the given block.
+ void insertForwardRef(Block *block, SMLoc loc) {
+ forwardRef.back().try_emplace(block, loc);
+ }
+
+ /// Erase any forward reference to the given block.
+ bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); }
+
+ /// Record that a definition was added at the current scope.
+ void recordDefinition(StringRef def);
+
+ /// Get the value entry for the given SSA name.
+ SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name);
+
+ /// Create a forward reference placeholder value with the given location and
+ /// result type.
+ Value createForwardRefPlaceholder(SMLoc loc, Type type);
+
+ /// Return true if this is a forward reference.
+ bool isForwardRefPlaceholder(Value value) {
+ return forwardRefPlaceholders.count(value);
+ }
+
+ /// This struct represents an isolated SSA name scope. This scope may contain
+ /// other nested non-isolated scopes. These scopes are used for operations
+ /// that are known to be isolated to allow for reusing names within their
+ /// regions, even if those names are used above.
+ struct IsolatedSSANameScope {
+ /// Record that a definition was added at the current scope.
+ void recordDefinition(StringRef def) {
+ definitionsPerScope.back().insert(def);
+ }
+
+ /// Push a nested name scope.
+ void pushSSANameScope() { definitionsPerScope.push_back({}); }
+
+ /// Pop a nested name scope.
+ void popSSANameScope() {
+ for (auto &def : definitionsPerScope.pop_back_val())
+ values.erase(def.getKey());
+ }
+
+ /// This keeps track of all of the SSA values we are tracking for each name
+ /// scope, indexed by their name. This has one entry per result number.
+ llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values;
+
+ /// This keeps track of all of the values defined by a specific name scope.
+ SmallVector<llvm::StringSet<>, 2> definitionsPerScope;
+ };
+
+ /// A list of isolated name scopes.
+ SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes;
+
+ /// This keeps track of the block names as well as the location of the first
+ /// reference for each nested name scope. This is used to diagnose invalid
+ /// block references and memorize them.
+ SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName;
+ SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef;
+
+ /// These are all of the placeholders we've made along with the location of
+ /// their first reference, to allow checking for use of undefined values.
+ DenseMap<Value, SMLoc> forwardRefPlaceholders;
+
+ /// The builder used when creating parsed operation instances.
+ OpBuilder opBuilder;
+
+ /// The top level module operation.
+ ModuleOp moduleOp;
+};
+} // end anonymous namespace
+
+OperationParser::~OperationParser() {
+ for (auto &fwd : forwardRefPlaceholders) {
+ // Drop all uses of undefined forward declared reference and destroy
+ // defining operation.
+ fwd.first->dropAllUses();
+ fwd.first->getDefiningOp()->destroy();
+ }
+}
+
+/// After parsing is finished, this function must be called to see if there are
+/// any remaining issues.
+ParseResult OperationParser::finalize() {
+ // Check for any forward references that are left. If we find any, error
+ // out.
+ if (!forwardRefPlaceholders.empty()) {
+ SmallVector<std::pair<const char *, Value>, 4> errors;
+ // Iteration over the map isn't deterministic, so sort by source location.
+ for (auto entry : forwardRefPlaceholders)
+ errors.push_back({entry.second.getPointer(), entry.first});
+ llvm::array_pod_sort(errors.begin(), errors.end());
+
+ for (auto entry : errors) {
+ auto loc = SMLoc::getFromPointer(entry.first);
+ emitError(loc, "use of undeclared SSA value name");
+ }
+ return failure();
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SSA Value Handling
+//===----------------------------------------------------------------------===//
+
+void OperationParser::pushSSANameScope(bool isIsolated) {
+ blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>());
+ forwardRef.push_back(DenseMap<Block *, SMLoc>());
+
+ // Push back a new name definition scope.
+ if (isIsolated)
+ isolatedNameScopes.push_back({});
+ isolatedNameScopes.back().pushSSANameScope();
+}
+
+ParseResult OperationParser::popSSANameScope() {
+ auto forwardRefInCurrentScope = forwardRef.pop_back_val();
+
+ // Verify that all referenced blocks were defined.
+ if (!forwardRefInCurrentScope.empty()) {
+ SmallVector<std::pair<const char *, Block *>, 4> errors;
+ // Iteration over the map isn't deterministic, so sort by source location.
+ for (auto entry : forwardRefInCurrentScope) {
+ errors.push_back({entry.second.getPointer(), entry.first});
+ // Add this block to the top-level region to allow for automatic cleanup.
+ moduleOp.getOperation()->getRegion(0).push_back(entry.first);
+ }
+ llvm::array_pod_sort(errors.begin(), errors.end());
+
+ for (auto entry : errors) {
+ auto loc = SMLoc::getFromPointer(entry.first);
+ emitError(loc, "reference to an undefined block");
+ }
+ return failure();
+ }
+
+ // Pop the next nested namescope. If there is only one internal namescope,
+ // just pop the isolated scope.
+ auto &currentNameScope = isolatedNameScopes.back();
+ if (currentNameScope.definitionsPerScope.size() == 1)
+ isolatedNameScopes.pop_back();
+ else
+ currentNameScope.popSSANameScope();
+
+ blocksByName.pop_back();
+ return success();
+}
+
+/// Register a definition of a value with the symbol table.
+ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) {
+ auto &entries = getSSAValueEntry(useInfo.name);
+
+ // Make sure there is a slot for this value.
+ if (entries.size() <= useInfo.number)
+ entries.resize(useInfo.number + 1);
+
+ // If we already have an entry for this, check to see if it was a definition
+ // or a forward reference.
+ if (auto existing = entries[useInfo.number].first) {
+ if (!isForwardRefPlaceholder(existing)) {
+ return emitError(useInfo.loc)
+ .append("redefinition of SSA value '", useInfo.name, "'")
+ .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
+ .append("previously defined here");
+ }
+
+ // If it was a forward reference, update everything that used it to use
+ // the actual definition instead, delete the forward ref, and remove it
+ // from our set of forward references we track.
+ existing->replaceAllUsesWith(value);
+ existing->getDefiningOp()->destroy();
+ forwardRefPlaceholders.erase(existing);
+ }
+
+ /// Record this definition for the current scope.
+ entries[useInfo.number] = {value, useInfo.loc};
+ recordDefinition(useInfo.name);
+ return success();
+}
+
+/// Parse a (possibly empty) list of SSA operands.
+///
+/// ssa-use-list ::= ssa-use (`,` ssa-use)*
+/// ssa-use-list-opt ::= ssa-use-list?
+///
+ParseResult
+OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
+ if (getToken().isNot(Token::percent_identifier))
+ return success();
+ return parseCommaSeparatedList([&]() -> ParseResult {
+ SSAUseInfo result;
+ if (parseSSAUse(result))
+ return failure();
+ results.push_back(result);
+ return success();
+ });
+}
+
+/// Parse a SSA operand for an operation.
+///
+/// ssa-use ::= ssa-id
+///
+ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) {
+ result.name = getTokenSpelling();
+ result.number = 0;
+ result.loc = getToken().getLoc();
+ if (parseToken(Token::percent_identifier, "expected SSA operand"))
+ return failure();
+
+ // If we have an attribute ID, it is a result number.
+ if (getToken().is(Token::hash_identifier)) {
+ if (auto value = getToken().getHashIdentifierNumber())
+ result.number = value.getValue();
+ else
+ return emitError("invalid SSA value result number");
+ consumeToken(Token::hash_identifier);
+ }
+
+ return success();
+}
+
+/// Given an unbound reference to an SSA value and its type, return the value
+/// it specifies. This returns null on failure.
+Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
+ auto &entries = getSSAValueEntry(useInfo.name);
+
+ // If we have already seen a value of this name, return it.
+ if (useInfo.number < entries.size() && entries[useInfo.number].first) {
+ auto result = entries[useInfo.number].first;
+ // Check that the type matches the other uses.
+ if (result->getType() == type)
+ return result;
+
+ emitError(useInfo.loc, "use of value '")
+ .append(useInfo.name,
+ "' expects different type than prior uses: ", type, " vs ",
+ result->getType())
+ .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
+ .append("prior use here");
+ return nullptr;
+ }
+
+ // Make sure we have enough slots for this.
+ if (entries.size() <= useInfo.number)
+ entries.resize(useInfo.number + 1);
+
+ // If the value has already been defined and this is an overly large result
+ // number, diagnose that.
+ if (entries[0].first && !isForwardRefPlaceholder(entries[0].first))
+ return (emitError(useInfo.loc, "reference to invalid result number"),
+ nullptr);
+
+ // Otherwise, this is a forward reference. Create a placeholder and remember
+ // that we did so.
+ auto result = createForwardRefPlaceholder(useInfo.loc, type);
+ entries[useInfo.number].first = result;
+ entries[useInfo.number].second = useInfo.loc;
+ return result;
+}
+
+/// Parse an SSA use with an associated type.
+///
+/// ssa-use-and-type ::= ssa-use `:` type
+ParseResult OperationParser::parseSSADefOrUseAndType(
+ const std::function<ParseResult(SSAUseInfo, Type)> &action) {
+ SSAUseInfo useInfo;
+ if (parseSSAUse(useInfo) ||
+ parseToken(Token::colon, "expected ':' and type for SSA operand"))
+ return failure();
+
+ auto type = parseType();
+ if (!type)
+ return failure();
+
+ return action(useInfo, type);
+}
+
+/// Parse a (possibly empty) list of SSA operands, followed by a colon, then
+/// followed by a type list.
+///
+/// ssa-use-and-type-list
+/// ::= ssa-use-list ':' type-list-no-parens
+///
+ParseResult OperationParser::parseOptionalSSAUseAndTypeList(
+ SmallVectorImpl<Value> &results) {
+ SmallVector<SSAUseInfo, 4> valueIDs;
+ if (parseOptionalSSAUseList(valueIDs))
+ return failure();
+
+ // If there were no operands, then there is no colon or type lists.
+ if (valueIDs.empty())
+ return success();
+
+ SmallVector<Type, 4> types;
+ if (parseToken(Token::colon, "expected ':' in operand list") ||
+ parseTypeListNoParens(types))
+ return failure();
+
+ if (valueIDs.size() != types.size())
+ return emitError("expected ")
+ << valueIDs.size() << " types to match operand list";
+
+ results.reserve(valueIDs.size());
+ for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
+ if (auto value = resolveSSAUse(valueIDs[i], types[i]))
+ results.push_back(value);
+ else
+ return failure();
+ }
+
+ return success();
+}
+
+/// Record that a definition was added at the current scope.
+void OperationParser::recordDefinition(StringRef def) {
+ isolatedNameScopes.back().recordDefinition(def);
+}
+
+/// Get the value entry for the given SSA name.
+SmallVectorImpl<std::pair<Value, SMLoc>> &
+OperationParser::getSSAValueEntry(StringRef name) {
+ return isolatedNameScopes.back().values[name];
+}
+
+/// Create and remember a new placeholder for a forward reference.
+Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
+ // Forward references are always created as operations, because we just need
+ // something with a def/use chain.
+ //
+ // We create these placeholders as having an empty name, which we know
+ // cannot be created through normal user input, allowing us to distinguish
+ // them.
+ auto name = OperationName("placeholder", getContext());
+ auto *op = Operation::create(
+ getEncodedSourceLocation(loc), name, type, /*operands=*/{},
+ /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
+ /*resizableOperandList=*/false);
+ forwardRefPlaceholders[op->getResult(0)] = loc;
+ return op->getResult(0);
+}
+
+//===----------------------------------------------------------------------===//
+// Operation Parsing
+//===----------------------------------------------------------------------===//
+
+/// Parse an operation.
+///
+/// operation ::= op-result-list?
+/// (generic-operation | custom-operation)
+/// trailing-location?
+/// generic-operation ::= string-literal '(' ssa-use-list? ')' attribute-dict?
+/// `:` function-type
+/// custom-operation ::= bare-id custom-operation-format
+/// op-result-list ::= op-result (`,` op-result)* `=`
+/// op-result ::= ssa-id (`:` integer-literal)
+///
+ParseResult OperationParser::parseOperation() {
+ auto loc = getToken().getLoc();
+ SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs;
+ size_t numExpectedResults = 0;
+ if (getToken().is(Token::percent_identifier)) {
+ // Parse the group of result ids.
+ auto parseNextResult = [&]() -> ParseResult {
+ // Parse the next result id.
+ if (!getToken().is(Token::percent_identifier))
+ return emitError("expected valid ssa identifier");
+
+ Token nameTok = getToken();
+ consumeToken(Token::percent_identifier);
+
+ // If the next token is a ':', we parse the expected result count.
+ size_t expectedSubResults = 1;
+ if (consumeIf(Token::colon)) {
+ // Check that the next token is an integer.
+ if (!getToken().is(Token::integer))
+ return emitError("expected integer number of results");
+
+ // Check that number of results is > 0.
+ auto val = getToken().getUInt64IntegerValue();
+ if (!val.hasValue() || val.getValue() < 1)
+ return emitError("expected named operation to have atleast 1 result");
+ consumeToken(Token::integer);
+ expectedSubResults = *val;
+ }
+
+ resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults,
+ nameTok.getLoc());
+ numExpectedResults += expectedSubResults;
+ return success();
+ };
+ if (parseCommaSeparatedList(parseNextResult))
+ return failure();
+
+ if (parseToken(Token::equal, "expected '=' after SSA name"))
+ return failure();
+ }
+
+ Operation *op;
+ if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
+ op = parseCustomOperation();
+ else if (getToken().is(Token::string))
+ op = parseGenericOperation();
+ else
+ return emitError("expected operation name in quotes");
+
+ // If parsing of the basic operation failed, then this whole thing fails.
+ if (!op)
+ return failure();
+
+ // If the operation had a name, register it.
+ if (!resultIDs.empty()) {
+ if (op->getNumResults() == 0)
+ return emitError(loc, "cannot name an operation with no results");
+ if (numExpectedResults != op->getNumResults())
+ return emitError(loc, "operation defines ")
+ << op->getNumResults() << " results but was provided "
+ << numExpectedResults << " to bind";
+
+ // Add definitions for each of the result groups.
+ unsigned opResI = 0;
+ for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) {
+ for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+ if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)},
+ op->getResult(opResI++)))
+ return failure();
+ }
+ }
+ }
+
+ return success();
+}
+
+/// Parse a single operation successor and its operand list.
+///
+/// successor ::= block-id branch-use-list?
+/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
+///
+ParseResult
+OperationParser::parseSuccessorAndUseList(Block *&dest,
+ SmallVectorImpl<Value> &operands) {
+ // Verify branch is identifier and get the matching block.
+ if (!getToken().is(Token::caret_identifier))
+ return emitError("expected block name");
+ dest = getBlockNamed(getTokenSpelling(), getToken().getLoc());
+ consumeToken();
+
+ // Handle optional arguments.
+ if (consumeIf(Token::l_paren) &&
+ (parseOptionalSSAUseAndTypeList(operands) ||
+ parseToken(Token::r_paren, "expected ')' to close argument list"))) {
+ return failure();
+ }
+
+ return success();
+}
+
+/// Parse a comma-separated list of operation successors in brackets.
+///
+/// successor-list ::= `[` successor (`,` successor )* `]`
+///
+ParseResult OperationParser::parseSuccessors(
+ SmallVectorImpl<Block *> &destinations,
+ SmallVectorImpl<SmallVector<Value, 4>> &operands) {
+ if (parseToken(Token::l_square, "expected '['"))
+ return failure();
+
+ auto parseElt = [this, &destinations, &operands]() {
+ Block *dest;
+ SmallVector<Value, 4> destOperands;
+ auto res = parseSuccessorAndUseList(dest, destOperands);
+ destinations.push_back(dest);
+ operands.push_back(destOperands);
+ return res;
+ };
+ return parseCommaSeparatedListUntil(Token::r_square, parseElt,
+ /*allowEmptyList=*/false);
+}
+
+namespace {
+// RAII-style guard for cleaning up the regions in the operation state before
+// deleting them. Within the parser, regions may get deleted if parsing failed,
+// and other errors may be present, in particular undominated uses. This makes
+// sure such uses are deleted.
+struct CleanupOpStateRegions {
+ ~CleanupOpStateRegions() {
+ SmallVector<Region *, 4> regionsToClean;
+ regionsToClean.reserve(state.regions.size());
+ for (auto &region : state.regions)
+ if (region)
+ for (auto &block : *region)
+ block.dropAllDefinedValueUses();
+ }
+ OperationState &state;
+};
+} // namespace
+
+Operation *OperationParser::parseGenericOperation() {
+ // Get location information for the operation.
+ auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
+
+ auto name = getToken().getStringValue();
+ if (name.empty())
+ return (emitError("empty operation name is invalid"), nullptr);
+ if (name.find('\0') != StringRef::npos)
+ return (emitError("null character not allowed in operation name"), nullptr);
+
+ consumeToken(Token::string);
+
+ OperationState result(srcLocation, name);
+
+ // Generic operations have a resizable operation list.
+ result.setOperandListToResizable();
+
+ // Parse the operand list.
+ SmallVector<SSAUseInfo, 8> operandInfos;
+
+ if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
+ parseOptionalSSAUseList(operandInfos) ||
+ parseToken(Token::r_paren, "expected ')' to end operand list")) {
+ return nullptr;
+ }
+
+ // Parse the successor list but don't add successors to the result yet to
+ // avoid messing up with the argument order.
+ SmallVector<Block *, 2> successors;
+ SmallVector<SmallVector<Value, 4>, 2> successorOperands;
+ if (getToken().is(Token::l_square)) {
+ // Check if the operation is a known terminator.
+ const AbstractOperation *abstractOp = result.name.getAbstractOperation();
+ if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator))
+ return emitError("successors in non-terminator"), nullptr;
+ if (parseSuccessors(successors, successorOperands))
+ return nullptr;
+ }
+
+ // Parse the region list.
+ CleanupOpStateRegions guard{result};
+ if (consumeIf(Token::l_paren)) {
+ do {
+ // Create temporary regions with the top level region as parent.
+ result.regions.emplace_back(new Region(moduleOp));
+ if (parseRegion(*result.regions.back(), /*entryArguments=*/{}))
+ return nullptr;
+ } while (consumeIf(Token::comma));
+ if (parseToken(Token::r_paren, "expected ')' to end region list"))
+ return nullptr;
+ }
+
+ if (getToken().is(Token::l_brace)) {
+ if (parseAttributeDict(result.attributes))
+ return nullptr;
+ }
+
+ if (parseToken(Token::colon, "expected ':' followed by operation type"))
+ return nullptr;
+
+ auto typeLoc = getToken().getLoc();
+ auto type = parseType();
+ if (!type)
+ return nullptr;
+ auto fnType = type.dyn_cast<FunctionType>();
+ if (!fnType)
+ return (emitError(typeLoc, "expected function type"), nullptr);
+
+ result.addTypes(fnType.getResults());
+
+ // Check that we have the right number of types for the operands.
+ auto operandTypes = fnType.getInputs();
+ if (operandTypes.size() != operandInfos.size()) {
+ auto plural = "s"[operandInfos.size() == 1];
+ return (emitError(typeLoc, "expected ")
+ << operandInfos.size() << " operand type" << plural
+ << " but had " << operandTypes.size(),
+ nullptr);
+ }
+
+ // Resolve all of the operands.
+ for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
+ result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+ if (!result.operands.back())
+ return nullptr;
+ }
+
+ // Add the successors, and their operands after the proper operands.
+ for (const auto &succ : llvm::zip(successors, successorOperands)) {
+ Block *successor = std::get<0>(succ);
+ const SmallVector<Value, 4> &operands = std::get<1>(succ);
+ result.addSuccessor(successor, operands);
+ }
+
+ // Parse a location if one is present.
+ if (parseOptionalTrailingLocation(result.location))
+ return nullptr;
+
+ return opBuilder.createOperation(result);
+}
+
+Operation *OperationParser::parseGenericOperation(Block *insertBlock,
+ Block::iterator insertPt) {
+ OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder);
+ opBuilder.setInsertionPoint(insertBlock, insertPt);
+ return parseGenericOperation();
+}
+
+namespace {
+class CustomOpAsmParser : public OpAsmParser {
+public:
+ CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition,
+ OperationParser &parser)
+ : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {}
+
+ /// Parse an instance of the operation described by 'opDefinition' into the
+ /// provided operation state.
+ ParseResult parseOperation(OperationState &opState) {
+ if (opDefinition->parseAssembly(*this, opState))
+ return failure();
+ return success();
+ }
+
+ Operation *parseGenericOperation(Block *insertBlock,
+ Block::iterator insertPt) final {
+ return parser.parseGenericOperation(insertBlock, insertPt);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Utilities
+ //===--------------------------------------------------------------------===//
+
+ /// Return if any errors were emitted during parsing.
+ bool didEmitError() const { return emittedError; }
+
+ /// Emit a diagnostic at the specified location and return failure.
+ InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
+ emittedError = true;
+ return parser.emitError(loc, "custom op '" + opDefinition->name + "' " +
+ message);
+ }
+
+ llvm::SMLoc getCurrentLocation() override {
+ return parser.getToken().getLoc();
+ }
+
+ Builder &getBuilder() const override { return parser.builder; }
+
+ llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+ //===--------------------------------------------------------------------===//
+ // Token Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a `->` token.
+ ParseResult parseArrow() override {
+ return parser.parseToken(Token::arrow, "expected '->'");
+ }
+
+ /// Parses a `->` if present.
+ ParseResult parseOptionalArrow() override {
+ return success(parser.consumeIf(Token::arrow));
+ }
+
+ /// Parse a `:` token.
+ ParseResult parseColon() override {
+ return parser.parseToken(Token::colon, "expected ':'");
+ }
+
+ /// Parse a `:` token if present.
+ ParseResult parseOptionalColon() override {
+ return success(parser.consumeIf(Token::colon));
+ }
+
+ /// Parse a `,` token.
+ ParseResult parseComma() override {
+ return parser.parseToken(Token::comma, "expected ','");
+ }
+
+ /// Parse a `,` token if present.
+ ParseResult parseOptionalComma() override {
+ return success(parser.consumeIf(Token::comma));
+ }
+
+ /// Parses a `...` if present.
+ ParseResult parseOptionalEllipsis() override {
+ return success(parser.consumeIf(Token::ellipsis));
+ }
+
+ /// Parse a `=` token.
+ ParseResult parseEqual() override {
+ return parser.parseToken(Token::equal, "expected '='");
+ }
+
+ /// Parse a '<' token.
+ ParseResult parseLess() override {
+ return parser.parseToken(Token::less, "expected '<'");
+ }
+
+ /// Parse a '>' token.
+ ParseResult parseGreater() override {
+ return parser.parseToken(Token::greater, "expected '>'");
+ }
+
+ /// Parse a `(` token.
+ ParseResult parseLParen() override {
+ return parser.parseToken(Token::l_paren, "expected '('");
+ }
+
+ /// Parses a '(' if present.
+ ParseResult parseOptionalLParen() override {
+ return success(parser.consumeIf(Token::l_paren));
+ }
+
+ /// Parse a `)` token.
+ ParseResult parseRParen() override {
+ return parser.parseToken(Token::r_paren, "expected ')'");
+ }
+
+ /// Parses a ')' if present.
+ ParseResult parseOptionalRParen() override {
+ return success(parser.consumeIf(Token::r_paren));
+ }
+
+ /// Parse a `[` token.
+ ParseResult parseLSquare() override {
+ return parser.parseToken(Token::l_square, "expected '['");
+ }
+
+ /// Parses a '[' if present.
+ ParseResult parseOptionalLSquare() override {
+ return success(parser.consumeIf(Token::l_square));
+ }
+
+ /// Parse a `]` token.
+ ParseResult parseRSquare() override {
+ return parser.parseToken(Token::r_square, "expected ']'");
+ }
+
+ /// Parses a ']' if present.
+ ParseResult parseOptionalRSquare() override {
+ return success(parser.consumeIf(Token::r_square));
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Attribute Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an arbitrary attribute of a given type and return it in result. This
+ /// also adds the attribute to the specified attribute list with the specified
+ /// name.
+ ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName,
+ SmallVectorImpl<NamedAttribute> &attrs) override {
+ result = parser.parseAttribute(type);
+ if (!result)
+ return failure();
+
+ attrs.push_back(parser.builder.getNamedAttr(attrName, result));
+ return success();
+ }
+
+ /// Parse a named dictionary into 'result' if it is present.
+ ParseResult
+ parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override {
+ if (parser.getToken().isNot(Token::l_brace))
+ return success();
+ return parser.parseAttributeDict(result);
+ }
+
+ /// Parse a named dictionary into 'result' if the `attributes` keyword is
+ /// present.
+ ParseResult parseOptionalAttrDictWithKeyword(
+ SmallVectorImpl<NamedAttribute> &result) override {
+ if (failed(parseOptionalKeyword("attributes")))
+ return success();
+ return parser.parseAttributeDict(result);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Identifier Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Returns if the current token corresponds to a keyword.
+ bool isCurrentTokenAKeyword() const {
+ return parser.getToken().is(Token::bare_identifier) ||
+ parser.getToken().isKeyword();
+ }
+
+ /// Parse the given keyword if present.
+ ParseResult parseOptionalKeyword(StringRef keyword) override {
+ // Check that the current token has the same spelling.
+ if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
+ return failure();
+ parser.consumeToken();
+ return success();
+ }
+
+ /// Parse a keyword, if present, into 'keyword'.
+ ParseResult parseOptionalKeyword(StringRef *keyword) override {
+ // Check that the current token is a keyword.
+ if (!isCurrentTokenAKeyword())
+ return failure();
+
+ *keyword = parser.getTokenSpelling();
+ parser.consumeToken();
+ return success();
+ }
+
+ /// Parse an optional @-identifier and store it (without the '@' symbol) in a
+ /// string attribute named 'attrName'.
+ ParseResult
+ parseOptionalSymbolName(StringAttr &result, StringRef attrName,
+ SmallVectorImpl<NamedAttribute> &attrs) override {
+ Token atToken = parser.getToken();
+ if (atToken.isNot(Token::at_identifier))
+ return failure();
+
+ result = getBuilder().getStringAttr(extractSymbolReference(atToken));
+ attrs.push_back(getBuilder().getNamedAttr(attrName, result));
+ parser.consumeToken();
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Operand Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a single operand.
+ ParseResult parseOperand(OperandType &result) override {
+ OperationParser::SSAUseInfo useInfo;
+ if (parser.parseSSAUse(useInfo))
+ return failure();
+
+ result = {useInfo.loc, useInfo.name, useInfo.number};
+ return success();
+ }
+
+ /// Parse zero or more SSA comma-separated operand references with a specified
+ /// surrounding delimiter, and an optional required operand count.
+ ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount = -1,
+ Delimiter delimiter = Delimiter::None) override {
+ return parseOperandOrRegionArgList(result, /*isOperandList=*/true,
+ requiredOperandCount, delimiter);
+ }
+
+ /// Parse zero or more SSA comma-separated operand or region arguments with
+ /// optional surrounding delimiter and required operand count.
+ ParseResult
+ parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
+ bool isOperandList, int requiredOperandCount = -1,
+ Delimiter delimiter = Delimiter::None) {
+ auto startLoc = parser.getToken().getLoc();
+
+ // Handle delimiters.
+ switch (delimiter) {
+ case Delimiter::None:
+ // Don't check for the absence of a delimiter if the number of operands
+ // is unknown (and hence the operand list could be empty).
+ if (requiredOperandCount == -1)
+ break;
+ // Token already matches an identifier and so can't be a delimiter.
+ if (parser.getToken().is(Token::percent_identifier))
+ break;
+ // Test against known delimiters.
+ if (parser.getToken().is(Token::l_paren) ||
+ parser.getToken().is(Token::l_square))
+ return emitError(startLoc, "unexpected delimiter");
+ return emitError(startLoc, "invalid operand");
+ case Delimiter::OptionalParen:
+ if (parser.getToken().isNot(Token::l_paren))
+ return success();
+ LLVM_FALLTHROUGH;
+ case Delimiter::Paren:
+ if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
+ return failure();
+ break;
+ case Delimiter::OptionalSquare:
+ if (parser.getToken().isNot(Token::l_square))
+ return success();
+ LLVM_FALLTHROUGH;
+ case Delimiter::Square:
+ if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
+ return failure();
+ break;
+ }
+
+ // Check for zero operands.
+ if (parser.getToken().is(Token::percent_identifier)) {
+ do {
+ OperandType operandOrArg;
+ if (isOperandList ? parseOperand(operandOrArg)
+ : parseRegionArgument(operandOrArg))
+ return failure();
+ result.push_back(operandOrArg);
+ } while (parser.consumeIf(Token::comma));
+ }
+
+ // Handle delimiters. If we reach here, the optional delimiters were
+ // present, so we need to parse their closing one.
+ switch (delimiter) {
+ case Delimiter::None:
+ break;
+ case Delimiter::OptionalParen:
+ case Delimiter::Paren:
+ if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
+ return failure();
+ break;
+ case Delimiter::OptionalSquare:
+ case Delimiter::Square:
+ if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
+ return failure();
+ break;
+ }
+
+ if (requiredOperandCount != -1 &&
+ result.size() != static_cast<size_t>(requiredOperandCount))
+ return emitError(startLoc, "expected ")
+ << requiredOperandCount << " operands";
+ return success();
+ }
+
+ /// Parse zero or more trailing SSA comma-separated trailing operand
+ /// references with a specified surrounding delimiter, and an optional
+ /// required operand count. A leading comma is expected before the operands.
+ ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount,
+ Delimiter delimiter) override {
+ if (parser.getToken().is(Token::comma)) {
+ parseComma();
+ return parseOperandList(result, requiredOperandCount, delimiter);
+ }
+ if (requiredOperandCount != -1)
+ return emitError(parser.getToken().getLoc(), "expected ")
+ << requiredOperandCount << " operands";
+ return success();
+ }
+
+ /// Resolve an operand to an SSA value, emitting an error on failure.
+ ParseResult resolveOperand(const OperandType &operand, Type type,
+ SmallVectorImpl<Value> &result) override {
+ OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+ operand.location};
+ if (auto value = parser.resolveSSAUse(operandInfo, type)) {
+ result.push_back(value);
+ return success();
+ }
+ return failure();
+ }
+
+ /// Parse an AffineMap of SSA ids.
+ ParseResult
+ parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands,
+ Attribute &mapAttr, StringRef attrName,
+ SmallVectorImpl<NamedAttribute> &attrs) override {
+ SmallVector<OperandType, 2> dimOperands;
+ SmallVector<OperandType, 1> symOperands;
+
+ auto parseElement = [&](bool isSymbol) -> ParseResult {
+ OperandType operand;
+ if (parseOperand(operand))
+ return failure();
+ if (isSymbol)
+ symOperands.push_back(operand);
+ else
+ dimOperands.push_back(operand);
+ return success();
+ };
+
+ AffineMap map;
+ if (parser.parseAffineMapOfSSAIds(map, parseElement))
+ return failure();
+ // Add AffineMap attribute.
+ if (map) {
+ mapAttr = AffineMapAttr::get(map);
+ attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr));
+ }
+
+ // Add dim operands before symbol operands in 'operands'.
+ operands.assign(dimOperands.begin(), dimOperands.end());
+ operands.append(symOperands.begin(), symOperands.end());
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Region Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a region that takes `arguments` of `argTypes` types. This
+ /// effectively defines the SSA values of `arguments` and assigns their type.
+ ParseResult parseRegion(Region &region, ArrayRef<OperandType> arguments,
+ ArrayRef<Type> argTypes,
+ bool enableNameShadowing) override {
+ assert(arguments.size() == argTypes.size() &&
+ "mismatching number of arguments and types");
+
+ SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2>
+ regionArguments;
+ for (const auto &pair : llvm::zip(arguments, argTypes)) {
+ const OperandType &operand = std::get<0>(pair);
+ Type type = std::get<1>(pair);
+ OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+ operand.location};
+ regionArguments.emplace_back(operandInfo, type);
+ }
+
+ // Try to parse the region.
+ assert((!enableNameShadowing ||
+ opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) &&
+ "name shadowing is only allowed on isolated regions");
+ if (parser.parseRegion(region, regionArguments, enableNameShadowing))
+ return failure();
+ return success();
+ }
+
+ /// Parses a region if present.
+ ParseResult parseOptionalRegion(Region &region,
+ ArrayRef<OperandType> arguments,
+ ArrayRef<Type> argTypes,
+ bool enableNameShadowing) override {
+ if (parser.getToken().isNot(Token::l_brace))
+ return success();
+ return parseRegion(region, arguments, argTypes, enableNameShadowing);
+ }
+
+ /// Parse a region argument. The type of the argument will be resolved later
+ /// by a call to `parseRegion`.
+ ParseResult parseRegionArgument(OperandType &argument) override {
+ return parseOperand(argument);
+ }
+
+ /// Parse a region argument if present.
+ ParseResult parseOptionalRegionArgument(OperandType &argument) override {
+ if (parser.getToken().isNot(Token::percent_identifier))
+ return success();
+ return parseRegionArgument(argument);
+ }
+
+ ParseResult
+ parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount = -1,
+ Delimiter delimiter = Delimiter::None) override {
+ return parseOperandOrRegionArgList(result, /*isOperandList=*/false,
+ requiredOperandCount, delimiter);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Successor Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a single operation successor and its operand list.
+ ParseResult
+ parseSuccessorAndUseList(Block *&dest,
+ SmallVectorImpl<Value> &operands) override {
+ return parser.parseSuccessorAndUseList(dest, operands);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Type Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a type.
+ ParseResult parseType(Type &result) override {
+ return failure(!(result = parser.parseType()));
+ }
+
+ /// Parse an optional arrow followed by a type list.
+ ParseResult
+ parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
+ if (!parser.consumeIf(Token::arrow))
+ return success();
+ return parser.parseFunctionResultTypes(result);
+ }
+
+ /// Parse a colon followed by a type.
+ ParseResult parseColonType(Type &result) override {
+ return failure(parser.parseToken(Token::colon, "expected ':'") ||
+ !(result = parser.parseType()));
+ }
+
+ /// Parse a colon followed by a type list, which must have at least one type.
+ ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
+ if (parser.parseToken(Token::colon, "expected ':'"))
+ return failure();
+ return parser.parseTypeListNoParens(result);
+ }
+
+ /// Parse an optional colon followed by a type list, which if present must
+ /// have at least one type.
+ ParseResult
+ parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
+ if (!parser.consumeIf(Token::colon))
+ return success();
+ return parser.parseTypeListNoParens(result);
+ }
+
+private:
+ /// The source location of the operation name.
+ SMLoc nameLoc;
+
+ /// The abstract information of the operation.
+ const AbstractOperation *opDefinition;
+
+ /// The main operation parser.
+ OperationParser &parser;
+
+ /// A flag that indicates if any errors were emitted during parsing.
+ bool emittedError = false;
+};
+} // end anonymous namespace.
+
+Operation *OperationParser::parseCustomOperation() {
+ auto opLoc = getToken().getLoc();
+ auto opName = getTokenSpelling();
+
+ auto *opDefinition = AbstractOperation::lookup(opName, getContext());
+ if (!opDefinition && !opName.contains('.')) {
+ // If the operation name has no namespace prefix we treat it as a standard
+ // operation and prefix it with "std".
+ // TODO: Would it be better to just build a mapping of the registered
+ // operations in the standard dialect?
+ opDefinition =
+ AbstractOperation::lookup(Twine("std." + opName).str(), getContext());
+ }
+
+ if (!opDefinition) {
+ emitError(opLoc) << "custom op '" << opName << "' is unknown";
+ return nullptr;
+ }
+
+ consumeToken();
+
+ // If the custom op parser crashes, produce some indication to help
+ // debugging.
+ std::string opNameStr = opName.str();
+ llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
+ opNameStr.c_str());
+
+ // Get location information for the operation.
+ auto srcLocation = getEncodedSourceLocation(opLoc);
+
+ // Have the op implementation take a crack and parsing this.
+ OperationState opState(srcLocation, opDefinition->name);
+ CleanupOpStateRegions guard{opState};
+ CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this);
+ if (opAsmParser.parseOperation(opState))
+ return nullptr;
+
+ // If it emitted an error, we failed.
+ if (opAsmParser.didEmitError())
+ return nullptr;
+
+ // Parse a location if one is present.
+ if (parseOptionalTrailingLocation(opState.location))
+ return nullptr;
+
+ // Otherwise, we succeeded. Use the state it parsed as our op information.
+ return opBuilder.createOperation(opState);
+}
+
+//===----------------------------------------------------------------------===//
+// Region Parsing
+//===----------------------------------------------------------------------===//
+
+/// Region.
+///
+/// region ::= '{' region-body
+///
+ParseResult OperationParser::parseRegion(
+ Region &region,
+ ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
+ bool isIsolatedNameScope) {
+ // Parse the '{'.
+ if (parseToken(Token::l_brace, "expected '{' to begin a region"))
+ return failure();
+
+ // Check for an empty region.
+ if (entryArguments.empty() && consumeIf(Token::r_brace))
+ return success();
+ auto currentPt = opBuilder.saveInsertionPoint();
+
+ // Push a new named value scope.
+ pushSSANameScope(isIsolatedNameScope);
+
+ // Parse the first block directly to allow for it to be unnamed.
+ Block *block = new Block();
+
+ // Add arguments to the entry block.
+ if (!entryArguments.empty()) {
+ for (auto &placeholderArgPair : entryArguments) {
+ auto &argInfo = placeholderArgPair.first;
+ // Ensure that the argument was not already defined.
+ if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) {
+ return emitError(argInfo.loc, "region entry argument '" + argInfo.name +
+ "' is already in use")
+ .attachNote(getEncodedSourceLocation(*defLoc))
+ << "previously referenced here";
+ }
+ if (addDefinition(placeholderArgPair.first,
+ block->addArgument(placeholderArgPair.second))) {
+ delete block;
+ return failure();
+ }
+ }
+
+ // If we had named arguments, then don't allow a block name.
+ if (getToken().is(Token::caret_identifier))
+ return emitError("invalid block name in region with named arguments");
+ }
+
+ if (parseBlock(block)) {
+ delete block;
+ return failure();
+ }
+
+ // Verify that no other arguments were parsed.
+ if (!entryArguments.empty() &&
+ block->getNumArguments() > entryArguments.size()) {
+ delete block;
+ return emitError("entry block arguments were already defined");
+ }
+
+ // Parse the rest of the region.
+ region.push_back(block);
+ if (parseRegionBody(region))
+ return failure();
+
+ // Pop the SSA value scope for this region.
+ if (popSSANameScope())
+ return failure();
+
+ // Reset the original insertion point.
+ opBuilder.restoreInsertionPoint(currentPt);
+ return success();
+}
+
+/// Region.
+///
+/// region-body ::= block* '}'
+///
+ParseResult OperationParser::parseRegionBody(Region &region) {
+ // Parse the list of blocks.
+ while (!consumeIf(Token::r_brace)) {
+ Block *newBlock = nullptr;
+ if (parseBlock(newBlock))
+ return failure();
+ region.push_back(newBlock);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Block Parsing
+//===----------------------------------------------------------------------===//
+
+/// Block declaration.
+///
+/// block ::= block-label? operation*
+/// block-label ::= block-id block-arg-list? `:`
+/// block-id ::= caret-id
+/// block-arg-list ::= `(` ssa-id-and-type-list? `)`
+///
+ParseResult OperationParser::parseBlock(Block *&block) {
+ // The first block of a region may already exist, if it does the caret
+ // identifier is optional.
+ if (block && getToken().isNot(Token::caret_identifier))
+ return parseBlockBody(block);
+
+ SMLoc nameLoc = getToken().getLoc();
+ auto name = getTokenSpelling();
+ if (parseToken(Token::caret_identifier, "expected block name"))
+ return failure();
+
+ block = defineBlockNamed(name, nameLoc, block);
+
+ // Fail if the block was already defined.
+ if (!block)
+ return emitError(nameLoc, "redefinition of block '") << name << "'";
+
+ // If an argument list is present, parse it.
+ if (consumeIf(Token::l_paren)) {
+ SmallVector<BlockArgument, 8> bbArgs;
+ if (parseOptionalBlockArgList(bbArgs, block) ||
+ parseToken(Token::r_paren, "expected ')' to end argument list"))
+ return failure();
+ }
+
+ if (parseToken(Token::colon, "expected ':' after block name"))
+ return failure();
+
+ return parseBlockBody(block);
+}
+
+ParseResult OperationParser::parseBlockBody(Block *block) {
+ // Set the insertion point to the end of the block to parse.
+ opBuilder.setInsertionPointToEnd(block);
+
+ // Parse the list of operations that make up the body of the block.
+ while (getToken().isNot(Token::caret_identifier, Token::r_brace))
+ if (parseOperation())
+ return failure();
+
+ return success();
+}
+
+/// Get the block with the specified name, creating it if it doesn't already
+/// exist. The location specified is the point of use, which allows
+/// us to diagnose references to blocks that are not defined precisely.
+Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) {
+ auto &blockAndLoc = getBlockInfoByName(name);
+ if (!blockAndLoc.first) {
+ blockAndLoc = {new Block(), loc};
+ insertForwardRef(blockAndLoc.first, loc);
+ }
+
+ return blockAndLoc.first;
+}
+
+/// Define the block with the specified name. Returns the Block* or nullptr in
+/// the case of redefinition.
+Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc,
+ Block *existing) {
+ auto &blockAndLoc = getBlockInfoByName(name);
+ if (!blockAndLoc.first) {
+ // If the caller provided a block, use it. Otherwise create a new one.
+ if (!existing)
+ existing = new Block();
+ blockAndLoc.first = existing;
+ blockAndLoc.second = loc;
+ return blockAndLoc.first;
+ }
+
+ // Forward declarations are removed once defined, so if we are defining a
+ // existing block and it is not a forward declaration, then it is a
+ // redeclaration.
+ if (!eraseForwardRef(blockAndLoc.first))
+ return nullptr;
+ return blockAndLoc.first;
+}
+
+/// Parse a (possibly empty) list of SSA operands with types as block arguments.
+///
+/// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
+///
+ParseResult OperationParser::parseOptionalBlockArgList(
+ SmallVectorImpl<BlockArgument> &results, Block *owner) {
+ if (getToken().is(Token::r_brace))
+ return success();
+
+ // If the block already has arguments, then we're handling the entry block.
+ // Parse and register the names for the arguments, but do not add them.
+ bool definingExistingArgs = owner->getNumArguments() != 0;
+ unsigned nextArgument = 0;
+
+ return parseCommaSeparatedList([&]() -> ParseResult {
+ return parseSSADefOrUseAndType(
+ [&](SSAUseInfo useInfo, Type type) -> ParseResult {
+ // If this block did not have existing arguments, define a new one.
+ if (!definingExistingArgs)
+ return addDefinition(useInfo, owner->addArgument(type));
+
+ // Otherwise, ensure that this argument has already been created.
+ if (nextArgument >= owner->getNumArguments())
+ return emitError("too many arguments specified in argument list");
+
+ // Finally, make sure the existing argument has the correct type.
+ auto arg = owner->getArgument(nextArgument++);
+ if (arg->getType() != type)
+ return emitError("argument and block argument type mismatch");
+ return addDefinition(useInfo, arg);
+ });
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// Top-level entity parsing.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This parser handles entities that are only valid at the top level of the
+/// file.
+class ModuleParser : public Parser {
+public:
+ explicit ModuleParser(ParserState &state) : Parser(state) {}
+
+ ParseResult parseModule(ModuleOp module);
+
+private:
+ /// Parse an attribute alias declaration.
+ ParseResult parseAttributeAliasDef();
+
+ /// Parse an attribute alias declaration.
+ ParseResult parseTypeAliasDef();
+};
+} // end anonymous namespace
+
+/// Parses an attribute alias declaration.
+///
+/// attribute-alias-def ::= '#' alias-name `=` attribute-value
+///
+ParseResult ModuleParser::parseAttributeAliasDef() {
+ assert(getToken().is(Token::hash_identifier));
+ StringRef aliasName = getTokenSpelling().drop_front();
+
+ // Check for redefinitions.
+ if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0)
+ return emitError("redefinition of attribute alias id '" + aliasName + "'");
+
+ // Make sure this isn't invading the dialect attribute namespace.
+ if (aliasName.contains('.'))
+ return emitError("attribute names with a '.' are reserved for "
+ "dialect-defined names");
+
+ consumeToken(Token::hash_identifier);
+
+ // Parse the '='.
+ if (parseToken(Token::equal, "expected '=' in attribute alias definition"))
+ return failure();
+
+ // Parse the attribute value.
+ Attribute attr = parseAttribute();
+ if (!attr)
+ return failure();
+
+ getState().symbols.attributeAliasDefinitions[aliasName] = attr;
+ return success();
+}
+
+/// Parse a type alias declaration.
+///
+/// type-alias-def ::= '!' alias-name `=` 'type' type
+///
+ParseResult ModuleParser::parseTypeAliasDef() {
+ assert(getToken().is(Token::exclamation_identifier));
+ StringRef aliasName = getTokenSpelling().drop_front();
+
+ // Check for redefinitions.
+ if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0)
+ return emitError("redefinition of type alias id '" + aliasName + "'");
+
+ // Make sure this isn't invading the dialect type namespace.
+ if (aliasName.contains('.'))
+ return emitError("type names with a '.' are reserved for "
+ "dialect-defined names");
+
+ consumeToken(Token::exclamation_identifier);
+
+ // Parse the '=' and 'type'.
+ if (parseToken(Token::equal, "expected '=' in type alias definition") ||
+ parseToken(Token::kw_type, "expected 'type' in type alias definition"))
+ return failure();
+
+ // Parse the type.
+ Type aliasedType = parseType();
+ if (!aliasedType)
+ return failure();
+
+ // Register this alias with the parser state.
+ getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
+ return success();
+}
+
+/// This is the top-level module parser.
+ParseResult ModuleParser::parseModule(ModuleOp module) {
+ OperationParser opParser(getState(), module);
+
+ // Module itself is a name scope.
+ opParser.pushSSANameScope(/*isIsolated=*/true);
+
+ while (true) {
+ switch (getToken().getKind()) {
+ default:
+ // Parse a top-level operation.
+ if (opParser.parseOperation())
+ return failure();
+ break;
+
+ // If we got to the end of the file, then we're done.
+ case Token::eof: {
+ if (opParser.finalize())
+ return failure();
+
+ // Handle the case where the top level module was explicitly defined.
+ auto &bodyBlocks = module.getBodyRegion().getBlocks();
+ auto &operations = bodyBlocks.front().getOperations();
+ assert(!operations.empty() && "expected a valid module terminator");
+
+ // Check that the first operation is a module, and it is the only
+ // non-terminator operation.
+ ModuleOp nested = dyn_cast<ModuleOp>(operations.front());
+ if (nested && std::next(operations.begin(), 2) == operations.end()) {
+ // Merge the data of the nested module operation into 'module'.
+ module.setLoc(nested.getLoc());
+ module.setAttrs(nested.getOperation()->getAttrList());
+ bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks());
+
+ // Erase the original module body.
+ bodyBlocks.pop_front();
+ }
+
+ return opParser.popSSANameScope();
+ }
+
+ // If we got an error token, then the lexer already emitted an error, just
+ // stop. Someday we could introduce error recovery if there was demand
+ // for it.
+ case Token::error:
+ return failure();
+
+ // Parse an attribute alias.
+ case Token::hash_identifier:
+ if (parseAttributeAliasDef())
+ return failure();
+ break;
+
+ // Parse a type alias.
+ case Token::exclamation_identifier:
+ if (parseTypeAliasDef())
+ return failure();
+ break;
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+
+/// This parses the file specified by the indicated SourceMgr and returns an
+/// MLIR module if it was valid. If not, it emits diagnostics and returns
+/// null.
+OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) {
+ auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
+
+ // This is the result module we are parsing into.
+ OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
+ sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
+
+ SymbolState aliasState;
+ ParserState state(sourceMgr, context, aliasState);
+ if (ModuleParser(state).parseModule(*module))
+ return nullptr;
+
+ // Make sure the parse module has no other structural problems detected by
+ // the verifier.
+ if (failed(verify(*module)))
+ return nullptr;
+
+ return module;
+}
+
+/// This parses the file specified by the indicated filename and returns an
+/// MLIR module if it was valid. If not, the error message is emitted through
+/// the error handler registered in the context, and a null pointer is returned.
+OwningModuleRef mlir::parseSourceFile(StringRef filename,
+ MLIRContext *context) {
+ llvm::SourceMgr sourceMgr;
+ return parseSourceFile(filename, sourceMgr, context);
+}
+
+/// This parses the file specified by the indicated filename using the provided
+/// SourceMgr and returns an MLIR module if it was valid. If not, the error
+/// message is emitted through the error handler registered in the context, and
+/// a null pointer is returned.
+OwningModuleRef mlir::parseSourceFile(StringRef filename,
+ llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) {
+ if (sourceMgr.getNumBuffers() != 0) {
+ // TODO(b/136086478): Extend to support multiple buffers.
+ emitError(mlir::UnknownLoc::get(context),
+ "only main buffer parsed at the moment");
+ return nullptr;
+ }
+ auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename);
+ if (std::error_code error = file_or_err.getError()) {
+ emitError(mlir::UnknownLoc::get(context),
+ "could not open input file " + filename);
+ return nullptr;
+ }
+
+ // Load the MLIR module.
+ sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
+ return parseSourceFile(sourceMgr, context);
+}
+
+/// This parses the program string to a MLIR module if it was valid. If not,
+/// it emits diagnostics and returns null.
+OwningModuleRef mlir::parseSourceString(StringRef moduleStr,
+ MLIRContext *context) {
+ auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr);
+ if (!memBuffer)
+ return nullptr;
+
+ SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+ return parseSourceFile(sourceMgr, context);
+}
+
+/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
+/// parsing failed, nullptr is returned. The number of bytes read from the input
+/// string is returned in 'numRead'.
+template <typename T, typename ParserFn>
+static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
+ ParserFn &&parserFn) {
+ SymbolState aliasState;
+ return parseSymbol<T>(
+ inputStr, context, aliasState,
+ [&](Parser &parser) {
+ SourceMgrDiagnosticHandler handler(
+ const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
+ parser.getContext());
+ return parserFn(parser);
+ },
+ &numRead);
+}
+
+Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
+ size_t numRead = 0;
+ return parseAttribute(attrStr, context, numRead);
+}
+Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
+ size_t numRead = 0;
+ return parseAttribute(attrStr, type, numRead);
+}
+
+Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
+ size_t &numRead) {
+ return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
+ return parser.parseAttribute();
+ });
+}
+Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
+ return parseSymbol<Attribute>(
+ attrStr, type.getContext(), numRead,
+ [type](Parser &parser) { return parser.parseAttribute(type); });
+}
+
+Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
+ size_t numRead = 0;
+ return parseType(typeStr, context, numRead);
+}
+
+Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
+ return parseSymbol<Type>(typeStr, context, numRead,
+ [](Parser &parser) { return parser.parseType(); });
+}
diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp
new file mode 100644
index 00000000000..84de4c396f4
--- /dev/null
+++ b/mlir/lib/Parser/Token.cpp
@@ -0,0 +1,155 @@
+//===- Token.cpp - MLIR Token Implementation ------------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Token class for the MLIR textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Token.h"
+#include "llvm/ADT/StringExtras.h"
+using namespace mlir;
+using llvm::SMLoc;
+using llvm::SMRange;
+
+SMLoc Token::getLoc() const { return SMLoc::getFromPointer(spelling.data()); }
+
+SMLoc Token::getEndLoc() const {
+ return SMLoc::getFromPointer(spelling.data() + spelling.size());
+}
+
+SMRange Token::getLocRange() const { return SMRange(getLoc(), getEndLoc()); }
+
+/// For an integer token, return its value as an unsigned. If it doesn't fit,
+/// return None.
+Optional<unsigned> Token::getUnsignedIntegerValue() const {
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+ unsigned result = 0;
+ if (spelling.getAsInteger(isHex ? 0 : 10, result))
+ return None;
+ return result;
+}
+
+/// For an integer token, return its value as a uint64_t. If it doesn't fit,
+/// return None.
+Optional<uint64_t> Token::getUInt64IntegerValue() const {
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+ uint64_t result = 0;
+ if (spelling.getAsInteger(isHex ? 0 : 10, result))
+ return None;
+ return result;
+}
+
+/// For a floatliteral, return its value as a double. Return None if the value
+/// underflows or overflows.
+Optional<double> Token::getFloatingPointValue() const {
+ double result = 0;
+ if (spelling.getAsDouble(result))
+ return None;
+ return result;
+}
+
+/// For an inttype token, return its bitwidth.
+Optional<unsigned> Token::getIntTypeBitwidth() const {
+ unsigned result = 0;
+ if (spelling[1] == '0' || spelling.drop_front().getAsInteger(10, result) ||
+ result == 0)
+ return None;
+ return result;
+}
+
+/// Given a token containing a string literal, return its value, including
+/// removing the quote characters and unescaping the contents of the string. The
+/// lexer has already verified that this token is valid.
+std::string Token::getStringValue() const {
+ assert(getKind() == string ||
+ (getKind() == at_identifier && getSpelling()[1] == '"'));
+ // Start by dropping the quotes.
+ StringRef bytes = getSpelling().drop_front().drop_back();
+ if (getKind() == at_identifier)
+ bytes = bytes.drop_front();
+
+ std::string result;
+ result.reserve(bytes.size());
+ for (unsigned i = 0, e = bytes.size(); i != e;) {
+ auto c = bytes[i++];
+ if (c != '\\') {
+ result.push_back(c);
+ continue;
+ }
+
+ assert(i + 1 <= e && "invalid string should be caught by lexer");
+ auto c1 = bytes[i++];
+ switch (c1) {
+ case '"':
+ case '\\':
+ result.push_back(c1);
+ continue;
+ case 'n':
+ result.push_back('\n');
+ continue;
+ case 't':
+ result.push_back('\t');
+ continue;
+ default:
+ break;
+ }
+
+ assert(i + 1 <= e && "invalid string should be caught by lexer");
+ auto c2 = bytes[i++];
+
+ assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
+ result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
+ }
+
+ return result;
+}
+
+/// Given a hash_identifier token like #123, try to parse the number out of
+/// the identifier, returning None if it is a named identifier like #x or
+/// if the integer doesn't fit.
+Optional<unsigned> Token::getHashIdentifierNumber() const {
+ assert(getKind() == hash_identifier);
+ unsigned result = 0;
+ if (spelling.drop_front().getAsInteger(10, result))
+ return None;
+ return result;
+}
+
+/// Given a punctuation or keyword token kind, return the spelling of the
+/// token as a string. Warning: This will abort on markers, identifiers and
+/// literal tokens since they have no fixed spelling.
+StringRef Token::getTokenSpelling(Kind kind) {
+ switch (kind) {
+ default:
+ llvm_unreachable("This token kind has no fixed spelling");
+#define TOK_PUNCTUATION(NAME, SPELLING) \
+ case NAME: \
+ return SPELLING;
+#define TOK_OPERATOR(NAME, SPELLING) \
+ case NAME: \
+ return SPELLING;
+#define TOK_KEYWORD(SPELLING) \
+ case kw_##SPELLING: \
+ return #SPELLING;
+#include "TokenKinds.def"
+ }
+}
+
+/// Return true if this is one of the keyword token kinds (e.g. kw_if).
+bool Token::isKeyword() const {
+ switch (kind) {
+ default:
+ return false;
+#define TOK_KEYWORD(SPELLING) \
+ case kw_##SPELLING: \
+ return true;
+#include "TokenKinds.def"
+ }
+}
diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h
new file mode 100644
index 00000000000..7487736fac7
--- /dev/null
+++ b/mlir/lib/Parser/Token.h
@@ -0,0 +1,107 @@
+//===- Token.h - MLIR Token Interface ---------------------------*- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LIB_PARSER_TOKEN_H
+#define MLIR_LIB_PARSER_TOKEN_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SMLoc.h"
+
+namespace mlir {
+
+/// This represents a token in the MLIR syntax.
+class Token {
+public:
+ enum Kind {
+#define TOK_MARKER(NAME) NAME,
+#define TOK_IDENTIFIER(NAME) NAME,
+#define TOK_LITERAL(NAME) NAME,
+#define TOK_PUNCTUATION(NAME, SPELLING) NAME,
+#define TOK_OPERATOR(NAME, SPELLING) NAME,
+#define TOK_KEYWORD(SPELLING) kw_##SPELLING,
+#include "TokenKinds.def"
+ };
+
+ Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
+
+ // Return the bytes that make up this token.
+ StringRef getSpelling() const { return spelling; }
+
+ // Token classification.
+ Kind getKind() const { return kind; }
+ bool is(Kind K) const { return kind == K; }
+
+ bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); }
+
+ /// Return true if this token is one of the specified kinds.
+ template <typename... T>
+ bool isAny(Kind k1, Kind k2, Kind k3, T... others) const {
+ if (is(k1))
+ return true;
+ return isAny(k2, k3, others...);
+ }
+
+ bool isNot(Kind k) const { return kind != k; }
+
+ /// Return true if this token isn't one of the specified kinds.
+ template <typename... T> bool isNot(Kind k1, Kind k2, T... others) const {
+ return !isAny(k1, k2, others...);
+ }
+
+ /// Return true if this is one of the keyword token kinds (e.g. kw_if).
+ bool isKeyword() const;
+
+ // Helpers to decode specific sorts of tokens.
+
+ /// For an integer token, return its value as an unsigned. If it doesn't fit,
+ /// return None.
+ Optional<unsigned> getUnsignedIntegerValue() const;
+
+ /// For an integer token, return its value as an uint64_t. If it doesn't fit,
+ /// return None.
+ Optional<uint64_t> getUInt64IntegerValue() const;
+
+ /// For a floatliteral token, return its value as a double. Returns None in
+ /// the case of underflow or overflow.
+ Optional<double> getFloatingPointValue() const;
+
+ /// For an inttype token, return its bitwidth.
+ Optional<unsigned> getIntTypeBitwidth() const;
+
+ /// Given a hash_identifier token like #123, try to parse the number out of
+ /// the identifier, returning None if it is a named identifier like #x or
+ /// if the integer doesn't fit.
+ Optional<unsigned> getHashIdentifierNumber() const;
+
+ /// Given a token containing a string literal, return its value, including
+ /// removing the quote characters and unescaping the contents of the string.
+ std::string getStringValue() const;
+
+ // Location processing.
+ llvm::SMLoc getLoc() const;
+ llvm::SMLoc getEndLoc() const;
+ llvm::SMRange getLocRange() const;
+
+ /// Given a punctuation or keyword token kind, return the spelling of the
+ /// token as a string. Warning: This will abort on markers, identifiers and
+ /// literal tokens since they have no fixed spelling.
+ static StringRef getTokenSpelling(Kind kind);
+
+private:
+ /// Discriminator that indicates the sort of token this is.
+ Kind kind;
+
+ /// A reference to the entire token contents; this is always a pointer into
+ /// a memory buffer owned by the source manager.
+ StringRef spelling;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_LIB_PARSER_TOKEN_H
diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
new file mode 100644
index 00000000000..fc9f7821f1a
--- /dev/null
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -0,0 +1,124 @@
+//===- TokenKinds.def - MLIR Token Description ------------------*- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file is intended to be #include'd multiple times to extract information
+// about tokens for various clients in the lexer.
+//
+//===----------------------------------------------------------------------===//
+
+#if !defined(TOK_MARKER) && !defined(TOK_IDENTIFIER) && !defined(TOK_LITERAL)&&\
+ !defined(TOK_PUNCTUATION) && !defined(TOK_OPERATOR) && !defined(TOK_KEYWORD)
+# error Must define one of the TOK_ macros.
+#endif
+
+#ifndef TOK_MARKER
+#define TOK_MARKER(X)
+#endif
+#ifndef TOK_IDENTIFIER
+#define TOK_IDENTIFIER(NAME)
+#endif
+#ifndef TOK_LITERAL
+#define TOK_LITERAL(NAME)
+#endif
+#ifndef TOK_PUNCTUATION
+#define TOK_PUNCTUATION(NAME, SPELLING)
+#endif
+#ifndef TOK_OPERATOR
+#define TOK_OPERATOR(NAME, SPELLING)
+#endif
+#ifndef TOK_KEYWORD
+#define TOK_KEYWORD(SPELLING)
+#endif
+
+
+// Markers
+TOK_MARKER(eof)
+TOK_MARKER(error)
+
+// Identifiers.
+TOK_IDENTIFIER(bare_identifier) // foo
+TOK_IDENTIFIER(at_identifier) // @foo
+TOK_IDENTIFIER(hash_identifier) // #foo
+TOK_IDENTIFIER(percent_identifier) // %foo
+TOK_IDENTIFIER(caret_identifier) // ^foo
+TOK_IDENTIFIER(exclamation_identifier) // !foo
+
+// Literals
+TOK_LITERAL(floatliteral) // 2.0
+TOK_LITERAL(integer) // 42
+TOK_LITERAL(string) // "foo"
+TOK_LITERAL(inttype) // i421
+
+// Punctuation.
+TOK_PUNCTUATION(arrow, "->")
+TOK_PUNCTUATION(at, "@")
+TOK_PUNCTUATION(colon, ":")
+TOK_PUNCTUATION(comma, ",")
+TOK_PUNCTUATION(question, "?")
+TOK_PUNCTUATION(l_paren, "(")
+TOK_PUNCTUATION(r_paren, ")")
+TOK_PUNCTUATION(l_brace, "{")
+TOK_PUNCTUATION(r_brace, "}")
+TOK_PUNCTUATION(l_square, "[")
+TOK_PUNCTUATION(r_square, "]")
+TOK_PUNCTUATION(less, "<")
+TOK_PUNCTUATION(greater, ">")
+TOK_PUNCTUATION(equal, "=")
+TOK_PUNCTUATION(ellipsis, "...")
+// TODO: More punctuation.
+
+// Operators.
+TOK_OPERATOR(plus, "+")
+TOK_OPERATOR(minus, "-")
+TOK_OPERATOR(star, "*")
+// TODO: More operator tokens
+
+// Keywords. These turn "foo" into Token::kw_foo enums.
+
+// NOTE: Please key these alphabetized to make it easier to find something in
+// this list and to cater to OCD.
+TOK_KEYWORD(attributes)
+TOK_KEYWORD(bf16)
+TOK_KEYWORD(ceildiv)
+TOK_KEYWORD(complex)
+TOK_KEYWORD(dense)
+TOK_KEYWORD(f16)
+TOK_KEYWORD(f32)
+TOK_KEYWORD(f64)
+TOK_KEYWORD(false)
+TOK_KEYWORD(floordiv)
+TOK_KEYWORD(for)
+TOK_KEYWORD(func)
+TOK_KEYWORD(index)
+TOK_KEYWORD(loc)
+TOK_KEYWORD(max)
+TOK_KEYWORD(memref)
+TOK_KEYWORD(min)
+TOK_KEYWORD(mod)
+TOK_KEYWORD(none)
+TOK_KEYWORD(offset)
+TOK_KEYWORD(opaque)
+TOK_KEYWORD(size)
+TOK_KEYWORD(sparse)
+TOK_KEYWORD(step)
+TOK_KEYWORD(strides)
+TOK_KEYWORD(symbol)
+TOK_KEYWORD(tensor)
+TOK_KEYWORD(to)
+TOK_KEYWORD(true)
+TOK_KEYWORD(tuple)
+TOK_KEYWORD(type)
+TOK_KEYWORD(unit)
+TOK_KEYWORD(vector)
+
+#undef TOK_MARKER
+#undef TOK_IDENTIFIER
+#undef TOK_LITERAL
+#undef TOK_PUNCTUATION
+#undef TOK_OPERATOR
+#undef TOK_KEYWORD
OpenPOWER on IntegriCloud