diff options
Diffstat (limited to 'mlir/lib/Transforms/ConvertToCFG.cpp')
| -rw-r--r-- | mlir/lib/Transforms/ConvertToCFG.cpp | 122 |
1 files changed, 61 insertions, 61 deletions
diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 821f35ca539..abce624b06f 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -21,9 +21,9 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" @@ -39,14 +39,14 @@ using namespace mlir; namespace { // Generates CFG function equivalent to the given ML function. -class FunctionConverter : public StmtVisitor<FunctionConverter> { +class FunctionConverter : public InstVisitor<FunctionConverter> { public: FunctionConverter(Function *cfgFunc) : cfgFunc(cfgFunc), builder(cfgFunc) {} Function *convert(Function *mlFunc); - void visitForStmt(ForStmt *forStmt); - void visitIfStmt(IfStmt *ifStmt); - void visitOperationInst(OperationInst *opStmt); + void visitForInst(ForInst *forInst); + void visitIfInst(IfInst *ifInst); + void visitOperationInst(OperationInst *opInst); private: Value *getConstantIndexValue(int64_t value); @@ -64,49 +64,49 @@ private: } // end anonymous namespace // Return a vector of OperationInst's arguments as Values. For each -// statement operands, represented as Value, lookup its Value conterpart in +// instruction operands, represented as Value, lookup its Value conterpart in // the valueRemapping table. static llvm::SmallVector<mlir::Value *, 4> -operandsAs(Statement *opStmt, +operandsAs(Instruction *opInst, const llvm::DenseMap<const Value *, Value *> &valueRemapping) { llvm::SmallVector<Value *, 4> operands; - for (const Value *operand : opStmt->getOperands()) { + for (const Value *operand : opInst->getOperands()) { assert(valueRemapping.count(operand) != 0 && "operand is not defined"); operands.push_back(valueRemapping.lookup(operand)); } return operands; } -// Convert an operation statement into an operation instruction. +// Convert an operation instruction into an operation instruction. // // The operation description (name, number and types of operands or results) // remains the same but the values must be updated to be Values. Update the // mapping Value->Value as the conversion is performed. The operation // instruction is appended to current block (end of SESE region). -void FunctionConverter::visitOperationInst(OperationInst *opStmt) { +void FunctionConverter::visitOperationInst(OperationInst *opInst) { // Set up basic operation state (context, name, operands). - OperationState state(cfgFunc->getContext(), opStmt->getLoc(), - opStmt->getName()); - state.addOperands(operandsAs(opStmt, valueRemapping)); + OperationState state(cfgFunc->getContext(), opInst->getLoc(), + opInst->getName()); + state.addOperands(operandsAs(opInst, valueRemapping)); // Set up operation return types. The corresponding Values will become // available after the operation is created. state.addTypes(functional::map( - [](Value *result) { return result->getType(); }, opStmt->getResults())); + [](Value *result) { return result->getType(); }, opInst->getResults())); // Copy attributes. - for (auto attr : opStmt->getAttrs()) { + for (auto attr : opInst->getAttrs()) { state.addAttribute(attr.first.strref(), attr.second); } - auto opInst = builder.createOperation(state); + auto op = builder.createOperation(state); // Make results of the operation accessible to the following operations // through remapping. - assert(opInst->getNumResults() == opStmt->getNumResults()); + assert(opInst->getNumResults() == op->getNumResults()); for (unsigned i = 0, n = opInst->getNumResults(); i < n; ++i) { valueRemapping.insert( - std::make_pair(opStmt->getResult(i), opInst->getResult(i))); + std::make_pair(opInst->getResult(i), op->getResult(i))); } } @@ -116,10 +116,10 @@ Value *FunctionConverter::getConstantIndexValue(int64_t value) { return op->getResult(); } -// Visit all statements in the given statement block. +// Visit all instructions in the given instruction block. void FunctionConverter::visitBlock(Block *Block) { - for (auto &stmt : *Block) - this->visit(&stmt); + for (auto &inst : *Block) + this->visit(&inst); } // Given a range of values, emit the code that reduces them with "min" or "max" @@ -211,7 +211,7 @@ Value *FunctionConverter::buildMinMaxReductionSeq( // | <new insertion point> | // +--------------------------------+ // -void FunctionConverter::visitForStmt(ForStmt *forStmt) { +void FunctionConverter::visitForInst(ForInst *forInst) { // First, store the loop insertion location so that we can go back to it after // creating the new blocks (block creation updates the insertion point). Block *loopInsertionPoint = builder.getInsertionBlock(); @@ -228,27 +228,27 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // The loop condition block has an argument for loop induction variable. // Create it upfront and make the loop induction variable -> basic block - // argument remapping available to the following instructions. ForStatement + // argument remapping available to the following instructions. ForInstruction // is-a Value corresponding to the loop induction variable. builder.setInsertionPointToEnd(loopConditionBlock); Value *iv = loopConditionBlock->addArgument(builder.getIndexType()); - valueRemapping.insert(std::make_pair(forStmt, iv)); + valueRemapping.insert(std::make_pair(forInst, iv)); // Recursively construct loop body region. // Walking manually because we need custom logic before and after traversing // the list of children. builder.setInsertionPointToEnd(loopBodyFirstBlock); - visitBlock(forStmt->getBody()); + visitBlock(forInst->getBody()); // Builder point is currently at the last block of the loop body. Append the // induction variable stepping to this block and branch back to the exit // condition block. Construct an affine map f : (x -> x+step) and apply this // map to the induction variable. - auto affStep = builder.getAffineConstantExpr(forStmt->getStep()); + auto affStep = builder.getAffineConstantExpr(forInst->getStep()); auto affDim = builder.getAffineDimExpr(0); auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {}); auto stepOp = - builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv); + builder.create<AffineApplyOp>(forInst->getLoc(), affStepMap, iv); Value *nextIvValue = stepOp->getResult(0); builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock, nextIvValue); @@ -262,22 +262,22 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { return valueRemapping.lookup(value); }; auto operands = - functional::map(remapOperands, forStmt->getLowerBoundOperands()); + functional::map(remapOperands, forInst->getLowerBoundOperands()); auto lbAffineApply = builder.create<AffineApplyOp>( - forStmt->getLoc(), forStmt->getLowerBoundMap(), operands); + forInst->getLoc(), forInst->getLowerBoundMap(), operands); Value *lowerBound = buildMinMaxReductionSeq( - forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); - operands = functional::map(remapOperands, forStmt->getUpperBoundOperands()); + forInst->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); + operands = functional::map(remapOperands, forInst->getUpperBoundOperands()); auto ubAffineApply = builder.create<AffineApplyOp>( - forStmt->getLoc(), forStmt->getUpperBoundMap(), operands); + forInst->getLoc(), forInst->getUpperBoundMap(), operands); Value *upperBound = buildMinMaxReductionSeq( - forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); + forInst->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock, lowerBound); builder.setInsertionPointToEnd(loopConditionBlock); auto comparisonOp = builder.create<CmpIOp>( - forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound); + forInst->getLoc(), CmpIPredicate::SLT, iv, upperBound); auto comparisonResult = comparisonOp->getResult(); builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult, loopBodyFirstBlock, ArrayRef<Value *>(), @@ -288,16 +288,16 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { builder.setInsertionPointToEnd(postLoopBlock); } -// Convert an "if" statement into a flow of basic blocks. +// Convert an "if" instruction into a flow of basic blocks. // -// Create an SESE region for the if statement (including its "then" and optional -// "else" statement blocks) and append it to the end of the current region. The -// conditional region consists of a sequence of condition-checking blocks that -// implement the short-circuit scheme, followed by a "then" SESE region and an -// "else" SESE region, and the continuation block that post-dominates all blocks -// of the "if" statement. The flow of blocks that correspond to the "then" and -// "else" clauses are constructed recursively, enabling easy nesting of "if" -// statements and if-then-else-if chains. +// Create an SESE region for the if instruction (including its "then" and +// optional "else" instruction blocks) and append it to the end of the current +// region. The conditional region consists of a sequence of condition-checking +// blocks that implement the short-circuit scheme, followed by a "then" SESE +// region and an "else" SESE region, and the continuation block that +// post-dominates all blocks of the "if" instruction. The flow of blocks that +// correspond to the "then" and "else" clauses are constructed recursively, +// enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ // | <end of current SESE region> | @@ -365,17 +365,17 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // | <new insertion point> | // +--------------------------------+ // -void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { - assert(ifStmt != nullptr); +void FunctionConverter::visitIfInst(IfInst *ifInst) { + assert(ifInst != nullptr); - auto integerSet = ifStmt->getCondition().getIntegerSet(); + auto integerSet = ifInst->getCondition().getIntegerSet(); // Create basic blocks for the 'then' block and for the 'else' block. // Although 'else' block may be empty in absence of an 'else' clause, create // it anyway for the sake of consistency and output IR readability. Also // create extra blocks for condition checking to prepare for short-circuit - // logic: conditions in the 'if' statement are conjunctive, so we can jump to - // the false branch as soon as one condition fails. `cond_br` requires + // logic: conditions in the 'if' instruction are conjunctive, so we can jump + // to the false branch as soon as one condition fails. `cond_br` requires // another block as a target when the condition is true, and that block will // contain the next condition. Block *ifInsertionBlock = builder.getInsertionBlock(); @@ -412,14 +412,14 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { builder.getAffineMap(integerSet.getNumDims(), integerSet.getNumSymbols(), constraintExpr, {}); auto affineApplyOp = builder.create<AffineApplyOp>( - ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping)); + ifInst->getLoc(), affineMap, operandsAs(ifInst, valueRemapping)); Value *affResult = affineApplyOp->getResult(0); // Compare the result of the apply and branch. auto comparisonOp = builder.create<CmpIOp>( - ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, + ifInst->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult, zeroConstant); - builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(), + builder.create<CondBranchOp>(ifInst->getLoc(), comparisonOp->getResult(), nextBlock, /*trueArgs*/ ArrayRef<Value *>(), elseBlock, /*falseArgs*/ ArrayRef<Value *>()); @@ -429,13 +429,13 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Recursively traverse the 'then' block. builder.setInsertionPointToEnd(thenBlock); - visitBlock(ifStmt->getThen()); + visitBlock(ifInst->getThen()); Block *lastThenBlock = builder.getInsertionBlock(); // Recursively traverse the 'else' block if present. builder.setInsertionPointToEnd(elseBlock); - if (ifStmt->hasElse()) - visitBlock(ifStmt->getElse()); + if (ifInst->hasElse()) + visitBlock(ifInst->getElse()); Block *lastElseBlock = builder.getInsertionBlock(); // Create the continuation block here so that it appears lexically after the @@ -443,9 +443,9 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // to the continuation block. Block *continuationBlock = builder.createBlock(); builder.setInsertionPointToEnd(lastThenBlock); - builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock); + builder.create<BranchOp>(ifInst->getLoc(), continuationBlock); builder.setInsertionPointToEnd(lastElseBlock); - builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock); + builder.create<BranchOp>(ifInst->getLoc(), continuationBlock); // Make sure building can continue by setting up the continuation block as the // insertion point. @@ -454,12 +454,12 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Entry point of the function convertor. // -// Conversion is performed by recursively visiting statements of a Function. +// Conversion is performed by recursively visiting instructions of a Function. // It reasons in terms of single-entry single-exit (SESE) regions that are not // materialized in the code. Instead, the pointer to the last block of the // region is maintained throughout the conversion as the insertion point of the // IR builder since we never change the first block after its creation. "Block" -// statements such as loops and branches create new SESE regions for their +// instructions such as loops and branches create new SESE regions for their // bodies, and surround them with additional basic blocks for the control flow. // Individual operations are simply appended to the end of the last basic block // of the current region. The SESE invariant allows us to easily handle nested @@ -484,9 +484,9 @@ Function *FunctionConverter::convert(Function *mlFunc) { valueRemapping.insert(std::make_pair(mlArgument, cfgArgument)); } - // Convert statements in order. - for (auto &stmt : *mlFunc->getBody()) { - visit(&stmt); + // Convert instructions in order. + for (auto &inst : *mlFunc->getBody()) { + visit(&inst); } return cfgFunc; |

