summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/AsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp38
1 files changed, 34 insertions, 4 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index cb4c1f0edce..21bc3b824b1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -145,6 +145,7 @@ private:
// Visit functions.
void visitInstruction(const Instruction *inst);
void visitForInst(const ForInst *forInst);
+ void visitIfInst(const IfInst *ifInst);
void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
@@ -196,6 +197,10 @@ void ModuleState::visitAttribute(Attribute attr) {
}
}
+void ModuleState::visitIfInst(const IfInst *ifInst) {
+ recordIntegerSetReference(ifInst->getIntegerSet());
+}
+
void ModuleState::visitForInst(const ForInst *forInst) {
AffineMap lbMap = forInst->getLowerBoundMap();
if (!hasCustomForm(lbMap))
@@ -220,6 +225,8 @@ void ModuleState::visitOperationInst(const OperationInst *op) {
void ModuleState::visitInstruction(const Instruction *inst) {
switch (inst->getKind()) {
+ case Instruction::Kind::If:
+ return visitIfInst(cast<IfInst>(inst));
case Instruction::Kind::For:
return visitForInst(cast<ForInst>(inst));
case Instruction::Kind::OperationInst:
@@ -1070,6 +1077,7 @@ public:
void print(const Instruction *inst);
void print(const OperationInst *inst);
void print(const ForInst *inst);
+ void print(const IfInst *inst);
void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
@@ -1117,9 +1125,6 @@ public:
unsigned index) override;
/// Print a block list.
- void printBlockList(const BlockList &blocks) override {
- printBlockList(blocks, /*printEntryBlockArgs=*/true);
- }
void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
os << " {\n";
if (!blocks.empty()) {
@@ -1209,6 +1214,12 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
// Recursively number the stuff in the body.
numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
break;
+ case Instruction::Kind::If: {
+ auto *ifInst = cast<IfInst>(&inst);
+ numberValuesInBlock(*ifInst->getThen());
+ if (auto *elseBlock = ifInst->getElse())
+ numberValuesInBlock(*elseBlock);
+ }
}
}
}
@@ -1349,7 +1360,8 @@ void FunctionPrinter::printFunctionSignature() {
}
void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
- // Print the block label and argument list if requested.
+ // Print the block label and argument list, unless this is the first block of
+ // the function, or the first block of an IfInst/ForInst with no arguments.
if (printBlockArgs) {
os.indent(currentIndent);
printBlockName(block);
@@ -1406,6 +1418,8 @@ void FunctionPrinter::print(const Instruction *inst) {
return print(cast<OperationInst>(inst));
case Instruction::Kind::For:
return print(cast<ForInst>(inst));
+ case Instruction::Kind::If:
+ return print(cast<IfInst>(inst));
}
}
@@ -1433,6 +1447,22 @@ void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "}";
}
+void FunctionPrinter::print(const IfInst *inst) {
+ os.indent(currentIndent) << "if ";
+ IntegerSet set = inst->getIntegerSet();
+ printIntegerSetReference(set);
+ printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
+ printTrailingLocation(inst->getLoc());
+ os << " {\n";
+ print(inst->getThen(), /*printBlockArgs=*/false);
+ os.indent(currentIndent) << "}";
+ if (inst->hasElse()) {
+ os << " else {\n";
+ print(inst->getElse(), /*printBlockArgs=*/false);
+ os.indent(currentIndent) << "}";
+ }
+}
+
void FunctionPrinter::printValueID(const Value *value,
bool printResultNo) const {
int resultNo = -1;
OpenPOWER on IntegriCloud