diff options
Diffstat (limited to 'mlir/lib/Dialect/StandardOps/Ops.cpp')
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 76 |
1 files changed, 33 insertions, 43 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 713546fc40d..3189e42d061 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -212,32 +212,20 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() { // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// -namespace { /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref_cast /// into the root operation directly. -struct MemRefCastFolder : public RewritePattern { - /// The rootOpName is the name of the root operation to match against. - MemRefCastFolder(StringRef rootOpName, MLIRContext *context) - : RewritePattern(rootOpName, 1, context) {} - - PatternMatchResult match(Operation *op) const override { - for (auto *operand : op->getOperands()) - if (matchPattern(operand, m_Op<MemRefCastOp>())) - return matchSuccess(); - - return matchFailure(); - } - - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = dyn_cast<MemRefCastOp>(memref)) - op->setOperand(i, cast.getOperand()); - rewriter.updatedRootInPlace(op); +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp()); + if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) { + operand.set(cast.getOperand()); + folded = true; + } } -}; -} // end anonymous namespace. + return success(folded); +} //===----------------------------------------------------------------------===// // AddFOp @@ -1232,11 +1220,15 @@ static LogicalResult verify(DeallocOp op) { void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - /// dealloc(memrefcast) -> dealloc - results.insert<MemRefCastFolder>(getOperationName(), context); results.insert<SimplifyDeadDealloc>(context); } +LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { + /// dealloc(memrefcast) -> dealloc + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// @@ -1304,7 +1296,6 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { return {}; // The size at getIndex() is now a dynamic size of a memref. - auto memref = memrefOrTensor()->getDefiningOp(); if (auto alloc = dyn_cast_or_null<AllocOp>(memref)) return *(alloc.getDynamicSizes().begin() + @@ -1321,13 +1312,11 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { return *(sizes.begin() + getIndex()); } - return {}; -} - -void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { /// dim(memrefcast) -> dim - results.insert<MemRefCastFolder>(getOperationName(), context); + if (succeeded(foldMemRefCast(*this))) + return getResult(); + + return {}; } //===----------------------------------------------------------------------===// @@ -1507,10 +1496,10 @@ LogicalResult DmaStartOp::verify() { return success(); } -void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { /// dma_start(memrefcast) -> dma_start - results.insert<MemRefCastFolder>(getOperationName(), context); + return foldMemRefCast(*this); } // --------------------------------------------------------------------------- @@ -1565,10 +1554,10 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { /// dma_wait(memrefcast) -> dma_wait - results.insert<MemRefCastFolder>(getOperationName(), context); + return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// @@ -1688,10 +1677,11 @@ static LogicalResult verify(LoadOp op) { return success(); } -void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { /// load(memrefcast) -> load - results.insert<MemRefCastFolder>(getOperationName(), context); + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); } //===----------------------------------------------------------------------===// @@ -2092,10 +2082,10 @@ static LogicalResult verify(StoreOp op) { return success(); } -void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { /// store(memrefcast) -> store - results.insert<MemRefCastFolder>(getOperationName(), context); + return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// |

