diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 33 |
1 files changed, 13 insertions, 20 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 5a452c5242a..161a6c409c7 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1070,27 +1070,20 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { PatternMatchResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { - // Check that the condition is a constant. - if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>())) - return matchFailure(); - - Block *foldedDest; - SmallVector<Value *, 4> branchArgs; - - // If the condition is known to evaluate to false we fold to a branch to the - // false destination. Otherwise, we fold to a branch to the true - // destination. - if (matchPattern(condbr.getCondition(), m_Zero())) { - foldedDest = condbr.getFalseDest(); - branchArgs.assign(condbr.false_operand_begin(), - condbr.false_operand_end()); - } else { - foldedDest = condbr.getTrueDest(); - branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); + if (matchPattern(condbr.getCondition(), m_NonZero())) { + // True branch taken. + rewriter.replaceOpWithNewOp<BranchOp>( + condbr, condbr.getTrueDest(), + llvm::to_vector<4>(condbr.getTrueOperands())); + return matchSuccess(); + } else if (matchPattern(condbr.getCondition(), m_Zero())) { + // False branch taken. + rewriter.replaceOpWithNewOp<BranchOp>( + condbr, condbr.getFalseDest(), + llvm::to_vector<4>(condbr.getFalseOperands())); + return matchSuccess(); } - - rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs); - return matchSuccess(); + return matchFailure(); } }; } // end anonymous namespace. |

