diff options
Diffstat (limited to 'mlir/lib/Parser/Parser.cpp')
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 28 |
1 files changed, 26 insertions, 2 deletions
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 35c694b6a43..2843aae4bb8 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1400,7 +1400,7 @@ static std::string extractSymbolReference(Token tok) { /// | type /// | `[` (attribute-value (`,` attribute-value)*)? `]` /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` -/// | symbol-ref-id +/// | symbol-ref-id (`::` symbol-ref-id)* /// | `dense` `<` attribute-value `>` `:` /// (tensor-type | vector-type) /// | `sparse` `<` attribute-value `,` attribute-value `>` @@ -1509,7 +1509,31 @@ Attribute Parser::parseAttribute(Type type) { case Token::at_identifier: { std::string nameStr = extractSymbolReference(getToken()); consumeToken(Token::at_identifier); - return builder.getSymbolRefAttr(nameStr); + + // Parse any nested references. + std::vector<FlatSymbolRefAttr> nestedRefs; + while (getToken().is(Token::colon)) { + // Check for the '::' prefix. + const char *curPointer = getToken().getLoc().getPointer(); + consumeToken(Token::colon); + if (!consumeIf(Token::colon)) { + state.lex.resetPointer(curPointer); + consumeToken(); + break; + } + // Parse the reference itself. + auto curLoc = getToken().getLoc(); + if (getToken().isNot(Token::at_identifier)) { + emitError(curLoc, "expected nested symbol reference identifier"); + return Attribute(); + } + + std::string nameStr = extractSymbolReference(getToken()); + consumeToken(Token::at_identifier); + nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); + } + + return builder.getSymbolRefAttr(nameStr, nestedRefs); } // Parse a 'unit' attribute. |

