//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // OpDefinitionsGen uses the description of operations to generate C++ // definitions for ops. // //===----------------------------------------------------------------------===// #include "mlir/Support/STLExtras.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/ODSDialectHook.h" #include "mlir/TableGen/OpClass.h" #include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #define DEBUG_TYPE "mlir-tblgen-opdefgen" using namespace mlir; using namespace mlir::tblgen; using llvm::CodeInit; using llvm::DefInit; using llvm::formatv; using llvm::Init; using llvm::ListInit; using llvm::Record; using llvm::RecordKeeper; using llvm::StringInit; //===----------------------------------------------------------------------===// // Dialect hook registration //===----------------------------------------------------------------------===// static llvm::ManagedStatic> dialectHooks; ODSDialectHookRegistration::ODSDialectHookRegistration( StringRef dialectName, DialectEmitFunction emitFn) { bool inserted = dialectHooks->try_emplace(dialectName, emitFn).second; assert(inserted && "Multiple ODS hooks for the same dialect!"); (void)inserted; } //===----------------------------------------------------------------------===// // Static string definitions //===----------------------------------------------------------------------===// static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "tblgen_arg"; static const char *const builderOpState = "tblgen_state"; // The logic to calculate the actual value range for a declared operand/result // of an op with variadic operands/results. Note that this logic is not for // general use; it assumes all variadic operands/results must have the same // number of values. // // {0}: The list of whether each declared operand/result is variadic. // {1}: The total number of non-variadic operands/results. // {2}: The total number of variadic operands/results. // {3}: The total number of actual values. // {4}: The begin iterator of the actual values. // {5}: "operand" or "result". const char *sameVariadicSizeValueRangeCalcCode = R"( bool isVariadic[] = {{{0}}; int prevVariadicCount = 0; for (unsigned i = 0; i < index; ++i) if (isVariadic[i]) ++prevVariadicCount; // Calculate how many dynamic values a static variadic {5} corresponds to. // This assumes all static variadic {5}s have the same dynamic value count. int variadicSize = ({3} - {1}) / {2}; // `index` passed in as the parameter is the static index which counts each // {5} (variadic or not) as size 1. So here for each previous static variadic // {5}, we need to offset by (variadicSize - 1) to get where the dynamic // value pack for this static {5} starts. int offset = index + (variadicSize - 1) * prevVariadicCount; int size = isVariadic[index] ? variadicSize : 1; return {{std::next({4}, offset), std::next({4}, offset + size)}; )"; // The logic to calculate the actual value range for a declared operand/result // of an op with variadic operands/results. Note that this logic is assumes // the op has an attribute specifying the size of each operand/result segment // (variadic or not). // // {0}: The name of the attribute specifying the segment sizes. // {1}: The begin iterator of the actual values. const char *attrSizedSegmentValueRangeCalcCode = R"( auto sizeAttr = getAttrOfType("{0}"); unsigned start = 0; for (unsigned i = 0; i < index; ++i) start += (*(sizeAttr.begin() + i)).getZExtValue(); unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue(); return {{std::next({1}, start), std::next({1}, end)}; )"; static const char *const opCommentHeader = R"( //===----------------------------------------------------------------------===// // {0} {1} //===----------------------------------------------------------------------===// )"; //===----------------------------------------------------------------------===// // Utility structs and functions //===----------------------------------------------------------------------===// // Returns whether the record has a value of the given name that can be returned // via getValueAsString. static inline bool hasStringAttribute(const Record &record, StringRef fieldName) { auto valueInit = record.getValueInit(fieldName); return isa(valueInit) || isa(valueInit); } static std::string getArgumentName(const Operator &op, int index) { const auto &operand = op.getOperand(index); if (!operand.name.empty()) return operand.name; else return formatv("{0}_{1}", generatedArgName, index); } // Returns true if we can use unwrapped value for the given `attr` in builders. static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { return attr.getReturnType() != attr.getStorageType() && // We need to wrap the raw value into an attribute in the builder impl // so we need to make sure that the attribute specifies how to do that. !attr.getConstBuilderTemplate().empty(); } //===----------------------------------------------------------------------===// // Op emitter //===----------------------------------------------------------------------===// namespace { // Simple RAII helper for defining ifdef-undef-endif scopes. class IfDefScope { public: IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) { os << "#ifdef " << name << "\n" << "#undef " << name << "\n\n"; } ~IfDefScope() { os << "\n#endif // " << name << "\n\n"; } private: StringRef name; raw_ostream &os; }; } // end anonymous namespace namespace { // Helper class to emit a record into the given output stream. class OpEmitter { public: static void emitDecl(const Operator &op, raw_ostream &os); static void emitDef(const Operator &op, raw_ostream &os); private: OpEmitter(const Operator &op); void emitDecl(raw_ostream &os); void emitDef(raw_ostream &os); // Generates the OpAsmOpInterface for this operation if possible. void genOpAsmInterface(); // Generates the `getOperationName` method for this op. void genOpNameGetter(); // Generates getters for the attributes. void genAttrGetters(); // Generates getters for named operands. void genNamedOperandGetters(); // Generates getters for named results. void genNamedResultGetters(); // Generates getters for named regions. void genNamedRegionGetters(); // Generates builder methods for the operation. void genBuilder(); // Generates the build() method that takes each operand/attribute // as a stand-alone parameter. void genSeparateArgParamBuilder(); // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first operand's // type as all results' types. void genUseOperandAsResultTypeSeparateParamBuilder(); // Generates the build() method that takes all operands/attributes // collectively as one parameter. The generated build() method uses first // operand's type as all results' types. void genUseOperandAsResultTypeCollectiveParamBuilder(); // Generates the build() method that takes aggregate operands/attributes // parameters. This build() method uses inferred types as result types. // Requires: The type needs to be inferable via InferTypeOpInterface. void genInferedTypeCollectiveParamBuilder(); // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first attribute's // type as all result's types. void genUseAttrAsResultTypeBuilder(); // Generates the build() method that takes all result types collectively as // one parameter. Similarly for operands and attributes. void genCollectiveParamBuilder(); // The kind of parameter to generate for result types in builders. enum class TypeParamKind { None, // No result type in parameter list. Separate, // A separate parameter for each result type. Collective, // An ArrayRef for all result types. }; // The kind of parameter to generate for attributes in builders. enum class AttrParamKind { WrappedAttr, // A wrapped MLIR Attribute instance. UnwrappedValue, // A raw value without MLIR Attribute wrapper. }; // Builds the parameter list for build() method of this op. This method writes // to `paramList` the comma-separated parameter list and updates // `resultTypeNames` with the names for parameters for specifying result // types. The given `typeParamKind` and `attrParamKind` controls how result // types and attributes are placed in the parameter list. void buildParamList(std::string ¶mList, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); // Adds op arguments and regions into operation state for build() methods. void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, bool isRawValueAttr = false); // Generates canonicalizer declaration for the operation. void genCanonicalizerDecls(); // Generates the folder declaration for the operation. void genFolderDecls(); // Generates the parser for the operation. void genParser(); // Generates the printer for the operation. void genPrinter(); // Generates verify method for the operation. void genVerifier(); // Generates verify statements for operands and results in the operation. // The generated code will be attached to `body`. void genOperandResultVerifier(OpMethodBody &body, Operator::value_range values, StringRef valueKind); // Generates verify statements for regions in the operation. // The generated code will be attached to `body`. void genRegionVerifier(OpMethodBody &body); // Generates the traits used by the object. void genTraits(); // Generate the OpInterface methods. void genOpInterfaceMethods(); private: // The TableGen record for this op. // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly, // it should rather go through the Operator for better abstraction. const Record &def; // The wrapper operator class for querying information from this op. Operator op; // The C++ code builder for this op OpClass opClass; // The format context for verification code generation. FmtContext verifyCtx; }; } // end anonymous namespace OpEmitter::OpEmitter(const Operator &op) : def(op.getDef()), op(op), opClass(op.getCppClassName(), op.getExtraClassDeclaration()) { verifyCtx.withOp("(*this->getOperation())"); genTraits(); // Generate C++ code for various op methods. The order here determines the // methods in the generated file. genOpAsmInterface(); genOpNameGetter(); genNamedOperandGetters(); genNamedResultGetters(); genNamedRegionGetters(); genAttrGetters(); genBuilder(); genParser(); genPrinter(); genVerifier(); genCanonicalizerDecls(); genFolderDecls(); genOpInterfaceMethods(); // If a dialect hook is registered for this op's dialect, emit dialect // specific content. auto dialectHookIt = dialectHooks->find(op.getDialectName()); if (dialectHookIt != dialectHooks->end()) { dialectHookIt->second(op, opClass); } } void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { OpEmitter(op).emitDecl(os); } void OpEmitter::emitDef(const Operator &op, raw_ostream &os) { OpEmitter(op).emitDef(os); } void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); } void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } void OpEmitter::genAttrGetters() { FmtContext fctx; fctx.withBuilder("mlir::Builder(this->getContext())"); // Emit the derived attribute body. auto emitDerivedAttr = [&](StringRef name, Attribute attr) { auto &method = opClass.newMethod(attr.getReturnType(), name); auto &body = method.body(); body << " " << attr.getDerivedCodeBody() << "\n"; }; // Emit with return type specified. auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { auto &method = opClass.newMethod(attr.getReturnType(), name); auto &body = method.body(); body << " auto attr = " << name << "Attr();\n"; if (attr.hasDefaultValue()) { // Returns the default value if not set. // TODO: this is inefficient, we are recreating the attribute for every // call. This should be set instead. std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()); body << " if (!attr)\n return " << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf(defaultValue)) << ";\n"; } body << " return " << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) << ";\n"; }; // Generate raw named accessor type. This is a wrapper class that allows // referring to the attributes via accessors instead of having to use // the string interface for better compile time verification. auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { auto &method = opClass.newMethod(attr.getStorageType(), (name + "Attr").str()); auto &body = method.body(); body << " return this->getAttr(\"" << name << "\")."; if (attr.isOptional() || attr.hasDefaultValue()) body << "dyn_cast_or_null<"; else body << "cast<"; body << attr.getStorageType() << ">();"; }; for (auto &namedAttr : op.getAttributes()) { const auto &name = namedAttr.name; const auto &attr = namedAttr.attr; if (attr.isDerivedAttr()) { emitDerivedAttr(name, attr); } else { emitAttrWithStorageType(name, attr); emitAttrWithReturnType(name, attr); } } } // Generates the named operand getter methods for the given Operator `op` and // puts them in `opClass`. Uses `rangeType` as the return type of getters that // return a range of operands (individual operands are `Value ` and each // element in the range must also be `Value `); use `rangeBeginCall` to get // an iterator to the beginning of the operand range; use `rangeSizeCall` to // obtain the number of operands. `getOperandCallPattern` contains the code // necessary to obtain a single operand whose position will be substituted // instead of // "{0}" marker in the pattern. Note that the pattern should work for any kind // of ops, in particular for one-operand ops that may not have the // `getOperand(unsigned)` method. static void generateNamedOperandGetters(const Operator &op, Class &opClass, StringRef rangeType, StringRef rangeBeginCall, StringRef rangeSizeCall, StringRef getOperandCallPattern) { const int numOperands = op.getNumOperands(); const int numVariadicOperands = op.getNumVariadicOperands(); const int numNormalOperands = numOperands - numVariadicOperands; const auto *sameVariadicSize = op.getTrait("OpTrait::SameVariadicOperandSize"); const auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments"); if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) { PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " "specification over their sizes"); } if (numVariadicOperands < 2 && attrSizedOperands) { PrintFatalError(op.getLoc(), "op must have at least two variadic operands " "to use 'AttrSizedOperandSegments' trait"); } if (attrSizedOperands && sameVariadicSize) { PrintFatalError(op.getLoc(), "op cannot have both 'AttrSizedOperandSegments' and " "'SameVariadicOperandSize' traits"); } // First emit a "sink" getter method upon which we layer all nicer named // getter methods. auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); if (numVariadicOperands == 0) { // We still need to match the return type, which is a range. m.body() << " return {std::next(" << rangeBeginCall << ", index), std::next(" << rangeBeginCall << ", index + 1)};"; } else if (attrSizedOperands) { m.body() << formatv(attrSizedSegmentValueRangeCalcCode, "operand_segment_sizes", rangeBeginCall); } else { // Because the op can have arbitrarily interleaved variadic and non-variadic // operands, we need to embed a list in the "sink" getter method for // calculation at run-time. llvm::SmallVector isVariadic; isVariadic.reserve(numOperands); for (int i = 0; i < numOperands; ++i) { isVariadic.push_back(llvm::toStringRef(op.getOperand(i).isVariadic())); } std::string isVariadicList = llvm::join(isVariadic, ", "); m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, numNormalOperands, numVariadicOperands, rangeSizeCall, rangeBeginCall, "operand"); } // Then we emit nicer named getter methods by redirecting to the "sink" getter // method. for (int i = 0; i != numOperands; ++i) { const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; if (operand.isVariadic()) { auto &m = opClass.newMethod(rangeType, operand.name); m.body() << " return getODSOperands(" << i << ");"; } else { auto &m = opClass.newMethod("Value ", operand.name); m.body() << " return *getODSOperands(" << i << ").begin();"; } } } void OpEmitter::genNamedOperandGetters() { if (op.getTrait("OpTrait::AttrSizedOperandSegments")) opClass.setHasOperandAdaptorClass(false); generateNamedOperandGetters( op, opClass, /*rangeType=*/"Operation::operand_range", /*rangeBeginCall=*/"getOperation()->operand_begin()", /*rangeSizeCall=*/"getOperation()->getNumOperands()", /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); } void OpEmitter::genNamedResultGetters() { const int numResults = op.getNumResults(); const int numVariadicResults = op.getNumVariadicResults(); const int numNormalResults = numResults - numVariadicResults; // If we have more than one variadic results, we need more complicated logic // to calculate the value range for each result. const auto *sameVariadicSize = op.getTrait("OpTrait::SameVariadicResultSize"); const auto *attrSizedResults = op.getTrait("OpTrait::AttrSizedResultSegments"); if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) { PrintFatalError(op.getLoc(), "op has multiple variadic results but no " "specification over their sizes"); } if (numVariadicResults < 2 && attrSizedResults) { PrintFatalError(op.getLoc(), "op must have at least two variadic results " "to use 'AttrSizedResultSegments' trait"); } if (attrSizedResults && sameVariadicSize) { PrintFatalError(op.getLoc(), "op cannot have both 'AttrSizedResultSegments' and " "'SameVariadicResultSize' traits"); } auto &m = opClass.newMethod("Operation::result_range", "getODSResults", "unsigned index"); if (numVariadicResults == 0) { m.body() << " return {std::next(getOperation()->result_begin(), index), " "std::next(getOperation()->result_begin(), index + 1)};"; } else if (attrSizedResults) { m.body() << formatv(attrSizedSegmentValueRangeCalcCode, "result_segment_sizes", "getOperation()->result_begin()"); } else { llvm::SmallVector isVariadic; isVariadic.reserve(numResults); for (int i = 0; i < numResults; ++i) { isVariadic.push_back(llvm::toStringRef(op.getResult(i).isVariadic())); } std::string isVariadicList = llvm::join(isVariadic, ", "); m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, numNormalResults, numVariadicResults, "getOperation()->getNumResults()", "getOperation()->result_begin()", "result"); } for (int i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); if (result.name.empty()) continue; if (result.isVariadic()) { auto &m = opClass.newMethod("Operation::result_range", result.name); m.body() << " return getODSResults(" << i << ");"; } else { auto &m = opClass.newMethod("Value ", result.name); m.body() << " return *getODSResults(" << i << ").begin();"; } } } void OpEmitter::genNamedRegionGetters() { unsigned numRegions = op.getNumRegions(); for (unsigned i = 0; i < numRegions; ++i) { const auto ®ion = op.getRegion(i); if (!region.name.empty()) { auto &m = opClass.newMethod("Region &", region.name); m.body() << formatv(" return this->getOperation()->getRegion({0});", i); } } } static bool canGenerateUnwrappedBuilder(Operator &op) { // If this op does not have native attributes at all, return directly to avoid // redefining builders. if (op.getNumNativeAttributes() == 0) return false; bool canGenerate = false; // We are generating builders that take raw values for attributes. We need to // make sure the native attributes have a meaningful "unwrapped" value type // different from the wrapped mlir::Attribute type to avoid redefining // builders. This checks for the op has at least one such native attribute. for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { NamedAttribute &namedAttr = op.getAttribute(i); if (canUseUnwrappedRawValue(namedAttr.attr)) { canGenerate = true; break; } } return canGenerate; } void OpEmitter::genSeparateArgParamBuilder() { SmallVector attrBuilderType; attrBuilderType.push_back(AttrParamKind::WrappedAttr); if (canGenerateUnwrappedBuilder(op)) attrBuilderType.push_back(AttrParamKind::UnwrappedValue); // Emit with separate builders with or without unwrapped attributes and/or // inferring result type. auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, bool inferType) { std::string paramList; llvm::SmallVector resultNames; buildParamList(paramList, resultNames, paramKind, attrType); auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); auto &body = m.body(); genCodeForAddingArgAndRegionForBuilder( body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); // Push all result types to the operation state if (inferType) { // Generate builder that infers type too. // TODO(jpienaar): Subsume this with general checking if type can be // infered automatically. // TODO(jpienaar): Expand to handle regions. body << formatv(R"( SmallVector inferedReturnTypes; if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands, {1}.attributes, /*regions=*/{{}, inferedReturnTypes))) {1}.addTypes(inferedReturnTypes); else llvm::report_fatal_error("Failed to infer result type(s).");)", opClass.getClassName(), builderOpState); return; } switch (paramKind) { case TypeParamKind::None: return; case TypeParamKind::Separate: for (int i = 0, e = op.getNumResults(); i < e; ++i) { body << " " << builderOpState << ".addTypes(" << resultNames[i] << ");\n"; } return; case TypeParamKind::Collective: body << " " << builderOpState << ".addTypes(resultTypes);\n"; return; }; llvm_unreachable("unhandled TypeParamKind"); }; bool canInferType = op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0; for (auto attrType : attrBuilderType) { emit(attrType, TypeParamKind::Separate, /*inferType=*/false); if (canInferType) emit(attrType, TypeParamKind::None, /*inferType=*/true); // Emit separate arg build with collective type, unless there is only one // variadic result, in which case the above would have already generated // the same build method. if (!(op.getNumResults() == 1 && op.getResult(0).isVariadic())) emit(attrType, TypeParamKind::Collective, /*inferType=*/false); } } void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { // If this op has a variadic result, we cannot generate this builder because // we don't know how many results to create. if (op.getNumVariadicResults() != 0) return; int numResults = op.getNumResults(); // Signature std::string params = std::string("Builder *, OperationState &") + builderOpState + ", ValueRange operands, ArrayRef attributes"; auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); // Operands body << " " << builderOpState << ".addOperands(operands);\n\n"; // Attributes body << " " << builderOpState << ".addAttributes(attributes);\n"; // Create the correct number of regions if (int numRegions = op.getNumRegions()) { for (int i = 0; i < numRegions; ++i) m.body() << " (void)" << builderOpState << ".addRegion();\n"; } // Result types SmallVector resultTypes(numResults, "operands[0].getType()"); body << " " << builderOpState << ".addTypes({" << llvm::join(resultTypes, ", ") << "});\n\n"; } void OpEmitter::genInferedTypeCollectiveParamBuilder() { // TODO(jpienaar): Expand to support regions. const char *params = "Builder *builder, OperationState &{0}, " "ValueRange operands, ArrayRef attributes"; auto &m = opClass.newMethod("void", "build", formatv(params, builderOpState).str(), OpMethod::MP_Static); auto &body = m.body(); body << formatv(R"( SmallVector inferedReturnTypes; if (succeeded({0}::inferReturnTypes({1}.location, operands, attributes, /*regions=*/{{}, inferedReturnTypes))) build(builder, tblgen_state, inferedReturnTypes, operands, attributes); else llvm::report_fatal_error("Failed to infer result type(s).");)", opClass.getClassName(), builderOpState); } void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { std::string paramList; llvm::SmallVector resultNames; buildParamList(paramList, resultNames, TypeParamKind::None); auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); genCodeForAddingArgAndRegionForBuilder(m.body()); auto numResults = op.getNumResults(); if (numResults == 0) return; // Push all result types to the operation state const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; std::string resultType = formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); m.body() << " " << builderOpState << ".addTypes({" << resultType; for (int i = 1; i != numResults; ++i) m.body() << ", " << resultType; m.body() << "});\n\n"; } void OpEmitter::genUseAttrAsResultTypeBuilder() { std::string params = std::string("Builder *, OperationState &") + builderOpState + ", ValueRange operands, ArrayRef attributes"; auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); // Push all result types to the operation state std::string resultType; const auto &namedAttr = op.getAttribute(0); body << " for (auto attr : attributes) {\n"; body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n"; if (namedAttr.attr.isTypeAttr()) { resultType = "attr.second.cast().getValue()"; } else { resultType = "attr.second.getType()"; } // Operands body << " " << builderOpState << ".addOperands(operands);\n\n"; // Attributes body << " " << builderOpState << ".addAttributes(attributes);\n"; // Result types SmallVector resultTypes(op.getNumResults(), resultType); body << " " << builderOpState << ".addTypes({" << llvm::join(resultTypes, ", ") << "});\n"; body << " }\n"; } void OpEmitter::genBuilder() { // Handle custom builders if provided. // TODO(antiagainst): Create wrapper class for OpBuilder to hide the native // TableGen API calls here. { auto *listInit = dyn_cast_or_null(def.getValueInit("builders")); if (listInit) { for (Init *init : listInit->getValues()) { Record *builderDef = cast(init)->getDef(); StringRef params = builderDef->getValueAsString("params"); StringRef body = builderDef->getValueAsString("body"); bool hasBody = !body.empty(); auto &method = opClass.newMethod("void", "build", params, OpMethod::MP_Static, /*declOnly=*/!hasBody); if (hasBody) method.body() << body; } } if (op.skipDefaultBuilders()) { if (!listInit || listInit->empty()) PrintFatalError( op.getLoc(), "default builders are skipped and no custom builders provided"); return; } } // Generate default builders that requires all result type, operands, and // attributes as parameters. // We generate three classes of builders here: // 1. one having a stand-alone parameter for each operand / attribute, and genSeparateArgParamBuilder(); // 2. one having an aggregated parameter for all result types / operands / // attributes, and genCollectiveParamBuilder(); // 3. one having a stand-alone parameter for each operand and attribute, // use the first operand or attribute's type as all result types // to facilitate different call patterns. if (op.getNumVariadicResults() == 0) { if (op.getTrait("OpTrait::SameOperandsAndResultType")) { genUseOperandAsResultTypeSeparateParamBuilder(); genUseOperandAsResultTypeCollectiveParamBuilder(); } if (op.getTrait("OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); } } void OpEmitter::genCollectiveParamBuilder() { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariadicResults(); int numNonVariadicResults = numResults - numVariadicResults; int numOperands = op.getNumOperands(); int numVariadicOperands = op.getNumVariadicOperands(); int numNonVariadicOperands = numOperands - numVariadicOperands; // Signature std::string params = std::string("Builder *, OperationState &") + builderOpState + ", ArrayRef resultTypes, ValueRange operands, " "ArrayRef attributes"; auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); // Operands if (numVariadicOperands == 0 || numNonVariadicOperands != 0) body << " assert(operands.size()" << (numVariadicOperands != 0 ? " >= " : " == ") << numNonVariadicOperands << "u && \"mismatched number of parameters\");\n"; body << " " << builderOpState << ".addOperands(operands);\n\n"; // Attributes body << " " << builderOpState << ".addAttributes(attributes);\n"; // Create the correct number of regions if (int numRegions = op.getNumRegions()) { for (int i = 0; i < numRegions; ++i) m.body() << " (void)" << builderOpState << ".addRegion();\n"; } // Result types if (numVariadicResults == 0 || numNonVariadicResults != 0) body << " assert(resultTypes.size()" << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults << "u && \"mismatched number of return types\");\n"; body << " " << builderOpState << ".addTypes(resultTypes);\n"; // Generate builder that infers type too. // TODO(jpienaar): Subsume this with general checking if type can be infered // automatically. // TODO(jpienaar): Expand to handle regions. if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0) genInferedTypeCollectiveParamBuilder(); } void OpEmitter::buildParamList(std::string ¶mList, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind) { resultTypeNames.clear(); auto numResults = op.getNumResults(); resultTypeNames.reserve(numResults); paramList = "Builder *tblgen_builder, OperationState &"; paramList.append(builderOpState); switch (typeParamKind) { case TypeParamKind::None: break; case TypeParamKind::Separate: { // Add parameters for all return types for (int i = 0; i < numResults; ++i) { const auto &result = op.getResult(i); std::string resultName = result.name; if (resultName.empty()) resultName = formatv("resultType{0}", i); paramList.append(result.isVariadic() ? ", ArrayRef " : ", Type "); paramList.append(resultName); resultTypeNames.emplace_back(std::move(resultName)); } } break; case TypeParamKind::Collective: { paramList.append(", ArrayRef resultTypes"); resultTypeNames.push_back("resultTypes"); } break; } // Add parameters for all arguments (operands and attributes). int numOperands = 0; int numAttrs = 0; int defaultValuedAttrStartIndex = op.getNumArgs(); if (attrParamKind == AttrParamKind::UnwrappedValue) { // Calculate the start index from which we can attach default values in the // builder declaration. for (int i = op.getNumArgs() - 1; i >= 0; --i) { auto *namedAttr = op.getArg(i).dyn_cast(); if (!namedAttr || !namedAttr->attr.hasDefaultValue()) break; if (!canUseUnwrappedRawValue(namedAttr->attr)) break; // Creating an APInt requires us to provide bitwidth, value, and // signedness, which is complicated compared to others. Similarly // for APFloat. // TODO(b/144412160) Adjust the 'returnType' field of such attributes // to support them. StringRef retType = namedAttr->attr.getReturnType(); if (retType == "APInt" || retType == "APFloat") break; defaultValuedAttrStartIndex = i; } } for (int i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); if (argument.is()) { const auto &operand = op.getOperand(numOperands); paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value "); paramList.append(getArgumentName(op, numOperands)); ++numOperands; } else { const auto &namedAttr = op.getAttribute(numAttrs); const auto &attr = namedAttr.attr; paramList.append(", "); if (attr.isOptional()) paramList.append("/*optional*/"); switch (attrParamKind) { case AttrParamKind::WrappedAttr: paramList.append(attr.getStorageType()); break; case AttrParamKind::UnwrappedValue: if (canUseUnwrappedRawValue(attr)) { paramList.append(attr.getReturnType()); } else { paramList.append(attr.getStorageType()); } break; } paramList.append(" "); paramList.append(namedAttr.name); // Attach default value if requested and possible. if (attrParamKind == AttrParamKind::UnwrappedValue && i >= defaultValuedAttrStartIndex) { bool isString = attr.getReturnType() == "StringRef"; paramList.append(" = "); if (isString) paramList.append("\""); paramList.append(attr.getDefaultValue()); if (isString) paramList.append("\""); } ++numAttrs; } } } void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, bool isRawValueAttr) { // Push all operands to the result for (int i = 0, e = op.getNumOperands(); i < e; ++i) { body << " " << builderOpState << ".addOperands(" << getArgumentName(op, i) << ");\n"; } // Push all attributes to the result for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; if (!attr.isDerivedAttr()) { bool emitNotNullCheck = attr.isOptional(); if (emitNotNullCheck) { body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; } if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { // If this is a raw value, then we need to wrap it in an Attribute // instance. FmtContext fctx; fctx.withBuilder("(*tblgen_builder)"); std::string value = tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name); body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, namedAttr.name, value); } else { body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, namedAttr.name); } if (emitNotNullCheck) { body << " }\n"; } } } // Create the correct number of regions if (int numRegions = op.getNumRegions()) { for (int i = 0; i < numRegions; ++i) body << " (void)" << builderOpState << ".addRegion();\n"; } } void OpEmitter::genCanonicalizerDecls() { if (!def.getValueAsBit("hasCanonicalizer")) return; const char *const params = "OwningRewritePatternList &results, MLIRContext *context"; opClass.newMethod("void", "getCanonicalizationPatterns", params, OpMethod::MP_Static, /*declOnly=*/true); } void OpEmitter::genFolderDecls() { bool hasSingleResult = op.getNumResults() == 1 && op.getNumVariadicResults() == 0; if (def.getValueAsBit("hasFolder")) { if (hasSingleResult) { const char *const params = "ArrayRef operands"; opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None, /*declOnly=*/true); } else { const char *const params = "ArrayRef operands, " "SmallVectorImpl &results"; opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None, /*declOnly=*/true); } } } void OpEmitter::genOpInterfaceMethods() { for (const auto &trait : op.getTraits()) { auto opTrait = dyn_cast(&trait); if (!opTrait || !opTrait->shouldDeclareMethods()) continue; auto interface = opTrait->getOpInterface(); for (auto method : interface.getMethods()) { // Don't declare if the method has a body. if (method.getBody()) continue; std::string args; llvm::raw_string_ostream os(args); mlir::interleaveComma(method.getArguments(), os, [&](const OpInterfaceMethod::Argument &arg) { os << arg.type << " " << arg.name; }); opClass.newMethod(method.getReturnType(), method.getName(), os.str(), method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None, /*declOnly=*/true); } } } void OpEmitter::genParser() { if (!hasStringAttribute(def, "parser")) return; auto &method = opClass.newMethod( "ParseResult", "parse", "OpAsmParser &parser, OperationState &result", OpMethod::MP_Static); FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r"); method.body() << " " << tgfmt(parser, &fctx); } void OpEmitter::genPrinter() { auto valueInit = def.getValueInit("printer"); CodeInit *codeInit = dyn_cast(valueInit); if (!codeInit) return; auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p"); FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r"); method.body() << " " << tgfmt(printer, &fctx); } void OpEmitter::genVerifier() { auto valueInit = def.getValueInit("verifier"); CodeInit *codeInit = dyn_cast(valueInit); bool hasCustomVerify = codeInit && !codeInit->getValue().empty(); auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/""); auto &body = method.body(); // Populate substitutions for attributes and named operands and results. for (const auto &namedAttr : op.getAttributes()) verifyCtx.addSubst(namedAttr.name, formatv("this->getAttr(\"{0}\")", namedAttr.name)); for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); // Skip from from first variadic operands for now. Else getOperand index // used below doesn't match. if (value.isVariadic()) break; if (!value.name.empty()) verifyCtx.addSubst(value.name, formatv("this->getOperation()->getOperand({0})", i)); } for (int i = 0, e = op.getNumResults(); i < e; ++i) { auto &value = op.getResult(i); // Skip from from first variadic results for now. Else getResult index used // below doesn't match. if (value.isVariadic()) break; if (!value.name.empty()) verifyCtx.addSubst(value.name, formatv("this->getOperation()->getResult({0})", i)); } // Verify the attributes have the correct type. for (const auto &namedAttr : op.getAttributes()) { const auto &attr = namedAttr.attr; if (attr.isDerivedAttr()) continue; auto attrName = namedAttr.name; // Prefix with `tblgen_` to avoid hiding the attribute accessor. auto varName = tblgenNamePrefix + attrName; body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName, attrName); bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); if (allowMissingAttr) { // If the attribute has a default value, then only verify the predicate if // set. This does effectively assume that the default value is valid. // TODO: verify the debug value is valid (perhaps in debug mode only). body << " if (" << varName << ") {\n"; } else { body << " if (!" << varName << ") return emitOpError(\"requires attribute '" << attrName << "'\");\n {\n"; } auto attrPred = attr.getPredicate(); if (!attrPred.isNull()) { body << tgfmt( " if (!($0)) return emitOpError(\"attribute '$1' " "failed to satisfy constraint: $2\");\n", /*ctx=*/nullptr, tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)), attrName, attr.getDescription()); } body << " }\n"; } const char *code = R"( auto sizeAttr = getAttrOfType("{0}"); auto numElements = sizeAttr.getType().cast().getNumElements(); if (numElements != {1}) {{ return emitOpError("'{0}' attribute for specifying {2} segments " "must have {1} elements"); } )"; for (auto &trait : op.getTraits()) { if (auto *t = dyn_cast(&trait)) { body << tgfmt(" if (!($0)) {\n " "return emitOpError(\"failed to verify that $1\");\n }\n", &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), t->getDescription()); } else if (auto *t = dyn_cast(&trait)) { if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") { body << formatv(code, "operand_segment_sizes", op.getNumOperands(), "operand"); } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") { body << formatv(code, "result_segment_sizes", op.getNumResults(), "result"); } } } // These should happen after we verified the traits because // getODSOperands()/getODSResults() may depend on traits (e.g., // AttrSizedOperandSegments/AttrSizedResultSegments). genOperandResultVerifier(body, op.getOperands(), "operand"); genOperandResultVerifier(body, op.getResults(), "result"); genRegionVerifier(body); if (hasCustomVerify) { FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r"); body << " " << tgfmt(printer, &fctx); } else { body << " return mlir::success();\n"; } } void OpEmitter::genOperandResultVerifier(OpMethodBody &body, Operator::value_range values, StringRef valueKind) { FmtContext fctx; body << " {\n"; body << " unsigned index = 0; (void)index;\n"; for (auto staticValue : llvm::enumerate(values)) { if (!staticValue.value().hasPredicate()) continue; // Emit a loop to check all the dynamic values in the pack. body << formatv(" for (Value v : getODS{0}{1}s({2})) {{\n", // Capitalize the first letter to match the function name valueKind.substr(0, 1).upper(), valueKind.substr(1), staticValue.index()); auto constraint = staticValue.value().constraint; body << " (void)v;\n" << " if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("v.getType()")) << ")) {\n" << formatv(" return emitOpError(\"{0} #\") << index " "<< \" must be {1}, but got \" << v.getType();\n", valueKind, constraint.getDescription()) << " }\n" // if << " ++index;\n" << " }\n"; // for } body << " }\n"; } void OpEmitter::genRegionVerifier(OpMethodBody &body) { unsigned numRegions = op.getNumRegions(); // Verify this op has the correct number of regions body << formatv( " if (this->getOperation()->getNumRegions() != {0}) {\n " "return emitOpError(\"has incorrect number of regions: expected {0} but " "found \") << this->getOperation()->getNumRegions();\n }\n", numRegions); for (unsigned i = 0; i < numRegions; ++i) { const auto ®ion = op.getRegion(i); std::string name = formatv("#{0}", i); if (!region.name.empty()) { name += formatv(" ('{0}')", region.name); } auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str(); auto constraint = tgfmt(region.constraint.getConditionTemplate(), &verifyCtx.withSelf(getRegion)) .str(); body << formatv(" if (!({0})) {\n " "return emitOpError(\"region {1} failed to verify " "constraint: {2}\");\n }\n", constraint, name, region.constraint.getDescription()); } } void OpEmitter::genTraits() { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariadicResults(); // Add return size trait. if (numVariadicResults != 0) { if (numResults == numVariadicResults) opClass.addTrait("OpTrait::VariadicResults"); else opClass.addTrait("OpTrait::AtLeastNResults<" + Twine(numResults - numVariadicResults) + ">::Impl"); } else { switch (numResults) { case 0: opClass.addTrait("OpTrait::ZeroResult"); break; case 1: opClass.addTrait("OpTrait::OneResult"); break; default: opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl"); break; } } for (const auto &trait : op.getTraits()) { if (auto opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getTrait()); else if (auto opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getTrait()); } // Add variadic size trait and normal op traits. int numOperands = op.getNumOperands(); int numVariadicOperands = op.getNumVariadicOperands(); // Add operand size trait. if (numVariadicOperands != 0) { if (numOperands == numVariadicOperands) opClass.addTrait("OpTrait::VariadicOperands"); else opClass.addTrait("OpTrait::AtLeastNOperands<" + Twine(numOperands - numVariadicOperands) + ">::Impl"); } else { switch (numOperands) { case 0: opClass.addTrait("OpTrait::ZeroOperands"); break; case 1: opClass.addTrait("OpTrait::OneOperand"); break; default: opClass.addTrait("OpTrait::NOperands<" + Twine(numOperands) + ">::Impl"); break; } } } void OpEmitter::genOpNameGetter() { auto &method = opClass.newMethod("StringRef", "getOperationName", /*params=*/"", OpMethod::MP_Static); method.body() << " return \"" << op.getOperationName() << "\";\n"; } void OpEmitter::genOpAsmInterface() { // If the user only has one results or specifically added the Asm trait, // then don't generate it for them. We specifically only handle multi result // operations, because the name of a single result in the common case is not // interesting(generally 'result'/'output'/etc.). // TODO: We could also add a flag to allow operations to opt in to this // generation, even if they only have a single operation. int numResults = op.getNumResults(); if (numResults <= 1 || op.getTrait("OpAsmOpInterface::Trait")) return; SmallVector resultNames(numResults); for (int i = 0; i != numResults; ++i) resultNames[i] = op.getResultName(i); // Don't add the trait if none of the results have a valid name. if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); })) return; opClass.addTrait("OpAsmOpInterface::Trait"); // Generate the right accessor for the number of results. auto &method = opClass.newMethod("void", "getAsmResultNames", "OpAsmSetValueNameFn setNameFn"); auto &body = method.body(); for (int i = 0; i != numResults; ++i) { body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" << " if (!llvm::empty(resultGroup" << i << "))\n" << " setNameFn(*resultGroup" << i << ".begin(), \"" << resultNames[i] << "\");\n"; } } //===----------------------------------------------------------------------===// // OpOperandAdaptor emitter //===----------------------------------------------------------------------===// namespace { // Helper class to emit Op operand adaptors to an output stream. Operand // adaptors are wrappers around ArrayRef that provide named operand // getters identical to those defined in the Op. class OpOperandAdaptorEmitter { public: static void emitDecl(const Operator &op, raw_ostream &os); static void emitDef(const Operator &op, raw_ostream &os); private: explicit OpOperandAdaptorEmitter(const Operator &op); Class adapterClass; }; } // end namespace OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) : adapterClass(op.getCppClassName().str() + "OperandAdaptor") { adapterClass.newField("ArrayRef", "tblgen_operands"); auto &constructor = adapterClass.newConstructor("ArrayRef values"); constructor.body() << " tblgen_operands = values;\n"; generateNamedOperandGetters(op, adapterClass, /*rangeType=*/"ArrayRef", /*rangeBeginCall=*/"tblgen_operands.begin()", /*rangeSizeCall=*/"tblgen_operands.size()", /*getOperandCallPattern=*/"tblgen_operands[{0}]"); } void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os); } void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os); } // Emits the opcode enum and op classes. static void emitOpClasses(const std::vector &defs, raw_ostream &os, bool emitDecl) { IfDefScope scope("GET_OP_CLASSES", os); // First emit forward declaration for each class, this allows them to refer // to each others in traits for example. if (emitDecl) { for (auto *def : defs) { Operator op(*def); os << "class " << op.getCppClassName() << ";\n"; } } for (auto *def : defs) { Operator op(*def); const auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments"); if (emitDecl) { os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); // We cannot generate the operand adaptor class if operand getters depend // on an attribute. if (!attrSizedOperands) OpOperandAdaptorEmitter::emitDecl(op, os); OpEmitter::emitDecl(op, os); } else { os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); if (!attrSizedOperands) OpOperandAdaptorEmitter::emitDef(op, os); OpEmitter::emitDef(op, os); } } } // Emits a comma-separated list of the ops. static void emitOpList(const std::vector &defs, raw_ostream &os) { IfDefScope scope("GET_OP_LIST", os); interleave( // TODO: We are constructing the Operator wrapper instance just for // getting it's qualified class name here. Reduce the overhead by having a // lightweight version of Operator class just for that purpose. defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, [&os]() { os << ",\n"; }); } static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os); const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); emitOpClasses(defs, os, /*emitDecl=*/true); return false; } static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Definitions", os); const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); emitOpList(defs, os); emitOpClasses(defs, os, /*emitDecl=*/false); return false; } static mlir::GenRegistration genOpDecls("gen-op-decls", "Generate op declarations", [](const RecordKeeper &records, raw_ostream &os) { return emitOpDecls(records, os); }); static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", [](const RecordKeeper &records, raw_ostream &os) { return emitOpDefs(records, os); });