summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-02-28 09:30:52 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 16:50:15 -0700
commitdb1757f8586d8a6fc79936bf89fcb76c05c1262b (patch)
tree151008cc9e0b11fab40e32a16953b47a17ae35c8
parent8cc50208a66efa8de2279299ecb0b8e344b340f0 (diff)
downloadbcm5719-llvm-db1757f8586d8a6fc79936bf89fcb76c05c1262b.tar.gz
bcm5719-llvm-db1757f8586d8a6fc79936bf89fcb76c05c1262b.zip
Add support for named function argument attributes. The attribute dictionary is printed after the argument type:
func @arg_attrs(i32 {arg_attr: 10}) func @arg_attrs(%arg0: i32 {arg_attr: 10}) PiperOrigin-RevId: 236136830
-rw-r--r--mlir/g3doc/LangRef.md4
-rw-r--r--mlir/include/mlir/IR/Attributes.h1
-rw-r--r--mlir/include/mlir/IR/Function.h47
-rw-r--r--mlir/lib/Analysis/Verifier.cpp32
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp23
-rw-r--r--mlir/lib/IR/Attributes.cpp5
-rw-r--r--mlir/lib/IR/Function.cpp15
-rw-r--r--mlir/lib/Parser/Parser.cpp45
-rw-r--r--mlir/test/IR/parser.mlir8
9 files changed, 139 insertions, 41 deletions
diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md
index 07a422fdfe0..1f3a7d697e2 100644
--- a/mlir/g3doc/LangRef.md
+++ b/mlir/g3doc/LangRef.md
@@ -1101,8 +1101,8 @@ function ::= `func` function-signature function-attributes? function-body?
function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
-argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:`
-type
+argument-list ::= type attribute-dict? (`,` type attribute-dict?)* | /*empty*/
+named-argument ::= ssa-id `:` type attribute-dict?
function-attributes ::= `attributes` attribute-dict
function-body ::= `{` block+ `}`
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 52c5e4eb495..f51af212313 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -548,6 +548,7 @@ using NamedAttribute = std::pair<Identifier, Attribute>;
/// searches for everything.
class NamedAttributeList {
public:
+ NamedAttributeList() : attrs(nullptr) {}
NamedAttributeList(MLIRContext *context, ArrayRef<NamedAttribute> attributes);
/// Return all of the attributes on this operation.
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 9431b6702c1..5ac9176550f 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -178,18 +178,41 @@ public:
/// constants to names. Attributes may be dynamically added and removed over
/// the lifetime of an function.
- /// Return all of the attributes on this instruction.
+ /// Return all of the attributes on this function.
ArrayRef<NamedAttribute> getAttrs() const { return attrs.getAttrs(); }
+ /// Return all of the attributes for the argument at 'index'.
+ ArrayRef<NamedAttribute> getArgAttrs(unsigned index) const {
+ assert(index < getNumArguments() && "invalid argument number");
+ return argAttrs[index].getAttrs();
+ }
+
/// Set the attributes held by this function.
void setAttrs(ArrayRef<NamedAttribute> attributes) {
attrs.setAttrs(getContext(), attributes);
}
+ /// Set the attributes held by the argument at 'index'.
+ void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
+ assert(index < getNumArguments() && "invalid argument number");
+ argAttrs[index].setAttrs(getContext(), attributes);
+ }
+
/// Return the specified attribute if present, null otherwise.
Attribute getAttr(Identifier name) const { return attrs.get(name); }
Attribute getAttr(StringRef name) const { return attrs.get(name); }
+ /// Return the specified attribute, if present, for the argument at 'index',
+ /// null otherwise.
+ Attribute getArgAttr(unsigned index, Identifier name) const {
+ assert(index < getNumArguments() && "invalid argument number");
+ return argAttrs[index].get(name);
+ }
+ Attribute getArgAttr(unsigned index, StringRef name) const {
+ assert(index < getNumArguments() && "invalid argument number");
+ return argAttrs[index].get(name);
+ }
+
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) const {
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
@@ -198,17 +221,36 @@ public:
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
+ template <typename AttrClass>
+ AttrClass getArgAttrOfType(unsigned index, Identifier name) const {
+ return getArgAttr(index, name).dyn_cast_or_null<AttrClass>();
+ }
+
+ template <typename AttrClass>
+ AttrClass getArgAttrOfType(unsigned index, StringRef name) const {
+ return getArgAttr(index, name).dyn_cast_or_null<AttrClass>();
+ }
+
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute value) {
attrs.set(getContext(), name, value);
}
+ void setArgAttr(unsigned index, Identifier name, Attribute value) {
+ assert(index < getNumArguments() && "invalid argument number");
+ argAttrs[index].set(getContext(), name, value);
+ }
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
return attrs.remove(getContext(), name);
}
+ NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
+ Identifier name) {
+ assert(index < getNumArguments() && "invalid argument number");
+ return attrs.remove(getContext(), name);
+ }
//===--------------------------------------------------------------------===//
// Other
@@ -272,6 +314,9 @@ private:
/// This holds general named attributes for the function.
NamedAttributeList attrs;
+ /// The attributes lists for each of the function arguments.
+ std::vector<NamedAttributeList> argAttrs;
+
/// The contents of the body.
BlockList blocks;
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index 739aac6b892..1be4a86692a 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -38,6 +38,7 @@
#include "mlir/IR/Function.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/Module.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
@@ -125,15 +126,6 @@ bool FuncVerifier::verify() {
if (!funcNameRegex.match(fn.getName().strref()))
return failure("invalid function name '" + fn.getName().strref() + "'", fn);
- // External functions have nothing more to check.
- if (fn.isExternal())
- return false;
-
- // Verify the first block has no predecessors.
- auto *firstBB = &fn.front();
- if (!firstBB->hasNoPredecessors())
- return failure("entry block of function may not have predecessors", fn);
-
/// Verify that all of the attributes are okay.
for (auto attr : fn.getAttrs()) {
if (!attrNameRegex.match(attr.first))
@@ -143,6 +135,28 @@ bool FuncVerifier::verify() {
return true;
}
+ /// Verify that all of the argument attributes are okay.
+ for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
+ for (auto attr : fn.getArgAttrs(i)) {
+ if (!attrNameRegex.match(attr.first))
+ return failure(
+ llvm::formatv("invalid attribute name '{0}' on argument {1}",
+ attr.first.strref(), i),
+ fn);
+ if (verifyAttribute(attr.second, fn))
+ return true;
+ }
+ }
+
+ // External functions have nothing more to check.
+ if (fn.isExternal())
+ return false;
+
+ // Verify the first block has no predecessors.
+ auto *firstBB = &fn.front();
+ if (!firstBB->hasNoPredecessors())
+ return failure("entry block of function may not have predecessors", fn);
+
// Verify that the argument list of the function and the arg list of the first
// block line up.
auto fnInputTypes = fn.getType().getInputs();
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e46c6710425..d025f224974 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1299,20 +1299,21 @@ void FunctionPrinter::printFunctionSignature() {
os << "func @" << function->getName() << '(';
auto fnType = function->getType();
+ bool isExternal = function->isExternal();
+ for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
+ if (i > 0)
+ os << ", ";
- // If this is an external function, don't print argument labels.
- if (function->isExternal()) {
- interleaveComma(fnType.getInputs(),
- [&](Type eltType) { printType(eltType); });
- } else {
- for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
- if (i > 0)
- os << ", ";
- auto *arg = function->getArgument(i);
- printOperand(arg);
+ // If this is an external function, don't print argument labels.
+ if (!isExternal) {
+ printOperand(function->getArgument(i));
os << ": ";
- printType(arg->getType());
}
+
+ printType(fnType.getInput(i));
+
+ // Print the attributes for this argument.
+ printOptionalAttrDict(function->getArgAttrs(i));
}
os << ')';
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 7168acf09a6..d64f9c49914 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -501,7 +501,10 @@ Attribute NamedAttributeList::get(StringRef name) const {
return nullptr;
}
Attribute NamedAttributeList::get(Identifier name) const {
- return get(name.strref());
+ for (auto elt : getAttrs())
+ if (elt.first == name)
+ return elt.second;
+ return nullptr;
}
/// If the an attribute exists with the specified name, change it to the new
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index 37bf19f788e..c9a482144af 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -29,7 +29,8 @@ using namespace mlir;
Function::Function(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: name(Identifier::get(name, type.getContext())), location(location),
- type(type), attrs(type.getContext(), attrs), blocks(this) {}
+ type(type), attrs(type.getContext(), attrs),
+ argAttrs(type.getNumInputs()), blocks(this) {}
Function::~Function() {
// Instructions may have cyclic references, which need to be dropped before we
@@ -167,7 +168,8 @@ Function *Function::clone(BlockAndValueMapping &mapper) const {
// If the function has a body, then the user might be deleting arguments to
// the function by specifying them in the mapper. If so, we don't add the
// argument to the input type vector.
- if (!empty()) {
+ bool isExternalFn = isExternal();
+ if (!isExternalFn) {
SmallVector<Type, 4> inputTypes;
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i)))
@@ -175,8 +177,15 @@ Function *Function::clone(BlockAndValueMapping &mapper) const {
newType = FunctionType::get(inputTypes, type.getResults(), getContext());
}
- // Create a new function and clone the current function into it.
+ // Create the new function.
Function *newFunc = new Function(getLoc(), getName(), newType);
+
+ /// Set the argument attributes for arguments that aren't being replaced.
+ for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
+ if (isExternalFn || !mapper.contains(getArgument(i)))
+ newFunc->setArgAttrs(destI++, getArgAttrs(i));
+
+ /// Clone the current function into the new one and return it.
cloneInto(newFunc, mapper);
return newFunc;
}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 372ce007b60..ba5f669f06b 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -3380,10 +3380,13 @@ private:
ParseResult parseTypeAliasDef();
// Functions.
- ParseResult parseArgumentList(SmallVectorImpl<Type> &argTypes,
- SmallVectorImpl<StringRef> &argNames);
- ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
- SmallVectorImpl<StringRef> &argNames);
+ ParseResult
+ parseArgumentList(SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<StringRef> &argNames,
+ SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
+ ParseResult parseFunctionSignature(
+ StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames,
+ SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
ParseResult parseFunc();
};
} // end anonymous namespace
@@ -3466,13 +3469,14 @@ ParseResult ModuleParser::parseTypeAliasDef() {
/// Parse a (possibly empty) list of Function arguments with types.
///
-/// named-argument ::= ssa-id `:` type
+/// named-argument ::= ssa-id `:` type attribute-dict?
/// argument-list ::= named-argument (`,` named-argument)* | /*empty*/
-/// argument-list ::= type (`,` type)* | /*empty*/
+/// argument-list ::= type attribute-dict? (`,` type attribute-dict?)*
+/// | /*empty*/
///
-ParseResult
-ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
- SmallVectorImpl<StringRef> &argNames) {
+ParseResult ModuleParser::parseArgumentList(
+ SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames,
+ SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
consumeToken(Token::l_paren);
// The argument list either has to consistently have ssa-id's followed by
@@ -3502,6 +3506,14 @@ ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
if (!elt)
return ParseFailure;
argTypes.push_back(elt);
+
+ // Parse the attribute dict.
+ SmallVector<NamedAttribute, 2> attrs;
+ if (getToken().is(Token::l_brace)) {
+ if (parseAttributeDict(attrs))
+ return ParseFailure;
+ }
+ argAttrs.push_back(attrs);
return ParseSuccess;
};
@@ -3514,9 +3526,9 @@ ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
/// function-signature ::=
/// function-id `(` argument-list `)` (`->` type-list)?
///
-ParseResult
-ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
- SmallVectorImpl<StringRef> &argNames) {
+ParseResult ModuleParser::parseFunctionSignature(
+ StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames,
+ SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
if (getToken().isNot(Token::at_identifier))
return emitError("expected a function identifier like '@foo'");
@@ -3527,7 +3539,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
return emitError("expected '(' in function signature");
SmallVector<Type, 4> argTypes;
- if (parseArgumentList(argTypes, argNames))
+ if (parseArgumentList(argTypes, argNames, argAttrs))
return ParseFailure;
// Parse the return type if present.
@@ -3553,9 +3565,10 @@ ParseResult ModuleParser::parseFunc() {
StringRef name;
FunctionType type;
SmallVector<StringRef, 4> argNames;
+ SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
auto loc = getToken().getLoc();
- if (parseFunctionSignature(name, type, argNames))
+ if (parseFunctionSignature(name, type, argNames, argAttrs))
return ParseFailure;
// If function attributes are present, parse them.
@@ -3579,6 +3592,10 @@ ParseResult ModuleParser::parseFunc() {
if (parseOptionalTrailingLocation(function))
return ParseFailure;
+ // Add the attributes to the function arguments.
+ for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i)
+ function->setArgAttrs(i, argAttrs[i]);
+
// External functions have no body.
if (getToken().isNot(Token::l_brace))
return ParseSuccess;
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 894b6746c56..575eb6194ce 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -812,3 +812,11 @@ func @internal_attrs()
// CHECK-LABEL: func @_valid.function$name
func @_valid.function$name()
+
+// CHECK-LABEL: func @external_func_arg_attrs(i32, i1 {arg.attr: 10}, i32)
+func @external_func_arg_attrs(i32, i1 {arg.attr: 10}, i32)
+
+// CHECK-LABEL: func @func_arg_attrs(%arg0: i1 {arg.attr: 10})
+func @func_arg_attrs(%arg0: i1 {arg.attr: 10}) {
+ return
+}
OpenPOWER on IntegriCloud