diff options
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 38 |
1 files changed, 34 insertions, 4 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index cb4c1f0edce..21bc3b824b1 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -145,6 +145,7 @@ private: // Visit functions. void visitInstruction(const Instruction *inst); void visitForInst(const ForInst *forInst); + void visitIfInst(const IfInst *ifInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -196,6 +197,10 @@ void ModuleState::visitAttribute(Attribute attr) { } } +void ModuleState::visitIfInst(const IfInst *ifInst) { + recordIntegerSetReference(ifInst->getIntegerSet()); +} + void ModuleState::visitForInst(const ForInst *forInst) { AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasCustomForm(lbMap)) @@ -220,6 +225,8 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { + case Instruction::Kind::If: + return visitIfInst(cast<IfInst>(inst)); case Instruction::Kind::For: return visitForInst(cast<ForInst>(inst)); case Instruction::Kind::OperationInst: @@ -1070,6 +1077,7 @@ public: void print(const Instruction *inst); void print(const OperationInst *inst); void print(const ForInst *inst); + void print(const IfInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1117,9 +1125,6 @@ public: unsigned index) override; /// Print a block list. - void printBlockList(const BlockList &blocks) override { - printBlockList(blocks, /*printEntryBlockArgs=*/true); - } void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; if (!blocks.empty()) { @@ -1209,6 +1214,12 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // Recursively number the stuff in the body. numberValuesInBlock(*cast<ForInst>(&inst)->getBody()); break; + case Instruction::Kind::If: { + auto *ifInst = cast<IfInst>(&inst); + numberValuesInBlock(*ifInst->getThen()); + if (auto *elseBlock = ifInst->getElse()) + numberValuesInBlock(*elseBlock); + } } } } @@ -1349,7 +1360,8 @@ void FunctionPrinter::printFunctionSignature() { } void FunctionPrinter::print(const Block *block, bool printBlockArgs) { - // Print the block label and argument list if requested. + // Print the block label and argument list, unless this is the first block of + // the function, or the first block of an IfInst/ForInst with no arguments. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1406,6 +1418,8 @@ void FunctionPrinter::print(const Instruction *inst) { return print(cast<OperationInst>(inst)); case Instruction::Kind::For: return print(cast<ForInst>(inst)); + case Instruction::Kind::If: + return print(cast<IfInst>(inst)); } } @@ -1433,6 +1447,22 @@ void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "}"; } +void FunctionPrinter::print(const IfInst *inst) { + os.indent(currentIndent) << "if "; + IntegerSet set = inst->getIntegerSet(); + printIntegerSetReference(set); + printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); + printTrailingLocation(inst->getLoc()); + os << " {\n"; + print(inst->getThen(), /*printBlockArgs=*/false); + os.indent(currentIndent) << "}"; + if (inst->hasElse()) { + os << " else {\n"; + print(inst->getElse(), /*printBlockArgs=*/false); + os.indent(currentIndent) << "}"; + } +} + void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; |

