//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements the MLIR AsmPrinter class, which is used to implement // the various print() methods on the core IR objects. // //===----------------------------------------------------------------------===// #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Regex.h" using namespace mlir; void Identifier::print(raw_ostream &os) const { os << str(); } void Identifier::dump() const { print(llvm::errs()); } void OperationName::print(raw_ostream &os) const { os << getStringRef(); } void OperationName::dump() const { print(llvm::errs()); } OpAsmPrinter::~OpAsmPrinter() {} //===----------------------------------------------------------------------===// // ModuleState //===----------------------------------------------------------------------===// // TODO(riverriddle) Rethink this flag when we have a pass that can remove debug // info or when we have a system for printer flags. static llvm::cl::opt shouldPrintDebugInfoOpt("mlir-print-debuginfo", llvm::cl::desc("Print debug info in MLIR output"), llvm::cl::init(false)); static llvm::cl::opt printPrettyDebugInfo( "mlir-pretty-debuginfo", llvm::cl::desc("Print pretty debug info in MLIR output"), llvm::cl::init(false)); // Use the generic op output form in the operation printer even if the custom // form is defined. static llvm::cl::opt printGenericOpForm("mlir-print-op-generic", llvm::cl::desc("Print the generic op form"), llvm::cl::init(false), llvm::cl::Hidden); namespace { /// A special index constant used for non-kind attribute aliases. static constexpr int kNonAttrKindAlias = -1; class ModuleState { public: explicit ModuleState(MLIRContext *context) : interfaces(context) {} void initialize(Operation *op); Twine getAttributeAlias(Attribute attr) const { auto alias = attrToAlias.find(attr); if (alias == attrToAlias.end()) return Twine(); // Return the alias for this attribute, along with the index if this was // generated by a kind alias. int kindIndex = alias->second.second; return alias->second.first + (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex)); } void printAttributeAliases(raw_ostream &os) const { auto printAlias = [&](StringRef alias, Attribute attr, int index) { os << '#' << alias; if (index != kNonAttrKindAlias) os << index; os << " = " << attr << '\n'; }; // Print all of the attribute kind aliases. for (auto &kindAlias : attrKindToAlias) { for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i) printAlias(kindAlias.second.first, kindAlias.second.second[i], i); os << "\n"; } // In a second pass print all of the remaining attribute aliases that aren't // kind aliases. for (Attribute attr : usedAttributes) { auto alias = attrToAlias.find(attr); if (alias != attrToAlias.end() && alias->second.second == kNonAttrKindAlias) printAlias(alias->second.first, attr, alias->second.second); } } StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); } void printTypeAliases(raw_ostream &os) const { for (Type type : usedTypes) { auto alias = typeToAlias.find(type); if (alias != typeToAlias.end()) os << '!' << alias->second << " = type " << type << '\n'; } } /// Get an instance of the OpAsmDialectInterface for the given dialect, or /// null if one wasn't registered. const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { return interfaces.getInterfaceFor(dialect); } private: void recordAttributeReference(Attribute attr) { // Don't recheck attributes that have already been seen or those that // already have an alias. if (!usedAttributes.insert(attr) || attrToAlias.count(attr)) return; // If this attribute kind has an alias, then record one for this attribute. auto alias = attrKindToAlias.find(static_cast(attr.getKind())); if (alias == attrKindToAlias.end()) return; std::pair attrAlias(alias->second.first, alias->second.second.size()); attrToAlias.insert({attr, attrAlias}); alias->second.second.push_back(attr); } void recordTypeReference(Type ty) { usedTypes.insert(ty); } // Visit functions. void visitOperation(Operation *op); void visitType(Type type); void visitAttribute(Attribute attr); // Initialize symbol aliases. void initializeSymbolAliases(); /// Set of attributes known to be used within the module. llvm::SetVector usedAttributes; /// Mapping between attribute and a pair comprised of a base alias name and a /// count suffix. If the suffix is set to -1, it is not displayed. llvm::MapVector> attrToAlias; /// Mapping between attribute kind and a pair comprised of a base alias name /// and a unique list of attributes belonging to this kind sorted by location /// seen in the module. llvm::MapVector>> attrKindToAlias; /// Set of types known to be used within the module. llvm::SetVector usedTypes; /// A mapping between a type and a given alias. DenseMap typeToAlias; /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection interfaces; }; } // end anonymous namespace // TODO Support visiting other types/operations when implemented. void ModuleState::visitType(Type type) { recordTypeReference(type); if (auto funcType = type.dyn_cast()) { // Visit input and result types for functions. for (auto input : funcType.getInputs()) visitType(input); for (auto result : funcType.getResults()) visitType(result); return; } if (auto memref = type.dyn_cast()) { // Visit affine maps in memref type. for (auto map : memref.getAffineMaps()) recordAttributeReference(AffineMapAttr::get(map)); } if (auto shapedType = type.dyn_cast()) { visitType(shapedType.getElementType()); } } void ModuleState::visitAttribute(Attribute attr) { recordAttributeReference(attr); if (auto arrayAttr = attr.dyn_cast()) { for (auto elt : arrayAttr.getValue()) visitAttribute(elt); } else if (auto typeAttr = attr.dyn_cast()) { visitType(typeAttr.getValue()); } } void ModuleState::visitOperation(Operation *op) { // Visit all the types used in the operation. for (auto type : op->getOperandTypes()) visitType(type); for (auto type : op->getResultTypes()) visitType(type); for (auto ®ion : op->getRegions()) for (auto &block : region) for (auto *arg : block.getArguments()) visitType(arg->getType()); // Visit each of the attributes. for (auto elt : op->getAttrs()) visitAttribute(elt.second); } // Utility to generate a function to register a symbol alias. static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) { assert(!name.empty() && "expected alias name to be non-empty"); // TODO(riverriddle) Assert that the provided alias name can be lexed as // an identifier. // Check that the alias doesn't contain a '.' character and the name is not // already in use. return !name.contains('.') && usedAliases.insert(name).second; } void ModuleState::initializeSymbolAliases() { // Track the identifiers in use for each symbol so that the same identifier // isn't used twice. llvm::StringSet<> usedAliases; // Collect the set of aliases from each dialect. SmallVector, 8> attributeKindAliases; SmallVector, 8> attributeAliases; SmallVector, 16> typeAliases; // AffineMap/Integer set have specific kind aliases. attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map"); attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set"); for (auto &interface : interfaces) { interface.getAttributeKindAliases(attributeKindAliases); interface.getAttributeAliases(attributeAliases); interface.getTypeAliases(typeAliases); } // Setup the attribute kind aliases. StringRef alias; unsigned attrKind; for (auto &attrAliasPair : attributeKindAliases) { std::tie(attrKind, alias) = attrAliasPair; assert(!alias.empty() && "expected non-empty alias string"); if (!usedAliases.count(alias) && !alias.contains('.')) attrKindToAlias.insert({attrKind, {alias, {}}}); } // Clear the set of used identifiers so that the attribute kind aliases are // just a prefix and not the full alias, i.e. there may be some overlap. usedAliases.clear(); // Register the attribute aliases. // Create a regex for the attribute kind alias names, these have a prefix with // a counter appended to the end. We prevent normal aliases from having these // names to avoid collisions. llvm::Regex reservedAttrNames("[0-9]+$"); // Attribute value aliases. Attribute attr; for (auto &attrAliasPair : attributeAliases) { std::tie(attr, alias) = attrAliasPair; if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases)) attrToAlias.insert({attr, {alias, kNonAttrKindAlias}}); } // Clear the set of used identifiers as types can have the same identifiers as // affine structures. usedAliases.clear(); // Type aliases. for (auto &typeAliasPair : typeAliases) if (canRegisterAlias(typeAliasPair.second, usedAliases)) typeToAlias.insert(typeAliasPair); } void ModuleState::initialize(Operation *op) { // Initialize the symbol aliases. initializeSymbolAliases(); // Visit each of the nested operations. op->walk([&](Operation *op) { visitOperation(op); }); } //===----------------------------------------------------------------------===// // ModulePrinter //===----------------------------------------------------------------------===// namespace { class ModulePrinter { public: ModulePrinter(raw_ostream &os, ModuleState *state = nullptr) : os(os), state(state) {} explicit ModulePrinter(ModulePrinter &printer) : os(printer.os), state(printer.state) {} template inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; }); } void print(ModuleOp module); /// Print the given attribute. If 'mayElideType' is true, some attributes are /// printed without the type when the type matches the default used in the /// parser (for example i64 is the default for integer attributes). void printAttribute(Attribute attr, bool mayElideType = false); void printType(Type type); void printLocation(LocationAttr loc); void printAffineMap(AffineMap map); void printAffineExpr( AffineExpr expr, llvm::function_ref printValueName = nullptr); void printAffineConstraint(AffineExpr expr, bool isEq); void printIntegerSet(IntegerSet set); protected: void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}); void printTrailingLocation(Location loc); void printLocationInternal(LocationAttr loc, bool pretty = false); void printDenseElementsAttr(DenseElementsAttr attr); /// This enum is used to represent the binding stength of the enclosing /// context that an AffineExprStorage is being printed in, so we can /// intelligently produce parens. enum class BindingStrength { Weak, // + and - Strong, // All other binary operators. }; void printAffineExprInternal( AffineExpr expr, BindingStrength enclosingTightness, llvm::function_ref printValueName = nullptr); /// The output stream for the printer. raw_ostream &os; /// An optional printer state for the module. ModuleState *state; }; } // end anonymous namespace void ModulePrinter::printTrailingLocation(Location loc) { // Check to see if we are printing debug information. if (!shouldPrintDebugInfoOpt) return; os << " "; printLocation(loc); } void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { switch (loc.getKind()) { case StandardAttributes::UnknownLocation: if (pretty) os << "[unknown]"; else os << "unknown"; break; case StandardAttributes::FileLineColLocation: { auto fileLoc = loc.cast(); auto mayQuote = pretty ? "" : "\""; os << mayQuote << fileLoc.getFilename() << mayQuote << ':' << fileLoc.getLine() << ':' << fileLoc.getColumn(); break; } case StandardAttributes::NameLocation: { auto nameLoc = loc.cast(); os << '\"' << nameLoc.getName() << '\"'; // Print the child if it isn't unknown. auto childLoc = nameLoc.getChildLoc(); if (!childLoc.isa()) { os << '('; printLocationInternal(childLoc, pretty); os << ')'; } break; } case StandardAttributes::CallSiteLocation: { auto callLocation = loc.cast(); auto caller = callLocation.getCaller(); auto callee = callLocation.getCallee(); if (!pretty) os << "callsite("; printLocationInternal(callee, pretty); if (pretty) { if (callee.isa()) { if (caller.isa()) { os << " at "; } else { os << "\n at "; } } else { os << "\n at "; } } else { os << " at "; } printLocationInternal(caller, pretty); if (!pretty) os << ")"; break; } case StandardAttributes::FusedLocation: { auto fusedLoc = loc.cast(); if (!pretty) os << "fused"; if (auto metadata = fusedLoc.getMetadata()) os << '<' << metadata << '>'; os << '['; interleave( fusedLoc.getLocations(), [&](Location loc) { printLocationInternal(loc, pretty); }, [&]() { os << ", "; }); os << ']'; break; } } } /// Print a floating point value in a way that the parser will be able to /// round-trip losslessly. static void printFloatValue(const APFloat &apValue, raw_ostream &os) { // We would like to output the FP constant value in exponential notation, // but we cannot do this if doing so will lose precision. Check here to // make sure that we only output it in exponential format if we can parse // the value back and get the same value. bool isInf = apValue.isInfinity(); bool isNaN = apValue.isNaN(); if (!isInf && !isNaN) { SmallString<128> strValue; apValue.toString(strValue, 6, 0, false); // Check to make sure that the stringized number is not some string like // "Inf" or NaN, that atof will accept, but the lexer will not. Check // that the string matches the "[-+]?[0-9]" regex. assert(((strValue[0] >= '0' && strValue[0] <= '9') || ((strValue[0] == '-' || strValue[0] == '+') && (strValue[1] >= '0' && strValue[1] <= '9'))) && "[-+]?[0-9] regex does not match!"); // Parse back the stringized version and check that the value is equal // (i.e., there is no precision loss). If it is not, use the default format // of APFloat instead of the exponential notation. if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { strValue.clear(); apValue.toString(strValue); } os << strValue; return; } // Print special values in hexadecimal format. The sign bit should be // included in the literal. SmallVector str; APInt apInt = apValue.bitcastToAPInt(); apInt.toString(str, /*Radix=*/16, /*Signed=*/false, /*formatAsCLiteral=*/true); os << str; } void ModulePrinter::printLocation(LocationAttr loc) { if (printPrettyDebugInfo) { printLocationInternal(loc, /*pretty=*/true); } else { os << "loc("; printLocationInternal(loc); os << ')'; } } /// Returns if the given dialect symbol data is simple enough to print in the /// pretty form, i.e. without the enclosing "". static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { // The name must start with an identifier. if (symName.empty() || !isalpha(symName.front())) return false; // Ignore all the characters that are valid in an identifier in the symbol // name. symName = symName.drop_while([](char c) { return llvm::isAlnum(c) || c == '.'; }); if (symName.empty()) return true; // If we got to an unexpected character, then it must be a <>. Check those // recursively. if (symName.front() != '<' || symName.back() != '>') return false; SmallVector nestedPunctuation; do { // If we ran out of characters, then we had a punctuation mismatch. if (symName.empty()) return false; auto c = symName.front(); symName = symName.drop_front(); switch (c) { // We never allow null characters. This is an EOF indicator for the lexer // which we could handle, but isn't important for any known dialect. case '\0': return false; case '<': case '[': case '(': case '{': nestedPunctuation.push_back(c); continue; case '-': // Treat `->` as a special token. if (!symName.empty() && symName.front() == '>') { symName = symName.drop_front(); continue; } break; // Reject types with mismatched brackets. case '>': if (nestedPunctuation.pop_back_val() != '<') return false; break; case ']': if (nestedPunctuation.pop_back_val() != '[') return false; break; case ')': if (nestedPunctuation.pop_back_val() != '(') return false; break; case '}': if (nestedPunctuation.pop_back_val() != '{') return false; break; default: continue; } // We're done when the punctuation is fully matched. } while (!nestedPunctuation.empty()); // If there were extra characters, then we failed. return symName.empty(); } /// Print the given dialect symbol to the stream. static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, StringRef dialectName, StringRef symString) { os << symPrefix << dialectName; // If this symbol name is simple enough, print it directly in pretty form, // otherwise, we print it as an escaped string. if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { os << '.' << symString; return; } // TODO: escape the symbol name, it could contain " characters. os << "<\"" << symString << "\">"; } void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { if (!attr) { os << "<>"; return; } // Check for an alias for this attribute. if (state) { Twine alias = state->getAttributeAlias(attr); if (!alias.isTriviallyEmpty()) { os << '#' << alias; return; } } switch (attr.getKind()) { default: { auto &dialect = attr.getDialect(); // Ask the dialect to serialize the attribute to a string. std::string attrName; { llvm::raw_string_ostream attrNameStr(attrName); dialect.printAttribute(attr, attrNameStr); } printDialectSymbol(os, "#", dialect.getNamespace(), attrName); break; } case StandardAttributes::Opaque: { auto opaqueAttr = attr.cast(); printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), opaqueAttr.getAttrData()); break; } case StandardAttributes::Unit: os << "unit"; break; case StandardAttributes::Bool: os << (attr.cast().getValue() ? "true" : "false"); // BoolAttr always elides the type. return; case StandardAttributes::Dictionary: os << '{'; interleaveComma(attr.cast().getValue(), [&](NamedAttribute attr) { os << attr.first << " = "; printAttribute(attr.second); }); os << '}'; break; case StandardAttributes::Integer: { auto intAttr = attr.cast(); // Print all integer attributes as signed unless i1. bool isSigned = intAttr.getType().isIndex() || intAttr.getType().getIntOrFloatBitWidth() != 1; intAttr.getValue().print(os, isSigned); // IntegerAttr elides the type if I64. if (mayElideType && intAttr.getType().isInteger(64)) return; break; } case StandardAttributes::Float: { auto floatAttr = attr.cast(); printFloatValue(floatAttr.getValue(), os); // FloatAttr elides the type if F64. if (mayElideType && floatAttr.getType().isF64()) return; break; } case StandardAttributes::String: os << '"'; printEscapedString(attr.cast().getValue(), os); os << '"'; break; case StandardAttributes::Array: os << '['; interleaveComma(attr.cast().getValue(), [&](Attribute attr) { printAttribute(attr, /*mayElideType=*/true); }); os << ']'; break; case StandardAttributes::AffineMap: attr.cast().getValue().print(os); // AffineMap always elides the type. return; case StandardAttributes::IntegerSet: attr.cast().getValue().print(os); break; case StandardAttributes::Type: printType(attr.cast().getValue()); break; case StandardAttributes::SymbolRef: os << '@' << attr.cast().getValue(); break; case StandardAttributes::OpaqueElements: { auto eltsAttr = attr.cast(); os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", "; os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">"; break; } case StandardAttributes::DenseElements: { auto eltsAttr = attr.cast(); os << "dense<"; printDenseElementsAttr(eltsAttr); os << '>'; break; } case StandardAttributes::SparseElements: { auto elementsAttr = attr.cast(); os << "sparse<"; printDenseElementsAttr(elementsAttr.getIndices()); os << ", "; printDenseElementsAttr(elementsAttr.getValues()); os << '>'; break; } // Location attributes. case StandardAttributes::CallSiteLocation: case StandardAttributes::FileLineColLocation: case StandardAttributes::FusedLocation: case StandardAttributes::NameLocation: case StandardAttributes::UnknownLocation: printLocation(attr.cast()); break; } // Print the type if it isn't a 'none' type. auto attrType = attr.getType(); if (!attrType.isa()) { os << " : "; printType(attrType); } } /// Print the integer element of the given DenseElementsAttr at 'index'. static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os, unsigned index) { APInt value = *std::next(attr.int_value_begin(), index); if (value.getBitWidth() == 1) os << (value.getBoolValue() ? "true" : "false"); else value.print(os, /*isSigned=*/true); } /// Print the float element of the given DenseElementsAttr at 'index'. static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os, unsigned index) { APFloat value = *std::next(attr.float_value_begin(), index); printFloatValue(value, os); } void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { auto type = attr.getType(); auto shape = type.getShape(); auto rank = type.getRank(); // The function used to print elements of this attribute. auto printEltFn = type.getElementType().isa() ? printDenseIntElement : printDenseFloatElement; // Special case for 0-d and splat tensors. if (attr.isSplat()) { printEltFn(attr, os, 0); return; } // Special case for degenerate tensors. auto numElements = type.getNumElements(); if (numElements == 0) { for (int i = 0; i < rank; ++i) os << '['; for (int i = 0; i < rank; ++i) os << ']'; return; } // We use a mixed-radix counter to iterate through the shape. When we bump a // non-least-significant digit, we emit a close bracket. When we next emit an // element we re-open all closed brackets. // The mixed-radix counter, with radices in 'shape'. SmallVector counter(rank, 0); // The number of brackets that have been opened and not closed. unsigned openBrackets = 0; auto bumpCounter = [&]() { // Bump the least significant digit. ++counter[rank - 1]; // Iterate backwards bubbling back the increment. for (unsigned i = rank - 1; i > 0; --i) if (counter[i] >= shape[i]) { // Index 'i' is rolled over. Bump (i-1) and close a bracket. counter[i] = 0; ++counter[i - 1]; --openBrackets; os << ']'; } }; for (unsigned idx = 0, e = numElements; idx != e; ++idx) { if (idx != 0) os << ", "; while (openBrackets++ < rank) os << '['; openBrackets = rank; printEltFn(attr, os, idx); bumpCounter(); } while (openBrackets-- > 0) os << ']'; } void ModulePrinter::printType(Type type) { // Check for an alias for this type. if (state) { StringRef alias = state->getTypeAlias(type); if (!alias.empty()) { os << '!' << alias; return; } } switch (type.getKind()) { default: { auto &dialect = type.getDialect(); // Ask the dialect to serialize the type to a string. std::string typeName; { llvm::raw_string_ostream typeNameStr(typeName); dialect.printType(type, typeNameStr); } printDialectSymbol(os, "!", dialect.getNamespace(), typeName); return; } case Type::Kind::Opaque: { auto opaqueTy = type.cast(); printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), opaqueTy.getTypeData()); return; } case StandardTypes::Index: os << "index"; return; case StandardTypes::BF16: os << "bf16"; return; case StandardTypes::F16: os << "f16"; return; case StandardTypes::F32: os << "f32"; return; case StandardTypes::F64: os << "f64"; return; case StandardTypes::Integer: { auto integer = type.cast(); os << 'i' << integer.getWidth(); return; } case Type::Kind::Function: { auto func = type.cast(); os << '('; interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); os << ") -> "; auto results = func.getResults(); if (results.size() == 1 && !results[0].isa()) os << results[0]; else { os << '('; interleaveComma(results, [&](Type type) { printType(type); }); os << ')'; } return; } case StandardTypes::Vector: { auto v = type.cast(); os << "vector<"; for (auto dim : v.getShape()) os << dim << 'x'; os << v.getElementType() << '>'; return; } case StandardTypes::RankedTensor: { auto v = type.cast(); os << "tensor<"; for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } os << v.getElementType() << '>'; return; } case StandardTypes::UnrankedTensor: { auto v = type.cast(); os << "tensor<*x"; printType(v.getElementType()); os << '>'; return; } case StandardTypes::MemRef: { auto v = type.cast(); os << "memref<"; for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } printType(v.getElementType()); for (auto map : v.getAffineMaps()) { os << ", "; printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. if (v.getMemorySpace()) os << ", " << v.getMemorySpace(); os << '>'; return; } case StandardTypes::Complex: os << "complex<"; printType(type.cast().getElementType()); os << '>'; return; case StandardTypes::Tuple: { auto tuple = type.cast(); os << "tuple<"; interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); }); os << '>'; return; } case StandardTypes::None: os << "none"; return; } } //===----------------------------------------------------------------------===// // Affine expressions and maps //===----------------------------------------------------------------------===// void ModulePrinter::printAffineExpr( AffineExpr expr, llvm::function_ref printValueName) { printAffineExprInternal(expr, BindingStrength::Weak, printValueName); } void ModulePrinter::printAffineExprInternal( AffineExpr expr, BindingStrength enclosingTightness, llvm::function_ref printValueName) { const char *binopSpelling = nullptr; switch (expr.getKind()) { case AffineExprKind::SymbolId: { unsigned pos = expr.cast().getPosition(); if (printValueName) printValueName(pos, /*isSymbol=*/true); else os << 's' << pos; return; } case AffineExprKind::DimId: { unsigned pos = expr.cast().getPosition(); if (printValueName) printValueName(pos, /*isSymbol=*/false); else os << 'd' << pos; return; } case AffineExprKind::Constant: os << expr.cast().getValue(); return; case AffineExprKind::Add: binopSpelling = " + "; break; case AffineExprKind::Mul: binopSpelling = " * "; break; case AffineExprKind::FloorDiv: binopSpelling = " floordiv "; break; case AffineExprKind::CeilDiv: binopSpelling = " ceildiv "; break; case AffineExprKind::Mod: binopSpelling = " mod "; break; } auto binOp = expr.cast(); AffineExpr lhsExpr = binOp.getLHS(); AffineExpr rhsExpr = binOp.getRHS(); // Handle tightly binding binary operators. if (binOp.getKind() != AffineExprKind::Add) { if (enclosingTightness == BindingStrength::Strong) os << '('; // Pretty print multiplication with -1. auto rhsConst = rhsExpr.dyn_cast(); if (rhsConst && rhsConst.getValue() == -1) { os << "-"; printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); return; } printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); os << binopSpelling; printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } // Print out special "pretty" forms for add. if (enclosingTightness == BindingStrength::Strong) os << '('; // Pretty print addition to a product that has a negative operand as a // subtraction. if (auto rhs = rhsExpr.dyn_cast()) { if (rhs.getKind() == AffineExprKind::Mul) { AffineExpr rrhsExpr = rhs.getRHS(); if (auto rrhs = rrhsExpr.dyn_cast()) { if (rrhs.getValue() == -1) { printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " - "; if (rhs.getLHS().getKind() == AffineExprKind::Add) { printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, printValueName); } else { printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, printValueName); } if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } if (rrhs.getValue() < -1) { printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " - "; printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, printValueName); os << " * " << -rrhs.getValue(); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } } } } // Pretty print addition to a negative number as a subtraction. if (auto rhsConst = rhsExpr.dyn_cast()) { if (rhsConst.getValue() < 0) { printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " - " << -rhsConst.getValue(); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } } printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " + "; printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); if (enclosingTightness == BindingStrength::Strong) os << ')'; } void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { printAffineExprInternal(expr, BindingStrength::Weak); isEq ? os << " == 0" : os << " >= 0"; } void ModulePrinter::printAffineMap(AffineMap map) { // Dimension identifiers. os << '('; for (int i = 0; i < (int)map.getNumDims() - 1; ++i) os << 'd' << i << ", "; if (map.getNumDims() >= 1) os << 'd' << map.getNumDims() - 1; os << ')'; // Symbolic identifiers. if (map.getNumSymbols() != 0) { os << '['; for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) os << 's' << i << ", "; if (map.getNumSymbols() >= 1) os << 's' << map.getNumSymbols() - 1; os << ']'; } // Result affine expressions. os << " -> ("; interleaveComma(map.getResults(), [&](AffineExpr expr) { printAffineExpr(expr); }); os << ')'; } void ModulePrinter::printIntegerSet(IntegerSet set) { // Dimension identifiers. os << '('; for (unsigned i = 1; i < set.getNumDims(); ++i) os << 'd' << i - 1 << ", "; if (set.getNumDims() >= 1) os << 'd' << set.getNumDims() - 1; os << ')'; // Symbolic identifiers. if (set.getNumSymbols() != 0) { os << '['; for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) os << 's' << i << ", "; if (set.getNumSymbols() >= 1) os << 's' << set.getNumSymbols() - 1; os << ']'; } // Print constraints. os << " : ("; int numConstraints = set.getNumConstraints(); for (int i = 1; i < numConstraints; ++i) { printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); os << ", "; } if (numConstraints >= 1) printAffineConstraint(set.getConstraint(numConstraints - 1), set.isEq(numConstraints - 1)); os << ')'; } //===----------------------------------------------------------------------===// // Operation printing //===----------------------------------------------------------------------===// void ModulePrinter::printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs) { // If there are no attributes, then there is nothing to be done. if (attrs.empty()) return; // Filter out any attributes that shouldn't be included. SmallVector filteredAttrs; for (auto attr : attrs) { // If the caller has requested that this attribute be ignored, then drop it. if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return attr.first.is(elided); })) continue; // Otherwise add it to our filteredAttrs list. filteredAttrs.push_back(attr); } // If there are no attributes left to print after filtering, then we're done. if (filteredAttrs.empty()) return; // Otherwise, print them all out in braces. os << " {"; interleaveComma(filteredAttrs, [&](NamedAttribute attr) { os << attr.first; // Pretty printing elides the attribute value for unit attributes. if (attr.second.isa()) return; os << " = "; printAttribute(attr.second); }); os << '}'; } namespace { // OperationPrinter contains common functionality for printing operations. class OperationPrinter : public ModulePrinter, private OpAsmPrinter { public: OperationPrinter(Operation *op, ModulePrinter &other); OperationPrinter(Region *region, ModulePrinter &other); // Methods to print operations. void print(Operation *op); void print(Block *block, bool printBlockArgs = true, bool printBlockTerminator = true); void printOperation(Operation *op); void printGenericOp(Operation *op) override; // Implement OpAsmPrinter. raw_ostream &getStream() const override { return os; } void printType(Type type) override { ModulePrinter::printType(type); } void printAttribute(Attribute attr) override { ModulePrinter::printAttribute(attr); } void printOperand(Value *value) override { printValueID(value); } void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) override { return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); }; enum { nameSentinel = ~0U }; void printBlockName(Block *block) { auto id = getBlockID(block); if (id != ~0U) os << "^bb" << id; else os << "^INVALIDBLOCK"; } unsigned getBlockID(Block *block) { auto it = blockIDs.find(block); return it != blockIDs.end() ? it->second : ~0U; } void printSuccessorAndUseList(Operation *term, unsigned index) override; /// Print a region. void printRegion(Region &blocks, bool printEntryBlockArgs, bool printBlockTerminators) override { os << " {\n"; if (!blocks.empty()) { auto *entryBlock = &blocks.front(); print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0, printBlockTerminators); for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1)) print(&b); } os.indent(currentIndent) << "}"; } /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. This may only be used for IsolatedFromAbove /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. void shadowRegionArgs(Region ®ion, ArrayRef namesToUse) override; void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ArrayRef operands) override { AffineMap map = mapAttr.getValue(); unsigned numDims = map.getNumDims(); auto printValueName = [&](unsigned pos, bool isSymbol) { unsigned index = isSymbol ? numDims + pos : pos; assert(index < operands.size()); if (isSymbol) os << "symbol("; printValueID(operands[index]); if (isSymbol) os << ')'; }; interleaveComma(map.getResults(), [&](AffineExpr expr) { printAffineExpr(expr, printValueName); }); } // Number of spaces used for indenting nested operations. const static unsigned indentWidth = 2; protected: void numberValueID(Value *value); void numberValuesInRegion(Region ®ion); void numberValuesInBlock(Block &block); void printValueID(Value *value, bool printResultNo = true) const { printValueIDImpl(value, printResultNo, os); } private: void printValueIDImpl(Value *value, bool printResultNo, raw_ostream &stream) const; /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. StringRef uniqueValueName(StringRef name); /// This is the value ID for each SSA value. If this returns ~0, then the /// valueID has an entry in valueNames. DenseMap valueIDs; DenseMap valueNames; /// This is the block ID for each block in the current. DenseMap blockIDs; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. /// Note: the value of the map is unused. llvm::ScopedHashTable usedNames; llvm::BumpPtrAllocator usedNameAllocator; // This is the current indentation level for nested structures. unsigned currentIndent = 0; /// This is the next value ID to assign in numbering. unsigned nextValueID = 0; /// This is the next ID to assign to a region entry block argument. unsigned nextArgumentID = 0; /// This is the next ID to assign when a name conflict is detected. unsigned nextConflictID = 0; }; } // end anonymous namespace OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other) : ModulePrinter(other) { if (op->getNumResults() != 0) numberValueID(op->getResult(0)); for (auto ®ion : op->getRegions()) numberValuesInRegion(region); } OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other) : ModulePrinter(other) { numberValuesInRegion(*region); } /// Number all of the SSA values in the specified region. void OperationPrinter::numberValuesInRegion(Region ®ion) { // Save the current value ids to allow for numbering values in sibling regions // the same. unsigned curValueID = nextValueID; unsigned curArgumentID = nextArgumentID; unsigned curConflictID = nextConflictID; // Push a new used names scope. llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); // Number the values within this region in a breadth-first order. unsigned nextBlockID = 0; for (auto &block : region) { // Each block gets a unique ID, and all of the operations within it get // numbered as well. blockIDs[&block] = nextBlockID++; numberValuesInBlock(block); } // After that we traverse the nested regions. // TODO: Rework this loop to not use recursion. for (auto &block : region) { for (auto &op : block) for (auto &nestedRegion : op.getRegions()) numberValuesInRegion(nestedRegion); } // Restore the original value ids. nextValueID = curValueID; nextArgumentID = curArgumentID; nextConflictID = curConflictID; } /// Number all of the SSA values in the specified block, without traversing /// nested regions. void OperationPrinter::numberValuesInBlock(Block &block) { // Number the block arguments. for (auto *arg : block.getArguments()) numberValueID(arg); // We number operation that have results, and we only number the first result. for (auto &op : block) if (op.getNumResults() != 0) numberValueID(op.getResult(0)); } void OperationPrinter::numberValueID(Value *value) { assert(!valueIDs.count(value) && "Value numbered multiple times"); SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); // Check to see if this value requested a special name. auto *op = value->getDefiningOp(); if (state && op) { if (auto *interface = state->getOpAsmInterface(op->getDialect())) interface->getOpResultName(op, specialName); } if (specialNameBuffer.empty()) { switch (value->getKind()) { case Value::Kind::BlockArgument: // If this is an argument to the entry block of a region, give it an 'arg' // name. if (auto *block = cast(value)->getOwner()) { auto *parentRegion = block->getParent(); if (parentRegion && block == &parentRegion->front()) { specialName << "arg" << nextArgumentID++; break; } } // Otherwise number it normally. valueIDs[value] = nextValueID++; return; case Value::Kind::OpResult: // This is an uninteresting result, give it a boring number and be // done with it. valueIDs[value] = nextValueID++; return; } } // Ok, this value had an interesting name. Remember it with a sentinel. valueIDs[value] = nameSentinel; valueNames[value] = uniqueValueName(specialName.str()); } /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. StringRef OperationPrinter::uniqueValueName(StringRef name) { // Check to see if this name is already unique. if (!usedNames.count(name)) { name = name.copy(usedNameAllocator); } else { // Otherwise, we had a conflict - probe until we find a unique name. This // is guaranteed to terminate (and usually in a single iteration) because it // generates new names by incrementing nextConflictID. SmallString<64> probeName(name); probeName.push_back('_'); while (true) { probeName.resize(name.size() + 1); probeName += llvm::utostr(nextConflictID++); if (!usedNames.count(probeName)) { name = StringRef(probeName).copy(usedNameAllocator); break; } } } usedNames.insert(name, char()); return name; } void OperationPrinter::print(Block *block, bool printBlockArgs, bool printBlockTerminator) { // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); // Print the argument list if non-empty. if (!block->args_empty()) { os << '('; interleaveComma(block->getArguments(), [&](BlockArgument *arg) { printValueID(arg); os << ": "; printType(arg->getType()); }); os << ')'; } os << ':'; // Print out some context information about the predecessors of this block. if (!block->getParent()) { os << "\t// block is not in a region!"; } else if (block->hasNoPredecessors()) { os << "\t// no predecessors"; } else if (auto *pred = block->getSinglePredecessor()) { os << "\t// pred: "; printBlockName(pred); } else { // We want to print the predecessors in increasing numeric order, not in // whatever order the use-list is in, so gather and sort them. SmallVector, 4> predIDs; for (auto *pred : block->getPredecessors()) predIDs.push_back({getBlockID(pred), pred}); llvm::array_pod_sort(predIDs.begin(), predIDs.end()); os << "\t// " << predIDs.size() << " preds: "; interleaveComma(predIDs, [&](std::pair pred) { printBlockName(pred.second); }); } os << '\n'; } currentIndent += indentWidth; auto range = llvm::make_range( block->getOperations().begin(), std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); for (auto &op : range) { print(&op); os << '\n'; } currentIndent -= indentWidth; } void OperationPrinter::print(Operation *op) { os.indent(currentIndent); printOperation(op); printTrailingLocation(op->getLoc()); } void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, raw_ostream &stream) const { if (!value) { stream << "<>"; return; } int resultNo = -1; auto lookupValue = value; // If this is a reference to the result of a multi-result operation or // operation, print out the # identifier and make sure to map our lookup // to the first result of the operation. if (auto *result = dyn_cast(value)) { if (result->getOwner()->getNumResults() != 1) { resultNo = result->getResultNumber(); lookupValue = result->getOwner()->getResult(0); } } auto it = valueIDs.find(lookupValue); if (it == valueIDs.end()) { stream << "<>"; return; } stream << '%'; if (it->second != nameSentinel) { stream << it->second; } else { auto nameIt = valueNames.find(lookupValue); assert(nameIt != valueNames.end() && "Didn't have a name entry?"); stream << nameIt->second; } if (resultNo != -1 && printResultNo) stream << '#' << resultNo; } /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. This may only be used for IsolatedFromAbove /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. void OperationPrinter::shadowRegionArgs(Region ®ion, ArrayRef namesToUse) { assert(!region.empty() && "cannot shadow arguments of an empty region"); assert(region.front().getNumArguments() == namesToUse.size() && "incorrect number of names passed in"); assert(region.getParentOp()->isKnownIsolatedFromAbove() && "only KnownIsolatedFromAbove ops can shadow names"); SmallVector nameStr; for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { auto *nameToUse = namesToUse[i]; if (nameToUse == nullptr) continue; auto *nameToReplace = region.front().getArgument(i); nameStr.clear(); llvm::raw_svector_ostream nameStream(nameStr); printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream); // Entry block arguments should already have a pretty "arg" name. assert(valueIDs[nameToReplace] == nameSentinel); // Use the name without the leading %. auto name = StringRef(nameStream.str()).drop_front(); // Overwrite the name. valueNames[nameToReplace] = name.copy(usedNameAllocator); } } void OperationPrinter::printOperation(Operation *op) { if (size_t numResults = op->getNumResults()) { printValueID(op->getResult(0), /*printResultNo=*/false); if (numResults > 1) os << ':' << numResults; os << " = "; } // TODO(riverriddle): FuncOp cannot be round-tripped currently, as // FunctionType cannot be used in a TypeAttr. if (printGenericOpForm && !isa(op)) return printGenericOp(op); // Check to see if this is a known operation. If so, use the registered // custom printer hook. if (auto *opInfo = op->getAbstractOperation()) { opInfo->printAssembly(op, *this); return; } // Otherwise print with the generic assembly form. printGenericOp(op); } void OperationPrinter::printGenericOp(Operation *op) { os << '"'; printEscapedString(op->getName().getStringRef(), os); os << "\"("; // Get the list of operands that are not successor operands. unsigned totalNumSuccessorOperands = 0; unsigned numSuccessors = op->getNumSuccessors(); for (unsigned i = 0; i < numSuccessors; ++i) totalNumSuccessorOperands += op->getNumSuccessorOperands(i); unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; SmallVector properOperands( op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); interleaveComma(properOperands, [&](Value *value) { printValueID(value); }); os << ')'; // For terminators, print the list of successors and their operands. if (numSuccessors != 0) { os << '['; for (unsigned i = 0; i < numSuccessors; ++i) { if (i != 0) os << ", "; printSuccessorAndUseList(op, i); } os << ']'; } // Print regions. if (op->getNumRegions() != 0) { os << " ("; interleaveComma(op->getRegions(), [&](Region ®ion) { printRegion(region, /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); }); os << ')'; } auto attrs = op->getAttrs(); printOptionalAttrDict(attrs); // Print the type signature of the operation. os << " : "; printFunctionalType(op); } void OperationPrinter::printSuccessorAndUseList(Operation *term, unsigned index) { printBlockName(term->getSuccessor(index)); auto succOperands = term->getSuccessorOperands(index); if (succOperands.begin() == succOperands.end()) return; os << '('; interleaveComma(succOperands, [this](Value *operand) { printValueID(operand); }); os << " : "; interleaveComma(succOperands, [this](Value *operand) { printType(operand->getType()); }); os << ')'; } void ModulePrinter::print(ModuleOp module) { // Output the aliases at the top level. if (state) { state->printAttributeAliases(os); state->printTypeAliases(os); } // Print the module. OperationPrinter(module, *this).print(module); os << '\n'; } //===----------------------------------------------------------------------===// // print and dump methods //===----------------------------------------------------------------------===// void Attribute::print(raw_ostream &os) const { ModulePrinter(os).printAttribute(*this); } void Attribute::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); } void Type::dump() { print(llvm::errs()); } void AffineMap::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void IntegerSet::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void AffineExpr::print(raw_ostream &os) const { if (expr == nullptr) { os << "null affine expr"; return; } ModulePrinter(os).printAffineExpr(*this); } void AffineExpr::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void AffineMap::print(raw_ostream &os) const { if (map == nullptr) { os << "null affine map"; return; } ModulePrinter(os).printAffineMap(*this); } void IntegerSet::print(raw_ostream &os) const { ModulePrinter(os).printIntegerSet(*this); } void Value::print(raw_ostream &os) { switch (getKind()) { case Value::Kind::BlockArgument: // TODO: Improve this. os << "\n"; return; case Value::Kind::OpResult: return getDefiningOp()->print(os); } } void Value::dump() { print(llvm::errs()); llvm::errs() << "\n"; } void Operation::print(raw_ostream &os) { // Handle top-level operations. if (!getParent()) { ModulePrinter modulePrinter(os); OperationPrinter(this, modulePrinter).print(this); return; } auto region = getParentRegion(); if (!region) { os << "<>\n"; return; } // Get the top-level region. while (auto *nextRegion = region->getParentRegion()) region = nextRegion; ModuleState state(getContext()); ModulePrinter modulePrinter(os, &state); OperationPrinter(region, modulePrinter).print(this); } void Operation::dump() { print(llvm::errs()); llvm::errs() << "\n"; } void Block::print(raw_ostream &os) { auto region = getParent(); if (!region) { os << "<>\n"; return; } // Get the top-level region. while (auto *nextRegion = region->getParentRegion()) region = nextRegion; ModuleState state(region->getContext()); ModulePrinter modulePrinter(os, &state); OperationPrinter(region, modulePrinter).print(this); } void Block::dump() { print(llvm::errs()); } /// Print out the name of the block without printing its body. void Block::printAsOperand(raw_ostream &os, bool printType) { auto region = getParent(); if (!region) { os << "<>\n"; return; } // Get the top-level region. while (auto *nextRegion = region->getParentRegion()) region = nextRegion; ModulePrinter modulePrinter(os); OperationPrinter(region, modulePrinter).printBlockName(this); } void ModuleOp::print(raw_ostream &os) { ModuleState state(getContext()); state.initialize(*this); ModulePrinter(os, &state).print(*this); } void ModuleOp::dump() { print(llvm::errs()); }