summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp193
-rw-r--r--mlir/lib/IR/CMakeLists.txt2
2 files changed, 141 insertions, 54 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 8ffe9c51a9c..655a776118c 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -58,6 +58,13 @@ DialectAsmPrinter::~DialectAsmPrinter() {}
OpAsmPrinter::~OpAsmPrinter() {}
+//===--------------------------------------------------------------------===//
+// Operation OpAsm interface.
+//===--------------------------------------------------------------------===//
+
+/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
+#include "mlir/IR/OpAsmInterface.cpp.inc"
+
//===----------------------------------------------------------------------===//
// OpPrintingFlags
//===----------------------------------------------------------------------===//
@@ -1490,17 +1497,26 @@ public:
const static unsigned indentWidth = 2;
protected:
- void numberValueID(Value *value);
void numberValuesInRegion(Region &region);
void numberValuesInBlock(Block &block);
+ void numberValuesInOp(Operation &op);
void printValueID(Value *value, bool printResultNo = true) const {
printValueIDImpl(value, printResultNo, os);
}
private:
+ /// Given a result of an operation 'result', find the result group head
+ /// 'lookupValue' and the result of 'result' within that group in
+ /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
+ /// has more than 1 result.
+ void getResultIDAndNumber(OpResult *result, Value *&lookupValue,
+ int &lookupResultNo) const;
void printValueIDImpl(Value *value, bool printResultNo,
raw_ostream &stream) const;
+ /// Set a special value name for the given value.
+ void setValueName(Value *value, StringRef name);
+
/// Uniques the given value name within the printer. If the given name
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
@@ -1510,6 +1526,11 @@ private:
DenseMap<Value *, unsigned> valueIDs;
DenseMap<Value *, StringRef> valueNames;
+ /// This is a map of operations that contain multiple named result groups,
+ /// i.e. there may be multiple names for the results of the operation. The key
+ /// of this map are the result numbers that start a result group.
+ DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
+
/// This is the block ID for each block in the current.
DenseMap<Block *, unsigned> blockIDs;
@@ -1534,8 +1555,7 @@ private:
OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other)
: ModulePrinter(other) {
llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
- if (op->getNumResults() != 0)
- numberValueID(op->getResult(0));
+ numberValuesInOp(*op);
for (auto &region : op->getRegions())
numberValuesInRegion(region);
@@ -1546,7 +1566,6 @@ OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other)
numberValuesInRegion(*region);
}
-/// Number all of the SSA values in the specified region.
void OperationPrinter::numberValuesInRegion(Region &region) {
// Save the current value ids to allow for numbering values in sibling regions
// the same.
@@ -1580,59 +1599,72 @@ void OperationPrinter::numberValuesInRegion(Region &region) {
nextConflictID = curConflictID;
}
-/// Number all of the SSA values in the specified block, without traversing
-/// nested regions.
void OperationPrinter::numberValuesInBlock(Block &block) {
- // Number the block arguments.
- for (auto *arg : block.getArguments())
- numberValueID(arg);
+ bool isEntryBlock = block.isEntryBlock();
+
+ // Number the block arguments. We give entry block arguments a special name
+ // 'arg'.
+ SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
+ llvm::raw_svector_ostream specialName(specialNameBuffer);
+ for (auto *arg : block.getArguments()) {
+ if (isEntryBlock) {
+ specialNameBuffer.resize(strlen("arg"));
+ specialName << nextArgumentID++;
+ }
+ setValueName(arg, specialName.str());
+ }
- // We number operation that have results, and we only number the first result.
+ // Number the operations in this block.
for (auto &op : block)
- if (op.getNumResults() != 0)
- numberValueID(op.getResult(0));
+ numberValuesInOp(op);
}
-void OperationPrinter::numberValueID(Value *value) {
- assert(!valueIDs.count(value) && "Value numbered multiple times");
-
- SmallString<32> specialNameBuffer;
- llvm::raw_svector_ostream specialName(specialNameBuffer);
+void OperationPrinter::numberValuesInOp(Operation &op) {
+ unsigned numResults = op.getNumResults();
+ if (numResults == 0)
+ return;
+ Value *resultBegin = op.getResult(0);
+
+ // Function used to set the special result names for the operation.
+ SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+ auto setResultNameFn = [&](Value *result, StringRef name) {
+ assert(!valueIDs.count(result) && "result numbered multiple times");
+ assert(result->getDefiningOp() == &op && "result not defined by 'op'");
+ setValueName(result, name);
+
+ // Record the result number for groups not anchored at 0.
+ if (int resultNo = cast<OpResult>(result)->getResultNumber())
+ resultGroups.push_back(resultNo);
+ };
- // Check to see if this value requested a special name.
- auto *op = value->getDefiningOp();
- if (state && op) {
- if (auto *interface = state->getOpAsmInterface(op->getDialect()))
- interface->getOpResultName(op, specialName);
+ if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
+ asmInterface.getAsmResultNames(setResultNameFn);
+ } else if (auto *dialectAsmInterface =
+ state ? state->getOpAsmInterface(op.getDialect()) : nullptr) {
+ dialectAsmInterface->getAsmResultNames(&op, setResultNameFn);
}
- if (specialNameBuffer.empty()) {
- auto *blockArg = dyn_cast<BlockArgument>(value);
- if (!blockArg) {
- // This is an uninteresting operation result, give it a boring number and
- // be done with it.
- valueIDs[value] = nextValueID++;
- return;
- }
+ // If the first result wasn't numbered, give it a default number.
+ if (valueIDs.try_emplace(resultBegin, nextValueID).second)
+ ++nextValueID;
- // Otherwise, if this is an argument to the entry block of a region, give it
- // an 'arg' name.
- if (auto *block = blockArg->getOwner()) {
- auto *parentRegion = block->getParent();
- if (parentRegion && block == &parentRegion->front())
- specialName << "arg" << nextArgumentID++;
- }
+ // If this operation has multiple result groups, mark it.
+ if (resultGroups.size() != 1) {
+ llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
+ opResultGroups.try_emplace(&op, std::move(resultGroups));
+ }
+}
- // Otherwise number it normally.
- if (specialNameBuffer.empty()) {
- valueIDs[value] = nextValueID++;
- return;
- }
+/// Set a special value name for the given value.
+void OperationPrinter::setValueName(Value *value, StringRef name) {
+ // If the name is empty, the value uses the default numbering.
+ if (name.empty()) {
+ valueIDs[value] = nextValueID++;
+ return;
}
- // Ok, this value had an interesting name. Remember it with a sentinel.
valueIDs[value] = nameSentinel;
- valueNames[value] = uniqueValueName(specialName.str());
+ valueNames[value] = uniqueValueName(name);
}
/// Uniques the given value name within the printer. If the given name
@@ -1722,6 +1754,45 @@ void OperationPrinter::print(Operation *op) {
printTrailingLocation(op->getLoc());
}
+void OperationPrinter::getResultIDAndNumber(OpResult *result,
+ Value *&lookupValue,
+ int &lookupResultNo) const {
+ Operation *owner = result->getOwner();
+ if (owner->getNumResults() == 1)
+ return;
+ int resultNo = result->getResultNumber();
+
+ // If this operation has multiple result groups, we will need to find the
+ // one corresponding to this result.
+ auto resultGroupIt = opResultGroups.find(owner);
+ if (resultGroupIt == opResultGroups.end()) {
+ // If not, just use the first result.
+ lookupResultNo = resultNo;
+ lookupValue = owner->getResult(0);
+ return;
+ }
+
+ // Find the correct index using a binary search, as the groups are ordered.
+ ArrayRef<int> resultGroups = resultGroupIt->second;
+ auto it = llvm::upper_bound(resultGroups, resultNo);
+ int groupResultNo = 0, groupSize = 0;
+
+ // If there are no smaller elements, the last result group is the lookup.
+ if (it == resultGroups.end()) {
+ groupResultNo = resultGroups.back();
+ groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
+ } else {
+ // Otherwise, the previous element is the lookup.
+ groupResultNo = *std::prev(it);
+ groupSize = *it - groupResultNo;
+ }
+
+ // We only record the result number for a group of size greater than 1.
+ if (groupSize != 1)
+ lookupResultNo = resultNo - groupResultNo;
+ lookupValue = owner->getResult(groupResultNo);
+}
+
void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
raw_ostream &stream) const {
if (!value) {
@@ -1735,12 +1806,8 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
// If this is a reference to the result of a multi-result operation or
// operation, print out the # identifier and make sure to map our lookup
// to the first result of the operation.
- if (auto *result = dyn_cast<OpResult>(value)) {
- if (result->getOwner()->getNumResults() != 1) {
- resultNo = result->getResultNumber();
- lookupValue = result->getOwner()->getResult(0);
- }
- }
+ if (OpResult *result = dyn_cast<OpResult>(value))
+ getResultIDAndNumber(result, lookupValue, resultNo);
auto it = valueIDs.find(lookupValue);
if (it == valueIDs.end()) {
@@ -1798,9 +1865,29 @@ void OperationPrinter::shadowRegionArgs(Region &region,
void OperationPrinter::printOperation(Operation *op) {
if (size_t numResults = op->getNumResults()) {
- printValueID(op->getResult(0), /*printResultNo=*/false);
- if (numResults > 1)
- os << ':' << numResults;
+ auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
+ printValueID(op->getResult(resultNo), /*printResultNo=*/false);
+ if (resultCount > 1)
+ os << ':' << resultCount;
+ };
+
+ // Check to see if this operation has multiple result groups.
+ auto resultGroupIt = opResultGroups.find(op);
+ if (resultGroupIt != opResultGroups.end()) {
+ ArrayRef<int> resultGroups = resultGroupIt->second;
+ // Interleave the groups excluding the last one, this one will be handled
+ // separately.
+ interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
+ printResultGroup(resultGroups[i],
+ resultGroups[i + 1] - resultGroups[i]);
+ });
+ os << ", ";
+ printResultGroup(resultGroups.back(), numResults - resultGroups.back());
+
+ } else {
+ printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
+ }
+
os << " = ";
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 2519b83eede..415d9d66e22 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -5,5 +5,5 @@ add_llvm_library(MLIRIR
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
)
-add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIRSupport LLVMSupport)
+add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIROpAsmInterfacesIncGen MLIRSupport LLVMSupport)
target_link_libraries(MLIRIR MLIRSupport LLVMSupport)
OpenPOWER on IntegriCloud