summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-04-04 09:25:38 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-04-05 07:40:50 -0700
commit76cb20532636eb52300970d2b3eb00101b3821a1 (patch)
tree209584d042e4857f3f0cdca81c84a62c0232685e /mlir
parentc7790df2ed9bdcde12683aee6cb89a2668b56661 (diff)
downloadbcm5719-llvm-76cb20532636eb52300970d2b3eb00101b3821a1.tar.gz
bcm5719-llvm-76cb20532636eb52300970d2b3eb00101b3821a1.zip
[TableGen] Enforce constraints on attributes
Previously, attribute constraints are basically unused: we set true for almost anything. This CL refactors common attribute kinds and sets constraints on them properly. And fixed verification failures found by this change. A noticeable one is that certain TF ops' attributes are required to be 64-bit integer, but the corresponding TFLite ops expect 32-bit integer attributes. Added bitwidth converters to handle this difference. -- PiperOrigin-RevId: 241944008
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/IR/OpBase.td69
-rw-r--r--mlir/include/mlir/LLVMIR/LLVMOps.td2
-rw-r--r--mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp2
-rw-r--r--mlir/test/Quantization/convert-fakequant-invalid.mlir6
-rw-r--r--mlir/test/Quantization/convert-fakequant.mlir14
-rw-r--r--mlir/test/mlir-tblgen/attr-enum.td8
-rw-r--r--mlir/test/mlir-tblgen/op-attribute.td31
-rw-r--r--mlir/test/mlir-tblgen/pattern-attr.td41
-rw-r--r--mlir/test/mlir-tblgen/pattern-bound-symbol.td3
-rw-r--r--mlir/test/mlir-tblgen/pattern-tAttr.td2
-rw-r--r--mlir/test/mlir-tblgen/predicate.td4
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp38
12 files changed, 166 insertions, 54 deletions
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 473997a50da..e149d796667 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -421,56 +421,65 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
let isOptional = 0b1;
}
-// A generic attribute that must be constructed around a specific type.
-// Backed by a C++ class "attrName".
-class TypeBasedAttr<BuildableType t, string attrName, string descr> :
- Attr<CPred<"true">, descr> {
- let constBuilderCall =
- "{0}.get" # attrName # "({0}." # t.builderCall # ", {1})";
- let storageType = attrName;
+// A generic attribute that must be constructed around a specific type
+// `attrValType`. Backed by MLIR attribute kind `attrKind`.
+class TypedAttrBase<BuildableType attrValType, string attrKind,
+ Pred condition, string descr> :
+ Attr<condition, descr> {
+ let constBuilderCall = "{0}.get" # attrKind # "({0}." #
+ attrValType.builderCall # ", {1})";
+ let storageType = attrKind;
}
// Any attribute.
-def AnyAttr : Attr<CPred<"true">, "any"> {
+def AnyAttr : Attr<CPred<"true">, "any attribute"> {
let storageType = "Attribute";
let returnType = "Attribute";
let convertFromStorage = "{0}";
let constBuilderCall = "{1}";
}
-def BoolAttr : Attr<CPred<"true">, "bool"> {
+def BoolAttr : Attr<CPred<"{0}.isa<BoolAttr>()">, "bool attribute"> {
let storageType = [{ BoolAttr }];
let returnType = [{ bool }];
let constBuilderCall = [{ {0}.getBoolAttr({1}) }];
}
-// Base class for instantiating integer attributes of fixed width.
-class IntegerAttrBase<BuildableType t, string descr> :
- TypeBasedAttr<t, "IntegerAttr", descr> {
+// Base class for integer attributes of fixed width.
+class IntegerAttrBase<I attrValType, string descr> :
+ TypedAttrBase<attrValType, "IntegerAttr",
+ AllOf<[CPred<"{0}.isa<IntegerAttr>()">,
+ CPred<"{0}.cast<IntegerAttr>().getType()."
+ "isInteger(" # attrValType.bitwidth # ")">]>,
+ descr> {
let returnType = [{ APInt }];
}
-def I32Attr : IntegerAttrBase<I32, "32-bit integer">;
-def I64Attr : IntegerAttrBase<I64, "64-bit integer">;
+def I32Attr : IntegerAttrBase<I32, "32-bit integer attribute">;
+def I64Attr : IntegerAttrBase<I64, "64-bit integer attribute">;
-// Base class for instantiating float attributes of fixed width.
-class FloatAttrBase<BuildableType t, string descr> :
- TypeBasedAttr<t, "FloatAttr", descr> {
+// Base class for float attributes of fixed width.
+class FloatAttrBase<F attrValType, string descr> :
+ TypedAttrBase<attrValType, "FloatAttr",
+ AllOf<[CPred<"{0}.isa<FloatAttr>()">,
+ CPred<"{0}.cast<FloatAttr>().getType().isF" #
+ attrValType.bitwidth # "()">]>,
+ descr> {
let returnType = [{ APFloat }];
}
-def F32Attr : FloatAttrBase<F32, "32-bit float">;
-def F64Attr : FloatAttrBase<F64, "64-bit float">;
+def F32Attr : FloatAttrBase<F32, "32-bit float attribute">;
+def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;
// An attribute backed by a string type.
-class StringBasedAttr<Pred condition, string descr> :
- Attr<condition, descr> {
+class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
let constBuilderCall = [{ {0}.getStringAttr("{1}") }];
let storageType = [{ StringAttr }];
let returnType = [{ StringRef }];
}
-def StrAttr : StringBasedAttr<CPred<"true">, "string">;
+def StrAttr : StringBasedAttr<CPred<"{0}.isa<StringAttr>()">,
+ "string attribute">;
// An enum attribute case.
class EnumAttrCase<string sym> : StringBasedAttr<
@@ -485,7 +494,9 @@ class EnumAttrCase<string sym> : StringBasedAttr<
// on the string: only the symbols of the allowed cases are permitted as the
// string value.
class EnumAttr<string name, string description, list<EnumAttrCase> cases> :
- StringBasedAttr<AnyOf<!foreach(case, cases, case.predicate)>, description> {
+ StringBasedAttr<AllOf<[StrAttr.predicate,
+ AnyOf<!foreach(case, cases, case.predicate)>]>,
+ description> {
// The C++ enum class name
string className = name;
// List of all accepted cases
@@ -499,16 +510,20 @@ class ElementsAttrBase<Pred condition, string description> :
let convertFromStorage = "{0}";
}
-def ElementsAttr: ElementsAttrBase<CPred<"true">, "constant vector/tensor">;
+def ElementsAttr: ElementsAttrBase<CPred<"{0}.isa<ElementsAttr>()">,
+ "constant vector/tensor attribute">;
-def ArrayAttr : Attr<CPred<"true">, "array"> {
+// TODO(antiagainst): Define common ArrayAttr subclasses and properly
+// set constraints.
+def ArrayAttr : Attr<CPred<"{0}.isa<ArrayAttr>()">, "array attribute"> {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayAttr }];
code convertFromStorage = "{0}";
}
// Attributes containing functions.
-def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">, "function"> {
+def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">,
+ "function attribute"> {
let storageType = [{ FunctionAttr }];
let returnType = [{ Function * }];
let convertFromStorage = [{ {0}.getValue() }];
@@ -531,7 +546,7 @@ class TypeAttrBase<string retType, string description> :
// DerivedAttr are attributes whose value is computed from properties
// of the operation. They do not require additional storage and are
// materialized as needed.
-class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived"> {
+class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived attribute"> {
let returnType = ret;
code body = b;
}
diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td
index be3e1491fea..d9a6f2b98ff 100644
--- a/mlir/include/mlir/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/LLVMIR/LLVMOps.td
@@ -151,7 +151,7 @@ def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">;
// Other integer operations.
def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
- Arguments<(ins I32Attr:$predicate, LLVM_Type:$lhs,
+ Arguments<(ins I64Attr:$predicate, LLVM_Type:$lhs,
LLVM_Type:$rhs)> {
let llvmBuilder = [{
$res = builder.CreateICmp(getLLVMCmpPredicate(
diff --git a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp
index a1c9568e422..2d007e2b423 100644
--- a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp
@@ -68,7 +68,7 @@ public:
UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
- fqOp.min().convertToDouble(), fqOp.max().convertToDouble(),
+ fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
fqOp.narrow_range(), converter.expressedType);
if (!uniformElementType) {
diff --git a/mlir/test/Quantization/convert-fakequant-invalid.mlir b/mlir/test/Quantization/convert-fakequant-invalid.mlir
index bdaab47ab63..e2ba76a5a79 100644
--- a/mlir/test/Quantization/convert-fakequant-invalid.mlir
+++ b/mlir/test/Quantization/convert-fakequant-invalid.mlir
@@ -6,7 +6,7 @@ func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.500000]}}
%0 = "quant.const_fake_quant"(%arg0) {
- min: 1.1, max: 1.5, num_bits: 8
+ min: 1.1 : f32, max: 1.5 : f32, num_bits: 8
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
@@ -17,7 +17,7 @@ func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000,1.000000}}
%0 = "quant.const_fake_quant"(%arg0) {
- min: 1.1, max: 1.0, num_bits: 8
+ min: 1.1 : f32, max: 1.0 : f32, num_bits: 8
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
@@ -28,7 +28,7 @@ func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {
^bb0(%arg0: tensor<8x4x3xi1>):
// expected-error@+1 {{op operand #0 must be tensor of 32-bit float values}}
%0 = "quant.const_fake_quant"(%arg0) {
- min: 1.1, max: 1.0, num_bits: 8
+ min: 1.1 : f32, max: 1.0 : f32, num_bits: 8
} : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1>
return %0 : tensor<8x4x3xi1>
}
diff --git a/mlir/test/Quantization/convert-fakequant.mlir b/mlir/test/Quantization/convert-fakequant.mlir
index fcfa18e832f..8acabd50cce 100644
--- a/mlir/test/Quantization/convert-fakequant.mlir
+++ b/mlir/test/Quantization/convert-fakequant.mlir
@@ -10,7 +10,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
- min: 0.0, max: 1.0, num_bits: 8
+ min: 0.0 : f32, max: 1.0 : f32, num_bits: 8
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
@@ -25,7 +25,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
- min: 0.0, max: 1.0, num_bits: 8, narrow_range: true
+ min: 0.0 : f32, max: 1.0 : f32, num_bits: 8, narrow_range: true
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
@@ -40,7 +40,7 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
- min: -1.0, max: 0.9921875, num_bits: 8, narrow_range: false
+ min: -1.0 : f32, max: 0.9921875 : f32, num_bits: 8, narrow_range: false
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
@@ -52,11 +52,11 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32
func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>
- // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>)
+ // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">>
+ // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">>)
// CHECK-SAME: -> tensor<8x4x3xf32>
%0 = "quant.const_fake_quant"(%arg0) {
- min: -1.0, max: 0.999969482, num_bits: 16
+ min: -1.0 : f32, max: 0.999969482 : f32, num_bits: 16
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
@@ -71,7 +71,7 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
// CHECK-SAME: -> tensor<f32>
%0 = "quant.const_fake_quant"(%arg0) {
- min: 0.0, max: 1.0, num_bits: 8
+ min: 0.0 : f32, max: 1.0 : f32, num_bits: 8
} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
diff --git a/mlir/test/mlir-tblgen/attr-enum.td b/mlir/test/mlir-tblgen/attr-enum.td
index 6d27985256f..84fbdb0b666 100644
--- a/mlir/test/mlir-tblgen/attr-enum.td
+++ b/mlir/test/mlir-tblgen/attr-enum.td
@@ -21,7 +21,7 @@ def NS_OpA : Op<"op_a_with_enum_attr", []> {
// DEF-LABEL: OpA::verify()
// DEF: auto tblgen_attr = this->getAttr("attr");
-// DEF: if (!(((tblgen_attr.cast<StringAttr>().getValue() == "A")) || ((tblgen_attr.cast<StringAttr>().getValue() == "B")) || ((tblgen_attr.cast<StringAttr>().getValue() == "C"))))
+// DEF: if (!(((tblgen_attr.isa<StringAttr>())) && (((tblgen_attr.cast<StringAttr>().getValue() == "A")) || ((tblgen_attr.cast<StringAttr>().getValue() == "B")) || ((tblgen_attr.cast<StringAttr>().getValue() == "C")))))
// DEF-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: some enum");
def NS_OpB : Op<"op_b_with_enum_attr", []> {
@@ -31,8 +31,12 @@ def NS_OpB : Op<"op_b_with_enum_attr", []> {
def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>;
// PAT-LABEL: struct GeneratedConvert0
+
// PAT: PatternMatchResult match
-// PAT: if (!((op0->getAttrOfType<StringAttr>("attr").cast<StringAttr>().getValue() == "A"))) return matchFailure();
+// PAT: auto attr = op0->getAttrOfType<StringAttr>("attr");
+// PAT-NEXT: if (!attr) return matchFailure();
+// PAT-NEXT: if (!((attr.cast<StringAttr>().getValue() == "A"))) return matchFailure();
+
// PAT: void rewrite
// PAT: auto vOpB0 = rewriter.create<NS::OpB>(loc,
// PAT-NEXT: rewriter.getStringAttr("B")
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 6c81dc6a6bf..4a4776524ca 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -64,6 +64,37 @@ def AOp : Op<"a_op", []> {
// CHECK-NEXT: if (tblgen_cAttr) {
// CHECK-NEXT: if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind");
+def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">;
+
+// Test common attribute kinds' constraints
+// ---
+
+def BOp : Op<"b_op", []> {
+ let arguments = (ins
+ AnyAttr:$any_attr,
+ BoolAttr:$bool_attr,
+ I32Attr:$i32_attr,
+ I64Attr:$i64_attr,
+ F32Attr:$f32_attr,
+ F64Attr:$f64_attr,
+ StrAttr:$str_attr,
+ ElementsAttr:$elements_attr,
+ FunctionAttr:$function_attr,
+ SomeTypeAttr:$type_attr
+ );
+}
+
+// CHECK-LABEL: BOp::verify
+// CHECK: if (!((true)))
+// CHECK: if (!((tblgen_bool_attr.isa<BoolAttr>())))
+// CHECK: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isInteger(32)))))
+// CHECK: if (!(((tblgen_i64_attr.isa<IntegerAttr>())) && ((tblgen_i64_attr.cast<IntegerAttr>().getType().isInteger(64)))))
+// CHECK: if (!(((tblgen_f32_attr.isa<FloatAttr>())) && ((tblgen_f32_attr.cast<FloatAttr>().getType().isF32()))))
+// CHECK: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
+// CHECK: if (!((tblgen_str_attr.isa<StringAttr>())))
+// CHECK: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
+// CHECK: if (!((tblgen_function_attr.isa<FunctionAttr>())))
+// CHECK: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
def MixOperandsAndAttrs : Op<"mix_operands_and_attrs", []> {
let arguments = (ins F32Attr:$attr, F32:$operand, F32Attr:$otherAttr, F32:$otherArg);
diff --git a/mlir/test/mlir-tblgen/pattern-attr.td b/mlir/test/mlir-tblgen/pattern-attr.td
new file mode 100644
index 00000000000..17908502075
--- /dev/null
+++ b/mlir/test/mlir-tblgen/pattern-attr.td
@@ -0,0 +1,41 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def MoreConstraint : AttrConstraint<CPred<"MoreConstraint">, "more constraint">;
+
+def OpA : Op<"op_a", []> {
+ let arguments = (ins
+ I32Attr:$required_attr,
+ OptionalAttr<I32Attr>:$optional_attr,
+ DefaultValuedAttr<I32Attr, "42">:$default_valued_attr,
+ I32Attr:$more_attr
+ );
+
+ let results = (outs I32:$result);
+}
+
+def : Pat<(OpA $required, $optional, $default_valued, MoreConstraint:$more),
+ (OpA $required, $optional, $default_valued, $more)>;
+
+// Test attribute capturing
+// ---
+
+// CHECK-LABEL: struct GeneratedConvert0
+
+// CHECK: auto attr = op0->getAttrOfType<IntegerAttr>("required_attr");
+// CHECK-NEXT: if (!attr) return matchFailure();
+// CHECK-NEXT: s.required = attr;
+
+// CHECK: auto attr = op0->getAttrOfType<IntegerAttr>("optional_attr");
+// CHECK-NEXT: s.optional = attr;
+
+// CHECK: auto attr = op0->getAttrOfType<IntegerAttr>("default_valued_attr");
+// CHECK-NEXT: if (!attr) attr = mlir::Builder(ctx).getIntegerAttr(mlir::Builder(ctx).getIntegerType(32), 42);
+// CHECK-NEXT: s.default_valued = attr;
+
+// CHECK: auto attr = op0->getAttrOfType<IntegerAttr>("more_attr");
+// CHECK-NEXT: if (!attr) return matchFailure();
+// CHECK-NEXT: if (!((MoreConstraint))) return matchFailure();
+// CHECK-NEXT: s.more = attr;
+
diff --git a/mlir/test/mlir-tblgen/pattern-bound-symbol.td b/mlir/test/mlir-tblgen/pattern-bound-symbol.td
index 55f4d163116..46cf2e28a42 100644
--- a/mlir/test/mlir-tblgen/pattern-bound-symbol.td
+++ b/mlir/test/mlir-tblgen/pattern-bound-symbol.td
@@ -45,7 +45,8 @@ def : Pattern<(OpA:$res_a $operand, $attr),
// CHECK: mlir::Operation* tblgen_res_a; (void)tblgen_res_a;
// CHECK: tblgen_res_a = op0;
// CHECK: s.operand = op0->getOperand(0);
-// CHECK: s.attr = op0->getAttrOfType<IntegerAttr>("attr");
+// CHECK: attr = op0->getAttrOfType<IntegerAttr>("attr");
+// CHECK: s.attr = attr;
// CHECK: if (!(tblgen_res_a->hasOneUse())) return matchFailure();
// Test bound results in result pattern
diff --git a/mlir/test/mlir-tblgen/pattern-tAttr.td b/mlir/test/mlir-tblgen/pattern-tAttr.td
index 08911156ead..32008d95819 100644
--- a/mlir/test/mlir-tblgen/pattern-tAttr.td
+++ b/mlir/test/mlir-tblgen/pattern-tAttr.td
@@ -4,7 +4,7 @@ include "mlir/IR/OpBase.td"
// Create a Type and Attribute.
def T : BuildableType<"buildT()">;
-def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">;
+def T_Attr : TypedAttrBase<T, "Attribute",CPred<"true">, "attribute of T type">;
def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">;
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index fc2ab3fb6d2..6b912afa7fd 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -61,7 +61,7 @@ def OpF : Op<"op_for_int_min_val", []> {
// CHECK-LABEL: OpF::verify()
// CHECK: (tblgen_attr.cast<IntegerAttr>().getInt() >= 10)
-// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit integer whose minimal value is 10");
+// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit integer attribute whose minimal value is 10");
def OpG : Op<"op_for_arr_min_count", []> {
let arguments = (ins Confined<ArrayAttr, [ArrayMinCount<8>]>:$attr);
@@ -69,4 +69,4 @@ def OpG : Op<"op_for_arr_min_count", []> {
// CHECK-LABEL: OpG::verify()
// CHECK: (tblgen_attr.cast<ArrayAttr>().size() >= 8)
-// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array with at least 8 elements");
+// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements");
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 6622a783a95..3b3dfd38af5 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -331,8 +331,29 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
- auto matcher = tree.getArgAsLeaf(index);
+ const auto &attr = namedAttr->attr;
+
+ os.indent(indent) << "{\n";
+ indent += 2;
+ os.indent(indent) << formatv(
+ "auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
+ attr.getStorageType(), namedAttr->getName());
+
+ // TODO(antiagainst): This should use getter method to avoid duplication.
+ if (attr.hasDefaultValue()) {
+ os.indent(indent) << "if (!attr) attr = "
+ << formatv(attr.getDefaultValueTemplate().c_str(),
+ "mlir::Builder(ctx)")
+ << ";\n";
+ } else if (attr.isOptional()) {
+ // For a missing attribut that is optional according to definition, we
+ // should just capature a mlir::Attribute() to signal the missing state.
+ // That is precisely what getAttr() returns on missing attributes.
+ } else {
+ os.indent(indent) << "if (!attr) return matchFailure();\n";
+ }
+ auto matcher = tree.getArgAsLeaf(index);
if (!matcher.isUnspecified()) {
if (!matcher.isAttrMatcher()) {
PrintFatalError(
@@ -342,20 +363,19 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
- std::string condition = formatv(
- matcher.getConditionTemplate().c_str(),
- formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
- namedAttr->attr.getStorageType(), namedAttr->getName()));
- os.indent(indent) << "if (!(" << condition << ")) return matchFailure();\n";
+ os.indent(indent) << "if (!("
+ << formatv(matcher.getConditionTemplate().c_str(), "attr")
+ << ")) return matchFailure();\n";
}
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
- os.indent(indent) << getBoundArgument(name) << " = op" << depth
- << "->getAttrOfType<" << namedAttr->attr.getStorageType()
- << ">(\"" << namedAttr->getName() << "\");\n";
+ os.indent(indent) << getBoundArgument(name) << " = attr;\n";
}
+
+ indent -= 2;
+ os.indent(indent) << "}\n";
}
void PatternEmitter::emitMatchMethod(DagNode tree) {
OpenPOWER on IntegriCloud