diff options
| author | River Riddle <riverriddle@google.com> | 2019-11-20 10:19:01 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-20 10:45:45 -0800 |
| commit | eb418559ef29716cc34c891c93490c38ac5ea1dd (patch) | |
| tree | be1e4ac3d32e5df31b8668785ba15f92fe1895a9 /mlir/lib/IR | |
| parent | 3c055957de7e47e53d3ee8f5ab283cdb5c0ea535 (diff) | |
| download | bcm5719-llvm-eb418559ef29716cc34c891c93490c38ac5ea1dd.tar.gz bcm5719-llvm-eb418559ef29716cc34c891c93490c38ac5ea1dd.zip | |
Add a new OpAsmOpInterface to allow for ops to directly hook into the AsmPrinter.
This interface provides more fine-grained hooks into the AsmPrinter than the dialect interface, allowing for operations to define the asm name to use for results directly on the operations themselves. The hook is also expanded to enable defining named result "groups". Get a special name to use when printing the results of this operation.
The given callback is invoked with a specific result value that starts a
result "pack", and the name to give this result pack. To signal that a
result pack should use the default naming scheme, a None can be passed
in instead of the name.
For example, if you have an operation that has four results and you want
to split these into three distinct groups you could do the following:
setNameFn(getResult(0), "first_result");
setNameFn(getResult(1), "middle_results");
setNameFn(getResult(3), ""); // use the default numbering.
This would print the operation as follows:
%first_result, %middle_results:2, %0 = "my.op" ...
PiperOrigin-RevId: 281546873
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 193 | ||||
| -rw-r--r-- | mlir/lib/IR/CMakeLists.txt | 2 |
2 files changed, 141 insertions, 54 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 8ffe9c51a9c..655a776118c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -58,6 +58,13 @@ DialectAsmPrinter::~DialectAsmPrinter() {} OpAsmPrinter::~OpAsmPrinter() {} +//===--------------------------------------------------------------------===// +// Operation OpAsm interface. +//===--------------------------------------------------------------------===// + +/// The OpAsmOpInterface, see OpAsmInterface.td for more details. +#include "mlir/IR/OpAsmInterface.cpp.inc" + //===----------------------------------------------------------------------===// // OpPrintingFlags //===----------------------------------------------------------------------===// @@ -1490,17 +1497,26 @@ public: const static unsigned indentWidth = 2; protected: - void numberValueID(Value *value); void numberValuesInRegion(Region ®ion); void numberValuesInBlock(Block &block); + void numberValuesInOp(Operation &op); void printValueID(Value *value, bool printResultNo = true) const { printValueIDImpl(value, printResultNo, os); } private: + /// Given a result of an operation 'result', find the result group head + /// 'lookupValue' and the result of 'result' within that group in + /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group + /// has more than 1 result. + void getResultIDAndNumber(OpResult *result, Value *&lookupValue, + int &lookupResultNo) const; void printValueIDImpl(Value *value, bool printResultNo, raw_ostream &stream) const; + /// Set a special value name for the given value. + void setValueName(Value *value, StringRef name); + /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. StringRef uniqueValueName(StringRef name); @@ -1510,6 +1526,11 @@ private: DenseMap<Value *, unsigned> valueIDs; DenseMap<Value *, StringRef> valueNames; + /// This is a map of operations that contain multiple named result groups, + /// i.e. there may be multiple names for the results of the operation. The key + /// of this map are the result numbers that start a result group. + DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; + /// This is the block ID for each block in the current. DenseMap<Block *, unsigned> blockIDs; @@ -1534,8 +1555,7 @@ private: OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other) : ModulePrinter(other) { llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); - if (op->getNumResults() != 0) - numberValueID(op->getResult(0)); + numberValuesInOp(*op); for (auto ®ion : op->getRegions()) numberValuesInRegion(region); @@ -1546,7 +1566,6 @@ OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other) numberValuesInRegion(*region); } -/// Number all of the SSA values in the specified region. void OperationPrinter::numberValuesInRegion(Region ®ion) { // Save the current value ids to allow for numbering values in sibling regions // the same. @@ -1580,59 +1599,72 @@ void OperationPrinter::numberValuesInRegion(Region ®ion) { nextConflictID = curConflictID; } -/// Number all of the SSA values in the specified block, without traversing -/// nested regions. void OperationPrinter::numberValuesInBlock(Block &block) { - // Number the block arguments. - for (auto *arg : block.getArguments()) - numberValueID(arg); + bool isEntryBlock = block.isEntryBlock(); + + // Number the block arguments. We give entry block arguments a special name + // 'arg'. + SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); + llvm::raw_svector_ostream specialName(specialNameBuffer); + for (auto *arg : block.getArguments()) { + if (isEntryBlock) { + specialNameBuffer.resize(strlen("arg")); + specialName << nextArgumentID++; + } + setValueName(arg, specialName.str()); + } - // We number operation that have results, and we only number the first result. + // Number the operations in this block. for (auto &op : block) - if (op.getNumResults() != 0) - numberValueID(op.getResult(0)); + numberValuesInOp(op); } -void OperationPrinter::numberValueID(Value *value) { - assert(!valueIDs.count(value) && "Value numbered multiple times"); - - SmallString<32> specialNameBuffer; - llvm::raw_svector_ostream specialName(specialNameBuffer); +void OperationPrinter::numberValuesInOp(Operation &op) { + unsigned numResults = op.getNumResults(); + if (numResults == 0) + return; + Value *resultBegin = op.getResult(0); + + // Function used to set the special result names for the operation. + SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); + auto setResultNameFn = [&](Value *result, StringRef name) { + assert(!valueIDs.count(result) && "result numbered multiple times"); + assert(result->getDefiningOp() == &op && "result not defined by 'op'"); + setValueName(result, name); + + // Record the result number for groups not anchored at 0. + if (int resultNo = cast<OpResult>(result)->getResultNumber()) + resultGroups.push_back(resultNo); + }; - // Check to see if this value requested a special name. - auto *op = value->getDefiningOp(); - if (state && op) { - if (auto *interface = state->getOpAsmInterface(op->getDialect())) - interface->getOpResultName(op, specialName); + if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { + asmInterface.getAsmResultNames(setResultNameFn); + } else if (auto *dialectAsmInterface = + state ? state->getOpAsmInterface(op.getDialect()) : nullptr) { + dialectAsmInterface->getAsmResultNames(&op, setResultNameFn); } - if (specialNameBuffer.empty()) { - auto *blockArg = dyn_cast<BlockArgument>(value); - if (!blockArg) { - // This is an uninteresting operation result, give it a boring number and - // be done with it. - valueIDs[value] = nextValueID++; - return; - } + // If the first result wasn't numbered, give it a default number. + if (valueIDs.try_emplace(resultBegin, nextValueID).second) + ++nextValueID; - // Otherwise, if this is an argument to the entry block of a region, give it - // an 'arg' name. - if (auto *block = blockArg->getOwner()) { - auto *parentRegion = block->getParent(); - if (parentRegion && block == &parentRegion->front()) - specialName << "arg" << nextArgumentID++; - } + // If this operation has multiple result groups, mark it. + if (resultGroups.size() != 1) { + llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); + opResultGroups.try_emplace(&op, std::move(resultGroups)); + } +} - // Otherwise number it normally. - if (specialNameBuffer.empty()) { - valueIDs[value] = nextValueID++; - return; - } +/// Set a special value name for the given value. +void OperationPrinter::setValueName(Value *value, StringRef name) { + // If the name is empty, the value uses the default numbering. + if (name.empty()) { + valueIDs[value] = nextValueID++; + return; } - // Ok, this value had an interesting name. Remember it with a sentinel. valueIDs[value] = nameSentinel; - valueNames[value] = uniqueValueName(specialName.str()); + valueNames[value] = uniqueValueName(name); } /// Uniques the given value name within the printer. If the given name @@ -1722,6 +1754,45 @@ void OperationPrinter::print(Operation *op) { printTrailingLocation(op->getLoc()); } +void OperationPrinter::getResultIDAndNumber(OpResult *result, + Value *&lookupValue, + int &lookupResultNo) const { + Operation *owner = result->getOwner(); + if (owner->getNumResults() == 1) + return; + int resultNo = result->getResultNumber(); + + // If this operation has multiple result groups, we will need to find the + // one corresponding to this result. + auto resultGroupIt = opResultGroups.find(owner); + if (resultGroupIt == opResultGroups.end()) { + // If not, just use the first result. + lookupResultNo = resultNo; + lookupValue = owner->getResult(0); + return; + } + + // Find the correct index using a binary search, as the groups are ordered. + ArrayRef<int> resultGroups = resultGroupIt->second; + auto it = llvm::upper_bound(resultGroups, resultNo); + int groupResultNo = 0, groupSize = 0; + + // If there are no smaller elements, the last result group is the lookup. + if (it == resultGroups.end()) { + groupResultNo = resultGroups.back(); + groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); + } else { + // Otherwise, the previous element is the lookup. + groupResultNo = *std::prev(it); + groupSize = *it - groupResultNo; + } + + // We only record the result number for a group of size greater than 1. + if (groupSize != 1) + lookupResultNo = resultNo - groupResultNo; + lookupValue = owner->getResult(groupResultNo); +} + void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, raw_ostream &stream) const { if (!value) { @@ -1735,12 +1806,8 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, // If this is a reference to the result of a multi-result operation or // operation, print out the # identifier and make sure to map our lookup // to the first result of the operation. - if (auto *result = dyn_cast<OpResult>(value)) { - if (result->getOwner()->getNumResults() != 1) { - resultNo = result->getResultNumber(); - lookupValue = result->getOwner()->getResult(0); - } - } + if (OpResult *result = dyn_cast<OpResult>(value)) + getResultIDAndNumber(result, lookupValue, resultNo); auto it = valueIDs.find(lookupValue); if (it == valueIDs.end()) { @@ -1798,9 +1865,29 @@ void OperationPrinter::shadowRegionArgs(Region ®ion, void OperationPrinter::printOperation(Operation *op) { if (size_t numResults = op->getNumResults()) { - printValueID(op->getResult(0), /*printResultNo=*/false); - if (numResults > 1) - os << ':' << numResults; + auto printResultGroup = [&](size_t resultNo, size_t resultCount) { + printValueID(op->getResult(resultNo), /*printResultNo=*/false); + if (resultCount > 1) + os << ':' << resultCount; + }; + + // Check to see if this operation has multiple result groups. + auto resultGroupIt = opResultGroups.find(op); + if (resultGroupIt != opResultGroups.end()) { + ArrayRef<int> resultGroups = resultGroupIt->second; + // Interleave the groups excluding the last one, this one will be handled + // separately. + interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { + printResultGroup(resultGroups[i], + resultGroups[i + 1] - resultGroups[i]); + }); + os << ", "; + printResultGroup(resultGroups.back(), numResults - resultGroups.back()); + + } else { + printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); + } + os << " = "; } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt index 2519b83eede..415d9d66e22 100644 --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -5,5 +5,5 @@ add_llvm_library(MLIRIR ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR ) -add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIRSupport LLVMSupport) +add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIROpAsmInterfacesIncGen MLIRSupport LLVMSupport) target_link_libraries(MLIRIR MLIRSupport LLVMSupport) |

