diff options
Diffstat (limited to 'mlir/lib/Transforms/LowerAffine.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LowerAffine.cpp | 59 |
1 files changed, 39 insertions, 20 deletions
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index ab37ff63bad..f770684f519 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -246,7 +247,7 @@ public: PassResult runOnFunction(Function *function) override; bool lowerForInst(ForInst *forInst); - bool lowerIfInst(IfInst *ifInst); + bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); static char passID; @@ -409,7 +410,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | <code before the IfInst> | +// | <code before the AffineIfOp> | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -453,10 +454,11 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // v v // +--------------------------------+ // | continue: | -// | <code after the IfInst> | +// | <code after the AffineIfOp> | // +--------------------------------+ // -bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { +bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { + auto *ifInst = ifOp->getInstruction(); auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'if' into two parts. The part @@ -466,22 +468,38 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the IfInst and add a + // continue blocks. Move the instructions over from the AffineIfOp and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - auto *oldThen = ifInst->getThen(); - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); + // If the 'then' block is not empty, then splice the instructions. + auto &oldThenBlocks = ifOp->getThenBlocks(); + if (!oldThenBlocks.empty()) { + // We currently only handle one 'then' block. + if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) + return true; + + Block *oldThen = &oldThenBlocks.front(); + + thenBlock->getInstructions().splice(thenBlock->begin(), + oldThen->getInstructions(), + oldThen->begin(), oldThen->end()); + } + FuncBuilder builder(thenBlock); builder.create<BranchOp>(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - if (auto *oldElse = ifInst->getElse()) { + auto &oldElseBlocks = ifOp->getElseBlocks(); + if (!oldElseBlocks.empty()) { + // We currently only handle one 'else' block. + if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) + return true; + + auto *oldElse = &oldElseBlocks.front(); elseBlock = new Block(); elseBlock->insertBefore(continueBlock); @@ -493,7 +511,7 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifInst->getCondition().getIntegerSet(); + auto integerSet = ifOp->getIntegerSet(); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain @@ -593,29 +611,30 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { PassResult LowerAffinePass::runOnFunction(Function *function) { SmallVector<Instruction *, 8> instsToRewrite; - // Collect all the If and For instructions as well as AffineApplyOps. We do - // this as a prepass to avoid invalidating the walker with our rewrite. + // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. + // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa<IfInst>(inst) || isa<ForInst>(inst)) + if (isa<ForInst>(inst)) instsToRewrite.push_back(inst); auto op = dyn_cast<OperationInst>(inst); - if (op && op->isa<AffineApplyOp>()) + if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>())) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) - if (auto *ifInst = dyn_cast<IfInst>(inst)) { - if (lowerIfInst(ifInst)) - return failure(); - } else if (auto *forInst = dyn_cast<ForInst>(inst)) { + if (auto *forInst = dyn_cast<ForInst>(inst)) { if (lowerForInst(forInst)) return failure(); } else { auto op = cast<OperationInst>(inst); - if (lowerAffineApply(op->cast<AffineApplyOp>())) + if (auto ifOp = op->dyn_cast<AffineIfOp>()) { + if (lowerAffineIf(ifOp)) + return failure(); + } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { return failure(); + } } return success(); |

