summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index fbde1fd1692..0af7e52b5b1 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -172,9 +172,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
// TODO: If we make terminators into Operations then we could turn this
// into a nice Operation::moveBefore(Operation*) method. We just need the
// guarantee that a block is non-empty.
- if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
- auto &entryBB = cfgFunc->front();
- cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin());
+ // TODO(clattner): This can all be simplified away now.
+ if (currentFunction->isCFG()) {
+ auto &entryBB = currentFunction->front();
+ cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
} else {
auto *mlFunc = cast<MLFunction>(currentFunction);
cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(),
@@ -315,7 +316,7 @@ static void processCFGFunction(CFGFunction *fn,
void setInsertionPoint(Operation *op) override {
// Any new operations should be added before this instruction.
- builder.setInsertionPoint(cast<Instruction>(op));
+ builder.setInsertionPoint(cast<OperationInst>(op));
}
private:
@@ -325,7 +326,8 @@ static void processCFGFunction(CFGFunction *fn,
GreedyPatternRewriteDriver driver(std::move(patterns));
for (auto &bb : *fn)
for (auto &op : bb)
- driver.addToWorklist(&op);
+ if (auto *opInst = dyn_cast<OperationStmt>(&op))
+ driver.addToWorklist(opInst);
CFGFuncBuilder cfgBuilder(fn);
CFGFuncRewriter rewriter(driver, cfgBuilder);
@@ -337,9 +339,8 @@ static void processCFGFunction(CFGFunction *fn,
///
void mlir::applyPatternsGreedily(Function *fn,
OwningRewritePatternList &&patterns) {
- if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
- processCFGFunction(cfg, std::move(patterns));
- } else {
- processMLFunction(cast<MLFunction>(fn), std::move(patterns));
- }
+ if (fn->isCFG())
+ processCFGFunction(fn, std::move(patterns));
+ else if (fn->isML())
+ processMLFunction(fn, std::move(patterns));
}
OpenPOWER on IntegriCloud