summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Parser
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Parser')
-rw-r--r--mlir/lib/Parser/Parser.cpp28
1 files changed, 26 insertions, 2 deletions
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 35c694b6a43..2843aae4bb8 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1400,7 +1400,7 @@ static std::string extractSymbolReference(Token tok) {
/// | type
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
-/// | symbol-ref-id
+/// | symbol-ref-id (`::` symbol-ref-id)*
/// | `dense` `<` attribute-value `>` `:`
/// (tensor-type | vector-type)
/// | `sparse` `<` attribute-value `,` attribute-value `>`
@@ -1509,7 +1509,31 @@ Attribute Parser::parseAttribute(Type type) {
case Token::at_identifier: {
std::string nameStr = extractSymbolReference(getToken());
consumeToken(Token::at_identifier);
- return builder.getSymbolRefAttr(nameStr);
+
+ // Parse any nested references.
+ std::vector<FlatSymbolRefAttr> nestedRefs;
+ while (getToken().is(Token::colon)) {
+ // Check for the '::' prefix.
+ const char *curPointer = getToken().getLoc().getPointer();
+ consumeToken(Token::colon);
+ if (!consumeIf(Token::colon)) {
+ state.lex.resetPointer(curPointer);
+ consumeToken();
+ break;
+ }
+ // Parse the reference itself.
+ auto curLoc = getToken().getLoc();
+ if (getToken().isNot(Token::at_identifier)) {
+ emitError(curLoc, "expected nested symbol reference identifier");
+ return Attribute();
+ }
+
+ std::string nameStr = extractSymbolReference(getToken());
+ consumeToken(Token::at_identifier);
+ nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+ }
+
+ return builder.getSymbolRefAttr(nameStr, nestedRefs);
}
// Parse a 'unit' attribute.
OpenPOWER on IntegriCloud