summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/CSE.cpp10
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp24
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp9
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp59
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp7
-rw-r--r--mlir/lib/Transforms/SimplifyAffineStructures.cpp37
6 files changed, 79 insertions, 67 deletions
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index afd18a49b79..c2e1636626d 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -188,6 +188,16 @@ void CSE::simplifyBlock(Block *bb) {
simplifyBlock(cast<ForInst>(i).getBody());
break;
}
+ case Instruction::Kind::If: {
+ auto &ifInst = cast<IfInst>(i);
+ if (auto *elseBlock = ifInst.getElse()) {
+ ScopedMapTy::ScopeTy scope(knownValues);
+ simplifyBlock(elseBlock);
+ }
+ ScopedMapTy::ScopeTy scope(knownValues);
+ simplifyBlock(ifInst.getThen());
+ break;
+ }
}
}
}
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index eebbbe9daa7..cee0a08a63c 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -19,7 +19,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -100,16 +99,16 @@ public:
SmallVector<ForInst *, 4> forInsts;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
- bool hasNonForRegion = false;
+ bool hasIfInst = false;
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
+ void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
+
void visitOperationInst(OperationInst *opInst) {
- if (opInst->getNumBlockLists() != 0)
- hasNonForRegion = true;
- else if (opInst->isa<LoadOp>())
+ if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
- else if (opInst->isa<StoreOp>())
+ if (opInst->isa<StoreOp>())
storeOpInsts.push_back(opInst);
}
};
@@ -411,8 +410,8 @@ bool MemRefDependenceGraph::init(Function *f) {
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.walkForInst(forInst);
- // Return false if a non 'for' region was found (not currently supported).
- if (collector.hasNonForRegion)
+ // Return false if IfInsts are found (not currently supported).
+ if (collector.hasIfInst)
return false;
Node node(id++, &inst);
for (auto *opInst : collector.loadOpInsts) {
@@ -435,18 +434,19 @@ bool MemRefDependenceGraph::init(Function *f) {
auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
- } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
+ }
+ if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
Node node(id++, &inst);
node.stores.push_back(opInst);
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
- } else if (opInst->getNumBlockLists() != 0) {
- // Return false if another region is found (not currently supported).
- return false;
}
}
+ // Return false if IfInsts are found (not currently supported).
+ if (isa<IfInst>(&inst))
+ return false;
}
// Walk memref access lists and add graph edges between dependent nodes.
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 6d63e4afd2d..39ef758833b 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -119,6 +119,15 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return true;
}
+ bool walkIfInstPostOrder(IfInst *ifInst) {
+ bool hasInnerLoops =
+ walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
+ if (ifInst->hasElse())
+ hasInnerLoops |=
+ walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
+ return hasInnerLoops;
+ }
+
bool walkOpInstPostOrder(OperationInst *opInst) {
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index f770684f519..ab37ff63bad 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -20,7 +20,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -247,7 +246,7 @@ public:
PassResult runOnFunction(Function *function) override;
bool lowerForInst(ForInst *forInst);
- bool lowerAffineIf(AffineIfOp *ifOp);
+ bool lowerIfInst(IfInst *ifInst);
bool lowerAffineApply(AffineApplyOp *op);
static char passID;
@@ -410,7 +409,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
// enabling easy nesting of "if" instructions and if-then-else-if chains.
//
// +--------------------------------+
-// | <code before the AffineIfOp> |
+// | <code before the IfInst> |
// | %zero = constant 0 : index |
// | %v = affine_apply #expr1(%ops) |
// | %c = cmpi "sge" %v, %zero |
@@ -454,11 +453,10 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
// v v
// +--------------------------------+
// | continue: |
-// | <code after the AffineIfOp> |
+// | <code after the IfInst> |
// +--------------------------------+
//
-bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
- auto *ifInst = ifOp->getInstruction();
+bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
auto loc = ifInst->getLoc();
// Start by splitting the block containing the 'if' into two parts. The part
@@ -468,38 +466,22 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
auto *continueBlock = condBlock->splitBlock(ifInst);
// Create a block for the 'then' code, inserting it between the cond and
- // continue blocks. Move the instructions over from the AffineIfOp and add a
+ // continue blocks. Move the instructions over from the IfInst and add a
// branch to the continuation point.
Block *thenBlock = new Block();
thenBlock->insertBefore(continueBlock);
- // If the 'then' block is not empty, then splice the instructions.
- auto &oldThenBlocks = ifOp->getThenBlocks();
- if (!oldThenBlocks.empty()) {
- // We currently only handle one 'then' block.
- if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end())
- return true;
-
- Block *oldThen = &oldThenBlocks.front();
-
- thenBlock->getInstructions().splice(thenBlock->begin(),
- oldThen->getInstructions(),
- oldThen->begin(), oldThen->end());
- }
-
+ auto *oldThen = ifInst->getThen();
+ thenBlock->getInstructions().splice(thenBlock->begin(),
+ oldThen->getInstructions(),
+ oldThen->begin(), oldThen->end());
FuncBuilder builder(thenBlock);
builder.create<BranchOp>(loc, continueBlock);
// Handle the 'else' block the same way, but we skip it if we have no else
// code.
Block *elseBlock = continueBlock;
- auto &oldElseBlocks = ifOp->getElseBlocks();
- if (!oldElseBlocks.empty()) {
- // We currently only handle one 'else' block.
- if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end())
- return true;
-
- auto *oldElse = &oldElseBlocks.front();
+ if (auto *oldElse = ifInst->getElse()) {
elseBlock = new Block();
elseBlock->insertBefore(continueBlock);
@@ -511,7 +493,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
}
// Ok, now we just have to handle the condition logic.
- auto integerSet = ifOp->getIntegerSet();
+ auto integerSet = ifInst->getCondition().getIntegerSet();
// Implement short-circuit logic. For each affine expression in the 'if'
// condition, convert it into an affine map and call `affine_apply` to obtain
@@ -611,30 +593,29 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) {
PassResult LowerAffinePass::runOnFunction(Function *function) {
SmallVector<Instruction *, 8> instsToRewrite;
- // Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
- // We do this as a prepass to avoid invalidating the walker with our rewrite.
+ // Collect all the If and For instructions as well as AffineApplyOps. We do
+ // this as a prepass to avoid invalidating the walker with our rewrite.
function->walkInsts([&](Instruction *inst) {
- if (isa<ForInst>(inst))
+ if (isa<IfInst>(inst) || isa<ForInst>(inst))
instsToRewrite.push_back(inst);
auto op = dyn_cast<OperationInst>(inst);
- if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>()))
+ if (op && op->isa<AffineApplyOp>())
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
for (auto *inst : instsToRewrite)
- if (auto *forInst = dyn_cast<ForInst>(inst)) {
+ if (auto *ifInst = dyn_cast<IfInst>(inst)) {
+ if (lowerIfInst(ifInst))
+ return failure();
+ } else if (auto *forInst = dyn_cast<ForInst>(inst)) {
if (lowerForInst(forInst))
return failure();
} else {
auto op = cast<OperationInst>(inst);
- if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
- if (lowerAffineIf(ifOp))
- return failure();
- } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+ if (lowerAffineApply(op->cast<AffineApplyOp>()))
return failure();
- }
}
return success();
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 2744b1d624c..09d961f85cd 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -20,7 +20,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -560,6 +559,9 @@ static bool instantiateMaterialization(Instruction *inst,
if (isa<ForInst>(inst))
return inst->emitError("NYI path ForInst");
+ if (isa<IfInst>(inst))
+ return inst->emitError("NYI path IfInst");
+
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(inst);
auto *opInst = cast<OperationInst>(inst);
@@ -568,9 +570,6 @@ static bool instantiateMaterialization(Instruction *inst,
if (opInst->isa<AffineApplyOp>()) {
return false;
}
- if (opInst->getNumBlockLists() != 0)
- return inst->emitError("NYI path Op with region");
-
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
auto *clone = instantiate(&b, write, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
index ba59123c700..bd39e47786a 100644
--- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
@@ -28,6 +28,7 @@
#define DEBUG_TYPE "simplify-affine-structure"
using namespace mlir;
+using llvm::report_fatal_error;
namespace {
@@ -41,6 +42,9 @@ struct SimplifyAffineStructures : public FunctionPass {
PassResult runOnFunction(Function *f) override;
+ void visitIfInst(IfInst *ifInst);
+ void visitOperationInst(OperationInst *opInst);
+
static char passID;
};
@@ -62,19 +66,28 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
return set;
}
-PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
- f->walkOps([&](OperationInst *opInst) {
- for (auto attr : opInst->getAttrs()) {
- if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
- MutableAffineMap mMap(mapAttr.getValue());
- mMap.simplify();
- auto map = mMap.getAffineMap();
- opInst->setAttr(attr.first, AffineMapAttr::get(map));
- } else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) {
- auto simplified = simplifyIntegerSet(setAttr.getValue());
- opInst->setAttr(attr.first, IntegerSetAttr::get(simplified));
- }
+void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) {
+ auto set = ifInst->getCondition().getIntegerSet();
+ ifInst->setIntegerSet(simplifyIntegerSet(set));
+}
+
+void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) {
+ for (auto attr : opInst->getAttrs()) {
+ if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
+ MutableAffineMap mMap(mapAttr.getValue());
+ mMap.simplify();
+ auto map = mMap.getAffineMap();
+ opInst->setAttr(attr.first, AffineMapAttr::get(map));
}
+ }
+}
+
+PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
+ f->walkInsts([&](Instruction *inst) {
+ if (auto *opInst = dyn_cast<OperationInst>(inst))
+ visitOperationInst(opInst);
+ if (auto *ifInst = dyn_cast<IfInst>(inst))
+ visitIfInst(ifInst);
});
return success();
OpenPOWER on IntegriCloud