summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/g3doc/OpDefinitions.md18
-rw-r--r--mlir/test/mlir-tblgen/op-decl.td5
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp183
3 files changed, 108 insertions, 98 deletions
diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md
index d00b19b316a..0e786a0a4e7 100644
--- a/mlir/g3doc/OpDefinitions.md
+++ b/mlir/g3doc/OpDefinitions.md
@@ -290,7 +290,7 @@ class. See [Constraints](#constraints) for more information.
### Operation interfaces
[Operation interfaces](Interfaces.md#operation-interfaces) are a mechanism by
-which to opaquely call methods and access information on an *Op instance,
+which to opaquely call methods and access information on an *Op instance*,
without knowing the exact operation type. Operation interfaces defined in C++
can be accessed in the ODS framework via the `OpInterfaceTrait` class. Aside
from using pre-existing interfaces in the C++ API, the ODS framework also
@@ -414,7 +414,7 @@ The following builders are generated:
// All result-types/operands/attributes have one aggregate parameter.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
ArrayRef<Type> resultTypes,
- ArrayRef<Value> operands,
+ ValueRange operands,
ArrayRef<NamedAttribute> attributes);
// Each result-type/operand/attribute has a separate parameter. The parameters
@@ -433,7 +433,19 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state,
Value *i32_operand, Value *f32_operand, ...,
APInt i32_attr, StringRef f32_attr, ...);
-// (And potentially others depending on the specific op.)
+// Each operand/attribute has a separate parameter but result type is aggregate.
+static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+ ArrayRef<Type> resultTypes,
+ Value *i32_operand, Value *f32_operand, ...,
+ IntegerAttr i32_attr, FloatAttr f32_attr, ...);
+
+// All operands/attributes have aggregate parameters.
+// Generated if InferTypeOpInterface interface is specified.
+static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+ ValueRange operands,
+ ArrayRef<NamedAttribute> attributes);
+
+// (And manually specified builders depending on the specific op.)
```
The first form provides basic uniformity so that we can create ops using the
diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 2c90c279e38..c0420cb19c1 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -68,8 +68,9 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
// CHECK: FloatAttr attr2Attr()
// CHECK: Optional< APFloat > attr2();
// CHECK: static void build(Value *val);
-// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value *a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2);
-// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes);
+// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value *a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
+// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value *a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
+// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
// CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result);
// CHECK: void print(OpAsmPrinter &p);
// CHECK: LogicalResult verify();
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 37fa9c7840b..a73b11359fe 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -514,21 +514,9 @@ private:
// Generates builder methods for the operation.
void genBuilder();
- // Generates the build() method that takes each result-type/operand/attribute
- // as a stand-alone parameter. Attributes will take wrapped mlir::Attribute
- // values. The generated build() method also requires specifying result types
- // for all results.
- void genSeparateParamWrappedAttrBuilder();
-
- // Generates the build() method that takes each result-type/operand/attribute
- // as a stand-alone parameter. Attributes will take raw values without
- // mlir::Attribute wrapper. The generated build() method also requires
- // specifying result types for all results.
- void genSeparateParamUnwrappedAttrBuilder();
-
- // Generates the build() method that takes a single parameter for all the
- // result types and a separate parameter for each operand/attribute.
- void genCollectiveTypeParamBuilder();
+ // Generates the build() method that takes each operand/attribute
+ // as a stand-alone parameter.
+ void genSeparateArgParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first operand's
@@ -897,26 +885,11 @@ void OpEmitter::genNamedRegionGetters() {
}
}
-void OpEmitter::genSeparateParamWrappedAttrBuilder() {
- std::string paramList;
- llvm::SmallVector<std::string, 4> resultNames;
- buildParamList(paramList, resultNames, TypeParamKind::Separate);
-
- auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
- genCodeForAddingArgAndRegionForBuilder(m.body());
-
- // Push all result types to the operation state
- for (int i = 0, e = op.getNumResults(); i < e; ++i) {
- m.body() << " " << builderOpState << ".addTypes(" << resultNames[i]
- << ");\n";
- }
-}
-
-void OpEmitter::genSeparateParamUnwrappedAttrBuilder() {
+static bool canGenerateUnwrappedBuilder(Operator &op) {
// If this op does not have native attributes at all, return directly to avoid
// redefining builders.
if (op.getNumNativeAttributes() == 0)
- return;
+ return false;
bool canGenerate = false;
// We are generating builders that take raw values for attributes. We need to
@@ -930,47 +903,75 @@ void OpEmitter::genSeparateParamUnwrappedAttrBuilder() {
break;
}
}
- if (!canGenerate)
- return;
-
- std::string paramList;
- llvm::SmallVector<std::string, 4> resultNames;
- buildParamList(paramList, resultNames, TypeParamKind::Separate,
- AttrParamKind::UnwrappedValue);
-
- auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
- genCodeForAddingArgAndRegionForBuilder(m.body(), /*isRawValueAttr=*/true);
-
- // Push all result types to the operation state.
- for (int i = 0, e = op.getNumResults(); i < e; ++i) {
- m.body() << " " << builderOpState << ".addTypes(" << resultNames[i]
- << ");\n";
- }
+ return canGenerate;
}
-void OpEmitter::genCollectiveTypeParamBuilder() {
- auto numResults = op.getNumResults();
-
- // If this op has no results, then just skip generating this builder.
- // Otherwise we are generating the same signature as the separate-parameter
- // builder.
- if (numResults == 0)
- return;
-
- // Similarly for ops with one single variadic result, which will also have one
- // `ArrayRef<Type>` parameter for the result type.
- if (numResults == 1 && op.getResult(0).isVariadic())
- return;
-
- std::string paramList;
- llvm::SmallVector<std::string, 4> resultNames;
- buildParamList(paramList, resultNames, TypeParamKind::Collective);
+void OpEmitter::genSeparateArgParamBuilder() {
+ SmallVector<AttrParamKind, 2> attrBuilderType;
+ attrBuilderType.push_back(AttrParamKind::WrappedAttr);
+ if (canGenerateUnwrappedBuilder(op))
+ attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
+
+ // Emit with separate builders with or without unwrapped attributes and/or
+ // inferring result type.
+ auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
+ bool inferType) {
+ std::string paramList;
+ llvm::SmallVector<std::string, 4> resultNames;
+ buildParamList(paramList, resultNames, paramKind, attrType);
+
+ auto &m =
+ opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
+ auto &body = m.body();
+ genCodeForAddingArgAndRegionForBuilder(
+ body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
+
+ // Push all result types to the operation state
+
+ if (inferType) {
+ // Generate builder that infers type too.
+ // TODO(jpienaar): Subsume this with general checking if type can be
+ // infered automatically.
+ // TODO(jpienaar): Expand to handle regions.
+ body << formatv(R"(
+ SmallVector<Type, 2> inferedReturnTypes;
+ if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands,
+ {1}.attributes, /*regions=*/{{}, inferedReturnTypes)))
+ {1}.addTypes(inferedReturnTypes);
+ else
+ llvm::report_fatal_error("Failed to infer result type(s).");)",
+ opClass.getClassName(), builderOpState);
+ return;
+ }
- auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
- genCodeForAddingArgAndRegionForBuilder(m.body());
+ switch (paramKind) {
+ case TypeParamKind::None:
+ return;
+ case TypeParamKind::Separate:
+ for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+ body << " " << builderOpState << ".addTypes(" << resultNames[i]
+ << ");\n";
+ }
+ return;
+ case TypeParamKind::Collective:
+ body << " " << builderOpState << ".addTypes(resultTypes);\n";
+ return;
+ };
+ llvm_unreachable("unhandled TypeParamKind");
+ };
- // Push all result types to the operation state
- m.body() << formatv(" {0}.addTypes(resultTypes);\n", builderOpState);
+ bool canInferType =
+ op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0;
+ for (auto attrType : attrBuilderType) {
+ emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
+ if (canInferType)
+ emit(attrType, TypeParamKind::None, /*inferType=*/true);
+ // Emit separate arg build with collective type, unless there is only one
+ // variadic result, in which case the above would have already generated
+ // the same build method.
+ if (op.getNumResults() == 1 && !op.getResult(0).isVariadic())
+ emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
+ }
}
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
@@ -1021,8 +1022,7 @@ void OpEmitter::genInferedTypeCollectiveParamBuilder() {
/*regions=*/{{}, inferedReturnTypes)))
build(builder, tblgen_state, inferedReturnTypes, operands, attributes);
else
- llvm::report_fatal_error("Failed to infer result type(s).");
- )",
+ llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
}
@@ -1111,18 +1111,13 @@ void OpEmitter::genBuilder() {
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
- // We generate three builders here:
- // 1. one having a stand-alone parameter for each result type / operand /
- // attribute, and
- genSeparateParamWrappedAttrBuilder();
- genSeparateParamUnwrappedAttrBuilder();
- // 2. one having a stand-alone parameter for each operand / attribute and
- // an aggregated parameter for all result types, and
- genCollectiveTypeParamBuilder();
- // 3. one having an aggregated parameter for all result types / operands /
+ // We generate three classes of builders here:
+ // 1. one having a stand-alone parameter for each operand / attribute, and
+ genSeparateArgParamBuilder();
+ // 2. one having an aggregated parameter for all result types / operands /
// attributes, and
genCollectiveParamBuilder();
- // 4. one having a stand-alone parameter for each operand and attribute,
+ // 3. one having a stand-alone parameter for each operand and attribute,
// use the first operand or attribute's type as all result types
// to facilitate different call patterns.
if (op.getNumVariadicResults() == 0) {
@@ -1133,11 +1128,6 @@ void OpEmitter::genBuilder() {
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder();
}
- // TODO(jpienaar): Subsume this with general checking if type can be infered
- // automatically.
- // TODO(jpienaar): Expand to handle regions.
- if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
- genInferedTypeCollectiveParamBuilder();
}
void OpEmitter::genCollectiveParamBuilder() {
@@ -1156,13 +1146,6 @@ void OpEmitter::genCollectiveParamBuilder() {
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
- // Result types
- if (numVariadicResults == 0 || numNonVariadicResults != 0)
- body << " assert(resultTypes.size()"
- << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
- << "u && \"mismatched number of return types\");\n";
- body << " " << builderOpState << ".addTypes(resultTypes);\n";
-
// Operands
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()"
@@ -1179,6 +1162,20 @@ void OpEmitter::genCollectiveParamBuilder() {
for (int i = 0; i < numRegions; ++i)
m.body() << " (void)" << builderOpState << ".addRegion();\n";
}
+
+ // Result types
+ if (numVariadicResults == 0 || numNonVariadicResults != 0)
+ body << " assert(resultTypes.size()"
+ << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
+ << "u && \"mismatched number of return types\");\n";
+ body << " " << builderOpState << ".addTypes(resultTypes);\n";
+
+ // Generate builder that infers type too.
+ // TODO(jpienaar): Subsume this with general checking if type can be infered
+ // automatically.
+ // TODO(jpienaar): Expand to handle regions.
+ if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
+ genInferedTypeCollectiveParamBuilder();
}
void OpEmitter::buildParamList(std::string &paramList,
OpenPOWER on IntegriCloud