diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index fbde1fd1692..0af7e52b5b1 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -172,9 +172,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, // TODO: If we make terminators into Operations then we could turn this // into a nice Operation::moveBefore(Operation*) method. We just need the // guarantee that a block is non-empty. - if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) { - auto &entryBB = cfgFunc->front(); - cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin()); + // TODO(clattner): This can all be simplified away now. + if (currentFunction->isCFG()) { + auto &entryBB = currentFunction->front(); + cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin()); } else { auto *mlFunc = cast<MLFunction>(currentFunction); cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(), @@ -315,7 +316,7 @@ static void processCFGFunction(CFGFunction *fn, void setInsertionPoint(Operation *op) override { // Any new operations should be added before this instruction. - builder.setInsertionPoint(cast<Instruction>(op)); + builder.setInsertionPoint(cast<OperationInst>(op)); } private: @@ -325,7 +326,8 @@ static void processCFGFunction(CFGFunction *fn, GreedyPatternRewriteDriver driver(std::move(patterns)); for (auto &bb : *fn) for (auto &op : bb) - driver.addToWorklist(&op); + if (auto *opInst = dyn_cast<OperationStmt>(&op)) + driver.addToWorklist(opInst); CFGFuncBuilder cfgBuilder(fn); CFGFuncRewriter rewriter(driver, cfgBuilder); @@ -337,9 +339,8 @@ static void processCFGFunction(CFGFunction *fn, /// void mlir::applyPatternsGreedily(Function *fn, OwningRewritePatternList &&patterns) { - if (auto *cfg = dyn_cast<CFGFunction>(fn)) { - processCFGFunction(cfg, std::move(patterns)); - } else { - processMLFunction(cast<MLFunction>(fn), std::move(patterns)); - } + if (fn->isCFG()) + processCFGFunction(fn, std::move(patterns)); + else if (fn->isML()) + processMLFunction(fn, std::move(patterns)); } |

