diff options
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/g3doc/OpDefinitions.md | 18 | ||||
-rw-r--r-- | mlir/test/mlir-tblgen/op-decl.td | 5 | ||||
-rw-r--r-- | mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 183 |
3 files changed, 108 insertions, 98 deletions
diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md index d00b19b316a..0e786a0a4e7 100644 --- a/mlir/g3doc/OpDefinitions.md +++ b/mlir/g3doc/OpDefinitions.md @@ -290,7 +290,7 @@ class. See [Constraints](#constraints) for more information. ### Operation interfaces [Operation interfaces](Interfaces.md#operation-interfaces) are a mechanism by -which to opaquely call methods and access information on an *Op instance, +which to opaquely call methods and access information on an *Op instance*, without knowing the exact operation type. Operation interfaces defined in C++ can be accessed in the ODS framework via the `OpInterfaceTrait` class. Aside from using pre-existing interfaces in the C++ API, the ODS framework also @@ -414,7 +414,7 @@ The following builders are generated: // All result-types/operands/attributes have one aggregate parameter. static void build(Builder *tblgen_builder, OperationState &tblgen_state, ArrayRef<Type> resultTypes, - ArrayRef<Value> operands, + ValueRange operands, ArrayRef<NamedAttribute> attributes); // Each result-type/operand/attribute has a separate parameter. The parameters @@ -433,7 +433,19 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state, Value *i32_operand, Value *f32_operand, ..., APInt i32_attr, StringRef f32_attr, ...); -// (And potentially others depending on the specific op.) +// Each operand/attribute has a separate parameter but result type is aggregate. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + ArrayRef<Type> resultTypes, + Value *i32_operand, Value *f32_operand, ..., + IntegerAttr i32_attr, FloatAttr f32_attr, ...); + +// All operands/attributes have aggregate parameters. +// Generated if InferTypeOpInterface interface is specified. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + ValueRange operands, + ArrayRef<NamedAttribute> attributes); + +// (And manually specified builders depending on the specific op.) ``` The first form provides basic uniformity so that we can create ops using the diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 2c90c279e38..c0420cb19c1 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -68,8 +68,9 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> { // CHECK: FloatAttr attr2Attr() // CHECK: Optional< APFloat > attr2(); // CHECK: static void build(Value *val); -// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value *a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2); -// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes); +// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value *a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value *a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes) // CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result); // CHECK: void print(OpAsmPrinter &p); // CHECK: LogicalResult verify(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 37fa9c7840b..a73b11359fe 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -514,21 +514,9 @@ private: // Generates builder methods for the operation. void genBuilder(); - // Generates the build() method that takes each result-type/operand/attribute - // as a stand-alone parameter. Attributes will take wrapped mlir::Attribute - // values. The generated build() method also requires specifying result types - // for all results. - void genSeparateParamWrappedAttrBuilder(); - - // Generates the build() method that takes each result-type/operand/attribute - // as a stand-alone parameter. Attributes will take raw values without - // mlir::Attribute wrapper. The generated build() method also requires - // specifying result types for all results. - void genSeparateParamUnwrappedAttrBuilder(); - - // Generates the build() method that takes a single parameter for all the - // result types and a separate parameter for each operand/attribute. - void genCollectiveTypeParamBuilder(); + // Generates the build() method that takes each operand/attribute + // as a stand-alone parameter. + void genSeparateArgParamBuilder(); // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first operand's @@ -897,26 +885,11 @@ void OpEmitter::genNamedRegionGetters() { } } -void OpEmitter::genSeparateParamWrappedAttrBuilder() { - std::string paramList; - llvm::SmallVector<std::string, 4> resultNames; - buildParamList(paramList, resultNames, TypeParamKind::Separate); - - auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - genCodeForAddingArgAndRegionForBuilder(m.body()); - - // Push all result types to the operation state - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - m.body() << " " << builderOpState << ".addTypes(" << resultNames[i] - << ");\n"; - } -} - -void OpEmitter::genSeparateParamUnwrappedAttrBuilder() { +static bool canGenerateUnwrappedBuilder(Operator &op) { // If this op does not have native attributes at all, return directly to avoid // redefining builders. if (op.getNumNativeAttributes() == 0) - return; + return false; bool canGenerate = false; // We are generating builders that take raw values for attributes. We need to @@ -930,47 +903,75 @@ void OpEmitter::genSeparateParamUnwrappedAttrBuilder() { break; } } - if (!canGenerate) - return; - - std::string paramList; - llvm::SmallVector<std::string, 4> resultNames; - buildParamList(paramList, resultNames, TypeParamKind::Separate, - AttrParamKind::UnwrappedValue); - - auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - genCodeForAddingArgAndRegionForBuilder(m.body(), /*isRawValueAttr=*/true); - - // Push all result types to the operation state. - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - m.body() << " " << builderOpState << ".addTypes(" << resultNames[i] - << ");\n"; - } + return canGenerate; } -void OpEmitter::genCollectiveTypeParamBuilder() { - auto numResults = op.getNumResults(); - - // If this op has no results, then just skip generating this builder. - // Otherwise we are generating the same signature as the separate-parameter - // builder. - if (numResults == 0) - return; - - // Similarly for ops with one single variadic result, which will also have one - // `ArrayRef<Type>` parameter for the result type. - if (numResults == 1 && op.getResult(0).isVariadic()) - return; - - std::string paramList; - llvm::SmallVector<std::string, 4> resultNames; - buildParamList(paramList, resultNames, TypeParamKind::Collective); +void OpEmitter::genSeparateArgParamBuilder() { + SmallVector<AttrParamKind, 2> attrBuilderType; + attrBuilderType.push_back(AttrParamKind::WrappedAttr); + if (canGenerateUnwrappedBuilder(op)) + attrBuilderType.push_back(AttrParamKind::UnwrappedValue); + + // Emit with separate builders with or without unwrapped attributes and/or + // inferring result type. + auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, + bool inferType) { + std::string paramList; + llvm::SmallVector<std::string, 4> resultNames; + buildParamList(paramList, resultNames, paramKind, attrType); + + auto &m = + opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); + auto &body = m.body(); + genCodeForAddingArgAndRegionForBuilder( + body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); + + // Push all result types to the operation state + + if (inferType) { + // Generate builder that infers type too. + // TODO(jpienaar): Subsume this with general checking if type can be + // infered automatically. + // TODO(jpienaar): Expand to handle regions. + body << formatv(R"( + SmallVector<Type, 2> inferedReturnTypes; + if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands, + {1}.attributes, /*regions=*/{{}, inferedReturnTypes))) + {1}.addTypes(inferedReturnTypes); + else + llvm::report_fatal_error("Failed to infer result type(s).");)", + opClass.getClassName(), builderOpState); + return; + } - auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - genCodeForAddingArgAndRegionForBuilder(m.body()); + switch (paramKind) { + case TypeParamKind::None: + return; + case TypeParamKind::Separate: + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + body << " " << builderOpState << ".addTypes(" << resultNames[i] + << ");\n"; + } + return; + case TypeParamKind::Collective: + body << " " << builderOpState << ".addTypes(resultTypes);\n"; + return; + }; + llvm_unreachable("unhandled TypeParamKind"); + }; - // Push all result types to the operation state - m.body() << formatv(" {0}.addTypes(resultTypes);\n", builderOpState); + bool canInferType = + op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0; + for (auto attrType : attrBuilderType) { + emit(attrType, TypeParamKind::Separate, /*inferType=*/false); + if (canInferType) + emit(attrType, TypeParamKind::None, /*inferType=*/true); + // Emit separate arg build with collective type, unless there is only one + // variadic result, in which case the above would have already generated + // the same build method. + if (op.getNumResults() == 1 && !op.getResult(0).isVariadic()) + emit(attrType, TypeParamKind::Collective, /*inferType=*/false); + } } void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { @@ -1021,8 +1022,7 @@ void OpEmitter::genInferedTypeCollectiveParamBuilder() { /*regions=*/{{}, inferedReturnTypes))) build(builder, tblgen_state, inferedReturnTypes, operands, attributes); else - llvm::report_fatal_error("Failed to infer result type(s)."); - )", + llvm::report_fatal_error("Failed to infer result type(s).");)", opClass.getClassName(), builderOpState); } @@ -1111,18 +1111,13 @@ void OpEmitter::genBuilder() { // Generate default builders that requires all result type, operands, and // attributes as parameters. - // We generate three builders here: - // 1. one having a stand-alone parameter for each result type / operand / - // attribute, and - genSeparateParamWrappedAttrBuilder(); - genSeparateParamUnwrappedAttrBuilder(); - // 2. one having a stand-alone parameter for each operand / attribute and - // an aggregated parameter for all result types, and - genCollectiveTypeParamBuilder(); - // 3. one having an aggregated parameter for all result types / operands / + // We generate three classes of builders here: + // 1. one having a stand-alone parameter for each operand / attribute, and + genSeparateArgParamBuilder(); + // 2. one having an aggregated parameter for all result types / operands / // attributes, and genCollectiveParamBuilder(); - // 4. one having a stand-alone parameter for each operand and attribute, + // 3. one having a stand-alone parameter for each operand and attribute, // use the first operand or attribute's type as all result types // to facilitate different call patterns. if (op.getNumVariadicResults() == 0) { @@ -1133,11 +1128,6 @@ void OpEmitter::genBuilder() { if (op.getTrait("OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); } - // TODO(jpienaar): Subsume this with general checking if type can be infered - // automatically. - // TODO(jpienaar): Expand to handle regions. - if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0) - genInferedTypeCollectiveParamBuilder(); } void OpEmitter::genCollectiveParamBuilder() { @@ -1156,13 +1146,6 @@ void OpEmitter::genCollectiveParamBuilder() { auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); - // Result types - if (numVariadicResults == 0 || numNonVariadicResults != 0) - body << " assert(resultTypes.size()" - << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults - << "u && \"mismatched number of return types\");\n"; - body << " " << builderOpState << ".addTypes(resultTypes);\n"; - // Operands if (numVariadicOperands == 0 || numNonVariadicOperands != 0) body << " assert(operands.size()" @@ -1179,6 +1162,20 @@ void OpEmitter::genCollectiveParamBuilder() { for (int i = 0; i < numRegions; ++i) m.body() << " (void)" << builderOpState << ".addRegion();\n"; } + + // Result types + if (numVariadicResults == 0 || numNonVariadicResults != 0) + body << " assert(resultTypes.size()" + << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults + << "u && \"mismatched number of return types\");\n"; + body << " " << builderOpState << ".addTypes(resultTypes);\n"; + + // Generate builder that infers type too. + // TODO(jpienaar): Subsume this with general checking if type can be infered + // automatically. + // TODO(jpienaar): Expand to handle regions. + if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0) + genInferedTypeCollectiveParamBuilder(); } void OpEmitter::buildParamList(std::string ¶mList, |