diff options
| author | Alex Zinenko <zinenko@google.com> | 2019-02-05 11:47:02 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:14:50 -0700 |
| commit | 40d5d09f9d52c581fd4419e8a54f4b952d904bb2 (patch) | |
| tree | 726a32361aa6fada028ba34728eb1f2a7928573a /mlir | |
| parent | 1b1f293a5d50b9dc90f9a9a00b78e9d0c9df667d (diff) | |
| download | bcm5719-llvm-40d5d09f9d52c581fd4419e8a54f4b952d904bb2.tar.gz bcm5719-llvm-40d5d09f9d52c581fd4419e8a54f4b952d904bb2.zip | |
Print parens around the return type of a function if it is also a function type
Existing type syntax contains the following productions:
function-type ::= type-list-parens `->` type-list
type-list ::= type | type-list-parens
type ::= <..> | function-type
Due to these rules, when the parser sees `->` followed by `(`, it cannot
disambiguate if `(` starts a parenthesized list of function result types, or a
parenthesized list of operands of another function type, returned from the
current function. We would need an unknown amount of lookahead to try to find
the `->` at the right level of function nesting to differentiate between type
lists and singular function types.
Instead, require the result type of the function that is a function type itself
to be always parenthesized, at the syntax level. Update the spec and the
parser to correspond to the production rule names used in the spec (although it
would have worked without modifications). Fix the function type parsing bug in
the process, as it used to accept the non-parenthesized list of types for
arguments, disallowed by the spec.
PiperOrigin-RevId: 232528361
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/g3doc/LangRef.md | 28 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 81 | ||||
| -rw-r--r-- | mlir/test/IR/core-ops.mlir | 4 | ||||
| -rw-r--r-- | mlir/test/IR/invalid-locations.mlir | 2 | ||||
| -rw-r--r-- | mlir/test/IR/invalid-ops.mlir | 2 | ||||
| -rw-r--r-- | mlir/test/IR/invalid.mlir | 28 | ||||
| -rw-r--r-- | mlir/test/IR/parser.mlir | 22 |
8 files changed, 125 insertions, 58 deletions
diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 205195122d9..af85a6db64a 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -504,18 +504,18 @@ have application-specific semantics. For example, MLIR supports a set of [dialect-specific types](#dialect-specific-types). ``` {.ebnf} -type ::= integer-type - | index-type - | float-type - | vector-type - | tensor-type - | memref-type +type ::= non-function-type | function-type - | dialect-type - | type-alias -// MLIR doesn't have a tuple type but functions can return multiple values. -type-list ::= type-list-parens | type +non-function-type ::= integer-type + | index-type + | float-type + | vector-type + | tensor-type + | memref-type + | dialect-type + | type-alias + type-list-no-parens ::= type (`,` type)* type-list-parens ::= `(` `)` | `(` type-list-no-parens `)` @@ -559,7 +559,11 @@ Builtin types consist of only the types needed for the validity of the IR. Syntax: ``` {.ebnf} -function-type ::= type-list-parens `->` type-list +// MLIR doesn't have a tuple type but functions can return multiple values. +function-result-type ::= type-list-parens + | non-function-type + +function-type ::= type-list-parens `->` function-result-type ``` MLIR supports first-class functions: the @@ -897,7 +901,7 @@ associated attributes according to the following grammar: ``` {.ebnf} function ::= `func` function-signature function-attributes? function-body? -function-signature ::= function-id `(` argument-list `)` (`->` type-list)? +function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)? argument-list ::= named-argument (`,` named-argument)* | /*empty*/ argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:` type diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 127f4952d8f..e08b1409f68 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -746,7 +746,7 @@ void ModulePrinter::printType(Type type) { interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); os << ") -> "; auto results = func.getResults(); - if (results.size() == 1) + if (results.size() == 1 && !results[0].isa<FunctionType>()) os << results[0]; else { os << '('; @@ -1313,10 +1313,17 @@ void FunctionPrinter::printFunctionSignature() { switch (fnType.getResults().size()) { case 0: break; - case 1: + case 1: { os << " -> "; - printType(fnType.getResults()[0]); + auto resultType = fnType.getResults()[0]; + bool resultIsFunc = resultType.isa<FunctionType>(); + if (resultIsFunc) + os << '('; + printType(resultType); + if (resultIsFunc) + os << ')'; break; + } default: os << " -> ("; interleaveComma(fnType.getResults(), @@ -1482,7 +1489,8 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { [&](const Value *value) { printType(value->getType()); }); os << ") -> "; - if (op->getNumResults() == 1) { + if (op->getNumResults() == 1 && + !op->getResult(0)->getType().isa<FunctionType>()) { printType(op->getResult(0)->getType()); } else { os << '('; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index b7e4fb147cb..bd9c5e9475d 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -187,9 +187,11 @@ public: Type parseTensorType(); Type parseMemRefType(); Type parseFunctionType(); + Type parseNonFunctionType(); Type parseType(); ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); - ParseResult parseTypeList(SmallVectorImpl<Type> &elements); + ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); + ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); // Attribute parsing. Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, @@ -308,32 +310,29 @@ ParseResult Parser::parseCommaSeparatedListUntil( // Type Parsing //===----------------------------------------------------------------------===// -/// Parse an arbitrary type. +/// Parse any type except the function type. /// -/// type ::= integer-type -/// | index-type -/// | float-type -/// | extended-type -/// | vector-type -/// | tensor-type -/// | memref-type -/// | function-type +/// non-function-type ::= integer-type +/// | index-type +/// | float-type +/// | extended-type +/// | vector-type +/// | tensor-type +/// | memref-type /// /// index-type ::= `index` /// float-type ::= `f16` | `bf16` | `f32` | `f64` /// -Type Parser::parseType() { +Type Parser::parseNonFunctionType() { switch (getToken().getKind()) { default: - return (emitError("expected type"), nullptr); + return (emitError("expected non-function type"), nullptr); case Token::kw_memref: return parseMemRefType(); case Token::kw_tensor: return parseTensorType(); case Token::kw_vector: return parseVectorType(); - case Token::l_paren: - return parseFunctionType(); // integer-type case Token::inttype: { auto width = getToken().getIntTypeBitwidth(); @@ -369,6 +368,17 @@ Type Parser::parseType() { } } +/// Parse an arbitrary type. +/// +/// type ::= function-type +/// | non-function-type +/// +Type Parser::parseType() { + if (getToken().is(Token::l_paren)) + return parseFunctionType(); + return parseNonFunctionType(); +} + /// Parse a vector type. /// /// vector-type ::= `vector` `<` const-dimension-list primitive-type `>` @@ -640,9 +650,9 @@ Type Parser::parseFunctionType() { assert(getToken().is(Token::l_paren)); SmallVector<Type, 4> arguments, results; - if (parseTypeList(arguments) || + if (parseTypeListParens(arguments) || parseToken(Token::arrow, "expected '->' in function type") || - parseTypeList(results)) + parseFunctionResultTypes(results)) return nullptr; return builder.getFunctionType(arguments, results); @@ -663,27 +673,38 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { return parseCommaSeparatedList(parseElt); } -/// Parse a "type list", which is a singular type, or a parenthesized list of -/// types. +/// Parse a parenthesized list of types. /// -/// type-list ::= type-list-parens | type /// type-list-parens ::= `(` `)` /// | `(` type-list-no-parens `)` /// -ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) { - auto parseElt = [&]() -> ParseResult { - auto elt = parseType(); - elements.push_back(elt); - return elt ? ParseSuccess : ParseFailure; - }; +ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { + if (parseToken(Token::l_paren, "expected '('")) + return ParseFailure; - // If there is no parens, then it must be a singular type. - if (!consumeIf(Token::l_paren)) - return parseElt(); + // Handle empty lists. + if (getToken().is(Token::r_paren)) + return consumeToken(), ParseSuccess; - if (parseCommaSeparatedListUntil(Token::r_paren, parseElt)) + if (parseTypeListNoParens(elements) || + parseToken(Token::r_paren, "expected ')'")) return ParseFailure; + return ParseSuccess; +} + +/// Parse a function result type. +/// +/// function-result-type ::= type-list-parens +/// | non-function-type +/// +ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { + if (getToken().is(Token::l_paren)) + return parseTypeListParens(elements); + Type t = parseNonFunctionType(); + if (!t) + return ParseFailure; + elements.push_back(t); return ParseSuccess; } @@ -3489,7 +3510,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type, // Parse the return type if present. SmallVector<Type, 4> results; if (consumeIf(Token::arrow)) { - if (parseTypeList(results)) + if (parseFunctionResultTypes(results)) return ParseFailure; } type = builder.getFunctionType(argTypes, results); diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index ff5b2563670..cb6d92d9670 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -1,8 +1,8 @@ // RUN: mlir-opt %s | FileCheck %s // Verify the printed output can be parsed. // RUN: mlir-opt %s | mlir-opt | FileCheck %s -// TODO(b/123888077): The following fails due to constant with function pointer. -// Disabled: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s // CHECK: #map0 = (d0) -> (d0 + 1) diff --git a/mlir/test/IR/invalid-locations.mlir b/mlir/test/IR/invalid-locations.mlir index 327ccdded04..7d4e19349c6 100644 --- a/mlir/test/IR/invalid-locations.mlir +++ b/mlir/test/IR/invalid-locations.mlir @@ -67,7 +67,7 @@ func @location_fused_missing_greater() { func @location_fused_missing_metadata() { ^bb: - // expected-error@+1 {{expected type}} + // expected-error@+1 {{expected non-function type}} return loc(fused<) // expected-error {{expected valid attribute metadata}} } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 9097d401613..9e536493b07 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -209,7 +209,7 @@ func @func_with_ops(i32, i64) { // Comparisons must have the "predicate" attribute. func @func_with_ops(i32, i32) { ^bb0(%a : i32, %b : i32): - %r = cmpi %a, %b : i32 // expected-error {{expected type}} + %r = cmpi %a, %b : i32 // expected-error {{expected non-function type}} } // ----- diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 1ba6c45c06b..9f53630cfb5 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -3,12 +3,12 @@ // Check different error cases. // ----- -func @illegaltype(i) // expected-error {{expected type}} +func @illegaltype(i) // expected-error {{expected non-function type}} // ----- func @illegaltype() { - %0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected type}} + %0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected non-function type}} } // ----- @@ -227,7 +227,7 @@ func @incomplete_for() { // ----- func @nonconstant_step(%1 : i32) { - for %2 = 1 to 5 step %1 { // expected-error {{expected type}} + for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}} // ----- @@ -326,7 +326,7 @@ func @d() {return} // expected-error {{custom op 'func' is unknown}} // ----- -func @malformed_type(%a : intt) { // expected-error {{expected type}} +func @malformed_type(%a : intt) { // expected-error {{expected non-function type}} } // ----- @@ -392,7 +392,7 @@ func @condbr_badtype() { ^bb0: %c = "foo"() : () -> i1 %a = "foo"() : () -> i32 - cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected type}} + cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected non-function type}} } // ----- @@ -506,6 +506,18 @@ func @undefined_function() { // ----- +func @invalid_result_type() -> () -> () // expected-error {{expected a top level entity}} + +// ----- + +func @func() -> (() -> ()) +func @referer() { + %f = constant @func : () -> () -> () // expected-error {{reference to function with mismatched type}} + return +} + +// ----- + #map1 = (i)[j] -> (i+j) func @bound_symbol_mismatch(%N : index) { @@ -538,7 +550,7 @@ func @large_bound() { // ----- func @max_in_upper_bound(%N : index) { - for %i = 1 to max (i)->(N, 100) { //expected-error {{expected type}} + for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}} } return } @@ -595,7 +607,7 @@ func @invalid_if_operands3(%N : index) { // expected-error@+1 {{expected '"' in string literal}} "J// ----- func @calls(%arg0: i32) { - // expected-error@+1 {{expected type}} + // expected-error@+1 {{expected non-function type}} %z = "casdasda"(%x) : (ppop32) -> i32 } // ----- @@ -767,7 +779,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t // ----- -!missing_type_alias = type // expected-error@+2 {{expected type}} +!missing_type_alias = type // expected-error@+2 {{expected non-function type}} // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 417f0f13f33..4b1b21473b2 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -182,6 +182,28 @@ func @func_with_two_args(%a : f16, %b : i8) -> (i1, i32) { return %c#0, %c#1 : i1, i32 // CHECK: return %0#0, %0#1 : i1, i32 } // CHECK: } +// CHECK-LABEL: func @second_order_func() -> (() -> ()) { +func @second_order_func() -> (() -> ()) { +// CHECK-NEXT: %f = constant @emptyMLF : () -> () + %c = constant @emptyMLF : () -> () +// CHECK-NEXT: return %f : () -> () + return %c : () -> () +} + +// CHECK-LABEL: func @third_order_func() -> (() -> (() -> ())) { +func @third_order_func() -> (() -> (() -> ())) { +// CHECK-NEXT: %f = constant @second_order_func : () -> (() -> ()) + %c = constant @second_order_func : () -> (() -> ()) +// CHECK-NEXT: return %f : () -> (() -> ()) + return %c : () -> (() -> ()) +} + +// CHECK-LABEL: func @identity_functor(%arg0: () -> ()) -> (() -> ()) { +func @identity_functor(%a : () -> ()) -> (() -> ()) { +// CHECK-NEXT: return %arg0 : () -> () + return %a : () -> () +} + // CHECK-LABEL: func @func_ops_in_loop() { func @func_ops_in_loop() { // CHECK: %0 = "foo"() : () -> i64 |

