diff options
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 136 |
1 files changed, 68 insertions, 68 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index daaaee7010c..cf822e025b8 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -25,11 +25,11 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/APFloat.h" @@ -117,10 +117,10 @@ private: void visitExtFunction(const Function *fn); void visitCFGFunction(const Function *fn); void visitMLFunction(const Function *fn); - void visitStatement(const Statement *stmt); - void visitForStmt(const ForStmt *forStmt); - void visitIfStmt(const IfStmt *ifStmt); - void visitOperationInst(const OperationInst *opStmt); + void visitInstruction(const Instruction *inst); + void visitForInst(const ForInst *forInst); + void visitIfInst(const IfInst *ifInst); + void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); void visitOperation(const OperationInst *op); @@ -184,47 +184,47 @@ void ModuleState::visitCFGFunction(const Function *fn) { if (auto *opInst = dyn_cast<OperationInst>(&op)) visitOperation(opInst); else { - llvm_unreachable("IfStmt/ForStmt in a CFG Function isn't supported"); + llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported"); } } } } -void ModuleState::visitIfStmt(const IfStmt *ifStmt) { - recordIntegerSetReference(ifStmt->getIntegerSet()); - for (auto &childStmt : *ifStmt->getThen()) - visitStatement(&childStmt); - if (ifStmt->hasElse()) - for (auto &childStmt : *ifStmt->getElse()) - visitStatement(&childStmt); +void ModuleState::visitIfInst(const IfInst *ifInst) { + recordIntegerSetReference(ifInst->getIntegerSet()); + for (auto &childInst : *ifInst->getThen()) + visitInstruction(&childInst); + if (ifInst->hasElse()) + for (auto &childInst : *ifInst->getElse()) + visitInstruction(&childInst); } -void ModuleState::visitForStmt(const ForStmt *forStmt) { - AffineMap lbMap = forStmt->getLowerBoundMap(); +void ModuleState::visitForInst(const ForInst *forInst) { + AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasShorthandForm(lbMap)) recordAffineMapReference(lbMap); - AffineMap ubMap = forStmt->getUpperBoundMap(); + AffineMap ubMap = forInst->getUpperBoundMap(); if (!hasShorthandForm(ubMap)) recordAffineMapReference(ubMap); - for (auto &childStmt : *forStmt->getBody()) - visitStatement(&childStmt); + for (auto &childInst : *forInst->getBody()) + visitInstruction(&childInst); } -void ModuleState::visitOperationInst(const OperationInst *opStmt) { - for (auto attr : opStmt->getAttrs()) +void ModuleState::visitOperationInst(const OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) visitAttribute(attr.second); } -void ModuleState::visitStatement(const Statement *stmt) { - switch (stmt->getKind()) { - case Statement::Kind::If: - return visitIfStmt(cast<IfStmt>(stmt)); - case Statement::Kind::For: - return visitForStmt(cast<ForStmt>(stmt)); - case Statement::Kind::OperationInst: - return visitOperationInst(cast<OperationInst>(stmt)); +void ModuleState::visitInstruction(const Instruction *inst) { + switch (inst->getKind()) { + case Instruction::Kind::If: + return visitIfInst(cast<IfInst>(inst)); + case Instruction::Kind::For: + return visitForInst(cast<ForInst>(inst)); + case Instruction::Kind::OperationInst: + return visitOperationInst(cast<OperationInst>(inst)); default: return; } @@ -232,8 +232,8 @@ void ModuleState::visitStatement(const Statement *stmt) { void ModuleState::visitMLFunction(const Function *fn) { visitType(fn->getType()); - for (auto &stmt : *fn->getBody()) { - ModuleState::visitStatement(&stmt); + for (auto &inst : *fn->getBody()) { + ModuleState::visitInstruction(&inst); } } @@ -909,11 +909,11 @@ public: void printMLFunctionSignature(); void printOtherFunctionSignature(); - // Methods to print statements. - void print(const Statement *stmt); + // Methods to print instructions. + void print(const Instruction *inst); void print(const OperationInst *inst); - void print(const ForStmt *stmt); - void print(const IfStmt *stmt); + void print(const ForInst *inst); + void print(const IfInst *inst); void print(const Block *block); void printOperation(const OperationInst *op); @@ -959,7 +959,7 @@ public: void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims); void printBound(AffineBound bound, const char *prefix); - // Number of spaces used for indenting nested statements. + // Number of spaces used for indenting nested instructions. const static unsigned indentWidth = 2; protected: @@ -1019,22 +1019,22 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // We number instruction that have results, and we only number the first // result. switch (inst.getKind()) { - case Statement::Kind::OperationInst: { + case Instruction::Kind::OperationInst: { auto *opInst = cast<OperationInst>(&inst); if (opInst->getNumResults() != 0) numberValueID(opInst->getResult(0)); break; } - case Statement::Kind::For: { - auto *forInst = cast<ForStmt>(&inst); + case Instruction::Kind::For: { + auto *forInst = cast<ForInst>(&inst); // Number the induction variable. numberValueID(forInst); // Recursively number the stuff in the body. numberValuesInBlock(*forInst->getBody()); break; } - case Statement::Kind::If: { - auto *ifInst = cast<IfStmt>(&inst); + case Instruction::Kind::If: { + auto *ifInst = cast<IfInst>(&inst); numberValuesInBlock(*ifInst->getThen()); if (auto *elseBlock = ifInst->getElse()) numberValuesInBlock(*elseBlock); @@ -1086,7 +1086,7 @@ void FunctionPrinter::numberValueID(const Value *value) { // done with it. valueIDs[value] = nextValueID++; return; - case Value::Kind::ForStmt: + case Value::Kind::ForInst: specialName << 'i' << nextLoopID++; break; } @@ -1220,21 +1220,21 @@ void FunctionPrinter::print(const Block *block) { currentIndent += indentWidth; - for (auto &stmt : block->getInstructions()) { - print(&stmt); + for (auto &inst : block->getInstructions()) { + print(&inst); os << '\n'; } currentIndent -= indentWidth; } -void FunctionPrinter::print(const Statement *stmt) { - switch (stmt->getKind()) { - case Statement::Kind::OperationInst: - return print(cast<OperationInst>(stmt)); - case Statement::Kind::For: - return print(cast<ForStmt>(stmt)); - case Statement::Kind::If: - return print(cast<IfStmt>(stmt)); +void FunctionPrinter::print(const Instruction *inst) { + switch (inst->getKind()) { + case Instruction::Kind::OperationInst: + return print(cast<OperationInst>(inst)); + case Instruction::Kind::For: + return print(cast<ForInst>(inst)); + case Instruction::Kind::If: + return print(cast<IfInst>(inst)); } } @@ -1243,33 +1243,33 @@ void FunctionPrinter::print(const OperationInst *inst) { printOperation(inst); } -void FunctionPrinter::print(const ForStmt *stmt) { +void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "for "; - printOperand(stmt); + printOperand(inst); os << " = "; - printBound(stmt->getLowerBound(), "max"); + printBound(inst->getLowerBound(), "max"); os << " to "; - printBound(stmt->getUpperBound(), "min"); + printBound(inst->getUpperBound(), "min"); - if (stmt->getStep() != 1) - os << " step " << stmt->getStep(); + if (inst->getStep() != 1) + os << " step " << inst->getStep(); os << " {\n"; - print(stmt->getBody()); + print(inst->getBody()); os.indent(currentIndent) << "}"; } -void FunctionPrinter::print(const IfStmt *stmt) { +void FunctionPrinter::print(const IfInst *inst) { os.indent(currentIndent) << "if "; - IntegerSet set = stmt->getIntegerSet(); + IntegerSet set = inst->getIntegerSet(); printIntegerSetReference(set); - printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims()); + printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); os << " {\n"; - print(stmt->getThen()); + print(inst->getThen()); os.indent(currentIndent) << "}"; - if (stmt->hasElse()) { + if (inst->hasElse()) { os << " else {\n"; - print(stmt->getElse()); + print(inst->getElse()); os.indent(currentIndent) << "}"; } } @@ -1280,7 +1280,7 @@ void FunctionPrinter::printValueID(const Value *value, auto lookupValue = value; // If this is a reference to the result of a multi-result instruction or - // statement, print out the # identifier and make sure to map our lookup + // instruction, print out the # identifier and make sure to map our lookup // to the first result of the instruction. if (auto *result = dyn_cast<InstResult>(value)) { if (result->getOwner()->getNumResults() != 1) { @@ -1493,8 +1493,8 @@ void Value::print(raw_ostream &os) const { return; case Value::Kind::InstResult: return getDefiningInst()->print(os); - case Value::Kind::ForStmt: - return cast<ForStmt>(this)->print(os); + case Value::Kind::ForInst: + return cast<ForInst>(this)->print(os); } } |

