diff options
| author | River Riddle <riverriddle@google.com> | 2019-02-28 09:30:52 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:50:15 -0700 |
| commit | db1757f8586d8a6fc79936bf89fcb76c05c1262b (patch) | |
| tree | 151008cc9e0b11fab40e32a16953b47a17ae35c8 | |
| parent | 8cc50208a66efa8de2279299ecb0b8e344b340f0 (diff) | |
| download | bcm5719-llvm-db1757f8586d8a6fc79936bf89fcb76c05c1262b.tar.gz bcm5719-llvm-db1757f8586d8a6fc79936bf89fcb76c05c1262b.zip | |
Add support for named function argument attributes. The attribute dictionary is printed after the argument type:
func @arg_attrs(i32 {arg_attr: 10})
func @arg_attrs(%arg0: i32 {arg_attr: 10})
PiperOrigin-RevId: 236136830
| -rw-r--r-- | mlir/g3doc/LangRef.md | 4 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Attributes.h | 1 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Function.h | 47 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Verifier.cpp | 32 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 23 | ||||
| -rw-r--r-- | mlir/lib/IR/Attributes.cpp | 5 | ||||
| -rw-r--r-- | mlir/lib/IR/Function.cpp | 15 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 45 | ||||
| -rw-r--r-- | mlir/test/IR/parser.mlir | 8 |
9 files changed, 139 insertions, 41 deletions
diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 07a422fdfe0..1f3a7d697e2 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -1101,8 +1101,8 @@ function ::= `func` function-signature function-attributes? function-body? function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)? argument-list ::= named-argument (`,` named-argument)* | /*empty*/ -argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:` -type +argument-list ::= type attribute-dict? (`,` type attribute-dict?)* | /*empty*/ +named-argument ::= ssa-id `:` type attribute-dict? function-attributes ::= `attributes` attribute-dict function-body ::= `{` block+ `}` diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 52c5e4eb495..f51af212313 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -548,6 +548,7 @@ using NamedAttribute = std::pair<Identifier, Attribute>; /// searches for everything. class NamedAttributeList { public: + NamedAttributeList() : attrs(nullptr) {} NamedAttributeList(MLIRContext *context, ArrayRef<NamedAttribute> attributes); /// Return all of the attributes on this operation. diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 9431b6702c1..5ac9176550f 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -178,18 +178,41 @@ public: /// constants to names. Attributes may be dynamically added and removed over /// the lifetime of an function. - /// Return all of the attributes on this instruction. + /// Return all of the attributes on this function. ArrayRef<NamedAttribute> getAttrs() const { return attrs.getAttrs(); } + /// Return all of the attributes for the argument at 'index'. + ArrayRef<NamedAttribute> getArgAttrs(unsigned index) const { + assert(index < getNumArguments() && "invalid argument number"); + return argAttrs[index].getAttrs(); + } + /// Set the attributes held by this function. void setAttrs(ArrayRef<NamedAttribute> attributes) { attrs.setAttrs(getContext(), attributes); } + /// Set the attributes held by the argument at 'index'. + void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) { + assert(index < getNumArguments() && "invalid argument number"); + argAttrs[index].setAttrs(getContext(), attributes); + } + /// Return the specified attribute if present, null otherwise. Attribute getAttr(Identifier name) const { return attrs.get(name); } Attribute getAttr(StringRef name) const { return attrs.get(name); } + /// Return the specified attribute, if present, for the argument at 'index', + /// null otherwise. + Attribute getArgAttr(unsigned index, Identifier name) const { + assert(index < getNumArguments() && "invalid argument number"); + return argAttrs[index].get(name); + } + Attribute getArgAttr(unsigned index, StringRef name) const { + assert(index < getNumArguments() && "invalid argument number"); + return argAttrs[index].get(name); + } + template <typename AttrClass> AttrClass getAttrOfType(Identifier name) const { return getAttr(name).dyn_cast_or_null<AttrClass>(); } @@ -198,17 +221,36 @@ public: return getAttr(name).dyn_cast_or_null<AttrClass>(); } + template <typename AttrClass> + AttrClass getArgAttrOfType(unsigned index, Identifier name) const { + return getArgAttr(index, name).dyn_cast_or_null<AttrClass>(); + } + + template <typename AttrClass> + AttrClass getArgAttrOfType(unsigned index, StringRef name) const { + return getArgAttr(index, name).dyn_cast_or_null<AttrClass>(); + } + /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. void setAttr(Identifier name, Attribute value) { attrs.set(getContext(), name, value); } + void setArgAttr(unsigned index, Identifier name, Attribute value) { + assert(index < getNumArguments() && "invalid argument number"); + argAttrs[index].set(getContext(), name, value); + } /// Remove the attribute with the specified name if it exists. The return /// value indicates whether the attribute was present or not. NamedAttributeList::RemoveResult removeAttr(Identifier name) { return attrs.remove(getContext(), name); } + NamedAttributeList::RemoveResult removeArgAttr(unsigned index, + Identifier name) { + assert(index < getNumArguments() && "invalid argument number"); + return attrs.remove(getContext(), name); + } //===--------------------------------------------------------------------===// // Other @@ -272,6 +314,9 @@ private: /// This holds general named attributes for the function. NamedAttributeList attrs; + /// The attributes lists for each of the function arguments. + std::vector<NamedAttributeList> argAttrs; + /// The contents of the body. BlockList blocks; diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 739aac6b892..1be4a86692a 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -38,6 +38,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/Module.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" @@ -125,15 +126,6 @@ bool FuncVerifier::verify() { if (!funcNameRegex.match(fn.getName().strref())) return failure("invalid function name '" + fn.getName().strref() + "'", fn); - // External functions have nothing more to check. - if (fn.isExternal()) - return false; - - // Verify the first block has no predecessors. - auto *firstBB = &fn.front(); - if (!firstBB->hasNoPredecessors()) - return failure("entry block of function may not have predecessors", fn); - /// Verify that all of the attributes are okay. for (auto attr : fn.getAttrs()) { if (!attrNameRegex.match(attr.first)) @@ -143,6 +135,28 @@ bool FuncVerifier::verify() { return true; } + /// Verify that all of the argument attributes are okay. + for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { + for (auto attr : fn.getArgAttrs(i)) { + if (!attrNameRegex.match(attr.first)) + return failure( + llvm::formatv("invalid attribute name '{0}' on argument {1}", + attr.first.strref(), i), + fn); + if (verifyAttribute(attr.second, fn)) + return true; + } + } + + // External functions have nothing more to check. + if (fn.isExternal()) + return false; + + // Verify the first block has no predecessors. + auto *firstBB = &fn.front(); + if (!firstBB->hasNoPredecessors()) + return failure("entry block of function may not have predecessors", fn); + // Verify that the argument list of the function and the arg list of the first // block line up. auto fnInputTypes = fn.getType().getInputs(); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index e46c6710425..d025f224974 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1299,20 +1299,21 @@ void FunctionPrinter::printFunctionSignature() { os << "func @" << function->getName() << '('; auto fnType = function->getType(); + bool isExternal = function->isExternal(); + for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) { + if (i > 0) + os << ", "; - // If this is an external function, don't print argument labels. - if (function->isExternal()) { - interleaveComma(fnType.getInputs(), - [&](Type eltType) { printType(eltType); }); - } else { - for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) { - if (i > 0) - os << ", "; - auto *arg = function->getArgument(i); - printOperand(arg); + // If this is an external function, don't print argument labels. + if (!isExternal) { + printOperand(function->getArgument(i)); os << ": "; - printType(arg->getType()); } + + printType(fnType.getInput(i)); + + // Print the attributes for this argument. + printOptionalAttrDict(function->getArgAttrs(i)); } os << ')'; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 7168acf09a6..d64f9c49914 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -501,7 +501,10 @@ Attribute NamedAttributeList::get(StringRef name) const { return nullptr; } Attribute NamedAttributeList::get(Identifier name) const { - return get(name.strref()); + for (auto elt : getAttrs()) + if (elt.first == name) + return elt.second; + return nullptr; } /// If the an attribute exists with the specified name, change it to the new diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 37bf19f788e..c9a482144af 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -29,7 +29,8 @@ using namespace mlir; Function::Function(Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : name(Identifier::get(name, type.getContext())), location(location), - type(type), attrs(type.getContext(), attrs), blocks(this) {} + type(type), attrs(type.getContext(), attrs), + argAttrs(type.getNumInputs()), blocks(this) {} Function::~Function() { // Instructions may have cyclic references, which need to be dropped before we @@ -167,7 +168,8 @@ Function *Function::clone(BlockAndValueMapping &mapper) const { // If the function has a body, then the user might be deleting arguments to // the function by specifying them in the mapper. If so, we don't add the // argument to the input type vector. - if (!empty()) { + bool isExternalFn = isExternal(); + if (!isExternalFn) { SmallVector<Type, 4> inputTypes; for (unsigned i = 0, e = getNumArguments(); i != e; ++i) if (!mapper.contains(getArgument(i))) @@ -175,8 +177,15 @@ Function *Function::clone(BlockAndValueMapping &mapper) const { newType = FunctionType::get(inputTypes, type.getResults(), getContext()); } - // Create a new function and clone the current function into it. + // Create the new function. Function *newFunc = new Function(getLoc(), getName(), newType); + + /// Set the argument attributes for arguments that aren't being replaced. + for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i) + if (isExternalFn || !mapper.contains(getArgument(i))) + newFunc->setArgAttrs(destI++, getArgAttrs(i)); + + /// Clone the current function into the new one and return it. cloneInto(newFunc, mapper); return newFunc; } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 372ce007b60..ba5f669f06b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3380,10 +3380,13 @@ private: ParseResult parseTypeAliasDef(); // Functions. - ParseResult parseArgumentList(SmallVectorImpl<Type> &argTypes, - SmallVectorImpl<StringRef> &argNames); - ParseResult parseFunctionSignature(StringRef &name, FunctionType &type, - SmallVectorImpl<StringRef> &argNames); + ParseResult + parseArgumentList(SmallVectorImpl<Type> &argTypes, + SmallVectorImpl<StringRef> &argNames, + SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs); + ParseResult parseFunctionSignature( + StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames, + SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs); ParseResult parseFunc(); }; } // end anonymous namespace @@ -3466,13 +3469,14 @@ ParseResult ModuleParser::parseTypeAliasDef() { /// Parse a (possibly empty) list of Function arguments with types. /// -/// named-argument ::= ssa-id `:` type +/// named-argument ::= ssa-id `:` type attribute-dict? /// argument-list ::= named-argument (`,` named-argument)* | /*empty*/ -/// argument-list ::= type (`,` type)* | /*empty*/ +/// argument-list ::= type attribute-dict? (`,` type attribute-dict?)* +/// | /*empty*/ /// -ParseResult -ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes, - SmallVectorImpl<StringRef> &argNames) { +ParseResult ModuleParser::parseArgumentList( + SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames, + SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) { consumeToken(Token::l_paren); // The argument list either has to consistently have ssa-id's followed by @@ -3502,6 +3506,14 @@ ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes, if (!elt) return ParseFailure; argTypes.push_back(elt); + + // Parse the attribute dict. + SmallVector<NamedAttribute, 2> attrs; + if (getToken().is(Token::l_brace)) { + if (parseAttributeDict(attrs)) + return ParseFailure; + } + argAttrs.push_back(attrs); return ParseSuccess; }; @@ -3514,9 +3526,9 @@ ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes, /// function-signature ::= /// function-id `(` argument-list `)` (`->` type-list)? /// -ParseResult -ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type, - SmallVectorImpl<StringRef> &argNames) { +ParseResult ModuleParser::parseFunctionSignature( + StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames, + SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) { if (getToken().isNot(Token::at_identifier)) return emitError("expected a function identifier like '@foo'"); @@ -3527,7 +3539,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type, return emitError("expected '(' in function signature"); SmallVector<Type, 4> argTypes; - if (parseArgumentList(argTypes, argNames)) + if (parseArgumentList(argTypes, argNames, argAttrs)) return ParseFailure; // Parse the return type if present. @@ -3553,9 +3565,10 @@ ParseResult ModuleParser::parseFunc() { StringRef name; FunctionType type; SmallVector<StringRef, 4> argNames; + SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs; auto loc = getToken().getLoc(); - if (parseFunctionSignature(name, type, argNames)) + if (parseFunctionSignature(name, type, argNames, argAttrs)) return ParseFailure; // If function attributes are present, parse them. @@ -3579,6 +3592,10 @@ ParseResult ModuleParser::parseFunc() { if (parseOptionalTrailingLocation(function)) return ParseFailure; + // Add the attributes to the function arguments. + for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) + function->setArgAttrs(i, argAttrs[i]); + // External functions have no body. if (getToken().isNot(Token::l_brace)) return ParseSuccess; diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 894b6746c56..575eb6194ce 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -812,3 +812,11 @@ func @internal_attrs() // CHECK-LABEL: func @_valid.function$name func @_valid.function$name() + +// CHECK-LABEL: func @external_func_arg_attrs(i32, i1 {arg.attr: 10}, i32) +func @external_func_arg_attrs(i32, i1 {arg.attr: 10}, i32) + +// CHECK-LABEL: func @func_arg_attrs(%arg0: i1 {arg.attr: 10}) +func @func_arg_attrs(%arg0: i1 {arg.attr: 10}) { + return +} |

