summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LowerAffine.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LowerAffine.cpp')
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp68
1 files changed, 34 insertions, 34 deletions
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index f770684f519..24ca4e95082 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
@@ -246,7 +247,7 @@ public:
LowerAffinePass() : FunctionPass(&passID) {}
PassResult runOnFunction(Function *function) override;
- bool lowerForInst(ForInst *forInst);
+ bool lowerAffineFor(OpPointer<AffineForOp> forOp);
bool lowerAffineIf(AffineIfOp *ifOp);
bool lowerAffineApply(AffineApplyOp *op);
@@ -295,11 +296,11 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
// a nested loop). Induction variable modification is appended to the body SESE
// region that always loops back to the condition block.
//
-// +--------------------------------+
-// | <code before the ForInst> |
-// | <compute initial %iv value> |
-// | br cond(%iv) |
-// +--------------------------------+
+// +---------------------------------+
+// | <code before the AffineForOp> |
+// | <compute initial %iv value> |
+// | br cond(%iv) |
+// +---------------------------------+
// |
// -------| |
// | v v
@@ -322,11 +323,12 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
// v
// +--------------------------------+
// | end: |
-// | <code after the ForInst> |
+// | <code after the AffineForOp> |
// +--------------------------------+
//
-bool LowerAffinePass::lowerForInst(ForInst *forInst) {
- auto loc = forInst->getLoc();
+bool LowerAffinePass::lowerAffineFor(OpPointer<AffineForOp> forOp) {
+ auto loc = forOp->getLoc();
+ auto *forInst = forOp->getInstruction();
// Start by splitting the block containing the 'for' into two parts. The part
// before will get the init code, the part after will be the end point.
@@ -339,23 +341,23 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
conditionBlock->insertBefore(endBlock);
auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext()));
- // Create the body block, moving the body of the forInst over to it.
+ // Create the body block, moving the body of the forOp over to it.
auto *bodyBlock = new Block();
bodyBlock->insertBefore(endBlock);
- auto *oldBody = forInst->getBody();
+ auto *oldBody = forOp->getBody();
bodyBlock->getInstructions().splice(bodyBlock->begin(),
oldBody->getInstructions(),
oldBody->begin(), oldBody->end());
- // The code in the body of the forInst now uses 'iv' as its indvar.
- forInst->getInductionVar()->replaceAllUsesWith(iv);
+ // The code in the body of the forOp now uses 'iv' as its indvar.
+ forOp->getInductionVar()->replaceAllUsesWith(iv);
// Append the induction variable stepping logic and branch back to the exit
// condition block. Construct an affine expression f : (x -> x+step) and
// apply this expression to the induction variable.
FuncBuilder builder(bodyBlock);
- auto affStep = builder.getAffineConstantExpr(forInst->getStep());
+ auto affStep = builder.getAffineConstantExpr(forOp->getStep());
auto affDim = builder.getAffineDimExpr(0);
auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {});
if (!stepped)
@@ -368,18 +370,18 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
builder.setInsertionPointToEnd(initBlock);
// Compute loop bounds.
- SmallVector<Value *, 8> operands(forInst->getLowerBoundOperands());
+ SmallVector<Value *, 8> operands(forOp->getLowerBoundOperands());
auto lbValues = expandAffineMap(&builder, forInst->getLoc(),
- forInst->getLowerBoundMap(), operands);
+ forOp->getLowerBoundMap(), operands);
if (!lbValues)
return true;
Value *lowerBound =
buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder);
- operands.assign(forInst->getUpperBoundOperands().begin(),
- forInst->getUpperBoundOperands().end());
+ operands.assign(forOp->getUpperBoundOperands().begin(),
+ forOp->getUpperBoundOperands().end());
auto ubValues = expandAffineMap(&builder, forInst->getLoc(),
- forInst->getUpperBoundMap(), operands);
+ forOp->getUpperBoundMap(), operands);
if (!ubValues)
return true;
Value *upperBound =
@@ -394,7 +396,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
endBlock, ArrayRef<Value *>());
// Ok, we're done!
- forInst->erase();
+ forOp->erase();
return false;
}
@@ -614,28 +616,26 @@ PassResult LowerAffinePass::runOnFunction(Function *function) {
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
// We do this as a prepass to avoid invalidating the walker with our rewrite.
function->walkInsts([&](Instruction *inst) {
- if (isa<ForInst>(inst))
- instsToRewrite.push_back(inst);
- auto op = dyn_cast<OperationInst>(inst);
- if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>()))
+ auto op = cast<OperationInst>(inst);
+ if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() ||
+ op->isa<AffineIfOp>())
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
- for (auto *inst : instsToRewrite)
- if (auto *forInst = dyn_cast<ForInst>(inst)) {
- if (lowerForInst(forInst))
+ for (auto *inst : instsToRewrite) {
+ auto op = cast<OperationInst>(inst);
+ if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
+ if (lowerAffineIf(ifOp))
return failure();
- } else {
- auto op = cast<OperationInst>(inst);
- if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
- if (lowerAffineIf(ifOp))
- return failure();
- } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+ } else if (auto forOp = op->dyn_cast<AffineForOp>()) {
+ if (lowerAffineFor(forOp))
return failure();
- }
+ } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+ return failure();
}
+ }
return success();
}
OpenPOWER on IntegriCloud