diff options
| author | Jacques Pienaar <jpienaar@google.com> | 2019-12-06 10:52:38 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-06 10:53:06 -0800 |
| commit | 398f04aa49109fd5d1eff2c1946a2956dc6b29c6 (patch) | |
| tree | 20c0b758574c11999f3d0c670fe5d0dc7a765b97 /mlir/tools | |
| parent | e216a72ab8587c443e4c5c06aabc71c36712ce7e (diff) | |
| download | bcm5719-llvm-398f04aa49109fd5d1eff2c1946a2956dc6b29c6.tar.gz bcm5719-llvm-398f04aa49109fd5d1eff2c1946a2956dc6b29c6.zip | |
Generate builder for ops that use InferTypeOpInterface trait in ODS
For ops with infer type op interface defined, generate version that calls the inferal method on build. This is intermediate step to removing special casing of SameOperandsAndResultType & FirstAttrDereivedResultType. After that would be generating the inference code, with the initial focus on shaped container types. In between I plan to refactor these a bit to reuse generated paths. The intention would not be to add the type inference trait in multiple places, but rather to take advantage of the current modelling in ODS where possible to emit it instead.
Switch the `inferReturnTypes` method to be static.
Skipping ops with regions here as I don't like the Region vs unique_ptr<Region> difference at the moment, and I want the infer return type trait to be useful for verification too. So instead, just skip it for now to avoid churn.
PiperOrigin-RevId: 284217913
Diffstat (limited to 'mlir/tools')
| -rw-r--r-- | mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 48 |
1 files changed, 38 insertions, 10 deletions
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index b5fd0862b45..004b93d5941 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -541,6 +541,11 @@ private: // operand's type as all results' types. void genUseOperandAsResultTypeCollectiveParamBuilder(); + // Generates the build() method that takes aggregate operands/attributes + // parameters. This build() method uses inferred types as result types. + // Requires: The type needs to be inferable via InferTypeOpInterface. + void genInferedTypeCollectiveParamBuilder(); + // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first attribute's // type as all result's types. @@ -968,11 +973,6 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); - // Result types - SmallVector<std::string, 2> resultTypes(numResults, "operands[0]->getType()"); - body << " " << builderOpState << ".addTypes({" - << llvm::join(resultTypes, ", ") << "});\n\n"; - // Operands body << " " << builderOpState << ".addOperands(operands);\n\n"; @@ -984,6 +984,27 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { for (int i = 0; i < numRegions; ++i) m.body() << " (void)" << builderOpState << ".addRegion();\n"; } + + // Result types + SmallVector<std::string, 2> resultTypes(numResults, "operands[0]->getType()"); + body << " " << builderOpState << ".addTypes({" + << llvm::join(resultTypes, ", ") << "});\n\n"; +} + +void OpEmitter::genInferedTypeCollectiveParamBuilder() { + // TODO(jpienaar): Expand to support regions. + std::string params = + (Twine("Builder *, OperationState &") + builderOpState + + ", ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes") + .str(); + auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); + auto &body = m.body(); + + body << " " << builderOpState << ".addOperands(operands);\n\n"; + body << " " << builderOpState << ".addAttributes(attributes);\n"; + body << " " << builderOpState << ".addTypes(" << opClass.getClassName() + << "::inferReturnTypes(" << builderOpState + << ".location, operands, attributes, /*regions=*/{}));\n"; } void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { @@ -1026,15 +1047,17 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() { } else { resultType = "attr.second.getType()"; } - SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType); - body << " " << builderOpState << ".addTypes({" - << llvm::join(resultTypes, ", ") << "});\n"; - body << " }\n"; // Operands body << " " << builderOpState << ".addOperands(operands);\n\n"; // Attributes body << " " << builderOpState << ".addAttributes(attributes);\n"; + + // Result types + SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType); + body << " " << builderOpState << ".addTypes({" + << llvm::join(resultTypes, ", ") << "});\n"; + body << " }\n"; } void OpEmitter::genBuilder() { @@ -1082,7 +1105,7 @@ void OpEmitter::genBuilder() { genCollectiveParamBuilder(); // 4. one having a stand-alone parameter for each operand and attribute, // use the first operand or attribute's type as all result types - // to facilitate different call patterns. + // to facilitate different call patterns. if (op.getNumVariadicResults() == 0) { if (op.getTrait("OpTrait::SameOperandsAndResultType")) { genUseOperandAsResultTypeSeparateParamBuilder(); @@ -1091,6 +1114,11 @@ void OpEmitter::genBuilder() { if (op.getTrait("OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); } + // TODO(jpienaar): Subsume this with general checking if type can be infered + // automatically. + // TODO(jpienaar): Expand to handle regions. + if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0) + genInferedTypeCollectiveParamBuilder(); } void OpEmitter::genCollectiveParamBuilder() { |

