summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/StandardOps/Ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/StandardOps/Ops.cpp')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp76
1 files changed, 33 insertions, 43 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 713546fc40d..3189e42d061 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -212,32 +212,20 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
// Common canonicalization pattern support logic
//===----------------------------------------------------------------------===//
-namespace {
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
/// into the root operation directly.
-struct MemRefCastFolder : public RewritePattern {
- /// The rootOpName is the name of the root operation to match against.
- MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
- : RewritePattern(rootOpName, 1, context) {}
-
- PatternMatchResult match(Operation *op) const override {
- for (auto *operand : op->getOperands())
- if (matchPattern(operand, m_Op<MemRefCastOp>()))
- return matchSuccess();
-
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
- if (auto *memref = op->getOperand(i)->getDefiningOp())
- if (auto cast = dyn_cast<MemRefCastOp>(memref))
- op->setOperand(i, cast.getOperand());
- rewriter.updatedRootInPlace(op);
+static LogicalResult foldMemRefCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
+ if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+ operand.set(cast.getOperand());
+ folded = true;
+ }
}
-};
-} // end anonymous namespace.
+ return success(folded);
+}
//===----------------------------------------------------------------------===//
// AddFOp
@@ -1232,11 +1220,15 @@ static LogicalResult verify(DeallocOp op) {
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- /// dealloc(memrefcast) -> dealloc
- results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyDeadDealloc>(context);
}
+LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// dealloc(memrefcast) -> dealloc
+ return foldMemRefCast(*this);
+}
+
//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
@@ -1304,7 +1296,6 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {};
// The size at getIndex() is now a dynamic size of a memref.
-
auto memref = memrefOrTensor()->getDefiningOp();
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
return *(alloc.getDynamicSizes().begin() +
@@ -1321,13 +1312,11 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return *(sizes.begin() + getIndex());
}
- return {};
-}
-
-void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
/// dim(memrefcast) -> dim
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+
+ return {};
}
//===----------------------------------------------------------------------===//
@@ -1507,10 +1496,10 @@ LogicalResult DmaStartOp::verify() {
return success();
}
-void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
// ---------------------------------------------------------------------------
@@ -1565,10 +1554,10 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
@@ -1688,10 +1677,11 @@ static LogicalResult verify(LoadOp op) {
return success();
}
-void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
/// load(memrefcast) -> load
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return OpFoldResult();
}
//===----------------------------------------------------------------------===//
@@ -2092,10 +2082,10 @@ static LogicalResult verify(StoreOp op) {
return success();
}
-void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud