summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/AsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp322
1 files changed, 145 insertions, 177 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d97b2fc8664..aeec1abb29c 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -35,6 +35,7 @@
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallString.h"
@@ -79,6 +80,9 @@ static llvm::cl::opt<bool>
namespace {
class ModuleState {
+ /// A special index constant used for non-kind attribute aliases.
+ static constexpr int kNonAttrKindAlias = -1;
+
public:
/// This is the current context if it is knowable, otherwise this is null.
MLIRContext *const context;
@@ -88,51 +92,68 @@ public:
// Initializes module state, populating affine map state.
void initialize(Module *module);
- StringRef getAffineMapAlias(AffineMap affineMap) const {
- return affineMapToAlias.lookup(affineMap);
- }
-
- int getAffineMapId(AffineMap affineMap) const {
- auto it = affineMapIds.find(affineMap);
- if (it == affineMapIds.end()) {
- return -1;
+ Twine getAttributeAlias(Attribute attr) const {
+ auto alias = attrToAlias.find(attr);
+ if (alias == attrToAlias.end())
+ return Twine();
+
+ // Return the alias for this attribute, along with the index if this was
+ // generated by a kind alias.
+ int kindIndex = alias->second.second;
+ return alias->second.first +
+ (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
+ }
+
+ void printAttributeAliases(raw_ostream &os) const {
+ auto printAlias = [&](StringRef alias, Attribute attr, int index) {
+ os << '#' << alias;
+ if (index != kNonAttrKindAlias)
+ os << index;
+ os << " = " << attr << '\n';
+ };
+
+ // Print all of the attribute kind aliases.
+ for (auto &kindAlias : attrKindToAlias) {
+ for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
+ printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
+ os << "\n";
}
- return it->second;
- }
-
- ArrayRef<AffineMap> getAffineMapIds() const { return affineMapsById; }
- StringRef getIntegerSetAlias(IntegerSet integerSet) const {
- return integerSetToAlias.lookup(integerSet);
- }
-
- int getIntegerSetId(IntegerSet integerSet) const {
- auto it = integerSetIds.find(integerSet);
- if (it == integerSetIds.end()) {
- return -1;
+ // In a second pass print all of the remaining attribute aliases that aren't
+ // kind aliases.
+ for (Attribute attr : usedAttributes) {
+ auto alias = attrToAlias.find(attr);
+ if (alias != attrToAlias.end() &&
+ alias->second.second == kNonAttrKindAlias)
+ printAlias(alias->second.first, attr, alias->second.second);
}
- return it->second;
}
- ArrayRef<IntegerSet> getIntegerSetIds() const { return integerSetsById; }
-
StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
- ArrayRef<Type> getTypeIds() const { return usedTypes.getArrayRef(); }
-
-private:
- void recordAffineMapReference(AffineMap affineMap) {
- if (affineMapIds.count(affineMap) == 0) {
- affineMapIds[affineMap] = affineMapsById.size();
- affineMapsById.push_back(affineMap);
+ void printTypeAliases(raw_ostream &os) const {
+ for (Type type : usedTypes) {
+ auto alias = typeToAlias.find(type);
+ if (alias != typeToAlias.end())
+ os << '!' << alias->second << " = type " << type << '\n';
}
}
- void recordIntegerSetReference(IntegerSet integerSet) {
- if (integerSetIds.count(integerSet) == 0) {
- integerSetIds[integerSet] = integerSetsById.size();
- integerSetsById.push_back(integerSet);
- }
+private:
+ void recordAttributeReference(Attribute attr) {
+ // Don't recheck attributes that have already been seen or those that
+ // already have an alias.
+ if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
+ return;
+
+ // If this attribute kind has an alias, then record one for this attribute.
+ auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
+ if (alias == attrKindToAlias.end())
+ return;
+ std::pair<StringRef, int> attrAlias(alias->second.first,
+ alias->second.second.size());
+ attrToAlias.insert({attr, attrAlias});
+ alias->second.second.push_back(attr);
}
void recordTypeReference(Type ty) { usedTypes.insert(ty); }
@@ -145,15 +166,23 @@ private:
// Initialize symbol aliases.
void initializeSymbolAliases();
- DenseMap<AffineMap, int> affineMapIds;
- std::vector<AffineMap> affineMapsById;
- DenseMap<AffineMap, StringRef> affineMapToAlias;
+ /// Set of attributes known to be used within the module.
+ llvm::SetVector<Attribute> usedAttributes;
- DenseMap<IntegerSet, int> integerSetIds;
- std::vector<IntegerSet> integerSetsById;
- DenseMap<IntegerSet, StringRef> integerSetToAlias;
+ /// Mapping between attribute and a pair comprised of a base alias name and a
+ /// count suffix. If the suffix is set to -1, it is not displayed.
+ llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias;
+ /// Mapping between attribute kind and a pair comprised of a base alias name
+ /// and a unique list of attributes belonging to this kind sorted by location
+ /// seen in the module.
+ llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
+ attrKindToAlias;
+
+ /// Set of types known to be used within the module.
llvm::SetVector<Type> usedTypes;
+
+ /// A mapping between a type and a given alias.
DenseMap<Type, StringRef> typeToAlias;
};
} // end anonymous namespace
@@ -169,24 +198,18 @@ void ModuleState::visitType(Type type) {
visitType(result);
} else if (auto memref = type.dyn_cast<MemRefType>()) {
// Visit affine maps in memref type.
- for (auto map : memref.getAffineMaps()) {
- recordAffineMapReference(map);
- }
+ for (auto map : memref.getAffineMaps())
+ recordAttributeReference(AffineMapAttr::get(map));
} else if (auto vecOrTensor = type.dyn_cast<VectorOrTensorType>()) {
visitType(vecOrTensor.getElementType());
}
}
void ModuleState::visitAttribute(Attribute attr) {
- if (auto mapAttr = attr.dyn_cast<AffineMapAttr>()) {
- recordAffineMapReference(mapAttr.getValue());
- } else if (auto setAttr = attr.dyn_cast<IntegerSetAttr>()) {
- recordIntegerSetReference(setAttr.getValue());
- } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
- for (auto elt : arrayAttr.getValue()) {
+ recordAttributeReference(attr);
+ if (auto arrayAttr = attr.dyn_cast<ArrayAttr>())
+ for (auto elt : arrayAttr.getValue())
visitAttribute(elt);
- }
- }
}
void ModuleState::visitOperation(Operation *op) {
@@ -202,21 +225,14 @@ void ModuleState::visitOperation(Operation *op) {
}
// Utility to generate a function to register a symbol alias.
-template <typename SymbolsInModuleSetTy, typename SymbolTy>
-static void registerSymbolAlias(StringRef name, SymbolTy sym,
- SymbolsInModuleSetTy &symbolsInModuleSet,
- llvm::StringSet<> &usedAliases,
- DenseMap<SymbolTy, StringRef> &symToAlias) {
+static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
assert(!name.empty() && "expected alias name to be non-empty");
- assert(sym && "expected alias symbol to be non-null");
// TODO(riverriddle) Assert that the provided alias name can be lexed as
// an identifier.
- // Check if the symbol is not referenced by the module or the name is
- // already used by another alias.
- if (!symbolsInModuleSet.count(sym) || !usedAliases.insert(name).second)
- return;
- symToAlias.try_emplace(sym, name);
+ // Check that the alias doesn't contain a '.' character and the name is not
+ // already in use.
+ return !name.contains('.') && usedAliases.insert(name).second;
}
void ModuleState::initializeSymbolAliases() {
@@ -228,53 +244,76 @@ void ModuleState::initializeSymbolAliases() {
auto dialects = context->getRegisteredDialects();
// Collect the set of aliases from each dialect.
- SmallVector<std::pair<StringRef, AffineMap>, 8> affineMapAliases;
- SmallVector<std::pair<StringRef, IntegerSet>, 8> integerSetAliases;
- SmallVector<std::pair<StringRef, Type>, 16> typeAliases;
+ SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
+ SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
+ SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
+
+ // AffineMap/Integer set have specific kind aliases.
+ attributeKindAliases.emplace_back(
+ static_cast<unsigned>(Attribute::Kind::AffineMap), "map");
+ attributeKindAliases.emplace_back(
+ static_cast<unsigned>(Attribute::Kind::IntegerSet), "set");
+
for (auto *dialect : dialects) {
- dialect->getAffineMapAliases(affineMapAliases);
- dialect->getIntegerSetAliases(integerSetAliases);
+ dialect->getAttributeKindAliases(attributeKindAliases);
+ dialect->getAttributeAliases(attributeAliases);
dialect->getTypeAliases(typeAliases);
}
- // Register the affine aliases.
- // Create a regex for the non-alias names of sets and maps, so that an alias
- // is not registered with a conflicting name.
- llvm::Regex reservedAffineNames("(set|map)[0-9]+");
-
- // AffineMap aliases
- for (auto &affineAliasPair : affineMapAliases) {
- if (!reservedAffineNames.match(affineAliasPair.first))
- registerSymbolAlias(affineAliasPair.first, affineAliasPair.second,
- affineMapIds, usedAliases, affineMapToAlias);
+ // Setup the attribute kind aliases.
+ StringRef alias;
+ unsigned attrKind;
+ for (auto &attrAliasPair : attributeKindAliases) {
+ std::tie(attrKind, alias) = attrAliasPair;
+ assert(!alias.empty() && "expected non-empty alias string");
+ if (!usedAliases.count(alias) && !alias.contains('.'))
+ attrKindToAlias.insert({attrKind, {alias, {}}});
}
- // IntegerSet aliases
- for (auto &integerSetAliasPair : integerSetAliases) {
- if (!reservedAffineNames.match(integerSetAliasPair.first))
- registerSymbolAlias(integerSetAliasPair.first, integerSetAliasPair.second,
- integerSetIds, usedAliases, integerSetToAlias);
+ // Clear the set of used identifiers so that the attribute kind aliases are
+ // just a prefix and not the full alias, i.e. there may be some overlap.
+ usedAliases.clear();
+
+ // Register the attribute aliases.
+ // Create a regex for the attribute kind alias names, these have a prefix with
+ // a counter appended to the end. We prevent normal aliases from having these
+ // names to avoid collisions.
+ llvm::Regex reservedAttrNames("[0-9]+$");
+
+ // Attribute value aliases.
+ Attribute attr;
+ for (auto &attrAliasPair : attributeAliases) {
+ std::tie(attr, alias) = attrAliasPair;
+ if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
+ attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
}
// Clear the set of used identifiers as types can have the same identifiers as
// affine structures.
usedAliases.clear();
+ // Type aliases.
for (auto &typeAliasPair : typeAliases)
- registerSymbolAlias(typeAliasPair.first, typeAliasPair.second, usedTypes,
- usedAliases, typeToAlias);
+ if (canRegisterAlias(typeAliasPair.second, usedAliases))
+ typeToAlias.insert(typeAliasPair);
}
// Initializes module state, populating affine map and integer set state.
void ModuleState::initialize(Module *module) {
+ // Initialize the symbol aliases.
+ initializeSymbolAliases();
+
+ // Walk the module and visit each operation.
for (auto &fn : *module) {
visitType(fn.getType());
+ for (auto attr : fn.getAttrs())
+ ModuleState::visitAttribute(attr.second);
+ for (auto attrList : fn.getAllArgAttrs())
+ for (auto attr : attrList.getAttrs())
+ ModuleState::visitAttribute(attr.second);
fn.walk([&](Operation *op) { ModuleState::visitOperation(op); });
}
-
- // Initialize the symbol aliases.
- initializeSymbolAliases();
}
//===----------------------------------------------------------------------===//
@@ -318,12 +357,6 @@ protected:
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {});
void printAttributeOptionalType(Attribute attr, bool includeType);
- void printAffineMapId(int affineMapId) const;
- void printAffineMapReference(AffineMap affineMap);
- void printAffineMapAlias(StringRef alias) const;
- void printIntegerSetId(int integerSetId) const;
- void printIntegerSetReference(IntegerSet integerSet);
- void printIntegerSetAlias(StringRef alias) const;
void printTrailingLocation(Location loc);
void printLocationInternal(Location loc, bool pretty = false);
void printDenseElementsAttr(DenseElementsAttr attr);
@@ -340,58 +373,6 @@ protected:
};
} // end anonymous namespace
-// Prints affine map identifier.
-void ModulePrinter::printAffineMapId(int affineMapId) const {
- os << "#map" << affineMapId;
-}
-
-void ModulePrinter::printAffineMapAlias(StringRef alias) const {
- os << '#' << alias;
-}
-
-void ModulePrinter::printAffineMapReference(AffineMap affineMap) {
- // Check for an affine map alias.
- auto alias = state.getAffineMapAlias(affineMap);
- if (!alias.empty())
- return printAffineMapAlias(alias);
-
- int mapId = state.getAffineMapId(affineMap);
- if (mapId >= 0) {
- // Map will be printed at top of module so print reference to its id.
- printAffineMapId(mapId);
- } else {
- // Map not in module state so print inline.
- affineMap.print(os);
- }
-}
-
-// Prints integer set identifier.
-void ModulePrinter::printIntegerSetId(int integerSetId) const {
- os << "#set" << integerSetId;
-}
-
-void ModulePrinter::printIntegerSetAlias(StringRef alias) const {
- os << '#' << alias;
-}
-
-void ModulePrinter::printIntegerSetReference(IntegerSet integerSet) {
- // Check for an integer set alias.
- auto alias = state.getIntegerSetAlias(integerSet);
- if (!alias.empty()) {
- printIntegerSetAlias(alias);
- return;
- }
-
- int setId;
- if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
- // The set will be printed at top of module; so print reference to its id.
- printIntegerSetId(setId);
- } else {
- // Set not in module state so print inline.
- integerSet.print(os);
- }
-}
-
void ModulePrinter::printTrailingLocation(Location loc) {
// Check to see if we are printing debug information.
if (!shouldPrintDebugInfoOpt)
@@ -463,31 +444,11 @@ void ModulePrinter::printLocationInternal(Location loc, bool pretty) {
}
void ModulePrinter::print(Module *module) {
- for (const auto &map : state.getAffineMapIds()) {
- StringRef alias = state.getAffineMapAlias(map);
- if (!alias.empty())
- printAffineMapAlias(alias);
- else
- printAffineMapId(state.getAffineMapId(map));
- os << " = ";
- map.print(os);
- os << '\n';
- }
- for (const auto &set : state.getIntegerSetIds()) {
- StringRef alias = state.getIntegerSetAlias(set);
- if (!alias.empty())
- printIntegerSetAlias(alias);
- else
- printIntegerSetId(state.getIntegerSetId(set));
- os << " = ";
- set.print(os);
- os << '\n';
- }
- for (const auto &type : state.getTypeIds()) {
- StringRef alias = state.getTypeAlias(type);
- if (!alias.empty())
- os << '!' << alias << " = type " << type << '\n';
- }
+ // Output the aliases at the top level.
+ state.printAttributeAliases(os);
+ state.printTypeAliases(os);
+
+ // Print the module.
for (auto &fn : *module)
print(&fn);
}
@@ -545,6 +506,13 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr,
return;
}
+ // Check for an alias for this attribute.
+ Twine alias = state.getAttributeAlias(attr);
+ if (!alias.isTriviallyEmpty()) {
+ os << '#' << alias;
+ return;
+ }
+
switch (attr.getKind()) {
case Attribute::Kind::Unit:
os << "unit";
@@ -587,10 +555,10 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr,
os << ']';
break;
case Attribute::Kind::AffineMap:
- printAffineMapReference(attr.cast<AffineMapAttr>().getValue());
+ attr.cast<AffineMapAttr>().getValue().print(os);
break;
case Attribute::Kind::IntegerSet:
- printIntegerSetReference(attr.cast<IntegerSetAttr>().getValue());
+ attr.cast<IntegerSetAttr>().getValue().print(os);
break;
case Attribute::Kind::Type:
printType(attr.cast<TypeAttr>().getValue());
@@ -889,7 +857,7 @@ void ModulePrinter::printType(Type type) {
printType(v.getElementType());
for (auto map : v.getAffineMaps()) {
os << ", ";
- printAffineMapReference(map);
+ printAttribute(AffineMapAttr::get(map));
}
// Only print the memory space if it is the non-default one.
if (v.getMemorySpace())
@@ -1303,12 +1271,12 @@ void FunctionPrinter::numberValueID(Value *value) {
Type type = op->getResult(0)->getType();
if (auto intCst = cst.dyn_cast<IntegerAttr>()) {
if (type.isIndex()) {
- specialName << 'c' << intCst;
+ specialName << 'c' << intCst.getInt();
} else if (type.cast<IntegerType>().isInteger(1)) {
// i1 constants get special names.
specialName << (intCst.getInt() ? "true" : "false");
} else {
- specialName << 'c' << intCst << '_' << type;
+ specialName << 'c' << intCst.getInt() << '_' << type;
}
} else if (cst.isa<FunctionAttr>()) {
specialName << 'f';
@@ -1638,7 +1606,7 @@ void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); }
void Attribute::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
- ModulePrinter(os, state).printAttribute(*this);
+ ModulePrinter(os, state).printAttributeAndType(*this);
}
void Attribute::dump() const {
OpenPOWER on IntegriCloud