diff options
Diffstat (limited to 'mlir/lib/StandardOps/Ops.cpp')
| -rw-r--r-- | mlir/lib/StandardOps/Ops.cpp | 76 |
1 files changed, 33 insertions, 43 deletions
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index a14f3a24e82..50db72faea1 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -356,15 +356,16 @@ struct SimplifyDeadAlloc : public RewritePattern { SimplifyDeadAlloc(MLIRContext *context) : RewritePattern(AllocOp::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { + // Check if the alloc'ed value has any uses. auto alloc = op->cast<AllocOp>(); - // Check if the alloc'ed value has no uses. - return alloc->use_empty() ? matchSuccess() : matchFailure(); - } + if (!alloc->use_empty()) + return matchFailure(); - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - // Erase the alloc operation. + // If it doesn't, we can eliminate it. op->erase(); + return matchSuccess(); } }; } // end anonymous namespace. @@ -486,29 +487,24 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { SimplifyIndirectCallWithKnownCallee(MLIRContext *context) : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { auto indirectCall = op->cast<CallIndirectOp>(); // Check that the callee is a constant operation. - Value *callee = indirectCall->getCallee(); - Instruction *calleeInst = callee->getDefiningInst(); - if (!calleeInst || !calleeInst->isa<ConstantOp>()) + Attribute callee; + if (!matchPattern(indirectCall->getCallee(), m_Constant(&callee))) return matchFailure(); // Check that the constant callee is a function. - if (calleeInst->cast<ConstantOp>()->getValue().isa<FunctionAttr>()) - return matchSuccess(); - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - auto indirectCall = op->cast<CallIndirectOp>(); - auto calleeOp = - indirectCall->getCallee()->getDefiningInst()->cast<ConstantOp>(); + FunctionAttr calledFn = callee.dyn_cast<FunctionAttr>(); + if (!calledFn) + return matchFailure(); // Replace with a direct call. - Function *calledFn = calleeOp->getValue().cast<FunctionAttr>().getValue(); SmallVector<Value *, 8> callOperands(indirectCall->getArgOperands()); - rewriter.replaceOpWithNewOp<CallOp>(op, calledFn, callOperands); + rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callOperands); + return matchSuccess(); } }; } // end anonymous namespace. @@ -802,15 +798,14 @@ struct SimplifyConstCondBranchPred : public RewritePattern { SimplifyConstCondBranchPred(MLIRContext *context) : RewritePattern(CondBranchOp::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { auto condbr = op->cast<CondBranchOp>(); - if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>())) - return matchSuccess(); - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - auto condbr = op->cast<CondBranchOp>(); + // Check that the condition is a constant. + if (!matchPattern(condbr->getCondition(), m_Op<ConstantOp>())) + return matchFailure(); + Block *foldedDest; SmallVector<Value *, 4> branchArgs; @@ -828,6 +823,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern { } rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs); + return matchSuccess(); } }; } // end anonymous namespace. @@ -1094,7 +1090,8 @@ struct SimplifyDeadDealloc : public RewritePattern { SimplifyDeadDealloc(MLIRContext *context) : RewritePattern(DeallocOp::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { auto dealloc = op->cast<DeallocOp>(); // Check that the memref operand's defining instruction is an AllocOp. @@ -1107,12 +1104,10 @@ struct SimplifyDeadDealloc : public RewritePattern { for (auto &use : memref->getUses()) if (!use.getOwner()->isa<DeallocOp>()) return matchFailure(); - return matchSuccess(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { // Erase the dealloc operation. op->erase(); + return matchSuccess(); } }; } // end anonymous namespace. @@ -1991,21 +1986,16 @@ namespace { /// struct SimplifyXMinusX : public RewritePattern { SimplifyXMinusX(MLIRContext *context) - : RewritePattern(SubIOp::getOperationName(), 1, context) {} + : RewritePattern(SubIOp::getOperationName(), 10, context) {} - PatternMatchResult match(Instruction *op) const override { - auto subi = op->cast<SubIOp>(); - if (subi->getOperand(0) == subi->getOperand(1)) - return matchSuccess(); - - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { auto subi = op->cast<SubIOp>(); - auto result = - rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType()); + if (subi->getOperand(0) != subi->getOperand(1)) + return matchFailure(); - rewriter.replaceOp(op, {result}); + rewriter.replaceOpWithNewOp<ConstantIntOp>(op, 0, subi->getType()); + return matchSuccess(); } }; } // end anonymous namespace. |

