diff options
| author | River Riddle <riverriddle@google.com> | 2019-11-01 15:39:30 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-01 15:40:16 -0700 |
| commit | 2ba4d802e030b51e78b7d29238ccc552ea19d1c4 (patch) | |
| tree | 849e569f89352035e505f006fb81a3f46efb7e2f /mlir/lib/Parser | |
| parent | 445cc3f6dd74e86575153a95ecfb8754d6d5b726 (diff) | |
| download | bcm5719-llvm-2ba4d802e030b51e78b7d29238ccc552ea19d1c4.tar.gz bcm5719-llvm-2ba4d802e030b51e78b7d29238ccc552ea19d1c4.zip | |
Remove the need for passing a location to parseAttribute/parseType.
Now that a proper parser is passed to these methods, there isn't a need to explicitly pass a source location. The source location can be recovered from the parser as necessary. This removes the need to explicitly decode an SMLoc in the case where we don't need to, which can be expensive.
This requires adding some basic nesting support to the parser for supporting nested parsers to allow for remapping source locations of the nested parsers to the top level parser for accurate diagnostics. This is due to the fact that the attribute and type parsers use different source buffers than the top level parser, as they may be represented in string form.
PiperOrigin-RevId: 278014858
Diffstat (limited to 'mlir/lib/Parser')
| -rw-r--r-- | mlir/lib/Parser/Lexer.h | 3 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 125 |
2 files changed, 94 insertions, 34 deletions
diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h index 896c26cc927..b1807711457 100644 --- a/mlir/lib/Parser/Lexer.h +++ b/mlir/lib/Parser/Lexer.h @@ -45,6 +45,9 @@ public: /// at the designated point in the input. void resetPointer(const char *newPointer) { curPtr = newPointer; } + /// Returns the start of the buffer. + const char *getBufferBegin() { return curBuffer.data(); } + private: // Helpers. Token formToken(Token::Kind kind, const char *tokStart) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index a6e02279adb..368f262ade7 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -52,16 +52,27 @@ namespace { class Parser; //===----------------------------------------------------------------------===// -// AliasState +// SymbolState //===----------------------------------------------------------------------===// -/// This class contains record of any parsed top-level aliases. -struct AliasState { +/// This class contains record of any parsed top-level symbols. +struct SymbolState { // A map from attribute alias identifier to Attribute. llvm::StringMap<Attribute> attributeAliasDefinitions; // A map from type alias identifier to Type. llvm::StringMap<Type> typeAliasDefinitions; + + /// A set of locations into the main parser memory buffer for each of the + /// active nested parsers. Given that some nested parsers, i.e. custom dialect + /// parsers, operate on a temporary memory buffer, this provides an anchor + /// point for emitting diagnostics. + SmallVector<llvm::SMLoc, 1> nestedParserLocs; + + /// The top-level lexer that contains the original memory buffer provided by + /// the user. This is used by nested parsers to get a properly encoded source + /// location. + Lexer *topLevelLexer = nullptr; }; //===----------------------------------------------------------------------===// @@ -72,9 +83,18 @@ struct AliasState { /// such as the current lexer position etc. struct ParserState { ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, - AliasState &aliases) + SymbolState &symbols) : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), - aliases(aliases) {} + symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { + // Set the top level lexer for the symbol state if one doesn't exist. + if (!symbols.topLevelLexer) + symbols.topLevelLexer = &lex; + } + ~ParserState() { + // Reset the top level lexer if it refers the lexer in our state. + if (symbols.topLevelLexer == &lex) + symbols.topLevelLexer = nullptr; + } ParserState(const ParserState &) = delete; void operator=(const ParserState &) = delete; @@ -87,8 +107,11 @@ struct ParserState { /// This is the next token that hasn't been consumed yet. Token curToken; - /// Any parsed alias state. - AliasState &aliases; + /// The current state for symbol parsing. + SymbolState &symbols; + + /// The depth of this parser in the nested parsing stack. + size_t parserDepth; }; //===----------------------------------------------------------------------===// @@ -140,7 +163,32 @@ public: /// Encode the specified source location information into an attribute for /// attachment to the IR. Location getEncodedSourceLocation(llvm::SMLoc loc) { - return state.lex.getEncodedSourceLocation(loc); + // If there are no active nested parsers, we can get the encoded source + // location directly. + if (state.parserDepth == 0) + return state.lex.getEncodedSourceLocation(loc); + // Otherwise, we need to re-encode it to point to the top level buffer. + return state.symbols.topLevelLexer->getEncodedSourceLocation( + remapLocationToTopLevelBuffer(loc)); + } + + /// Remaps the given SMLoc to the top level lexer of the parser. This is used + /// to adjust locations of potentially nested parsers to ensure that they can + /// be emitted properly as diagnostics. + llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { + // If there are no active nested parsers, we can return location directly. + SymbolState &symbols = state.symbols; + if (state.parserDepth == 0) + return loc; + assert(symbols.topLevelLexer && "expected valid top-level lexer"); + + // Otherwise, we need to remap the location to the main parser. This is + // simply offseting the location onto the location of the last nested + // parser. + size_t offset = loc.getPointer() - state.lex.getBufferBegin(); + auto *rawLoc = + symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; + return llvm::SMLoc::getFromPointer(rawLoc); } //===--------------------------------------------------------------------===// @@ -388,6 +436,11 @@ public: /// Return the location of the original name token. llvm::SMLoc getNameLoc() const override { return nameLoc; } + /// Re-encode the given source location as an MLIR location and return it. + Location getEncodedSourceLoc(llvm::SMLoc loc) override { + return parser.getEncodedSourceLocation(loc); + } + /// Returns the full specification of the symbol being parsed. This allows /// for using a separate parser if necessary. StringRef getFullSymbolSpec() const override { return fullSpec; } @@ -517,7 +570,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, return (p.emitError("expected string literal data in dialect symbol"), nullptr); symbolData = p.getToken().getStringValue(); - loc = p.getToken().getLoc(); + loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); p.consumeToken(Token::string); // Consume the '>'. @@ -529,6 +582,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, auto dotHalves = identifier.split('.'); dialectName = dotHalves.first; auto prettyName = dotHalves.second; + loc = llvm::SMLoc::getFromPointer(prettyName.data()); // If the dialect's symbol is followed immediately by a <, then lex the body // of it into prettyName. @@ -541,8 +595,16 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, symbolData = prettyName.str(); } + // Record the name location of the type remapped to the top level buffer. + llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); + p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); + // Call into the provided symbol construction function. - return createSymbol(dialectName, symbolData, loc); + Symbol sym = createSymbol(dialectName, symbolData, loc); + + // Pop the last parser location. + p.getState().symbols.nestedParserLocs.pop_back(); + return sym; } /// Parses a symbol, of type 'T', and returns it if parsing was successful. If @@ -550,14 +612,14 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, /// string is returned in 'numRead'. template <typename T, typename ParserFn> static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, - AliasState &aliasState, ParserFn &&parserFn, + SymbolState &symbolState, ParserFn &&parserFn, size_t *numRead = nullptr) { SourceMgr sourceMgr; auto memBuffer = MemoryBuffer::getMemBuffer( inputStr, /*BufferName=*/"<mlir_parser_buffer>", /*RequiresNullTerminator=*/false); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); - ParserState state(sourceMgr, context, aliasState); + ParserState state(sourceMgr, context, symbolState); Parser parser(state); Token startTok = parser.getToken(); @@ -573,8 +635,7 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, // Otherwise, ensure that all of the tokens were parsed. } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { - parser.emitError(endTok.getLoc(), - "encountered unexpected tokens after parsing"); + parser.emitError(endTok.getLoc(), "encountered unexpected token"); return T(); } return symbol; @@ -585,13 +646,12 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, //===----------------------------------------------------------------------===// InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { - auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); - // If we hit a parse error in response to a lexer error, then the lexer // already reported the error. if (getToken().is(Token::error)) - diag.abandon(); - return diag; + return InFlightDiagnostic(); + + return mlir::emitError(getEncodedSourceLocation(loc), message); } //===----------------------------------------------------------------------===// @@ -701,24 +761,22 @@ Type Parser::parseComplexType() { /// Type Parser::parseExtendedType() { return parseExtendedSymbol<Type>( - *this, Token::exclamation_identifier, state.aliases.typeAliasDefinitions, + *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, llvm::SMLoc loc) -> Type { - Location encodedLoc = getEncodedSourceLocation(loc); - // If we found a registered dialect, then ask it to parse the type. if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { return parseSymbol<Type>( - symbolData, state.context, state.aliases, [&](Parser &parser) { + symbolData, state.context, state.symbols, [&](Parser &parser) { CustomDialectAsmParser customParser(symbolData, parser); - return dialect->parseType(customParser, encodedLoc); + return dialect->parseType(customParser); }); } // Otherwise, form a new opaque type. return OpaqueType::getChecked( Identifier::get(dialectName, state.context), symbolData, - state.context, encodedLoc); + state.context, getEncodedSourceLocation(loc)); }); } @@ -1315,7 +1373,7 @@ Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { /// Attribute Parser::parseExtendedAttr(Type type) { Attribute attr = parseExtendedSymbol<Attribute>( - *this, Token::hash_identifier, state.aliases.attributeAliasDefinitions, + *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, llvm::SMLoc loc) -> Attribute { // Parse an optional trailing colon type. @@ -1326,10 +1384,9 @@ Attribute Parser::parseExtendedAttr(Type type) { // If we found a registered dialect, then ask it to parse the attribute. if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { return parseSymbol<Attribute>( - symbolData, state.context, state.aliases, [&](Parser &parser) { + symbolData, state.context, state.symbols, [&](Parser &parser) { CustomDialectAsmParser customParser(symbolData, parser); - return dialect->parseAttribute(customParser, attrType, - getEncodedSourceLocation(loc)); + return dialect->parseAttribute(customParser, attrType); }); } @@ -4242,7 +4299,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() { StringRef aliasName = getTokenSpelling().drop_front(); // Check for redefinitions. - if (getState().aliases.attributeAliasDefinitions.count(aliasName) > 0) + if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) return emitError("redefinition of attribute alias id '" + aliasName + "'"); // Make sure this isn't invading the dialect attribute namespace. @@ -4261,7 +4318,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() { if (!attr) return failure(); - getState().aliases.attributeAliasDefinitions[aliasName] = attr; + getState().symbols.attributeAliasDefinitions[aliasName] = attr; return success(); } @@ -4274,7 +4331,7 @@ ParseResult ModuleParser::parseTypeAliasDef() { StringRef aliasName = getTokenSpelling().drop_front(); // Check for redefinitions. - if (getState().aliases.typeAliasDefinitions.count(aliasName) > 0) + if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) return emitError("redefinition of type alias id '" + aliasName + "'"); // Make sure this isn't invading the dialect type namespace. @@ -4295,7 +4352,7 @@ ParseResult ModuleParser::parseTypeAliasDef() { return failure(); // Register this alias with the parser state. - getState().aliases.typeAliasDefinitions.try_emplace(aliasName, aliasedType); + getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); return success(); } @@ -4374,7 +4431,7 @@ OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); - AliasState aliasState; + SymbolState aliasState; ParserState state(sourceMgr, context, aliasState); if (ModuleParser(state).parseModule(*module)) return nullptr; @@ -4440,7 +4497,7 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr, template <typename T, typename ParserFn> static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, size_t &numRead, ParserFn &&parserFn) { - AliasState aliasState; + SymbolState aliasState; return parseSymbol<T>( inputStr, context, aliasState, [&](Parser &parser) { |

