diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-04-04 09:25:38 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2019-04-05 07:40:50 -0700 |
| commit | 76cb20532636eb52300970d2b3eb00101b3821a1 (patch) | |
| tree | 209584d042e4857f3f0cdca81c84a62c0232685e /mlir | |
| parent | c7790df2ed9bdcde12683aee6cb89a2668b56661 (diff) | |
| download | bcm5719-llvm-76cb20532636eb52300970d2b3eb00101b3821a1.tar.gz bcm5719-llvm-76cb20532636eb52300970d2b3eb00101b3821a1.zip | |
[TableGen] Enforce constraints on attributes
Previously, attribute constraints are basically unused: we set true for almost
anything. This CL refactors common attribute kinds and sets constraints on
them properly. And fixed verification failures found by this change.
A noticeable one is that certain TF ops' attributes are required to be 64-bit
integer, but the corresponding TFLite ops expect 32-bit integer attributes.
Added bitwidth converters to handle this difference.
--
PiperOrigin-RevId: 241944008
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/IR/OpBase.td | 69 | ||||
| -rw-r--r-- | mlir/include/mlir/LLVMIR/LLVMOps.td | 2 | ||||
| -rw-r--r-- | mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp | 2 | ||||
| -rw-r--r-- | mlir/test/Quantization/convert-fakequant-invalid.mlir | 6 | ||||
| -rw-r--r-- | mlir/test/Quantization/convert-fakequant.mlir | 14 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/attr-enum.td | 8 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/op-attribute.td | 31 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/pattern-attr.td | 41 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/pattern-bound-symbol.td | 3 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/pattern-tAttr.td | 2 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/predicate.td | 4 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/RewriterGen.cpp | 38 |
12 files changed, 166 insertions, 54 deletions
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 473997a50da..e149d796667 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -421,56 +421,65 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> { let isOptional = 0b1; } -// A generic attribute that must be constructed around a specific type. -// Backed by a C++ class "attrName". -class TypeBasedAttr<BuildableType t, string attrName, string descr> : - Attr<CPred<"true">, descr> { - let constBuilderCall = - "{0}.get" # attrName # "({0}." # t.builderCall # ", {1})"; - let storageType = attrName; +// A generic attribute that must be constructed around a specific type +// `attrValType`. Backed by MLIR attribute kind `attrKind`. +class TypedAttrBase<BuildableType attrValType, string attrKind, + Pred condition, string descr> : + Attr<condition, descr> { + let constBuilderCall = "{0}.get" # attrKind # "({0}." # + attrValType.builderCall # ", {1})"; + let storageType = attrKind; } // Any attribute. -def AnyAttr : Attr<CPred<"true">, "any"> { +def AnyAttr : Attr<CPred<"true">, "any attribute"> { let storageType = "Attribute"; let returnType = "Attribute"; let convertFromStorage = "{0}"; let constBuilderCall = "{1}"; } -def BoolAttr : Attr<CPred<"true">, "bool"> { +def BoolAttr : Attr<CPred<"{0}.isa<BoolAttr>()">, "bool attribute"> { let storageType = [{ BoolAttr }]; let returnType = [{ bool }]; let constBuilderCall = [{ {0}.getBoolAttr({1}) }]; } -// Base class for instantiating integer attributes of fixed width. -class IntegerAttrBase<BuildableType t, string descr> : - TypeBasedAttr<t, "IntegerAttr", descr> { +// Base class for integer attributes of fixed width. +class IntegerAttrBase<I attrValType, string descr> : + TypedAttrBase<attrValType, "IntegerAttr", + AllOf<[CPred<"{0}.isa<IntegerAttr>()">, + CPred<"{0}.cast<IntegerAttr>().getType()." + "isInteger(" # attrValType.bitwidth # ")">]>, + descr> { let returnType = [{ APInt }]; } -def I32Attr : IntegerAttrBase<I32, "32-bit integer">; -def I64Attr : IntegerAttrBase<I64, "64-bit integer">; +def I32Attr : IntegerAttrBase<I32, "32-bit integer attribute">; +def I64Attr : IntegerAttrBase<I64, "64-bit integer attribute">; -// Base class for instantiating float attributes of fixed width. -class FloatAttrBase<BuildableType t, string descr> : - TypeBasedAttr<t, "FloatAttr", descr> { +// Base class for float attributes of fixed width. +class FloatAttrBase<F attrValType, string descr> : + TypedAttrBase<attrValType, "FloatAttr", + AllOf<[CPred<"{0}.isa<FloatAttr>()">, + CPred<"{0}.cast<FloatAttr>().getType().isF" # + attrValType.bitwidth # "()">]>, + descr> { let returnType = [{ APFloat }]; } -def F32Attr : FloatAttrBase<F32, "32-bit float">; -def F64Attr : FloatAttrBase<F64, "64-bit float">; +def F32Attr : FloatAttrBase<F32, "32-bit float attribute">; +def F64Attr : FloatAttrBase<F64, "64-bit float attribute">; // An attribute backed by a string type. -class StringBasedAttr<Pred condition, string descr> : - Attr<condition, descr> { +class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> { let constBuilderCall = [{ {0}.getStringAttr("{1}") }]; let storageType = [{ StringAttr }]; let returnType = [{ StringRef }]; } -def StrAttr : StringBasedAttr<CPred<"true">, "string">; +def StrAttr : StringBasedAttr<CPred<"{0}.isa<StringAttr>()">, + "string attribute">; // An enum attribute case. class EnumAttrCase<string sym> : StringBasedAttr< @@ -485,7 +494,9 @@ class EnumAttrCase<string sym> : StringBasedAttr< // on the string: only the symbols of the allowed cases are permitted as the // string value. class EnumAttr<string name, string description, list<EnumAttrCase> cases> : - StringBasedAttr<AnyOf<!foreach(case, cases, case.predicate)>, description> { + StringBasedAttr<AllOf<[StrAttr.predicate, + AnyOf<!foreach(case, cases, case.predicate)>]>, + description> { // The C++ enum class name string className = name; // List of all accepted cases @@ -499,16 +510,20 @@ class ElementsAttrBase<Pred condition, string description> : let convertFromStorage = "{0}"; } -def ElementsAttr: ElementsAttrBase<CPred<"true">, "constant vector/tensor">; +def ElementsAttr: ElementsAttrBase<CPred<"{0}.isa<ElementsAttr>()">, + "constant vector/tensor attribute">; -def ArrayAttr : Attr<CPred<"true">, "array"> { +// TODO(antiagainst): Define common ArrayAttr subclasses and properly +// set constraints. +def ArrayAttr : Attr<CPred<"{0}.isa<ArrayAttr>()">, "array attribute"> { let storageType = [{ ArrayAttr }]; let returnType = [{ ArrayAttr }]; code convertFromStorage = "{0}"; } // Attributes containing functions. -def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">, "function"> { +def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">, + "function attribute"> { let storageType = [{ FunctionAttr }]; let returnType = [{ Function * }]; let convertFromStorage = [{ {0}.getValue() }]; @@ -531,7 +546,7 @@ class TypeAttrBase<string retType, string description> : // DerivedAttr are attributes whose value is computed from properties // of the operation. They do not require additional storage and are // materialized as needed. -class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived"> { +class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived attribute"> { let returnType = ret; code body = b; } diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index be3e1491fea..d9a6f2b98ff 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -151,7 +151,7 @@ def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">; // Other integer operations. def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, - Arguments<(ins I32Attr:$predicate, LLVM_Type:$lhs, + Arguments<(ins I64Attr:$predicate, LLVM_Type:$lhs, LLVM_Type:$rhs)> { let llvmBuilder = [{ $res = builder.CreateICmp(getLLVMCmpPredicate( diff --git a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp index a1c9568e422..2d007e2b423 100644 --- a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp @@ -68,7 +68,7 @@ public: UniformQuantizedType uniformElementType = fakeQuantAttrsToType( fqOp.getLoc(), fqOp.num_bits().getSExtValue(), - fqOp.min().convertToDouble(), fqOp.max().convertToDouble(), + fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), fqOp.narrow_range(), converter.expressedType); if (!uniformElementType) { diff --git a/mlir/test/Quantization/convert-fakequant-invalid.mlir b/mlir/test/Quantization/convert-fakequant-invalid.mlir index bdaab47ab63..e2ba76a5a79 100644 --- a/mlir/test/Quantization/convert-fakequant-invalid.mlir +++ b/mlir/test/Quantization/convert-fakequant-invalid.mlir @@ -6,7 +6,7 @@ func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.500000]}} %0 = "quant.const_fake_quant"(%arg0) { - min: 1.1, max: 1.5, num_bits: 8 + min: 1.1 : f32, max: 1.5 : f32, num_bits: 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } @@ -17,7 +17,7 @@ func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.000000}} %0 = "quant.const_fake_quant"(%arg0) { - min: 1.1, max: 1.0, num_bits: 8 + min: 1.1 : f32, max: 1.0 : f32, num_bits: 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } @@ -28,7 +28,7 @@ func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> { ^bb0(%arg0: tensor<8x4x3xi1>): // expected-error@+1 {{op operand #0 must be tensor of 32-bit float values}} %0 = "quant.const_fake_quant"(%arg0) { - min: 1.1, max: 1.0, num_bits: 8 + min: 1.1 : f32, max: 1.0 : f32, num_bits: 8 } : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1> return %0 : tensor<8x4x3xi1> } diff --git a/mlir/test/Quantization/convert-fakequant.mlir b/mlir/test/Quantization/convert-fakequant.mlir index fcfa18e832f..8acabd50cce 100644 --- a/mlir/test/Quantization/convert-fakequant.mlir +++ b/mlir/test/Quantization/convert-fakequant.mlir @@ -10,7 +10,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { - min: 0.0, max: 1.0, num_bits: 8 + min: 0.0 : f32, max: 1.0 : f32, num_bits: 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } @@ -25,7 +25,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { - min: 0.0, max: 1.0, num_bits: 8, narrow_range: true + min: 0.0 : f32, max: 1.0 : f32, num_bits: 8, narrow_range: true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } @@ -40,7 +40,7 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32 // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { - min: -1.0, max: 0.9921875, num_bits: 8, narrow_range: false + min: -1.0 : f32, max: 0.9921875 : f32, num_bits: 8, narrow_range: false } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } @@ -52,11 +52,11 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32 func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">> - // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>) + // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">> + // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { - min: -1.0, max: 0.999969482, num_bits: 16 + min: -1.0 : f32, max: 0.999969482 : f32, num_bits: 16 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } @@ -71,7 +71,7 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> { // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>) // CHECK-SAME: -> tensor<f32> %0 = "quant.const_fake_quant"(%arg0) { - min: 0.0, max: 1.0, num_bits: 8 + min: 0.0 : f32, max: 1.0 : f32, num_bits: 8 } : (tensor<f32>) -> tensor<f32> return %0 : tensor<f32> } diff --git a/mlir/test/mlir-tblgen/attr-enum.td b/mlir/test/mlir-tblgen/attr-enum.td index 6d27985256f..84fbdb0b666 100644 --- a/mlir/test/mlir-tblgen/attr-enum.td +++ b/mlir/test/mlir-tblgen/attr-enum.td @@ -21,7 +21,7 @@ def NS_OpA : Op<"op_a_with_enum_attr", []> { // DEF-LABEL: OpA::verify() // DEF: auto tblgen_attr = this->getAttr("attr"); -// DEF: if (!(((tblgen_attr.cast<StringAttr>().getValue() == "A")) || ((tblgen_attr.cast<StringAttr>().getValue() == "B")) || ((tblgen_attr.cast<StringAttr>().getValue() == "C")))) +// DEF: if (!(((tblgen_attr.isa<StringAttr>())) && (((tblgen_attr.cast<StringAttr>().getValue() == "A")) || ((tblgen_attr.cast<StringAttr>().getValue() == "B")) || ((tblgen_attr.cast<StringAttr>().getValue() == "C"))))) // DEF-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: some enum"); def NS_OpB : Op<"op_b_with_enum_attr", []> { @@ -31,8 +31,12 @@ def NS_OpB : Op<"op_b_with_enum_attr", []> { def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>; // PAT-LABEL: struct GeneratedConvert0 + // PAT: PatternMatchResult match -// PAT: if (!((op0->getAttrOfType<StringAttr>("attr").cast<StringAttr>().getValue() == "A"))) return matchFailure(); +// PAT: auto attr = op0->getAttrOfType<StringAttr>("attr"); +// PAT-NEXT: if (!attr) return matchFailure(); +// PAT-NEXT: if (!((attr.cast<StringAttr>().getValue() == "A"))) return matchFailure(); + // PAT: void rewrite // PAT: auto vOpB0 = rewriter.create<NS::OpB>(loc, // PAT-NEXT: rewriter.getStringAttr("B") diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 6c81dc6a6bf..4a4776524ca 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -64,6 +64,37 @@ def AOp : Op<"a_op", []> { // CHECK-NEXT: if (tblgen_cAttr) { // CHECK-NEXT: if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind"); +def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">; + +// Test common attribute kinds' constraints +// --- + +def BOp : Op<"b_op", []> { + let arguments = (ins + AnyAttr:$any_attr, + BoolAttr:$bool_attr, + I32Attr:$i32_attr, + I64Attr:$i64_attr, + F32Attr:$f32_attr, + F64Attr:$f64_attr, + StrAttr:$str_attr, + ElementsAttr:$elements_attr, + FunctionAttr:$function_attr, + SomeTypeAttr:$type_attr + ); +} + +// CHECK-LABEL: BOp::verify +// CHECK: if (!((true))) +// CHECK: if (!((tblgen_bool_attr.isa<BoolAttr>()))) +// CHECK: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isInteger(32))))) +// CHECK: if (!(((tblgen_i64_attr.isa<IntegerAttr>())) && ((tblgen_i64_attr.cast<IntegerAttr>().getType().isInteger(64))))) +// CHECK: if (!(((tblgen_f32_attr.isa<FloatAttr>())) && ((tblgen_f32_attr.cast<FloatAttr>().getType().isF32())))) +// CHECK: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64())))) +// CHECK: if (!((tblgen_str_attr.isa<StringAttr>()))) +// CHECK: if (!((tblgen_elements_attr.isa<ElementsAttr>()))) +// CHECK: if (!((tblgen_function_attr.isa<FunctionAttr>()))) +// CHECK: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>())))) def MixOperandsAndAttrs : Op<"mix_operands_and_attrs", []> { let arguments = (ins F32Attr:$attr, F32:$operand, F32Attr:$otherAttr, F32:$otherArg); diff --git a/mlir/test/mlir-tblgen/pattern-attr.td b/mlir/test/mlir-tblgen/pattern-attr.td new file mode 100644 index 00000000000..17908502075 --- /dev/null +++ b/mlir/test/mlir-tblgen/pattern-attr.td @@ -0,0 +1,41 @@ +// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def MoreConstraint : AttrConstraint<CPred<"MoreConstraint">, "more constraint">; + +def OpA : Op<"op_a", []> { + let arguments = (ins + I32Attr:$required_attr, + OptionalAttr<I32Attr>:$optional_attr, + DefaultValuedAttr<I32Attr, "42">:$default_valued_attr, + I32Attr:$more_attr + ); + + let results = (outs I32:$result); +} + +def : Pat<(OpA $required, $optional, $default_valued, MoreConstraint:$more), + (OpA $required, $optional, $default_valued, $more)>; + +// Test attribute capturing +// --- + +// CHECK-LABEL: struct GeneratedConvert0 + +// CHECK: auto attr = op0->getAttrOfType<IntegerAttr>("required_attr"); +// CHECK-NEXT: if (!attr) return matchFailure(); +// CHECK-NEXT: s.required = attr; + +// CHECK: auto attr = op0->getAttrOfType<IntegerAttr>("optional_attr"); +// CHECK-NEXT: s.optional = attr; + +// CHECK: auto attr = op0->getAttrOfType<IntegerAttr>("default_valued_attr"); +// CHECK-NEXT: if (!attr) attr = mlir::Builder(ctx).getIntegerAttr(mlir::Builder(ctx).getIntegerType(32), 42); +// CHECK-NEXT: s.default_valued = attr; + +// CHECK: auto attr = op0->getAttrOfType<IntegerAttr>("more_attr"); +// CHECK-NEXT: if (!attr) return matchFailure(); +// CHECK-NEXT: if (!((MoreConstraint))) return matchFailure(); +// CHECK-NEXT: s.more = attr; + diff --git a/mlir/test/mlir-tblgen/pattern-bound-symbol.td b/mlir/test/mlir-tblgen/pattern-bound-symbol.td index 55f4d163116..46cf2e28a42 100644 --- a/mlir/test/mlir-tblgen/pattern-bound-symbol.td +++ b/mlir/test/mlir-tblgen/pattern-bound-symbol.td @@ -45,7 +45,8 @@ def : Pattern<(OpA:$res_a $operand, $attr), // CHECK: mlir::Operation* tblgen_res_a; (void)tblgen_res_a; // CHECK: tblgen_res_a = op0; // CHECK: s.operand = op0->getOperand(0); -// CHECK: s.attr = op0->getAttrOfType<IntegerAttr>("attr"); +// CHECK: attr = op0->getAttrOfType<IntegerAttr>("attr"); +// CHECK: s.attr = attr; // CHECK: if (!(tblgen_res_a->hasOneUse())) return matchFailure(); // Test bound results in result pattern diff --git a/mlir/test/mlir-tblgen/pattern-tAttr.td b/mlir/test/mlir-tblgen/pattern-tAttr.td index 08911156ead..32008d95819 100644 --- a/mlir/test/mlir-tblgen/pattern-tAttr.td +++ b/mlir/test/mlir-tblgen/pattern-tAttr.td @@ -4,7 +4,7 @@ include "mlir/IR/OpBase.td" // Create a Type and Attribute. def T : BuildableType<"buildT()">; -def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">; +def T_Attr : TypedAttrBase<T, "Attribute",CPred<"true">, "attribute of T type">; def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">; def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">; diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index fc2ab3fb6d2..6b912afa7fd 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -61,7 +61,7 @@ def OpF : Op<"op_for_int_min_val", []> { // CHECK-LABEL: OpF::verify() // CHECK: (tblgen_attr.cast<IntegerAttr>().getInt() >= 10) -// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit integer whose minimal value is 10"); +// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit integer attribute whose minimal value is 10"); def OpG : Op<"op_for_arr_min_count", []> { let arguments = (ins Confined<ArrayAttr, [ArrayMinCount<8>]>:$attr); @@ -69,4 +69,4 @@ def OpG : Op<"op_for_arr_min_count", []> { // CHECK-LABEL: OpG::verify() // CHECK: (tblgen_attr.cast<ArrayAttr>().size() >= 8) -// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array with at least 8 elements"); +// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 6622a783a95..3b3dfd38af5 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -331,8 +331,29 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, int indent) { Operator &op = tree.getDialectOp(opMap); auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); - auto matcher = tree.getArgAsLeaf(index); + const auto &attr = namedAttr->attr; + + os.indent(indent) << "{\n"; + indent += 2; + os.indent(indent) << formatv( + "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, + attr.getStorageType(), namedAttr->getName()); + + // TODO(antiagainst): This should use getter method to avoid duplication. + if (attr.hasDefaultValue()) { + os.indent(indent) << "if (!attr) attr = " + << formatv(attr.getDefaultValueTemplate().c_str(), + "mlir::Builder(ctx)") + << ";\n"; + } else if (attr.isOptional()) { + // For a missing attribut that is optional according to definition, we + // should just capature a mlir::Attribute() to signal the missing state. + // That is precisely what getAttr() returns on missing attributes. + } else { + os.indent(indent) << "if (!attr) return matchFailure();\n"; + } + auto matcher = tree.getArgAsLeaf(index); if (!matcher.isUnspecified()) { if (!matcher.isAttrMatcher()) { PrintFatalError( @@ -342,20 +363,19 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, // If a constraint is specified, we need to generate C++ statements to // check the constraint. - std::string condition = formatv( - matcher.getConditionTemplate().c_str(), - formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, - namedAttr->attr.getStorageType(), namedAttr->getName())); - os.indent(indent) << "if (!(" << condition << ")) return matchFailure();\n"; + os.indent(indent) << "if (!(" + << formatv(matcher.getConditionTemplate().c_str(), "attr") + << ")) return matchFailure();\n"; } // Capture the value auto name = tree.getArgName(index); if (!name.empty()) { - os.indent(indent) << getBoundArgument(name) << " = op" << depth - << "->getAttrOfType<" << namedAttr->attr.getStorageType() - << ">(\"" << namedAttr->getName() << "\");\n"; + os.indent(indent) << getBoundArgument(name) << " = attr;\n"; } + + indent -= 2; + os.indent(indent) << "}\n"; } void PatternEmitter::emitMatchMethod(DagNode tree) { |

