diff options
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 84 |
1 files changed, 34 insertions, 50 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index af996213418..21bc3b824b1 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1078,7 +1078,7 @@ public: void print(const OperationInst *inst); void print(const ForInst *inst); void print(const IfInst *inst); - void print(const Block *block); + void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); void printGenericOp(const OperationInst *op); @@ -1125,10 +1125,15 @@ public: unsigned index) override; /// Print a block list. - void printBlockList(const BlockList &blocks) { + void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; - for (auto &b : blocks) - print(&b); + if (!blocks.empty()) { + auto *entryBlock = &blocks.front(); + print(entryBlock, + printEntryBlockArgs && entryBlock->getNumArguments() != 0); + for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1)) + print(&b); + } os.indent(currentIndent) << "}"; } @@ -1164,8 +1169,8 @@ private: /// This is the next value ID to assign in numbering. unsigned nextValueID = 0; - /// This is the ID to assign to the next induction variable. - unsigned nextLoopID = 0; + /// This is the ID to assign to the next region entry block argument. + unsigned nextRegionArgumentID = 0; /// This is the next ID to assign to a Function argument. unsigned nextArgumentID = 0; /// This is the next ID to assign when a name conflict is detected. @@ -1205,14 +1210,10 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { numberValuesInBlock(block); break; } - case Instruction::Kind::For: { - auto *forInst = cast<ForInst>(&inst); - // Number the induction variable. - numberValueID(forInst); + case Instruction::Kind::For: // Recursively number the stuff in the body. - numberValuesInBlock(*forInst->getBody()); + numberValuesInBlock(*cast<ForInst>(&inst)->getBody()); break; - } case Instruction::Kind::If: { auto *ifInst = cast<IfInst>(&inst); numberValuesInBlock(*ifInst->getThen()); @@ -1251,13 +1252,19 @@ void FunctionPrinter::numberValueID(const Value *value) { if (specialNameBuffer.empty()) { switch (value->getKind()) { case Value::Kind::BlockArgument: - // If this is an argument to the function, give it an 'arg' name. - if (auto *block = cast<BlockArgument>(value)->getOwner()) - if (auto *fn = block->getFunction()) - if (&fn->getBlockList().front() == block) { + // If this is an argument to the function, give it an 'arg' name. If the + // argument is to an entry block of an operation region, give it an 'i' + // name. + if (auto *block = cast<BlockArgument>(value)->getOwner()) { + auto *parentBlockList = block->getParent(); + if (parentBlockList && block == &parentBlockList->front()) { + if (parentBlockList->getContainingFunction()) specialName << "arg" << nextArgumentID++; - break; - } + else + specialName << "i" << nextRegionArgumentID++; + break; + } + } // Otherwise number it normally. valueIDs[value] = nextValueID++; return; @@ -1266,9 +1273,6 @@ void FunctionPrinter::numberValueID(const Value *value) { // done with it. valueIDs[value] = nextValueID++; return; - case Value::Kind::ForInst: - specialName << 'i' << nextLoopID++; - break; } } @@ -1312,10 +1316,8 @@ void FunctionPrinter::print() { printTrailingLocation(function->getLoc()); if (!function->empty()) { - os << " {\n"; - for (const auto &block : *function) - print(&block); - os << "}\n"; + printBlockList(function->getBlockList(), /*printEntryBlockArgs=*/false); + os << "\n"; } os << '\n'; } @@ -1357,26 +1359,10 @@ void FunctionPrinter::printFunctionSignature() { } } -/// Return true if the introducer for the specified block should be printed. -static bool shouldPrintBlockArguments(const Block *block) { - // Never print the entry block of the function - it is included in the - // argument list. - if (block == &block->getFunction()->front()) - return false; - - // If this is the first block in a nested region, and if there are no - // arguments, then we can omit it. - if (block == &block->getParent()->front() && block->getNumArguments() == 0) - return false; - - // Otherwise print it. - return true; -} - -void FunctionPrinter::print(const Block *block) { +void FunctionPrinter::print(const Block *block, bool printBlockArgs) { // Print the block label and argument list, unless this is the first block of // the function, or the first block of an IfInst/ForInst with no arguments. - if (shouldPrintBlockArguments(block)) { + if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1445,7 +1431,7 @@ void FunctionPrinter::print(const OperationInst *inst) { void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "for "; - printOperand(inst); + printOperand(inst->getInductionVar()); os << " = "; printBound(inst->getLowerBound(), "max"); os << " to "; @@ -1457,7 +1443,7 @@ void FunctionPrinter::print(const ForInst *inst) { printTrailingLocation(inst->getLoc()); os << " {\n"; - print(inst->getBody()); + print(inst->getBody(), /*printBlockArgs=*/false); os.indent(currentIndent) << "}"; } @@ -1468,11 +1454,11 @@ void FunctionPrinter::print(const IfInst *inst) { printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); printTrailingLocation(inst->getLoc()); os << " {\n"; - print(inst->getThen()); + print(inst->getThen(), /*printBlockArgs=*/false); os.indent(currentIndent) << "}"; if (inst->hasElse()) { os << " else {\n"; - print(inst->getElse()); + print(inst->getElse(), /*printBlockArgs=*/false); os.indent(currentIndent) << "}"; } } @@ -1583,7 +1569,7 @@ void FunctionPrinter::printGenericOp(const OperationInst *op) { // Print any trailing block lists. for (auto &blockList : op->getBlockLists()) - printBlockList(blockList); + printBlockList(blockList, /*printEntryBlockArgs=*/true); } void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term, @@ -1729,8 +1715,6 @@ void Value::print(raw_ostream &os) const { return; case Value::Kind::InstResult: return getDefiningInst()->print(os); - case Value::Kind::ForInst: - return cast<ForInst>(this)->print(os); } } |

