diff options
| author | River Riddle <riverriddle@google.com> | 2020-01-09 12:39:26 -0800 |
|---|---|---|
| committer | River Riddle <riverriddle@google.com> | 2020-01-09 12:48:35 -0800 |
| commit | fc3367dd5ed4698036c421b23cf4f52cf8aedcae (patch) | |
| tree | 18c98942fc9ebf1fa765f11fde93c5738ce5bfd2 /mlir | |
| parent | 646ca7d7e72e8408b3fa3472018eb9d1c2643ff5 (diff) | |
| download | bcm5719-llvm-fc3367dd5ed4698036c421b23cf4f52cf8aedcae.tar.gz bcm5719-llvm-fc3367dd5ed4698036c421b23cf4f52cf8aedcae.zip | |
[mlir] NFC: Move the state for managing SSA value names out of OperationPrinter and into a new class SSANameState.
Summary:
This reduces the complexity of OperationPrinter and simplifies the code by quite a bit. The SSANameState is now held by ModuleState. This is in preparation for a future revision that molds ModuleState into something that can be used by users for caching the printer state, as well as for implementing printAsOperand style methods.
Depends On D72292
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D72293
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 1081 |
1 files changed, 571 insertions, 510 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2eb436772e9..6c58d74e2ee 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Regex.h" +#include "llvm/Support/SaveAndRestore.h" using namespace mlir; void Identifier::print(raw_ostream &os) const { os << str(); } @@ -412,13 +413,357 @@ void AliasState::visitOperation(Operation *op) { } //===----------------------------------------------------------------------===// +// SSANameState +//===----------------------------------------------------------------------===// + +namespace { +/// This class manages the state of SSA value names. +class SSANameState { +public: + /// A sentinal value used for values with names set. + enum : unsigned { NameSentinel = ~0U }; + + SSANameState(Operation *op, + DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); + + /// Print the SSA identifier for the given value to 'stream'. If + /// 'printResultNo' is true, it also presents the result number ('#' number) + /// of this value. + void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; + + /// Return the result indices for each of the result groups registered by this + /// operation, or empty if none exist. + ArrayRef<int> getOpResultGroups(Operation *op); + + /// Get the ID for the given block. + unsigned getBlockID(Block *block); + + /// Renumber the arguments for the specified region to the same names as the + /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for + /// details. + void shadowRegionArgs(Region ®ion, ValueRange namesToUse); + +private: + /// Number the SSA values within the given IR unit. + void numberValuesInRegion( + Region ®ion, + DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); + void numberValuesInBlock( + Block &block, + DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); + void numberValuesInOp( + Operation &op, + DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); + + /// Given a result of an operation 'result', find the result group head + /// 'lookupValue' and the result of 'result' within that group in + /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group + /// has more than 1 result. + void getResultIDAndNumber(OpResult result, Value &lookupValue, + Optional<int> &lookupResultNo) const; + + /// Set a special value name for the given value. + void setValueName(Value value, StringRef name); + + /// Uniques the given value name within the printer. If the given name + /// conflicts, it is automatically renamed. + StringRef uniqueValueName(StringRef name); + + /// This is the value ID for each SSA value. If this returns NameSentinel, + /// then the valueID has an entry in valueNames. + DenseMap<Value, unsigned> valueIDs; + DenseMap<Value, StringRef> valueNames; + + /// This is a map of operations that contain multiple named result groups, + /// i.e. there may be multiple names for the results of the operation. The + /// value of this map are the result numbers that start a result group. + DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; + + /// This is the block ID for each block in the current. + DenseMap<Block *, unsigned> blockIDs; + + /// This keeps track of all of the non-numeric names that are in flight, + /// allowing us to check for duplicates. + /// Note: the value of the map is unused. + llvm::ScopedHashTable<StringRef, char> usedNames; + llvm::BumpPtrAllocator usedNameAllocator; + + /// This is the next value ID to assign in numbering. + unsigned nextValueID = 0; + /// This is the next ID to assign to a region entry block argument. + unsigned nextArgumentID = 0; + /// This is the next ID to assign when a name conflict is detected. + unsigned nextConflictID = 0; +}; +} // end anonymous namespace + +SSANameState::SSANameState( + Operation *op, + DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { + llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); + numberValuesInOp(*op, interfaces); + + for (auto ®ion : op->getRegions()) + numberValuesInRegion(region, interfaces); +} + +void SSANameState::printValueID(Value value, bool printResultNo, + raw_ostream &stream) const { + if (!value) { + stream << "<<NULL>>"; + return; + } + + Optional<int> resultNo; + auto lookupValue = value; + + // If this is an operation result, collect the head lookup value of the result + // group and the result number of 'result' within that group. + if (OpResult result = value.dyn_cast<OpResult>()) + getResultIDAndNumber(result, lookupValue, resultNo); + + auto it = valueIDs.find(lookupValue); + if (it == valueIDs.end()) { + stream << "<<UNKNOWN SSA VALUE>>"; + return; + } + + stream << '%'; + if (it->second != NameSentinel) { + stream << it->second; + } else { + auto nameIt = valueNames.find(lookupValue); + assert(nameIt != valueNames.end() && "Didn't have a name entry?"); + stream << nameIt->second; + } + + if (resultNo.hasValue() && printResultNo) + stream << '#' << resultNo; +} + +ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { + auto it = opResultGroups.find(op); + return it == opResultGroups.end() ? ArrayRef<int>() : it->second; +} + +unsigned SSANameState::getBlockID(Block *block) { + auto it = blockIDs.find(block); + return it != blockIDs.end() ? it->second : NameSentinel; +} + +void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { + assert(!region.empty() && "cannot shadow arguments of an empty region"); + assert(region.front().getNumArguments() == namesToUse.size() && + "incorrect number of names passed in"); + assert(region.getParentOp()->isKnownIsolatedFromAbove() && + "only KnownIsolatedFromAbove ops can shadow names"); + + SmallVector<char, 16> nameStr; + for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { + auto nameToUse = namesToUse[i]; + if (nameToUse == nullptr) + continue; + auto nameToReplace = region.front().getArgument(i); + + nameStr.clear(); + llvm::raw_svector_ostream nameStream(nameStr); + printValueID(nameToUse, /*printResultNo=*/true, nameStream); + + // Entry block arguments should already have a pretty "arg" name. + assert(valueIDs[nameToReplace] == NameSentinel); + + // Use the name without the leading %. + auto name = StringRef(nameStream.str()).drop_front(); + + // Overwrite the name. + valueNames[nameToReplace] = name.copy(usedNameAllocator); + } +} + +void SSANameState::numberValuesInRegion( + Region ®ion, + DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { + // Save the current value ids to allow for numbering values in sibling regions + // the same. + llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID); + llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID); + llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID); + + // Push a new used names scope. + llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); + + // Number the values within this region in a breadth-first order. + unsigned nextBlockID = 0; + for (auto &block : region) { + // Each block gets a unique ID, and all of the operations within it get + // numbered as well. + blockIDs[&block] = nextBlockID++; + numberValuesInBlock(block, interfaces); + } + + // After that we traverse the nested regions. + // TODO: Rework this loop to not use recursion. + for (auto &block : region) { + for (auto &op : block) + for (auto &nestedRegion : op.getRegions()) + numberValuesInRegion(nestedRegion, interfaces); + } +} + +void SSANameState::numberValuesInBlock( + Block &block, + DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { + auto setArgNameFn = [&](Value arg, StringRef name) { + assert(!valueIDs.count(arg) && "arg numbered multiple times"); + assert(arg.cast<BlockArgument>()->getOwner() == &block && + "arg not defined in 'block'"); + setValueName(arg, name); + }; + + bool isEntryBlock = block.isEntryBlock(); + if (isEntryBlock) { + if (auto *op = block.getParentOp()) { + if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect())) + asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); + } + } + + // Number the block arguments. We give entry block arguments a special name + // 'arg'. + SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); + llvm::raw_svector_ostream specialName(specialNameBuffer); + for (auto arg : block.getArguments()) { + if (valueIDs.count(arg)) + continue; + if (isEntryBlock) { + specialNameBuffer.resize(strlen("arg")); + specialName << nextArgumentID++; + } + setValueName(arg, specialName.str()); + } + + // Number the operations in this block. + for (auto &op : block) + numberValuesInOp(op, interfaces); +} + +void SSANameState::numberValuesInOp( + Operation &op, + DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { + unsigned numResults = op.getNumResults(); + if (numResults == 0) + return; + Value resultBegin = op.getResult(0); + + // Function used to set the special result names for the operation. + SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); + auto setResultNameFn = [&](Value result, StringRef name) { + assert(!valueIDs.count(result) && "result numbered multiple times"); + assert(result->getDefiningOp() == &op && "result not defined by 'op'"); + setValueName(result, name); + + // Record the result number for groups not anchored at 0. + if (int resultNo = result.cast<OpResult>()->getResultNumber()) + resultGroups.push_back(resultNo); + }; + if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) + asmInterface.getAsmResultNames(setResultNameFn); + else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect())) + asmInterface->getAsmResultNames(&op, setResultNameFn); + + // If the first result wasn't numbered, give it a default number. + if (valueIDs.try_emplace(resultBegin, nextValueID).second) + ++nextValueID; + + // If this operation has multiple result groups, mark it. + if (resultGroups.size() != 1) { + llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); + opResultGroups.try_emplace(&op, std::move(resultGroups)); + } +} + +void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, + Optional<int> &lookupResultNo) const { + Operation *owner = result->getOwner(); + if (owner->getNumResults() == 1) + return; + int resultNo = result->getResultNumber(); + + // If this operation has multiple result groups, we will need to find the + // one corresponding to this result. + auto resultGroupIt = opResultGroups.find(owner); + if (resultGroupIt == opResultGroups.end()) { + // If not, just use the first result. + lookupResultNo = resultNo; + lookupValue = owner->getResult(0); + return; + } + + // Find the correct index using a binary search, as the groups are ordered. + ArrayRef<int> resultGroups = resultGroupIt->second; + auto it = llvm::upper_bound(resultGroups, resultNo); + int groupResultNo = 0, groupSize = 0; + + // If there are no smaller elements, the last result group is the lookup. + if (it == resultGroups.end()) { + groupResultNo = resultGroups.back(); + groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); + } else { + // Otherwise, the previous element is the lookup. + groupResultNo = *std::prev(it); + groupSize = *it - groupResultNo; + } + + // We only record the result number for a group of size greater than 1. + if (groupSize != 1) + lookupResultNo = resultNo - groupResultNo; + lookupValue = owner->getResult(groupResultNo); +} + +void SSANameState::setValueName(Value value, StringRef name) { + // If the name is empty, the value uses the default numbering. + if (name.empty()) { + valueIDs[value] = nextValueID++; + return; + } + + valueIDs[value] = NameSentinel; + valueNames[value] = uniqueValueName(name); +} + +StringRef SSANameState::uniqueValueName(StringRef name) { + // Check to see if this name is already unique. + if (!usedNames.count(name)) { + name = name.copy(usedNameAllocator); + } else { + // Otherwise, we had a conflict - probe until we find a unique name. This + // is guaranteed to terminate (and usually in a single iteration) because it + // generates new names by incrementing nextConflictID. + SmallString<64> probeName(name); + probeName.push_back('_'); + while (true) { + probeName.resize(name.size() + 1); + probeName += llvm::utostr(nextConflictID++); + if (!usedNames.count(probeName)) { + name = StringRef(probeName).copy(usedNameAllocator); + break; + } + } + } + + usedNames.insert(name, char()); + return name; +} + +//===----------------------------------------------------------------------===// // ModuleState //===----------------------------------------------------------------------===// namespace { class ModuleState { public: - explicit ModuleState(MLIRContext *context) : interfaces(context) {} + explicit ModuleState(Operation *op) + : interfaces(op->getContext()), nameState(op, interfaces) {} /// Initialize the alias state to enable the printing of aliases. void initializeAliases(Operation *op) { @@ -434,12 +779,18 @@ public: /// Get the state used for aliases. AliasState &getAliasState() { return aliasState; } + /// Get the state used for SSA names. + SSANameState &getSSANameState() { return nameState; } + private: /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection<OpAsmDialectInterface> interfaces; /// The state used for attribute and type aliases. AliasState aliasState; + + /// The state used for SSA value names. + SSANameState nameState; }; } // end anonymous namespace @@ -1150,6 +1501,42 @@ void ModulePrinter::printType(Type type) { } } +void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, + ArrayRef<StringRef> elidedAttrs, + bool withKeyword) { + // If there are no attributes, then there is nothing to be done. + if (attrs.empty()) + return; + + // Filter out any attributes that shouldn't be included. + SmallVector<NamedAttribute, 8> filteredAttrs( + llvm::make_filter_range(attrs, [&](NamedAttribute attr) { + return !llvm::is_contained(elidedAttrs, attr.first.strref()); + })); + + // If there are no attributes left to print after filtering, then we're done. + if (filteredAttrs.empty()) + return; + + // Print the 'attributes' keyword if necessary. + if (withKeyword) + os << " attributes"; + + // Otherwise, print them all out in braces. + os << " {"; + interleaveComma(filteredAttrs, [&](NamedAttribute attr) { + os << attr.first; + + // Pretty printing elides the attribute value for unit attributes. + if (attr.second.isa<UnitAttr>()) + return; + + os << " = "; + printAttribute(attr.second); + }); + os << '}'; +} + //===----------------------------------------------------------------------===// // CustomDialectAsmPrinter //===----------------------------------------------------------------------===// @@ -1415,69 +1802,55 @@ void ModulePrinter::printIntegerSet(IntegerSet set) { } //===----------------------------------------------------------------------===// -// Operation printing +// OperationPrinter //===----------------------------------------------------------------------===// -void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, - ArrayRef<StringRef> elidedAttrs, - bool withKeyword) { - // If there are no attributes, then there is nothing to be done. - if (attrs.empty()) - return; - - // Filter out any attributes that shouldn't be included. - SmallVector<NamedAttribute, 8> filteredAttrs( - llvm::make_filter_range(attrs, [&](NamedAttribute attr) { - return !llvm::is_contained(elidedAttrs, attr.first.strref()); - })); - - // If there are no attributes left to print after filtering, then we're done. - if (filteredAttrs.empty()) - return; - - // Print the 'attributes' keyword if necessary. - if (withKeyword) - os << " attributes"; - - // Otherwise, print them all out in braces. - os << " {"; - interleaveComma(filteredAttrs, [&](NamedAttribute attr) { - os << attr.first; - - // Pretty printing elides the attribute value for unit attributes. - if (attr.second.isa<UnitAttr>()) - return; - - os << " = "; - printAttribute(attr.second); - }); - os << '}'; -} - namespace { - -// OperationPrinter contains common functionality for printing operations. +/// This class contains the logic for printing operations, regions, and blocks. class OperationPrinter : public ModulePrinter, private OpAsmPrinter { public: - OperationPrinter(Operation *op, ModulePrinter &other); - OperationPrinter(Region *region, ModulePrinter &other); + explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) { + assert(state && "expected valid state when printing operation"); + } - // Methods to print operations. + /// Print the given operation with its indent and location. void print(Operation *op); + /// Print the bare location, not including indentation/location/etc. + void printOperation(Operation *op); + /// Print the given operation in the generic form. + void printGenericOp(Operation *op) override; + + /// Print the name of the given block. + void printBlockName(Block *block); + + /// Print the given block. If 'printBlockArgs' is false, the arguments of the + /// block are not printed. If 'printBlockTerminator' is false, the terminator + /// operation of the block is not printed. void print(Block *block, bool printBlockArgs = true, bool printBlockTerminator = true); - void printOperation(Operation *op); - void printGenericOp(Operation *op) override; + /// Print the ID of the given value, optionally with its result number. + void printValueID(Value value, bool printResultNo = true) const; + + //===--------------------------------------------------------------------===// + // OpAsmPrinter methods + //===--------------------------------------------------------------------===// - // Implement OpAsmPrinter. + /// Return the current stream of the printer. raw_ostream &getStream() const override { return os; } + + /// Print the given type. void printType(Type type) override { ModulePrinter::printType(type); } + + /// Print the given attribute. void printAttribute(Attribute attr) override { ModulePrinter::printAttribute(attr); } + + /// Print the ID for the given value. void printOperand(Value value) override { printValueID(value); } + /// Print an optional attribute dictionary with a given set of elided values. void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {}) override { ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); @@ -1489,451 +1862,47 @@ public: /*withKeyword=*/true); } - enum { nameSentinel = ~0U }; - - void printBlockName(Block *block) { - auto id = getBlockID(block); - if (id != ~0U) - os << "^bb" << id; - else - os << "^INVALIDBLOCK"; - } - - unsigned getBlockID(Block *block) { - auto it = blockIDs.find(block); - return it != blockIDs.end() ? it->second : ~0U; - } - + /// Print an operation successor with the operands used for the block + /// arguments. void printSuccessorAndUseList(Operation *term, unsigned index) override; - /// Print a region. - void printRegion(Region &blocks, bool printEntryBlockArgs, - bool printBlockTerminators) override { - os << " {\n"; - if (!blocks.empty()) { - auto *entryBlock = &blocks.front(); - print(entryBlock, - printEntryBlockArgs && entryBlock->getNumArguments() != 0, - printBlockTerminators); - for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1)) - print(&b); - } - os.indent(currentIndent) << "}"; - } + /// Print the given region. + void printRegion(Region ®ion, bool printEntryBlockArgs, + bool printBlockTerminators) override; /// Renumber the arguments for the specified region to the same names as the - /// SSA values in namesToUse. This may only be used for IsolatedFromAbove - /// operations. If any entry in namesToUse is null, the corresponding + /// SSA values in namesToUse. This may only be used for IsolatedFromAbove + /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. - void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override; + void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { + state->getSSANameState().shadowRegionArgs(region, namesToUse); + } + /// Print the given affine map with the smybol and dimension operands printed + /// inline with the map. void printAffineMapOfSSAIds(AffineMapAttr mapAttr, - ValueRange operands) override { - AffineMap map = mapAttr.getValue(); - unsigned numDims = map.getNumDims(); - auto printValueName = [&](unsigned pos, bool isSymbol) { - unsigned index = isSymbol ? numDims + pos : pos; - assert(index < operands.size()); - if (isSymbol) - os << "symbol("; - printValueID(operands[index]); - if (isSymbol) - os << ')'; - }; - - interleaveComma(map.getResults(), [&](AffineExpr expr) { - printAffineExpr(expr, printValueName); - }); - } + ValueRange operands) override; /// Print the given string as a symbol reference. void printSymbolName(StringRef symbolRef) override { ::printSymbolReference(symbolRef, os); } - // Number of spaces used for indenting nested operations. - const static unsigned indentWidth = 2; - -protected: - void numberValuesInRegion(Region ®ion); - void numberValuesInBlock(Block &block); - void numberValuesInOp(Operation &op); - void printValueID(Value value, bool printResultNo = true) const { - printValueIDImpl(value, printResultNo, os); - } - private: - /// Given a result of an operation 'result', find the result group head - /// 'lookupValue' and the result of 'result' within that group in - /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group - /// has more than 1 result. - void getResultIDAndNumber(OpResult result, Value &lookupValue, - int &lookupResultNo) const; - void printValueIDImpl(Value value, bool printResultNo, - raw_ostream &stream) const; - - /// Set a special value name for the given value. - void setValueName(Value value, StringRef name); - - /// Uniques the given value name within the printer. If the given name - /// conflicts, it is automatically renamed. - StringRef uniqueValueName(StringRef name); - - /// This is the value ID for each SSA value. If this returns ~0, then the - /// valueID has an entry in valueNames. - DenseMap<Value, unsigned> valueIDs; - DenseMap<Value, StringRef> valueNames; - - /// This is a map of operations that contain multiple named result groups, - /// i.e. there may be multiple names for the results of the operation. The key - /// of this map are the result numbers that start a result group. - DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; - - /// This is the block ID for each block in the current. - DenseMap<Block *, unsigned> blockIDs; - - /// This keeps track of all of the non-numeric names that are in flight, - /// allowing us to check for duplicates. - /// Note: the value of the map is unused. - llvm::ScopedHashTable<StringRef, char> usedNames; - llvm::BumpPtrAllocator usedNameAllocator; + /// The number of spaces used for indenting nested operations. + const static unsigned indentWidth = 2; // This is the current indentation level for nested structures. unsigned currentIndent = 0; - - /// This is the next value ID to assign in numbering. - unsigned nextValueID = 0; - /// This is the next ID to assign to a region entry block argument. - unsigned nextArgumentID = 0; - /// This is the next ID to assign when a name conflict is detected. - unsigned nextConflictID = 0; }; } // end anonymous namespace -OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other) - : ModulePrinter(other) { - llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); - numberValuesInOp(*op); - - for (auto ®ion : op->getRegions()) - numberValuesInRegion(region); -} - -OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other) - : ModulePrinter(other) { - numberValuesInRegion(*region); -} - -void OperationPrinter::numberValuesInRegion(Region ®ion) { - // Save the current value ids to allow for numbering values in sibling regions - // the same. - unsigned curValueID = nextValueID; - unsigned curArgumentID = nextArgumentID; - unsigned curConflictID = nextConflictID; - - // Push a new used names scope. - llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); - - // Number the values within this region in a breadth-first order. - unsigned nextBlockID = 0; - for (auto &block : region) { - // Each block gets a unique ID, and all of the operations within it get - // numbered as well. - blockIDs[&block] = nextBlockID++; - numberValuesInBlock(block); - } - - // After that we traverse the nested regions. - // TODO: Rework this loop to not use recursion. - for (auto &block : region) { - for (auto &op : block) - for (auto &nestedRegion : op.getRegions()) - numberValuesInRegion(nestedRegion); - } - - // Restore the original value ids. - nextValueID = curValueID; - nextArgumentID = curArgumentID; - nextConflictID = curConflictID; -} - -void OperationPrinter::numberValuesInBlock(Block &block) { - auto setArgNameFn = [&](Value arg, StringRef name) { - assert(!valueIDs.count(arg) && "arg numbered multiple times"); - assert(arg.cast<BlockArgument>()->getOwner() == &block && - "arg not defined in 'block'"); - setValueName(arg, name); - }; - - bool isEntryBlock = block.isEntryBlock(); - if (isEntryBlock && state) { - if (auto *op = block.getParentOp()) { - if (auto dialectAsmInterface = state->getOpAsmInterface(op->getDialect())) - dialectAsmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); - } - } - - // Number the block arguments. We give entry block arguments a special name - // 'arg'. - SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); - llvm::raw_svector_ostream specialName(specialNameBuffer); - for (auto arg : block.getArguments()) { - if (valueIDs.count(arg)) - continue; - if (isEntryBlock) { - specialNameBuffer.resize(strlen("arg")); - specialName << nextArgumentID++; - } - setValueName(arg, specialName.str()); - } - - // Number the operations in this block. - for (auto &op : block) - numberValuesInOp(op); -} - -void OperationPrinter::numberValuesInOp(Operation &op) { - unsigned numResults = op.getNumResults(); - if (numResults == 0) - return; - Value resultBegin = op.getResult(0); - - // Function used to set the special result names for the operation. - SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); - auto setResultNameFn = [&](Value result, StringRef name) { - assert(!valueIDs.count(result) && "result numbered multiple times"); - assert(result->getDefiningOp() == &op && "result not defined by 'op'"); - setValueName(result, name); - - // Record the result number for groups not anchored at 0. - if (int resultNo = result.cast<OpResult>()->getResultNumber()) - resultGroups.push_back(resultNo); - }; - - if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { - asmInterface.getAsmResultNames(setResultNameFn); - } else if (auto *dialectAsmInterface = - state ? state->getOpAsmInterface(op.getDialect()) : nullptr) { - dialectAsmInterface->getAsmResultNames(&op, setResultNameFn); - } - - // If the first result wasn't numbered, give it a default number. - if (valueIDs.try_emplace(resultBegin, nextValueID).second) - ++nextValueID; - - // If this operation has multiple result groups, mark it. - if (resultGroups.size() != 1) { - llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); - opResultGroups.try_emplace(&op, std::move(resultGroups)); - } -} - -/// Set a special value name for the given value. -void OperationPrinter::setValueName(Value value, StringRef name) { - // If the name is empty, the value uses the default numbering. - if (name.empty()) { - valueIDs[value] = nextValueID++; - return; - } - - valueIDs[value] = nameSentinel; - valueNames[value] = uniqueValueName(name); -} - -/// Uniques the given value name within the printer. If the given name -/// conflicts, it is automatically renamed. -StringRef OperationPrinter::uniqueValueName(StringRef name) { - // Check to see if this name is already unique. - if (!usedNames.count(name)) { - name = name.copy(usedNameAllocator); - } else { - // Otherwise, we had a conflict - probe until we find a unique name. This - // is guaranteed to terminate (and usually in a single iteration) because it - // generates new names by incrementing nextConflictID. - SmallString<64> probeName(name); - probeName.push_back('_'); - while (true) { - probeName.resize(name.size() + 1); - probeName += llvm::utostr(nextConflictID++); - if (!usedNames.count(probeName)) { - name = StringRef(probeName).copy(usedNameAllocator); - break; - } - } - } - - usedNames.insert(name, char()); - return name; -} - -void OperationPrinter::print(Block *block, bool printBlockArgs, - bool printBlockTerminator) { - // Print the block label and argument list if requested. - if (printBlockArgs) { - os.indent(currentIndent); - printBlockName(block); - - // Print the argument list if non-empty. - if (!block->args_empty()) { - os << '('; - interleaveComma(block->getArguments(), [&](BlockArgument arg) { - printValueID(arg); - os << ": "; - printType(arg->getType()); - }); - os << ')'; - } - os << ':'; - - // Print out some context information about the predecessors of this block. - if (!block->getParent()) { - os << "\t// block is not in a region!"; - } else if (block->hasNoPredecessors()) { - os << "\t// no predecessors"; - } else if (auto *pred = block->getSinglePredecessor()) { - os << "\t// pred: "; - printBlockName(pred); - } else { - // We want to print the predecessors in increasing numeric order, not in - // whatever order the use-list is in, so gather and sort them. - SmallVector<std::pair<unsigned, Block *>, 4> predIDs; - for (auto *pred : block->getPredecessors()) - predIDs.push_back({getBlockID(pred), pred}); - llvm::array_pod_sort(predIDs.begin(), predIDs.end()); - - os << "\t// " << predIDs.size() << " preds: "; - - interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) { - printBlockName(pred.second); - }); - } - os << '\n'; - } - - currentIndent += indentWidth; - auto range = llvm::make_range( - block->getOperations().begin(), - std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); - for (auto &op : range) { - print(&op); - os << '\n'; - } - currentIndent -= indentWidth; -} - void OperationPrinter::print(Operation *op) { os.indent(currentIndent); printOperation(op); printTrailingLocation(op->getLoc()); } -void OperationPrinter::getResultIDAndNumber(OpResult result, Value &lookupValue, - int &lookupResultNo) const { - Operation *owner = result->getOwner(); - if (owner->getNumResults() == 1) - return; - int resultNo = result->getResultNumber(); - - // If this operation has multiple result groups, we will need to find the - // one corresponding to this result. - auto resultGroupIt = opResultGroups.find(owner); - if (resultGroupIt == opResultGroups.end()) { - // If not, just use the first result. - lookupResultNo = resultNo; - lookupValue = owner->getResult(0); - return; - } - - // Find the correct index using a binary search, as the groups are ordered. - ArrayRef<int> resultGroups = resultGroupIt->second; - auto it = llvm::upper_bound(resultGroups, resultNo); - int groupResultNo = 0, groupSize = 0; - - // If there are no smaller elements, the last result group is the lookup. - if (it == resultGroups.end()) { - groupResultNo = resultGroups.back(); - groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); - } else { - // Otherwise, the previous element is the lookup. - groupResultNo = *std::prev(it); - groupSize = *it - groupResultNo; - } - - // We only record the result number for a group of size greater than 1. - if (groupSize != 1) - lookupResultNo = resultNo - groupResultNo; - lookupValue = owner->getResult(groupResultNo); -} - -void OperationPrinter::printValueIDImpl(Value value, bool printResultNo, - raw_ostream &stream) const { - if (!value) { - stream << "<<NULL>>"; - return; - } - - int resultNo = -1; - auto lookupValue = value; - - // If this is a reference to the result of a multi-result operation or - // operation, print out the # identifier and make sure to map our lookup - // to the first result of the operation. - if (OpResult result = value.dyn_cast<OpResult>()) - getResultIDAndNumber(result, lookupValue, resultNo); - - auto it = valueIDs.find(lookupValue); - if (it == valueIDs.end()) { - stream << "<<UNKNOWN SSA VALUE>>"; - return; - } - - stream << '%'; - if (it->second != (unsigned)nameSentinel) { - stream << it->second; - } else { - auto nameIt = valueNames.find(lookupValue); - assert(nameIt != valueNames.end() && "Didn't have a name entry?"); - stream << nameIt->second; - } - - if (resultNo != -1 && printResultNo) - stream << '#' << resultNo; -} - -/// Renumber the arguments for the specified region to the same names as the -/// SSA values in namesToUse. This may only be used for IsolatedFromAbove -/// operations. If any entry in namesToUse is null, the corresponding -/// argument name is left alone. -void OperationPrinter::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { - assert(!region.empty() && "cannot shadow arguments of an empty region"); - assert(region.front().getNumArguments() == namesToUse.size() && - "incorrect number of names passed in"); - assert(region.getParentOp()->isKnownIsolatedFromAbove() && - "only KnownIsolatedFromAbove ops can shadow names"); - - SmallVector<char, 16> nameStr; - for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { - auto nameToUse = namesToUse[i]; - if (nameToUse == nullptr) - continue; - - auto nameToReplace = region.front().getArgument(i); - - nameStr.clear(); - llvm::raw_svector_ostream nameStream(nameStr); - printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream); - - // Entry block arguments should already have a pretty "arg" name. - assert(valueIDs[nameToReplace] == (unsigned)nameSentinel); - - // Use the name without the leading %. - auto name = StringRef(nameStream.str()).drop_front(); - - // Overwrite the name. - valueNames[nameToReplace] = name.copy(usedNameAllocator); - } -} - void OperationPrinter::printOperation(Operation *op) { if (size_t numResults = op->getNumResults()) { auto printResultGroup = [&](size_t resultNo, size_t resultCount) { @@ -1943,9 +1912,8 @@ void OperationPrinter::printOperation(Operation *op) { }; // Check to see if this operation has multiple result groups. - auto resultGroupIt = opResultGroups.find(op); - if (resultGroupIt != opResultGroups.end()) { - ArrayRef<int> resultGroups = resultGroupIt->second; + ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op); + if (!resultGroups.empty()) { // Interleave the groups excluding the last one, this one will be handled // separately. interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { @@ -1989,21 +1957,16 @@ void OperationPrinter::printGenericOp(Operation *op) { for (unsigned i = 0; i < numSuccessors; ++i) totalNumSuccessorOperands += op->getNumSuccessorOperands(i); unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; - SmallVector<Value, 8> properOperands( - op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); - - interleaveComma(properOperands, [&](Value value) { printValueID(value); }); + interleaveComma(op->getOperands().take_front(numProperOperands), + [&](Value value) { printValueID(value); }); os << ')'; // For terminators, print the list of successors and their operands. if (numSuccessors != 0) { os << '['; - for (unsigned i = 0; i < numSuccessors; ++i) { - if (i != 0) - os << ", "; - printSuccessorAndUseList(op, i); - } + interleaveComma(llvm::seq<unsigned>(0, numSuccessors), + [&](unsigned i) { printSuccessorAndUseList(op, i); }); os << ']'; } @@ -2025,6 +1988,73 @@ void OperationPrinter::printGenericOp(Operation *op) { printFunctionalType(op); } +void OperationPrinter::printBlockName(Block *block) { + auto id = state->getSSANameState().getBlockID(block); + if (id != SSANameState::NameSentinel) + os << "^bb" << id; + else + os << "^INVALIDBLOCK"; +} + +void OperationPrinter::print(Block *block, bool printBlockArgs, + bool printBlockTerminator) { + // Print the block label and argument list if requested. + if (printBlockArgs) { + os.indent(currentIndent); + printBlockName(block); + + // Print the argument list if non-empty. + if (!block->args_empty()) { + os << '('; + interleaveComma(block->getArguments(), [&](BlockArgument arg) { + printValueID(arg); + os << ": "; + printType(arg->getType()); + }); + os << ')'; + } + os << ':'; + + // Print out some context information about the predecessors of this block. + if (!block->getParent()) { + os << "\t// block is not in a region!"; + } else if (block->hasNoPredecessors()) { + os << "\t// no predecessors"; + } else if (auto *pred = block->getSinglePredecessor()) { + os << "\t// pred: "; + printBlockName(pred); + } else { + // We want to print the predecessors in increasing numeric order, not in + // whatever order the use-list is in, so gather and sort them. + SmallVector<std::pair<unsigned, Block *>, 4> predIDs; + for (auto *pred : block->getPredecessors()) + predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); + llvm::array_pod_sort(predIDs.begin(), predIDs.end()); + + os << "\t// " << predIDs.size() << " preds: "; + + interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) { + printBlockName(pred.second); + }); + } + os << '\n'; + } + + currentIndent += indentWidth; + auto range = llvm::make_range( + block->getOperations().begin(), + std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); + for (auto &op : range) { + print(&op); + os << '\n'; + } + currentIndent -= indentWidth; +} + +void OperationPrinter::printValueID(Value value, bool printResultNo) const { + state->getSSANameState().printValueID(value, printResultNo, os); +} + void OperationPrinter::printSuccessorAndUseList(Operation *term, unsigned index) { printBlockName(term->getSuccessor(index)); @@ -2042,15 +2072,47 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term, os << ')'; } +void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, + bool printBlockTerminators) { + os << " {\n"; + if (!region.empty()) { + auto *entryBlock = ®ion.front(); + print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0, + printBlockTerminators); + for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) + print(&b); + } + os.indent(currentIndent) << "}"; +} + +void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, + ValueRange operands) { + AffineMap map = mapAttr.getValue(); + unsigned numDims = map.getNumDims(); + auto printValueName = [&](unsigned pos, bool isSymbol) { + unsigned index = isSymbol ? numDims + pos : pos; + assert(index < operands.size()); + if (isSymbol) + os << "symbol("; + printValueID(operands[index]); + if (isSymbol) + os << ')'; + }; + + interleaveComma(map.getResults(), [&](AffineExpr expr) { + printAffineExpr(expr, printValueName); + }); +} + void ModulePrinter::print(ModuleOp module) { + assert(state && "expected valid state when printing an operation"); + // Output the aliases at the top level. - if (state) { - state->getAliasState().printAttributeAliases(os); - state->getAliasState().printTypeAliases(os); - } + state->getAliasState().printAttributeAliases(os); + state->getAliasState().printTypeAliases(os); // Print the module. - OperationPrinter(module, *this).print(module); + OperationPrinter(*this).print(module); os << '\n'; } @@ -2122,25 +2184,24 @@ void Value::dump() { void Operation::print(raw_ostream &os, OpPrintingFlags flags) { // Handle top-level operations or local printing. if (!getParent() || flags.shouldUseLocalScope()) { - ModuleState state(getContext()); + ModuleState state(this); ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(this, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); return; } - auto region = getParentRegion(); - if (!region) { - os << "<<UNLINKED INSTRUCTION>>\n"; + Operation *parentOp = getParentOp(); + if (!parentOp) { + os << "<<UNLINKED OPERATION>>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModuleState state(getContext()); + ModuleState state(parentOp); ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(region, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); } void Operation::dump() { @@ -2149,41 +2210,41 @@ void Operation::dump() { } void Block::print(raw_ostream &os) { - auto region = getParent(); - if (!region) { + Operation *parentOp = getParentOp(); + if (!parentOp) { os << "<<UNLINKED BLOCK>>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModuleState state(region->getContext()); + ModuleState state(parentOp); ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(region, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); } void Block::dump() { print(llvm::errs()); } /// Print out the name of the block without printing its body. void Block::printAsOperand(raw_ostream &os, bool printType) { - auto region = getParent(); - if (!region) { + Operation *parentOp = getParentOp(); + if (!parentOp) { os << "<<UNLINKED BLOCK>>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModulePrinter modulePrinter(os); - OperationPrinter(region, modulePrinter).printBlockName(this); + ModuleState state(parentOp); + ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); + OperationPrinter(modulePrinter).printBlockName(this); } void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { - ModuleState state(getContext()); + ModuleState state(*this); + // Don't populate aliases when printing at local scope. if (!flags.shouldUseLocalScope()) state.initializeAliases(*this); |

