summaryrefslogtreecommitdiffstats
path: root/mlir/lib/StandardOps/Ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/StandardOps/Ops.cpp')
-rw-r--r--mlir/lib/StandardOps/Ops.cpp76
1 files changed, 33 insertions, 43 deletions
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp
index a14f3a24e82..50db72faea1 100644
--- a/mlir/lib/StandardOps/Ops.cpp
+++ b/mlir/lib/StandardOps/Ops.cpp
@@ -356,15 +356,16 @@ struct SimplifyDeadAlloc : public RewritePattern {
SimplifyDeadAlloc(MLIRContext *context)
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Instruction *op) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
+ // Check if the alloc'ed value has any uses.
auto alloc = op->cast<AllocOp>();
- // Check if the alloc'ed value has no uses.
- return alloc->use_empty() ? matchSuccess() : matchFailure();
- }
+ if (!alloc->use_empty())
+ return matchFailure();
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
- // Erase the alloc operation.
+ // If it doesn't, we can eliminate it.
op->erase();
+ return matchSuccess();
}
};
} // end anonymous namespace.
@@ -486,29 +487,24 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
: RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Instruction *op) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
auto indirectCall = op->cast<CallIndirectOp>();
// Check that the callee is a constant operation.
- Value *callee = indirectCall->getCallee();
- Instruction *calleeInst = callee->getDefiningInst();
- if (!calleeInst || !calleeInst->isa<ConstantOp>())
+ Attribute callee;
+ if (!matchPattern(indirectCall->getCallee(), m_Constant(&callee)))
return matchFailure();
// Check that the constant callee is a function.
- if (calleeInst->cast<ConstantOp>()->getValue().isa<FunctionAttr>())
- return matchSuccess();
- return matchFailure();
- }
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
- auto indirectCall = op->cast<CallIndirectOp>();
- auto calleeOp =
- indirectCall->getCallee()->getDefiningInst()->cast<ConstantOp>();
+ FunctionAttr calledFn = callee.dyn_cast<FunctionAttr>();
+ if (!calledFn)
+ return matchFailure();
// Replace with a direct call.
- Function *calledFn = calleeOp->getValue().cast<FunctionAttr>().getValue();
SmallVector<Value *, 8> callOperands(indirectCall->getArgOperands());
- rewriter.replaceOpWithNewOp<CallOp>(op, calledFn, callOperands);
+ rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callOperands);
+ return matchSuccess();
}
};
} // end anonymous namespace.
@@ -802,15 +798,14 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
SimplifyConstCondBranchPred(MLIRContext *context)
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Instruction *op) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
auto condbr = op->cast<CondBranchOp>();
- if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
- return matchSuccess();
- return matchFailure();
- }
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
- auto condbr = op->cast<CondBranchOp>();
+ // Check that the condition is a constant.
+ if (!matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
+ return matchFailure();
+
Block *foldedDest;
SmallVector<Value *, 4> branchArgs;
@@ -828,6 +823,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
}
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
+ return matchSuccess();
}
};
} // end anonymous namespace.
@@ -1094,7 +1090,8 @@ struct SimplifyDeadDealloc : public RewritePattern {
SimplifyDeadDealloc(MLIRContext *context)
: RewritePattern(DeallocOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Instruction *op) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
auto dealloc = op->cast<DeallocOp>();
// Check that the memref operand's defining instruction is an AllocOp.
@@ -1107,12 +1104,10 @@ struct SimplifyDeadDealloc : public RewritePattern {
for (auto &use : memref->getUses())
if (!use.getOwner()->isa<DeallocOp>())
return matchFailure();
- return matchSuccess();
- }
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
// Erase the dealloc operation.
op->erase();
+ return matchSuccess();
}
};
} // end anonymous namespace.
@@ -1991,21 +1986,16 @@ namespace {
///
struct SimplifyXMinusX : public RewritePattern {
SimplifyXMinusX(MLIRContext *context)
- : RewritePattern(SubIOp::getOperationName(), 1, context) {}
+ : RewritePattern(SubIOp::getOperationName(), 10, context) {}
- PatternMatchResult match(Instruction *op) const override {
- auto subi = op->cast<SubIOp>();
- if (subi->getOperand(0) == subi->getOperand(1))
- return matchSuccess();
-
- return matchFailure();
- }
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
auto subi = op->cast<SubIOp>();
- auto result =
- rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
+ if (subi->getOperand(0) != subi->getOperand(1))
+ return matchFailure();
- rewriter.replaceOp(op, {result});
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, 0, subi->getType());
+ return matchSuccess();
}
};
} // end anonymous namespace.
OpenPOWER on IntegriCloud