diff options
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 193 | ||||
| -rw-r--r-- | mlir/lib/IR/CMakeLists.txt | 2 |
2 files changed, 141 insertions, 54 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 8ffe9c51a9c..655a776118c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -58,6 +58,13 @@ DialectAsmPrinter::~DialectAsmPrinter() {} OpAsmPrinter::~OpAsmPrinter() {} +//===--------------------------------------------------------------------===// +// Operation OpAsm interface. +//===--------------------------------------------------------------------===// + +/// The OpAsmOpInterface, see OpAsmInterface.td for more details. +#include "mlir/IR/OpAsmInterface.cpp.inc" + //===----------------------------------------------------------------------===// // OpPrintingFlags //===----------------------------------------------------------------------===// @@ -1490,17 +1497,26 @@ public: const static unsigned indentWidth = 2; protected: - void numberValueID(Value *value); void numberValuesInRegion(Region ®ion); void numberValuesInBlock(Block &block); + void numberValuesInOp(Operation &op); void printValueID(Value *value, bool printResultNo = true) const { printValueIDImpl(value, printResultNo, os); } private: + /// Given a result of an operation 'result', find the result group head + /// 'lookupValue' and the result of 'result' within that group in + /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group + /// has more than 1 result. + void getResultIDAndNumber(OpResult *result, Value *&lookupValue, + int &lookupResultNo) const; void printValueIDImpl(Value *value, bool printResultNo, raw_ostream &stream) const; + /// Set a special value name for the given value. + void setValueName(Value *value, StringRef name); + /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. StringRef uniqueValueName(StringRef name); @@ -1510,6 +1526,11 @@ private: DenseMap<Value *, unsigned> valueIDs; DenseMap<Value *, StringRef> valueNames; + /// This is a map of operations that contain multiple named result groups, + /// i.e. there may be multiple names for the results of the operation. The key + /// of this map are the result numbers that start a result group. + DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; + /// This is the block ID for each block in the current. DenseMap<Block *, unsigned> blockIDs; @@ -1534,8 +1555,7 @@ private: OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other) : ModulePrinter(other) { llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); - if (op->getNumResults() != 0) - numberValueID(op->getResult(0)); + numberValuesInOp(*op); for (auto ®ion : op->getRegions()) numberValuesInRegion(region); @@ -1546,7 +1566,6 @@ OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other) numberValuesInRegion(*region); } -/// Number all of the SSA values in the specified region. void OperationPrinter::numberValuesInRegion(Region ®ion) { // Save the current value ids to allow for numbering values in sibling regions // the same. @@ -1580,59 +1599,72 @@ void OperationPrinter::numberValuesInRegion(Region ®ion) { nextConflictID = curConflictID; } -/// Number all of the SSA values in the specified block, without traversing -/// nested regions. void OperationPrinter::numberValuesInBlock(Block &block) { - // Number the block arguments. - for (auto *arg : block.getArguments()) - numberValueID(arg); + bool isEntryBlock = block.isEntryBlock(); + + // Number the block arguments. We give entry block arguments a special name + // 'arg'. + SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); + llvm::raw_svector_ostream specialName(specialNameBuffer); + for (auto *arg : block.getArguments()) { + if (isEntryBlock) { + specialNameBuffer.resize(strlen("arg")); + specialName << nextArgumentID++; + } + setValueName(arg, specialName.str()); + } - // We number operation that have results, and we only number the first result. + // Number the operations in this block. for (auto &op : block) - if (op.getNumResults() != 0) - numberValueID(op.getResult(0)); + numberValuesInOp(op); } -void OperationPrinter::numberValueID(Value *value) { - assert(!valueIDs.count(value) && "Value numbered multiple times"); - - SmallString<32> specialNameBuffer; - llvm::raw_svector_ostream specialName(specialNameBuffer); +void OperationPrinter::numberValuesInOp(Operation &op) { + unsigned numResults = op.getNumResults(); + if (numResults == 0) + return; + Value *resultBegin = op.getResult(0); + + // Function used to set the special result names for the operation. + SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); + auto setResultNameFn = [&](Value *result, StringRef name) { + assert(!valueIDs.count(result) && "result numbered multiple times"); + assert(result->getDefiningOp() == &op && "result not defined by 'op'"); + setValueName(result, name); + + // Record the result number for groups not anchored at 0. + if (int resultNo = cast<OpResult>(result)->getResultNumber()) + resultGroups.push_back(resultNo); + }; - // Check to see if this value requested a special name. - auto *op = value->getDefiningOp(); - if (state && op) { - if (auto *interface = state->getOpAsmInterface(op->getDialect())) - interface->getOpResultName(op, specialName); + if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { + asmInterface.getAsmResultNames(setResultNameFn); + } else if (auto *dialectAsmInterface = + state ? state->getOpAsmInterface(op.getDialect()) : nullptr) { + dialectAsmInterface->getAsmResultNames(&op, setResultNameFn); } - if (specialNameBuffer.empty()) { - auto *blockArg = dyn_cast<BlockArgument>(value); - if (!blockArg) { - // This is an uninteresting operation result, give it a boring number and - // be done with it. - valueIDs[value] = nextValueID++; - return; - } + // If the first result wasn't numbered, give it a default number. + if (valueIDs.try_emplace(resultBegin, nextValueID).second) + ++nextValueID; - // Otherwise, if this is an argument to the entry block of a region, give it - // an 'arg' name. - if (auto *block = blockArg->getOwner()) { - auto *parentRegion = block->getParent(); - if (parentRegion && block == &parentRegion->front()) - specialName << "arg" << nextArgumentID++; - } + // If this operation has multiple result groups, mark it. + if (resultGroups.size() != 1) { + llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); + opResultGroups.try_emplace(&op, std::move(resultGroups)); + } +} - // Otherwise number it normally. - if (specialNameBuffer.empty()) { - valueIDs[value] = nextValueID++; - return; - } +/// Set a special value name for the given value. +void OperationPrinter::setValueName(Value *value, StringRef name) { + // If the name is empty, the value uses the default numbering. + if (name.empty()) { + valueIDs[value] = nextValueID++; + return; } - // Ok, this value had an interesting name. Remember it with a sentinel. valueIDs[value] = nameSentinel; - valueNames[value] = uniqueValueName(specialName.str()); + valueNames[value] = uniqueValueName(name); } /// Uniques the given value name within the printer. If the given name @@ -1722,6 +1754,45 @@ void OperationPrinter::print(Operation *op) { printTrailingLocation(op->getLoc()); } +void OperationPrinter::getResultIDAndNumber(OpResult *result, + Value *&lookupValue, + int &lookupResultNo) const { + Operation *owner = result->getOwner(); + if (owner->getNumResults() == 1) + return; + int resultNo = result->getResultNumber(); + + // If this operation has multiple result groups, we will need to find the + // one corresponding to this result. + auto resultGroupIt = opResultGroups.find(owner); + if (resultGroupIt == opResultGroups.end()) { + // If not, just use the first result. + lookupResultNo = resultNo; + lookupValue = owner->getResult(0); + return; + } + + // Find the correct index using a binary search, as the groups are ordered. + ArrayRef<int> resultGroups = resultGroupIt->second; + auto it = llvm::upper_bound(resultGroups, resultNo); + int groupResultNo = 0, groupSize = 0; + + // If there are no smaller elements, the last result group is the lookup. + if (it == resultGroups.end()) { + groupResultNo = resultGroups.back(); + groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); + } else { + // Otherwise, the previous element is the lookup. + groupResultNo = *std::prev(it); + groupSize = *it - groupResultNo; + } + + // We only record the result number for a group of size greater than 1. + if (groupSize != 1) + lookupResultNo = resultNo - groupResultNo; + lookupValue = owner->getResult(groupResultNo); +} + void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, raw_ostream &stream) const { if (!value) { @@ -1735,12 +1806,8 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, // If this is a reference to the result of a multi-result operation or // operation, print out the # identifier and make sure to map our lookup // to the first result of the operation. - if (auto *result = dyn_cast<OpResult>(value)) { - if (result->getOwner()->getNumResults() != 1) { - resultNo = result->getResultNumber(); - lookupValue = result->getOwner()->getResult(0); - } - } + if (OpResult *result = dyn_cast<OpResult>(value)) + getResultIDAndNumber(result, lookupValue, resultNo); auto it = valueIDs.find(lookupValue); if (it == valueIDs.end()) { @@ -1798,9 +1865,29 @@ void OperationPrinter::shadowRegionArgs(Region ®ion, void OperationPrinter::printOperation(Operation *op) { if (size_t numResults = op->getNumResults()) { - printValueID(op->getResult(0), /*printResultNo=*/false); - if (numResults > 1) - os << ':' << numResults; + auto printResultGroup = [&](size_t resultNo, size_t resultCount) { + printValueID(op->getResult(resultNo), /*printResultNo=*/false); + if (resultCount > 1) + os << ':' << resultCount; + }; + + // Check to see if this operation has multiple result groups. + auto resultGroupIt = opResultGroups.find(op); + if (resultGroupIt != opResultGroups.end()) { + ArrayRef<int> resultGroups = resultGroupIt->second; + // Interleave the groups excluding the last one, this one will be handled + // separately. + interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { + printResultGroup(resultGroups[i], + resultGroups[i + 1] - resultGroups[i]); + }); + os << ", "; + printResultGroup(resultGroups.back(), numResults - resultGroups.back()); + + } else { + printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); + } + os << " = "; } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt index 2519b83eede..415d9d66e22 100644 --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -5,5 +5,5 @@ add_llvm_library(MLIRIR ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR ) -add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIRSupport LLVMSupport) +add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIROpAsmInterfacesIncGen MLIRSupport LLVMSupport) target_link_libraries(MLIRIR MLIRSupport LLVMSupport) |

