summaryrefslogtreecommitdiffstats
path: root/mlir/lib/StandardOps/Ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/StandardOps/Ops.cpp')
-rw-r--r--mlir/lib/StandardOps/Ops.cpp87
1 files changed, 36 insertions, 51 deletions
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp
index 508ebfee889..dd6754645f0 100644
--- a/mlir/lib/StandardOps/Ops.cpp
+++ b/mlir/lib/StandardOps/Ops.cpp
@@ -291,24 +291,19 @@ static LogicalResult verify(AllocOp op) {
namespace {
/// Fold constant dimensions into an alloc operation.
-struct SimplifyAllocConst : public RewritePattern {
- SimplifyAllocConst(MLIRContext *context)
- : RewritePattern(AllocOp::getOperationName(), 1, context) {}
-
- PatternMatchResult match(Operation *op) const override {
- auto alloc = cast<AllocOp>(op);
+struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
+ using OpRewritePattern<AllocOp>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(AllocOp alloc,
+ PatternRewriter &rewriter) const override {
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
- for (auto *operand : alloc.getOperands())
- if (matchPattern(operand, m_ConstantIndex()))
- return matchSuccess();
- return matchFailure();
- }
+ if (llvm::none_of(alloc.getOperands(), [](Value *operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
+ return matchFailure();
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto allocOp = cast<AllocOp>(op);
- auto memrefType = allocOp.getType();
+ auto memrefType = alloc.getType();
// Ok, we have one or more constant operands. Collect the non-constant ones
// and keep track of the resultant memref type to build.
@@ -325,7 +320,7 @@ struct SimplifyAllocConst : public RewritePattern {
newShapeConstants.push_back(dimSize);
continue;
}
- auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp();
+ auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic shape dimension will be folded.
newShapeConstants.push_back(constantIndexOp.getValue());
@@ -334,7 +329,7 @@ struct SimplifyAllocConst : public RewritePattern {
} else {
// Dynamic shape dimension not folded; copy operand from old memref.
newShapeConstants.push_back(-1);
- newOperands.push_back(allocOp.getOperand(dynamicDimPos));
+ newOperands.push_back(alloc.getOperand(dynamicDimPos));
}
dynamicDimPos++;
}
@@ -347,30 +342,29 @@ struct SimplifyAllocConst : public RewritePattern {
// Create and insert the alloc op for the new memref.
auto newAlloc =
- rewriter.create<AllocOp>(allocOp.getLoc(), newMemRefType, newOperands);
+ rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands);
// Insert a cast so we have the same type as the old alloc.
- auto resultCast = rewriter.create<MemRefCastOp>(allocOp.getLoc(), newAlloc,
- allocOp.getType());
+ auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
+ alloc.getType());
- rewriter.replaceOp(op, {resultCast}, droppedOperands);
+ rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
+ return matchSuccess();
}
};
/// Fold alloc operations with no uses. Alloc has side effects on the heap,
/// but can still be deleted if it has zero uses.
-struct SimplifyDeadAlloc : public RewritePattern {
- SimplifyDeadAlloc(MLIRContext *context)
- : RewritePattern(AllocOp::getOperationName(), 1, context) {}
+struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
+ using OpRewritePattern<AllocOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(AllocOp alloc,
PatternRewriter &rewriter) const override {
// Check if the alloc'ed value has any uses.
- auto alloc = cast<AllocOp>(op);
if (!alloc.use_empty())
return matchFailure();
// If it doesn't, we can eliminate it.
- op->erase();
+ alloc.erase();
return matchSuccess();
}
};
@@ -484,24 +478,22 @@ FunctionType CallOp::getCalleeType() {
//===----------------------------------------------------------------------===//
namespace {
/// Fold indirect calls that have a constant function as the callee operand.
-struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
- SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
- : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
+struct SimplifyIndirectCallWithKnownCallee
+ : public OpRewritePattern<CallIndirectOp> {
+ using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
PatternRewriter &rewriter) const override {
- auto indirectCall = cast<CallIndirectOp>(op);
-
// Check that the callee is a constant callee.
FunctionAttr calledFn;
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
return matchFailure();
// Replace with a direct call.
- SmallVector<Type, 8> callResults(op->getResultTypes());
+ SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
- rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults,
- callOperands);
+ rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
+ callResults, callOperands);
return matchSuccess();
}
};
@@ -964,14 +956,11 @@ namespace {
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
///
-struct SimplifyConstCondBranchPred : public RewritePattern {
- SimplifyConstCondBranchPred(MLIRContext *context)
- : RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
+struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
- auto condbr = cast<CondBranchOp>(op);
-
// Check that the condition is a constant.
if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
return matchFailure();
@@ -991,7 +980,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
}
- rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
return matchSuccess();
}
};
@@ -1230,18 +1219,14 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result,
namespace {
/// Fold Dealloc operations that are deallocating an AllocOp that is only used
/// by other Dealloc operations.
-struct SimplifyDeadDealloc : public RewritePattern {
- SimplifyDeadDealloc(MLIRContext *context)
- : RewritePattern(DeallocOp::getOperationName(), 1, context) {}
+struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(DeallocOp dealloc,
PatternRewriter &rewriter) const override {
- auto dealloc = cast<DeallocOp>(op);
-
// Check that the memref operand's defining operation is an AllocOp.
Value *memref = dealloc.memref();
- Operation *defOp = memref->getDefiningOp();
- if (!isa_and_nonnull<AllocOp>(defOp))
+ if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
return matchFailure();
// Check that all of the uses of the AllocOp are other DeallocOps.
@@ -1250,7 +1235,7 @@ struct SimplifyDeadDealloc : public RewritePattern {
return matchFailure();
// Erase the dealloc operation.
- op->erase();
+ rewriter.replaceOp(dealloc, llvm::None);
return matchSuccess();
}
};
OpenPOWER on IntegriCloud