summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/AsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp84
1 files changed, 34 insertions, 50 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index af996213418..21bc3b824b1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1078,7 +1078,7 @@ public:
void print(const OperationInst *inst);
void print(const ForInst *inst);
void print(const IfInst *inst);
- void print(const Block *block);
+ void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
void printGenericOp(const OperationInst *op);
@@ -1125,10 +1125,15 @@ public:
unsigned index) override;
/// Print a block list.
- void printBlockList(const BlockList &blocks) {
+ void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
os << " {\n";
- for (auto &b : blocks)
- print(&b);
+ if (!blocks.empty()) {
+ auto *entryBlock = &blocks.front();
+ print(entryBlock,
+ printEntryBlockArgs && entryBlock->getNumArguments() != 0);
+ for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1))
+ print(&b);
+ }
os.indent(currentIndent) << "}";
}
@@ -1164,8 +1169,8 @@ private:
/// This is the next value ID to assign in numbering.
unsigned nextValueID = 0;
- /// This is the ID to assign to the next induction variable.
- unsigned nextLoopID = 0;
+ /// This is the ID to assign to the next region entry block argument.
+ unsigned nextRegionArgumentID = 0;
/// This is the next ID to assign to a Function argument.
unsigned nextArgumentID = 0;
/// This is the next ID to assign when a name conflict is detected.
@@ -1205,14 +1210,10 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
numberValuesInBlock(block);
break;
}
- case Instruction::Kind::For: {
- auto *forInst = cast<ForInst>(&inst);
- // Number the induction variable.
- numberValueID(forInst);
+ case Instruction::Kind::For:
// Recursively number the stuff in the body.
- numberValuesInBlock(*forInst->getBody());
+ numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
break;
- }
case Instruction::Kind::If: {
auto *ifInst = cast<IfInst>(&inst);
numberValuesInBlock(*ifInst->getThen());
@@ -1251,13 +1252,19 @@ void FunctionPrinter::numberValueID(const Value *value) {
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
case Value::Kind::BlockArgument:
- // If this is an argument to the function, give it an 'arg' name.
- if (auto *block = cast<BlockArgument>(value)->getOwner())
- if (auto *fn = block->getFunction())
- if (&fn->getBlockList().front() == block) {
+ // If this is an argument to the function, give it an 'arg' name. If the
+ // argument is to an entry block of an operation region, give it an 'i'
+ // name.
+ if (auto *block = cast<BlockArgument>(value)->getOwner()) {
+ auto *parentBlockList = block->getParent();
+ if (parentBlockList && block == &parentBlockList->front()) {
+ if (parentBlockList->getContainingFunction())
specialName << "arg" << nextArgumentID++;
- break;
- }
+ else
+ specialName << "i" << nextRegionArgumentID++;
+ break;
+ }
+ }
// Otherwise number it normally.
valueIDs[value] = nextValueID++;
return;
@@ -1266,9 +1273,6 @@ void FunctionPrinter::numberValueID(const Value *value) {
// done with it.
valueIDs[value] = nextValueID++;
return;
- case Value::Kind::ForInst:
- specialName << 'i' << nextLoopID++;
- break;
}
}
@@ -1312,10 +1316,8 @@ void FunctionPrinter::print() {
printTrailingLocation(function->getLoc());
if (!function->empty()) {
- os << " {\n";
- for (const auto &block : *function)
- print(&block);
- os << "}\n";
+ printBlockList(function->getBlockList(), /*printEntryBlockArgs=*/false);
+ os << "\n";
}
os << '\n';
}
@@ -1357,26 +1359,10 @@ void FunctionPrinter::printFunctionSignature() {
}
}
-/// Return true if the introducer for the specified block should be printed.
-static bool shouldPrintBlockArguments(const Block *block) {
- // Never print the entry block of the function - it is included in the
- // argument list.
- if (block == &block->getFunction()->front())
- return false;
-
- // If this is the first block in a nested region, and if there are no
- // arguments, then we can omit it.
- if (block == &block->getParent()->front() && block->getNumArguments() == 0)
- return false;
-
- // Otherwise print it.
- return true;
-}
-
-void FunctionPrinter::print(const Block *block) {
+void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
// Print the block label and argument list, unless this is the first block of
// the function, or the first block of an IfInst/ForInst with no arguments.
- if (shouldPrintBlockArguments(block)) {
+ if (printBlockArgs) {
os.indent(currentIndent);
printBlockName(block);
@@ -1445,7 +1431,7 @@ void FunctionPrinter::print(const OperationInst *inst) {
void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "for ";
- printOperand(inst);
+ printOperand(inst->getInductionVar());
os << " = ";
printBound(inst->getLowerBound(), "max");
os << " to ";
@@ -1457,7 +1443,7 @@ void FunctionPrinter::print(const ForInst *inst) {
printTrailingLocation(inst->getLoc());
os << " {\n";
- print(inst->getBody());
+ print(inst->getBody(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
}
@@ -1468,11 +1454,11 @@ void FunctionPrinter::print(const IfInst *inst) {
printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
printTrailingLocation(inst->getLoc());
os << " {\n";
- print(inst->getThen());
+ print(inst->getThen(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
if (inst->hasElse()) {
os << " else {\n";
- print(inst->getElse());
+ print(inst->getElse(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
}
}
@@ -1583,7 +1569,7 @@ void FunctionPrinter::printGenericOp(const OperationInst *op) {
// Print any trailing block lists.
for (auto &blockList : op->getBlockLists())
- printBlockList(blockList);
+ printBlockList(blockList, /*printEntryBlockArgs=*/true);
}
void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
@@ -1729,8 +1715,6 @@ void Value::print(raw_ostream &os) const {
return;
case Value::Kind::InstResult:
return getDefiningInst()->print(os);
- case Value::Kind::ForInst:
- return cast<ForInst>(this)->print(os);
}
}
OpenPOWER on IntegriCloud