summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/IR/op_base.td29
-rw-r--r--mlir/include/mlir/TableGen/Attribute.h3
-rw-r--r--mlir/lib/TableGen/Attribute.cpp4
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp26
4 files changed, 51 insertions, 11 deletions
diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td
index 02ddcd37a6d..353fa33aaca 100644
--- a/mlir/include/mlir/IR/op_base.td
+++ b/mlir/include/mlir/IR/op_base.td
@@ -160,6 +160,7 @@ class F<int width>
}
def F32 : F<32>;
+def F64 : F<64>;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
@@ -267,8 +268,14 @@ class Attr<Pred condition, string descr = ""> :
// Default value for attribute.
// Requires a constBuilderCall defined.
string defaultValue = ?;
+
+ // Whether the attribute is optional. Typically requires a custom
+ // convertFromStorage method to handle the case where the attribute is
+ // not present.
+ bit isOptional = 0b0;
}
+// Decorates an attribute to have an (unvalidated) default value if not present.
class DefaultValuedAttr<Attr attr, string val> :
Attr<attr.predicate, attr.description> {
// Construct this attribute with the input attribute and change only
@@ -281,6 +288,19 @@ class DefaultValuedAttr<Attr attr, string val> :
let defaultValue = val;
}
+// Decorates an attribute as optional. The return type of the generated
+// attribute accessor method will be Optional<>.
+class OptionalAttr<Attr attr> :
+ Attr<attr.predicate, attr.description> {
+ // Rewrite the attribute to be optional.
+ // Note: this has to be kept up to date with Attr above.
+ let storageType = attr.storageType;
+ let returnType = "Optional<" # attr.returnType #">";
+ let convertFromStorage = "{0} ? " # returnType # "({0}.getValue())" #
+ " : (llvm::None)";
+ let isOptional = 0b1;
+}
+
// A generic attribute that must be constructed around a specific type.
// Backed by a C++ class "attrName".
class TypeBasedAttr<BuildableType t, string attrName, string descr> :
@@ -299,7 +319,9 @@ class StringBasedAttr<string descr> : Attr<CPred<"true">, descr> {
// Base class for instantiating float attributes of fixed width.
class FloatAttrBase<BuildableType t, string descr> :
- TypeBasedAttr<t, "FloatAttr", descr>;
+ TypeBasedAttr<t, "FloatAttr", descr> {
+ let returnType = [{ APFloat }];
+}
// Base class for instantiating integer attributes of fixed width.
class IntegerAttrBase<BuildableType t, string descr> :
@@ -322,9 +344,8 @@ class ElementsAttrBase<Pred condition, string description> :
let convertFromStorage = "{0}";
}
def ElementsAttr: ElementsAttrBase<CPred<"true">, "constant vector/tensor">;
-def F32Attr : FloatAttrBase<F32, "32-bit float"> {
- let returnType = [{ APFloat }];
-}
+def F32Attr : FloatAttrBase<F32, "32-bit float">;
+def F64Attr : FloatAttrBase<F64, "64-bit float">;
def I32Attr : IntegerAttrBase<I32, "32-bit integer"> {
let storageType = [{ IntegerAttr }];
let returnType = [{ int }];
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index e601fdf22ea..324194dbf64 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -99,6 +99,9 @@ public:
// Returns whether this attribute has a default value.
bool hasDefaultValue() const;
+ // Returns whether this attribute is optional.
+ bool isOptional() const;
+
// Returns the template that can be used to produce the default value of
// the attribute.
// Syntax: {0} should be replaced with a builder.
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 2b8cda031ef..071cd61a4f4 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -118,6 +118,10 @@ bool tblgen::Attribute::hasDefaultValue() const {
return !getValueAsString(init).empty();
}
+bool tblgen::Attribute::isOptional() const {
+ return def->getValueAsBit("isOptional");
+}
+
std::string tblgen::Attribute::getDefaultValueTemplate() const {
assert(isConstBuildable() && "requiers constBuilderCall");
const auto *init = def->getValueInit("defaultValue");
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 167d44e2304..7606df0faaf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -203,8 +203,9 @@ void OpEmitter::emitAttrGetters() {
OUT(2) << attr.getReturnType() << ' ' << getter << "() const {\n";
// Return the queried attribute with the correct return type.
- std::string attrVal = formatv("this->getAttr(\"{1}\").dyn_cast<{0}>()",
- attr.getStorageType(), name);
+ std::string attrVal =
+ formatv("this->getAttr(\"{1}\").dyn_cast_or_null<{0}>()",
+ attr.getStorageType(), name);
OUT(4) << "auto attr = " << attrVal << ";\n";
if (attr.hasDefaultValue()) {
// Returns the default value if not set.
@@ -265,7 +266,8 @@ void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) {
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
break;
- os << ", " << attr.getStorageType() << ' ' << namedAttr.getName();
+ os << ", /*optional*/" << attr.getStorageType() << ' '
+ << namedAttr.getName();
}
os << ") {\n";
@@ -315,10 +317,19 @@ void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) {
}
// Push all attributes to the result
- for (const auto &namedAttr : op.getAttributes())
- if (!namedAttr.attr.isDerivedAttr())
+ for (const auto &namedAttr : op.getAttributes()) {
+ if (!namedAttr.attr.isDerivedAttr()) {
+ bool emitNotNullCheck = namedAttr.attr.isOptional();
+ if (emitNotNullCheck) {
+ OUT(4) << formatv("if ({0}) {\n", namedAttr.getName());
+ }
OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
namedAttr.getName());
+ if (emitNotNullCheck) {
+ OUT(4) << formatv("}\n");
+ }
+ }
+ }
OUT(2) << "}\n";
}
@@ -462,7 +473,8 @@ void OpEmitter::emitVerifier() {
continue;
}
- if (attr.hasDefaultValue()) {
+ bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
+ if (allowMissingAttr) {
// If the attribute has a default value, then only verify the predicate if
// set. This does effectively assume that the default value is valid.
// TODO: verify the debug value is valid (perhaps in debug mode only).
@@ -482,7 +494,7 @@ void OpEmitter::emitVerifier() {
name, attr.getTableGenDefName());
}
- if (attr.hasDefaultValue())
+ if (allowMissingAttr)
OUT(4) << "}\n";
}
OpenPOWER on IntegriCloud