diff options
Diffstat (limited to 'mlir/lib/Transforms/LowerAffine.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LowerAffine.cpp | 68 |
1 files changed, 34 insertions, 34 deletions
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f770684f519..24ca4e95082 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" @@ -246,7 +247,7 @@ public: LowerAffinePass() : FunctionPass(&passID) {} PassResult runOnFunction(Function *function) override; - bool lowerForInst(ForInst *forInst); + bool lowerAffineFor(OpPointer<AffineForOp> forOp); bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); @@ -295,11 +296,11 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // a nested loop). Induction variable modification is appended to the body SESE // region that always loops back to the condition block. // -// +--------------------------------+ -// | <code before the ForInst> | -// | <compute initial %iv value> | -// | br cond(%iv) | -// +--------------------------------+ +// +---------------------------------+ +// | <code before the AffineForOp> | +// | <compute initial %iv value> | +// | br cond(%iv) | +// +---------------------------------+ // | // -------| | // | v v @@ -322,11 +323,12 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // v // +--------------------------------+ // | end: | -// | <code after the ForInst> | +// | <code after the AffineForOp> | // +--------------------------------+ // -bool LowerAffinePass::lowerForInst(ForInst *forInst) { - auto loc = forInst->getLoc(); +bool LowerAffinePass::lowerAffineFor(OpPointer<AffineForOp> forOp) { + auto loc = forOp->getLoc(); + auto *forInst = forOp->getInstruction(); // Start by splitting the block containing the 'for' into two parts. The part // before will get the init code, the part after will be the end point. @@ -339,23 +341,23 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { conditionBlock->insertBefore(endBlock); auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext())); - // Create the body block, moving the body of the forInst over to it. + // Create the body block, moving the body of the forOp over to it. auto *bodyBlock = new Block(); bodyBlock->insertBefore(endBlock); - auto *oldBody = forInst->getBody(); + auto *oldBody = forOp->getBody(); bodyBlock->getInstructions().splice(bodyBlock->begin(), oldBody->getInstructions(), oldBody->begin(), oldBody->end()); - // The code in the body of the forInst now uses 'iv' as its indvar. - forInst->getInductionVar()->replaceAllUsesWith(iv); + // The code in the body of the forOp now uses 'iv' as its indvar. + forOp->getInductionVar()->replaceAllUsesWith(iv); // Append the induction variable stepping logic and branch back to the exit // condition block. Construct an affine expression f : (x -> x+step) and // apply this expression to the induction variable. FuncBuilder builder(bodyBlock); - auto affStep = builder.getAffineConstantExpr(forInst->getStep()); + auto affStep = builder.getAffineConstantExpr(forOp->getStep()); auto affDim = builder.getAffineDimExpr(0); auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {}); if (!stepped) @@ -368,18 +370,18 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { builder.setInsertionPointToEnd(initBlock); // Compute loop bounds. - SmallVector<Value *, 8> operands(forInst->getLowerBoundOperands()); + SmallVector<Value *, 8> operands(forOp->getLowerBoundOperands()); auto lbValues = expandAffineMap(&builder, forInst->getLoc(), - forInst->getLowerBoundMap(), operands); + forOp->getLowerBoundMap(), operands); if (!lbValues) return true; Value *lowerBound = buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder); - operands.assign(forInst->getUpperBoundOperands().begin(), - forInst->getUpperBoundOperands().end()); + operands.assign(forOp->getUpperBoundOperands().begin(), + forOp->getUpperBoundOperands().end()); auto ubValues = expandAffineMap(&builder, forInst->getLoc(), - forInst->getUpperBoundMap(), operands); + forOp->getUpperBoundMap(), operands); if (!ubValues) return true; Value *upperBound = @@ -394,7 +396,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { endBlock, ArrayRef<Value *>()); // Ok, we're done! - forInst->erase(); + forOp->erase(); return false; } @@ -614,28 +616,26 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa<ForInst>(inst)) - instsToRewrite.push_back(inst); - auto op = dyn_cast<OperationInst>(inst); - if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>())) + auto op = cast<OperationInst>(inst); + if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() || + op->isa<AffineIfOp>()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. - for (auto *inst : instsToRewrite) - if (auto *forInst = dyn_cast<ForInst>(inst)) { - if (lowerForInst(forInst)) + for (auto *inst : instsToRewrite) { + auto op = cast<OperationInst>(inst); + if (auto ifOp = op->dyn_cast<AffineIfOp>()) { + if (lowerAffineIf(ifOp)) return failure(); - } else { - auto op = cast<OperationInst>(inst); - if (auto ifOp = op->dyn_cast<AffineIfOp>()) { - if (lowerAffineIf(ifOp)) - return failure(); - } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { + } else if (auto forOp = op->dyn_cast<AffineForOp>()) { + if (lowerAffineFor(forOp)) return failure(); - } + } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { + return failure(); } + } return success(); } |

