diff options
author | River Riddle <riverriddle@google.com> | 2019-05-25 17:22:27 -0700 |
---|---|---|
committer | Mehdi Amini <joker.eph@gmail.com> | 2019-06-01 20:03:22 -0700 |
commit | 9e21ab8f522265d37159372dbce96f66488c4e34 (patch) | |
tree | 0ec933521a5894a21543e7ef1684e437b4978380 /mlir/lib/StandardOps/Ops.cpp | |
parent | 2f50b6c401fd4d6ff63718ef3b889a79ba32a640 (diff) | |
download | bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.tar.gz bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.zip |
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
--
PiperOrigin-RevId: 250003405
Diffstat (limited to 'mlir/lib/StandardOps/Ops.cpp')
-rw-r--r-- | mlir/lib/StandardOps/Ops.cpp | 87 |
1 files changed, 36 insertions, 51 deletions
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 508ebfee889..dd6754645f0 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -291,24 +291,19 @@ static LogicalResult verify(AllocOp op) { namespace { /// Fold constant dimensions into an alloc operation. -struct SimplifyAllocConst : public RewritePattern { - SimplifyAllocConst(MLIRContext *context) - : RewritePattern(AllocOp::getOperationName(), 1, context) {} - - PatternMatchResult match(Operation *op) const override { - auto alloc = cast<AllocOp>(op); +struct SimplifyAllocConst : public OpRewritePattern<AllocOp> { + using OpRewritePattern<AllocOp>::OpRewritePattern; + PatternMatchResult matchAndRewrite(AllocOp alloc, + PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. - for (auto *operand : alloc.getOperands()) - if (matchPattern(operand, m_ConstantIndex())) - return matchSuccess(); - return matchFailure(); - } + if (llvm::none_of(alloc.getOperands(), [](Value *operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return matchFailure(); - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - auto allocOp = cast<AllocOp>(op); - auto memrefType = allocOp.getType(); + auto memrefType = alloc.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. @@ -325,7 +320,7 @@ struct SimplifyAllocConst : public RewritePattern { newShapeConstants.push_back(dimSize); continue; } - auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp(); + auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); @@ -334,7 +329,7 @@ struct SimplifyAllocConst : public RewritePattern { } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(-1); - newOperands.push_back(allocOp.getOperand(dynamicDimPos)); + newOperands.push_back(alloc.getOperand(dynamicDimPos)); } dynamicDimPos++; } @@ -347,30 +342,29 @@ struct SimplifyAllocConst : public RewritePattern { // Create and insert the alloc op for the new memref. auto newAlloc = - rewriter.create<AllocOp>(allocOp.getLoc(), newMemRefType, newOperands); + rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands); // Insert a cast so we have the same type as the old alloc. - auto resultCast = rewriter.create<MemRefCastOp>(allocOp.getLoc(), newAlloc, - allocOp.getType()); + auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc, + alloc.getType()); - rewriter.replaceOp(op, {resultCast}, droppedOperands); + rewriter.replaceOp(alloc, {resultCast}, droppedOperands); + return matchSuccess(); } }; /// Fold alloc operations with no uses. Alloc has side effects on the heap, /// but can still be deleted if it has zero uses. -struct SimplifyDeadAlloc : public RewritePattern { - SimplifyDeadAlloc(MLIRContext *context) - : RewritePattern(AllocOp::getOperationName(), 1, context) {} +struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { + using OpRewritePattern<AllocOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(AllocOp alloc, PatternRewriter &rewriter) const override { // Check if the alloc'ed value has any uses. - auto alloc = cast<AllocOp>(op); if (!alloc.use_empty()) return matchFailure(); // If it doesn't, we can eliminate it. - op->erase(); + alloc.erase(); return matchSuccess(); } }; @@ -484,24 +478,22 @@ FunctionType CallOp::getCalleeType() { //===----------------------------------------------------------------------===// namespace { /// Fold indirect calls that have a constant function as the callee operand. -struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { - SimplifyIndirectCallWithKnownCallee(MLIRContext *context) - : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {} +struct SimplifyIndirectCallWithKnownCallee + : public OpRewritePattern<CallIndirectOp> { + using OpRewritePattern<CallIndirectOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, PatternRewriter &rewriter) const override { - auto indirectCall = cast<CallIndirectOp>(op); - // Check that the callee is a constant callee. FunctionAttr calledFn; if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) return matchFailure(); // Replace with a direct call. - SmallVector<Type, 8> callResults(op->getResultTypes()); + SmallVector<Type, 8> callResults(indirectCall.getResultTypes()); SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands()); - rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults, - callOperands); + rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(), + callResults, callOperands); return matchSuccess(); } }; @@ -964,14 +956,11 @@ namespace { /// cond_br true, ^bb1, ^bb2 -> br ^bb1 /// cond_br false, ^bb1, ^bb2 -> br ^bb2 /// -struct SimplifyConstCondBranchPred : public RewritePattern { - SimplifyConstCondBranchPred(MLIRContext *context) - : RewritePattern(CondBranchOp::getOperationName(), 1, context) {} +struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { + using OpRewritePattern<CondBranchOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { - auto condbr = cast<CondBranchOp>(op); - // Check that the condition is a constant. if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>())) return matchFailure(); @@ -991,7 +980,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern { branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); } - rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs); + rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs); return matchSuccess(); } }; @@ -1230,18 +1219,14 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result, namespace { /// Fold Dealloc operations that are deallocating an AllocOp that is only used /// by other Dealloc operations. -struct SimplifyDeadDealloc : public RewritePattern { - SimplifyDeadDealloc(MLIRContext *context) - : RewritePattern(DeallocOp::getOperationName(), 1, context) {} +struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { + using OpRewritePattern<DeallocOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(DeallocOp dealloc, PatternRewriter &rewriter) const override { - auto dealloc = cast<DeallocOp>(op); - // Check that the memref operand's defining operation is an AllocOp. Value *memref = dealloc.memref(); - Operation *defOp = memref->getDefiningOp(); - if (!isa_and_nonnull<AllocOp>(defOp)) + if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp())) return matchFailure(); // Check that all of the uses of the AllocOp are other DeallocOps. @@ -1250,7 +1235,7 @@ struct SimplifyDeadDealloc : public RewritePattern { return matchFailure(); // Erase the dealloc operation. - op->erase(); + rewriter.replaceOp(dealloc, llvm::None); return matchSuccess(); } }; |