diff options
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 322 |
1 files changed, 145 insertions, 177 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index d97b2fc8664..aeec1abb29c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -35,6 +35,7 @@ #include "mlir/Support/STLExtras.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallString.h" @@ -79,6 +80,9 @@ static llvm::cl::opt<bool> namespace { class ModuleState { + /// A special index constant used for non-kind attribute aliases. + static constexpr int kNonAttrKindAlias = -1; + public: /// This is the current context if it is knowable, otherwise this is null. MLIRContext *const context; @@ -88,51 +92,68 @@ public: // Initializes module state, populating affine map state. void initialize(Module *module); - StringRef getAffineMapAlias(AffineMap affineMap) const { - return affineMapToAlias.lookup(affineMap); - } - - int getAffineMapId(AffineMap affineMap) const { - auto it = affineMapIds.find(affineMap); - if (it == affineMapIds.end()) { - return -1; + Twine getAttributeAlias(Attribute attr) const { + auto alias = attrToAlias.find(attr); + if (alias == attrToAlias.end()) + return Twine(); + + // Return the alias for this attribute, along with the index if this was + // generated by a kind alias. + int kindIndex = alias->second.second; + return alias->second.first + + (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex)); + } + + void printAttributeAliases(raw_ostream &os) const { + auto printAlias = [&](StringRef alias, Attribute attr, int index) { + os << '#' << alias; + if (index != kNonAttrKindAlias) + os << index; + os << " = " << attr << '\n'; + }; + + // Print all of the attribute kind aliases. + for (auto &kindAlias : attrKindToAlias) { + for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i) + printAlias(kindAlias.second.first, kindAlias.second.second[i], i); + os << "\n"; } - return it->second; - } - - ArrayRef<AffineMap> getAffineMapIds() const { return affineMapsById; } - StringRef getIntegerSetAlias(IntegerSet integerSet) const { - return integerSetToAlias.lookup(integerSet); - } - - int getIntegerSetId(IntegerSet integerSet) const { - auto it = integerSetIds.find(integerSet); - if (it == integerSetIds.end()) { - return -1; + // In a second pass print all of the remaining attribute aliases that aren't + // kind aliases. + for (Attribute attr : usedAttributes) { + auto alias = attrToAlias.find(attr); + if (alias != attrToAlias.end() && + alias->second.second == kNonAttrKindAlias) + printAlias(alias->second.first, attr, alias->second.second); } - return it->second; } - ArrayRef<IntegerSet> getIntegerSetIds() const { return integerSetsById; } - StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); } - ArrayRef<Type> getTypeIds() const { return usedTypes.getArrayRef(); } - -private: - void recordAffineMapReference(AffineMap affineMap) { - if (affineMapIds.count(affineMap) == 0) { - affineMapIds[affineMap] = affineMapsById.size(); - affineMapsById.push_back(affineMap); + void printTypeAliases(raw_ostream &os) const { + for (Type type : usedTypes) { + auto alias = typeToAlias.find(type); + if (alias != typeToAlias.end()) + os << '!' << alias->second << " = type " << type << '\n'; } } - void recordIntegerSetReference(IntegerSet integerSet) { - if (integerSetIds.count(integerSet) == 0) { - integerSetIds[integerSet] = integerSetsById.size(); - integerSetsById.push_back(integerSet); - } +private: + void recordAttributeReference(Attribute attr) { + // Don't recheck attributes that have already been seen or those that + // already have an alias. + if (!usedAttributes.insert(attr) || attrToAlias.count(attr)) + return; + + // If this attribute kind has an alias, then record one for this attribute. + auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind())); + if (alias == attrKindToAlias.end()) + return; + std::pair<StringRef, int> attrAlias(alias->second.first, + alias->second.second.size()); + attrToAlias.insert({attr, attrAlias}); + alias->second.second.push_back(attr); } void recordTypeReference(Type ty) { usedTypes.insert(ty); } @@ -145,15 +166,23 @@ private: // Initialize symbol aliases. void initializeSymbolAliases(); - DenseMap<AffineMap, int> affineMapIds; - std::vector<AffineMap> affineMapsById; - DenseMap<AffineMap, StringRef> affineMapToAlias; + /// Set of attributes known to be used within the module. + llvm::SetVector<Attribute> usedAttributes; - DenseMap<IntegerSet, int> integerSetIds; - std::vector<IntegerSet> integerSetsById; - DenseMap<IntegerSet, StringRef> integerSetToAlias; + /// Mapping between attribute and a pair comprised of a base alias name and a + /// count suffix. If the suffix is set to -1, it is not displayed. + llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias; + /// Mapping between attribute kind and a pair comprised of a base alias name + /// and a unique list of attributes belonging to this kind sorted by location + /// seen in the module. + llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>> + attrKindToAlias; + + /// Set of types known to be used within the module. llvm::SetVector<Type> usedTypes; + + /// A mapping between a type and a given alias. DenseMap<Type, StringRef> typeToAlias; }; } // end anonymous namespace @@ -169,24 +198,18 @@ void ModuleState::visitType(Type type) { visitType(result); } else if (auto memref = type.dyn_cast<MemRefType>()) { // Visit affine maps in memref type. - for (auto map : memref.getAffineMaps()) { - recordAffineMapReference(map); - } + for (auto map : memref.getAffineMaps()) + recordAttributeReference(AffineMapAttr::get(map)); } else if (auto vecOrTensor = type.dyn_cast<VectorOrTensorType>()) { visitType(vecOrTensor.getElementType()); } } void ModuleState::visitAttribute(Attribute attr) { - if (auto mapAttr = attr.dyn_cast<AffineMapAttr>()) { - recordAffineMapReference(mapAttr.getValue()); - } else if (auto setAttr = attr.dyn_cast<IntegerSetAttr>()) { - recordIntegerSetReference(setAttr.getValue()); - } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { - for (auto elt : arrayAttr.getValue()) { + recordAttributeReference(attr); + if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) + for (auto elt : arrayAttr.getValue()) visitAttribute(elt); - } - } } void ModuleState::visitOperation(Operation *op) { @@ -202,21 +225,14 @@ void ModuleState::visitOperation(Operation *op) { } // Utility to generate a function to register a symbol alias. -template <typename SymbolsInModuleSetTy, typename SymbolTy> -static void registerSymbolAlias(StringRef name, SymbolTy sym, - SymbolsInModuleSetTy &symbolsInModuleSet, - llvm::StringSet<> &usedAliases, - DenseMap<SymbolTy, StringRef> &symToAlias) { +static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) { assert(!name.empty() && "expected alias name to be non-empty"); - assert(sym && "expected alias symbol to be non-null"); // TODO(riverriddle) Assert that the provided alias name can be lexed as // an identifier. - // Check if the symbol is not referenced by the module or the name is - // already used by another alias. - if (!symbolsInModuleSet.count(sym) || !usedAliases.insert(name).second) - return; - symToAlias.try_emplace(sym, name); + // Check that the alias doesn't contain a '.' character and the name is not + // already in use. + return !name.contains('.') && usedAliases.insert(name).second; } void ModuleState::initializeSymbolAliases() { @@ -228,53 +244,76 @@ void ModuleState::initializeSymbolAliases() { auto dialects = context->getRegisteredDialects(); // Collect the set of aliases from each dialect. - SmallVector<std::pair<StringRef, AffineMap>, 8> affineMapAliases; - SmallVector<std::pair<StringRef, IntegerSet>, 8> integerSetAliases; - SmallVector<std::pair<StringRef, Type>, 16> typeAliases; + SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases; + SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases; + SmallVector<std::pair<Type, StringRef>, 16> typeAliases; + + // AffineMap/Integer set have specific kind aliases. + attributeKindAliases.emplace_back( + static_cast<unsigned>(Attribute::Kind::AffineMap), "map"); + attributeKindAliases.emplace_back( + static_cast<unsigned>(Attribute::Kind::IntegerSet), "set"); + for (auto *dialect : dialects) { - dialect->getAffineMapAliases(affineMapAliases); - dialect->getIntegerSetAliases(integerSetAliases); + dialect->getAttributeKindAliases(attributeKindAliases); + dialect->getAttributeAliases(attributeAliases); dialect->getTypeAliases(typeAliases); } - // Register the affine aliases. - // Create a regex for the non-alias names of sets and maps, so that an alias - // is not registered with a conflicting name. - llvm::Regex reservedAffineNames("(set|map)[0-9]+"); - - // AffineMap aliases - for (auto &affineAliasPair : affineMapAliases) { - if (!reservedAffineNames.match(affineAliasPair.first)) - registerSymbolAlias(affineAliasPair.first, affineAliasPair.second, - affineMapIds, usedAliases, affineMapToAlias); + // Setup the attribute kind aliases. + StringRef alias; + unsigned attrKind; + for (auto &attrAliasPair : attributeKindAliases) { + std::tie(attrKind, alias) = attrAliasPair; + assert(!alias.empty() && "expected non-empty alias string"); + if (!usedAliases.count(alias) && !alias.contains('.')) + attrKindToAlias.insert({attrKind, {alias, {}}}); } - // IntegerSet aliases - for (auto &integerSetAliasPair : integerSetAliases) { - if (!reservedAffineNames.match(integerSetAliasPair.first)) - registerSymbolAlias(integerSetAliasPair.first, integerSetAliasPair.second, - integerSetIds, usedAliases, integerSetToAlias); + // Clear the set of used identifiers so that the attribute kind aliases are + // just a prefix and not the full alias, i.e. there may be some overlap. + usedAliases.clear(); + + // Register the attribute aliases. + // Create a regex for the attribute kind alias names, these have a prefix with + // a counter appended to the end. We prevent normal aliases from having these + // names to avoid collisions. + llvm::Regex reservedAttrNames("[0-9]+$"); + + // Attribute value aliases. + Attribute attr; + for (auto &attrAliasPair : attributeAliases) { + std::tie(attr, alias) = attrAliasPair; + if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases)) + attrToAlias.insert({attr, {alias, kNonAttrKindAlias}}); } // Clear the set of used identifiers as types can have the same identifiers as // affine structures. usedAliases.clear(); + // Type aliases. for (auto &typeAliasPair : typeAliases) - registerSymbolAlias(typeAliasPair.first, typeAliasPair.second, usedTypes, - usedAliases, typeToAlias); + if (canRegisterAlias(typeAliasPair.second, usedAliases)) + typeToAlias.insert(typeAliasPair); } // Initializes module state, populating affine map and integer set state. void ModuleState::initialize(Module *module) { + // Initialize the symbol aliases. + initializeSymbolAliases(); + + // Walk the module and visit each operation. for (auto &fn : *module) { visitType(fn.getType()); + for (auto attr : fn.getAttrs()) + ModuleState::visitAttribute(attr.second); + for (auto attrList : fn.getAllArgAttrs()) + for (auto attr : attrList.getAttrs()) + ModuleState::visitAttribute(attr.second); fn.walk([&](Operation *op) { ModuleState::visitOperation(op); }); } - - // Initialize the symbol aliases. - initializeSymbolAliases(); } //===----------------------------------------------------------------------===// @@ -318,12 +357,6 @@ protected: void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {}); void printAttributeOptionalType(Attribute attr, bool includeType); - void printAffineMapId(int affineMapId) const; - void printAffineMapReference(AffineMap affineMap); - void printAffineMapAlias(StringRef alias) const; - void printIntegerSetId(int integerSetId) const; - void printIntegerSetReference(IntegerSet integerSet); - void printIntegerSetAlias(StringRef alias) const; void printTrailingLocation(Location loc); void printLocationInternal(Location loc, bool pretty = false); void printDenseElementsAttr(DenseElementsAttr attr); @@ -340,58 +373,6 @@ protected: }; } // end anonymous namespace -// Prints affine map identifier. -void ModulePrinter::printAffineMapId(int affineMapId) const { - os << "#map" << affineMapId; -} - -void ModulePrinter::printAffineMapAlias(StringRef alias) const { - os << '#' << alias; -} - -void ModulePrinter::printAffineMapReference(AffineMap affineMap) { - // Check for an affine map alias. - auto alias = state.getAffineMapAlias(affineMap); - if (!alias.empty()) - return printAffineMapAlias(alias); - - int mapId = state.getAffineMapId(affineMap); - if (mapId >= 0) { - // Map will be printed at top of module so print reference to its id. - printAffineMapId(mapId); - } else { - // Map not in module state so print inline. - affineMap.print(os); - } -} - -// Prints integer set identifier. -void ModulePrinter::printIntegerSetId(int integerSetId) const { - os << "#set" << integerSetId; -} - -void ModulePrinter::printIntegerSetAlias(StringRef alias) const { - os << '#' << alias; -} - -void ModulePrinter::printIntegerSetReference(IntegerSet integerSet) { - // Check for an integer set alias. - auto alias = state.getIntegerSetAlias(integerSet); - if (!alias.empty()) { - printIntegerSetAlias(alias); - return; - } - - int setId; - if ((setId = state.getIntegerSetId(integerSet)) >= 0) { - // The set will be printed at top of module; so print reference to its id. - printIntegerSetId(setId); - } else { - // Set not in module state so print inline. - integerSet.print(os); - } -} - void ModulePrinter::printTrailingLocation(Location loc) { // Check to see if we are printing debug information. if (!shouldPrintDebugInfoOpt) @@ -463,31 +444,11 @@ void ModulePrinter::printLocationInternal(Location loc, bool pretty) { } void ModulePrinter::print(Module *module) { - for (const auto &map : state.getAffineMapIds()) { - StringRef alias = state.getAffineMapAlias(map); - if (!alias.empty()) - printAffineMapAlias(alias); - else - printAffineMapId(state.getAffineMapId(map)); - os << " = "; - map.print(os); - os << '\n'; - } - for (const auto &set : state.getIntegerSetIds()) { - StringRef alias = state.getIntegerSetAlias(set); - if (!alias.empty()) - printIntegerSetAlias(alias); - else - printIntegerSetId(state.getIntegerSetId(set)); - os << " = "; - set.print(os); - os << '\n'; - } - for (const auto &type : state.getTypeIds()) { - StringRef alias = state.getTypeAlias(type); - if (!alias.empty()) - os << '!' << alias << " = type " << type << '\n'; - } + // Output the aliases at the top level. + state.printAttributeAliases(os); + state.printTypeAliases(os); + + // Print the module. for (auto &fn : *module) print(&fn); } @@ -545,6 +506,13 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr, return; } + // Check for an alias for this attribute. + Twine alias = state.getAttributeAlias(attr); + if (!alias.isTriviallyEmpty()) { + os << '#' << alias; + return; + } + switch (attr.getKind()) { case Attribute::Kind::Unit: os << "unit"; @@ -587,10 +555,10 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr, os << ']'; break; case Attribute::Kind::AffineMap: - printAffineMapReference(attr.cast<AffineMapAttr>().getValue()); + attr.cast<AffineMapAttr>().getValue().print(os); break; case Attribute::Kind::IntegerSet: - printIntegerSetReference(attr.cast<IntegerSetAttr>().getValue()); + attr.cast<IntegerSetAttr>().getValue().print(os); break; case Attribute::Kind::Type: printType(attr.cast<TypeAttr>().getValue()); @@ -889,7 +857,7 @@ void ModulePrinter::printType(Type type) { printType(v.getElementType()); for (auto map : v.getAffineMaps()) { os << ", "; - printAffineMapReference(map); + printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. if (v.getMemorySpace()) @@ -1303,12 +1271,12 @@ void FunctionPrinter::numberValueID(Value *value) { Type type = op->getResult(0)->getType(); if (auto intCst = cst.dyn_cast<IntegerAttr>()) { if (type.isIndex()) { - specialName << 'c' << intCst; + specialName << 'c' << intCst.getInt(); } else if (type.cast<IntegerType>().isInteger(1)) { // i1 constants get special names. specialName << (intCst.getInt() ? "true" : "false"); } else { - specialName << 'c' << intCst << '_' << type; + specialName << 'c' << intCst.getInt() << '_' << type; } } else if (cst.isa<FunctionAttr>()) { specialName << 'f'; @@ -1638,7 +1606,7 @@ void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); } void Attribute::print(raw_ostream &os) const { ModuleState state(/*no context is known*/ nullptr); - ModulePrinter(os, state).printAttribute(*this); + ModulePrinter(os, state).printAttributeAndType(*this); } void Attribute::dump() const { |

