diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/IR/op_base.td | 29 | ||||
| -rw-r--r-- | mlir/include/mlir/TableGen/Attribute.h | 3 | ||||
| -rw-r--r-- | mlir/lib/TableGen/Attribute.cpp | 4 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 26 |
4 files changed, 51 insertions, 11 deletions
diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index 02ddcd37a6d..353fa33aaca 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -160,6 +160,7 @@ class F<int width> } def F32 : F<32>; +def F64 : F<64>; // A container type is a type that has another type embedded within it. class ContainerType<Type etype, Pred containerPred, code elementTypeCall, @@ -267,8 +268,14 @@ class Attr<Pred condition, string descr = ""> : // Default value for attribute. // Requires a constBuilderCall defined. string defaultValue = ?; + + // Whether the attribute is optional. Typically requires a custom + // convertFromStorage method to handle the case where the attribute is + // not present. + bit isOptional = 0b0; } +// Decorates an attribute to have an (unvalidated) default value if not present. class DefaultValuedAttr<Attr attr, string val> : Attr<attr.predicate, attr.description> { // Construct this attribute with the input attribute and change only @@ -281,6 +288,19 @@ class DefaultValuedAttr<Attr attr, string val> : let defaultValue = val; } +// Decorates an attribute as optional. The return type of the generated +// attribute accessor method will be Optional<>. +class OptionalAttr<Attr attr> : + Attr<attr.predicate, attr.description> { + // Rewrite the attribute to be optional. + // Note: this has to be kept up to date with Attr above. + let storageType = attr.storageType; + let returnType = "Optional<" # attr.returnType #">"; + let convertFromStorage = "{0} ? " # returnType # "({0}.getValue())" # + " : (llvm::None)"; + let isOptional = 0b1; +} + // A generic attribute that must be constructed around a specific type. // Backed by a C++ class "attrName". class TypeBasedAttr<BuildableType t, string attrName, string descr> : @@ -299,7 +319,9 @@ class StringBasedAttr<string descr> : Attr<CPred<"true">, descr> { // Base class for instantiating float attributes of fixed width. class FloatAttrBase<BuildableType t, string descr> : - TypeBasedAttr<t, "FloatAttr", descr>; + TypeBasedAttr<t, "FloatAttr", descr> { + let returnType = [{ APFloat }]; +} // Base class for instantiating integer attributes of fixed width. class IntegerAttrBase<BuildableType t, string descr> : @@ -322,9 +344,8 @@ class ElementsAttrBase<Pred condition, string description> : let convertFromStorage = "{0}"; } def ElementsAttr: ElementsAttrBase<CPred<"true">, "constant vector/tensor">; -def F32Attr : FloatAttrBase<F32, "32-bit float"> { - let returnType = [{ APFloat }]; -} +def F32Attr : FloatAttrBase<F32, "32-bit float">; +def F64Attr : FloatAttrBase<F64, "64-bit float">; def I32Attr : IntegerAttrBase<I32, "32-bit integer"> { let storageType = [{ IntegerAttr }]; let returnType = [{ int }]; diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index e601fdf22ea..324194dbf64 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -99,6 +99,9 @@ public: // Returns whether this attribute has a default value. bool hasDefaultValue() const; + // Returns whether this attribute is optional. + bool isOptional() const; + // Returns the template that can be used to produce the default value of // the attribute. // Syntax: {0} should be replaced with a builder. diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 2b8cda031ef..071cd61a4f4 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -118,6 +118,10 @@ bool tblgen::Attribute::hasDefaultValue() const { return !getValueAsString(init).empty(); } +bool tblgen::Attribute::isOptional() const { + return def->getValueAsBit("isOptional"); +} + std::string tblgen::Attribute::getDefaultValueTemplate() const { assert(isConstBuildable() && "requiers constBuilderCall"); const auto *init = def->getValueInit("defaultValue"); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 167d44e2304..7606df0faaf 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -203,8 +203,9 @@ void OpEmitter::emitAttrGetters() { OUT(2) << attr.getReturnType() << ' ' << getter << "() const {\n"; // Return the queried attribute with the correct return type. - std::string attrVal = formatv("this->getAttr(\"{1}\").dyn_cast<{0}>()", - attr.getStorageType(), name); + std::string attrVal = + formatv("this->getAttr(\"{1}\").dyn_cast_or_null<{0}>()", + attr.getStorageType(), name); OUT(4) << "auto attr = " << attrVal << ";\n"; if (attr.hasDefaultValue()) { // Returns the default value if not set. @@ -265,7 +266,8 @@ void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) { const auto &attr = namedAttr.attr; if (attr.isDerivedAttr()) break; - os << ", " << attr.getStorageType() << ' ' << namedAttr.getName(); + os << ", /*optional*/" << attr.getStorageType() << ' ' + << namedAttr.getName(); } os << ") {\n"; @@ -315,10 +317,19 @@ void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) { } // Push all attributes to the result - for (const auto &namedAttr : op.getAttributes()) - if (!namedAttr.attr.isDerivedAttr()) + for (const auto &namedAttr : op.getAttributes()) { + if (!namedAttr.attr.isDerivedAttr()) { + bool emitNotNullCheck = namedAttr.attr.isOptional(); + if (emitNotNullCheck) { + OUT(4) << formatv("if ({0}) {\n", namedAttr.getName()); + } OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n", namedAttr.getName()); + if (emitNotNullCheck) { + OUT(4) << formatv("}\n"); + } + } + } OUT(2) << "}\n"; } @@ -462,7 +473,8 @@ void OpEmitter::emitVerifier() { continue; } - if (attr.hasDefaultValue()) { + bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); + if (allowMissingAttr) { // If the attribute has a default value, then only verify the predicate if // set. This does effectively assume that the default value is valid. // TODO: verify the debug value is valid (perhaps in debug mode only). @@ -482,7 +494,7 @@ void OpEmitter::emitVerifier() { name, attr.getTableGenDefName()); } - if (attr.hasDefaultValue()) + if (allowMissingAttr) OUT(4) << "}\n"; } |

