summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2019-02-05 11:47:02 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 16:14:50 -0700
commit40d5d09f9d52c581fd4419e8a54f4b952d904bb2 (patch)
tree726a32361aa6fada028ba34728eb1f2a7928573a /mlir
parent1b1f293a5d50b9dc90f9a9a00b78e9d0c9df667d (diff)
downloadbcm5719-llvm-40d5d09f9d52c581fd4419e8a54f4b952d904bb2.tar.gz
bcm5719-llvm-40d5d09f9d52c581fd4419e8a54f4b952d904bb2.zip
Print parens around the return type of a function if it is also a function type
Existing type syntax contains the following productions: function-type ::= type-list-parens `->` type-list type-list ::= type | type-list-parens type ::= <..> | function-type Due to these rules, when the parser sees `->` followed by `(`, it cannot disambiguate if `(` starts a parenthesized list of function result types, or a parenthesized list of operands of another function type, returned from the current function. We would need an unknown amount of lookahead to try to find the `->` at the right level of function nesting to differentiate between type lists and singular function types. Instead, require the result type of the function that is a function type itself to be always parenthesized, at the syntax level. Update the spec and the parser to correspond to the production rule names used in the spec (although it would have worked without modifications). Fix the function type parsing bug in the process, as it used to accept the non-parenthesized list of types for arguments, disallowed by the spec. PiperOrigin-RevId: 232528361
Diffstat (limited to 'mlir')
-rw-r--r--mlir/g3doc/LangRef.md28
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp16
-rw-r--r--mlir/lib/Parser/Parser.cpp81
-rw-r--r--mlir/test/IR/core-ops.mlir4
-rw-r--r--mlir/test/IR/invalid-locations.mlir2
-rw-r--r--mlir/test/IR/invalid-ops.mlir2
-rw-r--r--mlir/test/IR/invalid.mlir28
-rw-r--r--mlir/test/IR/parser.mlir22
8 files changed, 125 insertions, 58 deletions
diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md
index 205195122d9..af85a6db64a 100644
--- a/mlir/g3doc/LangRef.md
+++ b/mlir/g3doc/LangRef.md
@@ -504,18 +504,18 @@ have application-specific semantics. For example, MLIR supports a set of
[dialect-specific types](#dialect-specific-types).
``` {.ebnf}
-type ::= integer-type
- | index-type
- | float-type
- | vector-type
- | tensor-type
- | memref-type
+type ::= non-function-type
| function-type
- | dialect-type
- | type-alias
-// MLIR doesn't have a tuple type but functions can return multiple values.
-type-list ::= type-list-parens | type
+non-function-type ::= integer-type
+ | index-type
+ | float-type
+ | vector-type
+ | tensor-type
+ | memref-type
+ | dialect-type
+ | type-alias
+
type-list-no-parens ::= type (`,` type)*
type-list-parens ::= `(` `)`
| `(` type-list-no-parens `)`
@@ -559,7 +559,11 @@ Builtin types consist of only the types needed for the validity of the IR.
Syntax:
``` {.ebnf}
-function-type ::= type-list-parens `->` type-list
+// MLIR doesn't have a tuple type but functions can return multiple values.
+function-result-type ::= type-list-parens
+ | non-function-type
+
+function-type ::= type-list-parens `->` function-result-type
```
MLIR supports first-class functions: the
@@ -897,7 +901,7 @@ associated attributes according to the following grammar:
``` {.ebnf}
function ::= `func` function-signature function-attributes? function-body?
-function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
+function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:`
type
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 127f4952d8f..e08b1409f68 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -746,7 +746,7 @@ void ModulePrinter::printType(Type type) {
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
os << ") -> ";
auto results = func.getResults();
- if (results.size() == 1)
+ if (results.size() == 1 && !results[0].isa<FunctionType>())
os << results[0];
else {
os << '(';
@@ -1313,10 +1313,17 @@ void FunctionPrinter::printFunctionSignature() {
switch (fnType.getResults().size()) {
case 0:
break;
- case 1:
+ case 1: {
os << " -> ";
- printType(fnType.getResults()[0]);
+ auto resultType = fnType.getResults()[0];
+ bool resultIsFunc = resultType.isa<FunctionType>();
+ if (resultIsFunc)
+ os << '(';
+ printType(resultType);
+ if (resultIsFunc)
+ os << ')';
break;
+ }
default:
os << " -> (";
interleaveComma(fnType.getResults(),
@@ -1482,7 +1489,8 @@ void FunctionPrinter::printGenericOp(const Instruction *op) {
[&](const Value *value) { printType(value->getType()); });
os << ") -> ";
- if (op->getNumResults() == 1) {
+ if (op->getNumResults() == 1 &&
+ !op->getResult(0)->getType().isa<FunctionType>()) {
printType(op->getResult(0)->getType());
} else {
os << '(';
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index b7e4fb147cb..bd9c5e9475d 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -187,9 +187,11 @@ public:
Type parseTensorType();
Type parseMemRefType();
Type parseFunctionType();
+ Type parseNonFunctionType();
Type parseType();
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
- ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
+ ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
+ ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
// Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
@@ -308,32 +310,29 @@ ParseResult Parser::parseCommaSeparatedListUntil(
// Type Parsing
//===----------------------------------------------------------------------===//
-/// Parse an arbitrary type.
+/// Parse any type except the function type.
///
-/// type ::= integer-type
-/// | index-type
-/// | float-type
-/// | extended-type
-/// | vector-type
-/// | tensor-type
-/// | memref-type
-/// | function-type
+/// non-function-type ::= integer-type
+/// | index-type
+/// | float-type
+/// | extended-type
+/// | vector-type
+/// | tensor-type
+/// | memref-type
///
/// index-type ::= `index`
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
///
-Type Parser::parseType() {
+Type Parser::parseNonFunctionType() {
switch (getToken().getKind()) {
default:
- return (emitError("expected type"), nullptr);
+ return (emitError("expected non-function type"), nullptr);
case Token::kw_memref:
return parseMemRefType();
case Token::kw_tensor:
return parseTensorType();
case Token::kw_vector:
return parseVectorType();
- case Token::l_paren:
- return parseFunctionType();
// integer-type
case Token::inttype: {
auto width = getToken().getIntTypeBitwidth();
@@ -369,6 +368,17 @@ Type Parser::parseType() {
}
}
+/// Parse an arbitrary type.
+///
+/// type ::= function-type
+/// | non-function-type
+///
+Type Parser::parseType() {
+ if (getToken().is(Token::l_paren))
+ return parseFunctionType();
+ return parseNonFunctionType();
+}
+
/// Parse a vector type.
///
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
@@ -640,9 +650,9 @@ Type Parser::parseFunctionType() {
assert(getToken().is(Token::l_paren));
SmallVector<Type, 4> arguments, results;
- if (parseTypeList(arguments) ||
+ if (parseTypeListParens(arguments) ||
parseToken(Token::arrow, "expected '->' in function type") ||
- parseTypeList(results))
+ parseFunctionResultTypes(results))
return nullptr;
return builder.getFunctionType(arguments, results);
@@ -663,27 +673,38 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
return parseCommaSeparatedList(parseElt);
}
-/// Parse a "type list", which is a singular type, or a parenthesized list of
-/// types.
+/// Parse a parenthesized list of types.
///
-/// type-list ::= type-list-parens | type
/// type-list-parens ::= `(` `)`
/// | `(` type-list-no-parens `)`
///
-ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
- auto parseElt = [&]() -> ParseResult {
- auto elt = parseType();
- elements.push_back(elt);
- return elt ? ParseSuccess : ParseFailure;
- };
+ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
+ if (parseToken(Token::l_paren, "expected '('"))
+ return ParseFailure;
- // If there is no parens, then it must be a singular type.
- if (!consumeIf(Token::l_paren))
- return parseElt();
+ // Handle empty lists.
+ if (getToken().is(Token::r_paren))
+ return consumeToken(), ParseSuccess;
- if (parseCommaSeparatedListUntil(Token::r_paren, parseElt))
+ if (parseTypeListNoParens(elements) ||
+ parseToken(Token::r_paren, "expected ')'"))
return ParseFailure;
+ return ParseSuccess;
+}
+
+/// Parse a function result type.
+///
+/// function-result-type ::= type-list-parens
+/// | non-function-type
+///
+ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
+ if (getToken().is(Token::l_paren))
+ return parseTypeListParens(elements);
+ Type t = parseNonFunctionType();
+ if (!t)
+ return ParseFailure;
+ elements.push_back(t);
return ParseSuccess;
}
@@ -3489,7 +3510,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
// Parse the return type if present.
SmallVector<Type, 4> results;
if (consumeIf(Token::arrow)) {
- if (parseTypeList(results))
+ if (parseFunctionResultTypes(results))
return ParseFailure;
}
type = builder.getFunctionType(argTypes, results);
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index ff5b2563670..cb6d92d9670 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt %s | FileCheck %s
// Verify the printed output can be parsed.
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// TODO(b/123888077): The following fails due to constant with function pointer.
-// Disabled: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
// CHECK: #map0 = (d0) -> (d0 + 1)
diff --git a/mlir/test/IR/invalid-locations.mlir b/mlir/test/IR/invalid-locations.mlir
index 327ccdded04..7d4e19349c6 100644
--- a/mlir/test/IR/invalid-locations.mlir
+++ b/mlir/test/IR/invalid-locations.mlir
@@ -67,7 +67,7 @@ func @location_fused_missing_greater() {
func @location_fused_missing_metadata() {
^bb:
- // expected-error@+1 {{expected type}}
+ // expected-error@+1 {{expected non-function type}}
return loc(fused<) // expected-error {{expected valid attribute metadata}}
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 9097d401613..9e536493b07 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -209,7 +209,7 @@ func @func_with_ops(i32, i64) {
// Comparisons must have the "predicate" attribute.
func @func_with_ops(i32, i32) {
^bb0(%a : i32, %b : i32):
- %r = cmpi %a, %b : i32 // expected-error {{expected type}}
+ %r = cmpi %a, %b : i32 // expected-error {{expected non-function type}}
}
// -----
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 1ba6c45c06b..9f53630cfb5 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -3,12 +3,12 @@
// Check different error cases.
// -----
-func @illegaltype(i) // expected-error {{expected type}}
+func @illegaltype(i) // expected-error {{expected non-function type}}
// -----
func @illegaltype() {
- %0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected type}}
+ %0 = constant splat<<vector 4 x f32>, 0> : vector<4 x f32> // expected-error {{expected non-function type}}
}
// -----
@@ -227,7 +227,7 @@ func @incomplete_for() {
// -----
func @nonconstant_step(%1 : i32) {
- for %2 = 1 to 5 step %1 { // expected-error {{expected type}}
+ for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}}
// -----
@@ -326,7 +326,7 @@ func @d() {return} // expected-error {{custom op 'func' is unknown}}
// -----
-func @malformed_type(%a : intt) { // expected-error {{expected type}}
+func @malformed_type(%a : intt) { // expected-error {{expected non-function type}}
}
// -----
@@ -392,7 +392,7 @@ func @condbr_badtype() {
^bb0:
%c = "foo"() : () -> i1
%a = "foo"() : () -> i32
- cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected type}}
+ cond_br %c, ^bb0(%a, %a : i32, ^bb0) // expected-error {{expected non-function type}}
}
// -----
@@ -506,6 +506,18 @@ func @undefined_function() {
// -----
+func @invalid_result_type() -> () -> () // expected-error {{expected a top level entity}}
+
+// -----
+
+func @func() -> (() -> ())
+func @referer() {
+ %f = constant @func : () -> () -> () // expected-error {{reference to function with mismatched type}}
+ return
+}
+
+// -----
+
#map1 = (i)[j] -> (i+j)
func @bound_symbol_mismatch(%N : index) {
@@ -538,7 +550,7 @@ func @large_bound() {
// -----
func @max_in_upper_bound(%N : index) {
- for %i = 1 to max (i)->(N, 100) { //expected-error {{expected type}}
+ for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}}
}
return
}
@@ -595,7 +607,7 @@ func @invalid_if_operands3(%N : index) {
// expected-error@+1 {{expected '"' in string literal}}
"J// -----
func @calls(%arg0: i32) {
- // expected-error@+1 {{expected type}}
+ // expected-error@+1 {{expected non-function type}}
%z = "casdasda"(%x) : (ppop32) -> i32
}
// -----
@@ -767,7 +779,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t
// -----
-!missing_type_alias = type // expected-error@+2 {{expected type}}
+!missing_type_alias = type // expected-error@+2 {{expected non-function type}}
// -----
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 417f0f13f33..4b1b21473b2 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -182,6 +182,28 @@ func @func_with_two_args(%a : f16, %b : i8) -> (i1, i32) {
return %c#0, %c#1 : i1, i32 // CHECK: return %0#0, %0#1 : i1, i32
} // CHECK: }
+// CHECK-LABEL: func @second_order_func() -> (() -> ()) {
+func @second_order_func() -> (() -> ()) {
+// CHECK-NEXT: %f = constant @emptyMLF : () -> ()
+ %c = constant @emptyMLF : () -> ()
+// CHECK-NEXT: return %f : () -> ()
+ return %c : () -> ()
+}
+
+// CHECK-LABEL: func @third_order_func() -> (() -> (() -> ())) {
+func @third_order_func() -> (() -> (() -> ())) {
+// CHECK-NEXT: %f = constant @second_order_func : () -> (() -> ())
+ %c = constant @second_order_func : () -> (() -> ())
+// CHECK-NEXT: return %f : () -> (() -> ())
+ return %c : () -> (() -> ())
+}
+
+// CHECK-LABEL: func @identity_functor(%arg0: () -> ()) -> (() -> ()) {
+func @identity_functor(%a : () -> ()) -> (() -> ()) {
+// CHECK-NEXT: return %arg0 : () -> ()
+ return %a : () -> ()
+}
+
// CHECK-LABEL: func @func_ops_in_loop() {
func @func_ops_in_loop() {
// CHECK: %0 = "foo"() : () -> i64
OpenPOWER on IntegriCloud