summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-13 14:52:39 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-13 17:19:02 -0800
commit7ac42fa26e5ac2c3554eb38b7456c6bd81e69cec (patch)
treeceff109190beaf4333847d0b2391491b29275237
parent27ae92516b925e5b8e416032117ef8922fca4d37 (diff)
downloadbcm5719-llvm-7ac42fa26e5ac2c3554eb38b7456c6bd81e69cec.tar.gz
bcm5719-llvm-7ac42fa26e5ac2c3554eb38b7456c6bd81e69cec.zip
Refactor various canonicalization patterns as in-place folds.
This is more efficient, and allows for these to fire in more situations: e.g. createOrFold, DialectConversion, etc. PiperOrigin-RevId: 285476837
-rw-r--r--mlir/include/mlir/Dialect/AffineOps/AffineOps.h11
-rw-r--r--mlir/include/mlir/Dialect/AffineOps/AffineOps.td5
-rw-r--r--mlir/include/mlir/Dialect/QuantOps/QuantOps.td2
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.h8
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.td6
-rw-r--r--mlir/lib/Dialect/AffineOps/AffineOps.cpp278
-rw-r--r--mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp44
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp76
8 files changed, 187 insertions, 243 deletions
diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
index 835ac24b4ae..8268f81b856 100644
--- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
+++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
@@ -295,8 +295,8 @@ public:
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verify();
- static void getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context);
+ LogicalResult fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results);
/// Returns true if this DMA operation is strided, returns false otherwise.
bool isStrided() {
@@ -380,8 +380,8 @@ public:
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verify();
- static void getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context);
+ LogicalResult fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results);
};
/// The "affine.load" op reads an element from a memref, where the index
@@ -450,6 +450,7 @@ public:
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
+ OpFoldResult fold(ArrayRef<Attribute> operands);
};
/// The "affine.store" op writes an element to a memref, where the index
@@ -520,6 +521,8 @@ public:
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
+ LogicalResult fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results);
};
/// Returns true if the given Value can be used as a dimension id.
diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
index 4d4060414dd..cea44b8dacd 100644
--- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
+++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
@@ -177,12 +177,13 @@ def AffineForOp : Affine_Op<"for",
/// Sets the upper bound to the given constant value.
void setConstantUpperBound(int64_t value);
- /// Returns true if both the lower and upper bound have the same operand
+ /// Returns true if both the lower and upper bound have the same operand
/// lists (same operands in the same order).
bool matchingBoundOperandList();
}];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
@@ -239,7 +240,7 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
}
}];
- let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def AffineMinOp : Affine_Op<"min"> {
diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
index 85d5cd2ee90..072715d65aa 100644
--- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
+++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
@@ -93,7 +93,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> {
def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
let arguments = (ins quant_RealOrStorageValueType:$arg);
let results = (outs quant_RealOrStorageValueType);
- let hasCanonicalizer = 0b1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h
index c7c8714752f..fcf16c05c33 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h
@@ -268,8 +268,8 @@ public:
void print(OpAsmPrinter &p);
LogicalResult verify();
- static void getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context);
+ LogicalResult fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results);
bool isStrided() {
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
@@ -331,8 +331,8 @@ public:
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
- static void getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context);
+ LogicalResult fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results);
};
/// Prints dimension and symbol list.
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 8e21a8bbbc1..553a612f5a6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -659,6 +659,7 @@ def DeallocOp : Std_Op<"dealloc"> {
let arguments = (ins AnyMemRef:$memref);
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def DimOp : Std_Op<"dim", [NoSideEffect]> {
@@ -691,7 +692,6 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
}];
let hasFolder = 1;
- let hasCanonicalizer = 1;
}
def DivFOp : FloatArithmeticOp<"divf"> {
@@ -834,7 +834,7 @@ def LoadOp : Std_Op<"load"> {
operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
}];
- let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def LogOp : FloatUnaryOp<"log"> {
@@ -1137,7 +1137,7 @@ def StoreOp : Std_Op<"store"> {
}
}];
- let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def SubFOp : FloatArithmeticOp<"subf"> {
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index 22d4ec10dd0..e58f6f8d6ed 100644
--- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -814,33 +814,20 @@ void AffineApplyOp::getCanonicalizationPatterns(
// Common canonicalization pattern support logic
//===----------------------------------------------------------------------===//
-namespace {
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
/// into the root operation directly.
-struct MemRefCastFolder : public RewritePattern {
- /// The rootOpName is the name of the root operation to match against.
- MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
- : RewritePattern(rootOpName, 1, context) {}
-
- PatternMatchResult match(Operation *op) const override {
- for (auto *operand : op->getOperands())
- if (matchPattern(operand, m_Op<MemRefCastOp>()))
- return matchSuccess();
-
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
- if (auto *memref = op->getOperand(i)->getDefiningOp())
- if (auto cast = dyn_cast<MemRefCastOp>(memref))
- op->setOperand(i, cast.getOperand());
- rewriter.updatedRootInPlace(op);
+static LogicalResult foldMemRefCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
+ if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+ operand.set(cast.getOperand());
+ folded = true;
+ }
}
-};
-
-} // end anonymous namespace.
+ return success(folded);
+}
//===----------------------------------------------------------------------===//
// AffineDmaStartOp
@@ -1005,10 +992,10 @@ LogicalResult AffineDmaStartOp::verify() {
return success();
}
-void AffineDmaStartOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
+LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
@@ -1081,10 +1068,10 @@ LogicalResult AffineDmaWaitOp::verify() {
return success();
}
-void AffineDmaWaitOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
+LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
@@ -1255,7 +1242,8 @@ static ParseResult parseBound(bool isLower, OperationState &result,
"expected valid affine map representation for loop bounds");
}
-ParseResult parseAffineForOp(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseAffineForOp(OpAsmParser &parser,
+ OperationState &result) {
auto &builder = parser.getBuilder();
OpAsmParser::OperandType inductionVariable;
// Parse the induction variable followed by '='.
@@ -1344,7 +1332,7 @@ static void printBound(AffineMapAttr boundMap,
map.getNumDims(), p);
}
-void print(OpAsmPrinter &p, AffineForOp op) {
+static void print(OpAsmPrinter &p, AffineForOp op) {
p << "affine.for ";
p.printOperand(op.getBody()->getArgument(0));
p << " = ";
@@ -1363,115 +1351,102 @@ void print(OpAsmPrinter &p, AffineForOp op) {
op.getStepAttrName()});
}
-namespace {
-/// This is a pattern to fold trivially empty loops.
-struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
+/// Fold the constant bounds of a loop.
+static LogicalResult foldLoopBounds(AffineForOp forOp) {
+ auto foldLowerOrUpperBound = [&forOp](bool lower) {
+ // Check to see if each of the operands is the result of a constant. If
+ // so, get the value. If not, ignore it.
+ SmallVector<Attribute, 8> operandConstants;
+ auto boundOperands =
+ lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
+ for (auto *operand : boundOperands) {
+ Attribute operandCst;
+ matchPattern(operand, m_Constant(&operandCst));
+ operandConstants.push_back(operandCst);
+ }
- PatternMatchResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- // Check that the body only contains a terminator.
- if (!has_single_element(*forOp.getBody()))
- return matchFailure();
- rewriter.eraseOp(forOp);
- return matchSuccess();
- }
-};
+ AffineMap boundMap =
+ lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
+ assert(boundMap.getNumResults() >= 1 &&
+ "bound maps should have at least one result");
+ SmallVector<Attribute, 4> foldedResults;
+ if (failed(boundMap.constantFold(operandConstants, foldedResults)))
+ return failure();
-/// This is a pattern to fold constant loop bounds.
-struct AffineForOpBoundFolder : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
+ // Compute the max or min as applicable over the results.
+ assert(!foldedResults.empty() && "bounds should have at least one result");
+ auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
+ for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
+ auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
+ maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
+ : llvm::APIntOps::smin(maxOrMin, foldedResult);
+ }
+ lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
+ : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
+ return success();
+ };
- PatternMatchResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- auto foldLowerOrUpperBound = [&forOp](bool lower) {
- // Check to see if each of the operands is the result of a constant. If
- // so, get the value. If not, ignore it.
- SmallVector<Attribute, 8> operandConstants;
- auto boundOperands =
- lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
- for (auto *operand : boundOperands) {
- Attribute operandCst;
- matchPattern(operand, m_Constant(&operandCst));
- operandConstants.push_back(operandCst);
- }
+ // Try to fold the lower bound.
+ bool folded = false;
+ if (!forOp.hasConstantLowerBound())
+ folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
- AffineMap boundMap =
- lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
- assert(boundMap.getNumResults() >= 1 &&
- "bound maps should have at least one result");
- SmallVector<Attribute, 4> foldedResults;
- if (failed(boundMap.constantFold(operandConstants, foldedResults)))
- return failure();
-
- // Compute the max or min as applicable over the results.
- assert(!foldedResults.empty() &&
- "bounds should have at least one result");
- auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
- for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
- auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
- maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
- : llvm::APIntOps::smin(maxOrMin, foldedResult);
- }
- lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
- : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
- return success();
- };
-
- // Try to fold the lower bound.
- bool folded = false;
- if (!forOp.hasConstantLowerBound())
- folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
-
- // Try to fold the upper bound.
- if (!forOp.hasConstantUpperBound())
- folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
-
- // If any of the bounds were folded we return success.
- if (!folded)
- return matchFailure();
- rewriter.updatedRootInPlace(forOp);
- return matchSuccess();
- }
-};
+ // Try to fold the upper bound.
+ if (!forOp.hasConstantUpperBound())
+ folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
+ return success(folded);
+}
-// This is a pattern to canonicalize affine for op loop bounds.
-struct AffineForOpBoundCanonicalizer : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
+/// Canonicalize the bounds of the given loop.
+static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
+ SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
+ SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
- PatternMatchResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
- SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
+ auto lbMap = forOp.getLowerBoundMap();
+ auto ubMap = forOp.getUpperBoundMap();
+ auto prevLbMap = lbMap;
+ auto prevUbMap = ubMap;
- auto lbMap = forOp.getLowerBoundMap();
- auto ubMap = forOp.getUpperBoundMap();
- auto prevLbMap = lbMap;
- auto prevUbMap = ubMap;
+ canonicalizeMapAndOperands(&lbMap, &lbOperands);
+ canonicalizeMapAndOperands(&ubMap, &ubOperands);
- canonicalizeMapAndOperands(&lbMap, &lbOperands);
- canonicalizeMapAndOperands(&ubMap, &ubOperands);
+ // Any canonicalization change always leads to updated map(s).
+ if (lbMap == prevLbMap && ubMap == prevUbMap)
+ return failure();
- // Any canonicalization change always leads to updated map(s).
- if (lbMap == prevLbMap && ubMap == prevUbMap)
- return matchFailure();
+ if (lbMap != prevLbMap)
+ forOp.setLowerBound(lbOperands, lbMap);
+ if (ubMap != prevUbMap)
+ forOp.setUpperBound(ubOperands, ubMap);
+ return success();
+}
- if (lbMap != prevLbMap)
- forOp.setLowerBound(lbOperands, lbMap);
- if (ubMap != prevUbMap)
- forOp.setUpperBound(ubOperands, ubMap);
+namespace {
+/// This is a pattern to fold trivially empty loops.
+struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
+ using OpRewritePattern<AffineForOp>::OpRewritePattern;
- rewriter.updatedRootInPlace(forOp);
+ PatternMatchResult matchAndRewrite(AffineForOp forOp,
+ PatternRewriter &rewriter) const override {
+ // Check that the body only contains a terminator.
+ if (!has_single_element(*forOp.getBody()))
+ return matchFailure();
+ rewriter.eraseOp(forOp);
return matchSuccess();
}
};
-
} // end anonymous namespace
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<AffineForEmptyLoopFolder, AffineForOpBoundFolder,
- AffineForOpBoundCanonicalizer>(context);
+ results.insert<AffineForEmptyLoopFolder>(context);
+}
+
+LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ bool folded = succeeded(foldLoopBounds(*this));
+ folded |= succeeded(canonicalizeLoopBounds(*this));
+ return success(folded);
}
AffineBound AffineForOp::getLowerBound() {
@@ -1741,37 +1716,23 @@ void AffineIfOp::build(Builder *builder, OperationState &result, IntegerSet set,
AffineIfOp::ensureTerminator(*elseRegion, *builder, result.location);
}
-namespace {
-// This is a pattern to canonicalize an affine if op's conditional (integer
-// set + operands).
-struct AffineIfOpCanonicalizer : public OpRewritePattern<AffineIfOp> {
- using OpRewritePattern<AffineIfOp>::OpRewritePattern;
+/// Canonicalize an affine if op's conditional (integer set + operands).
+LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ auto set = getIntegerSet();
+ SmallVector<Value *, 4> operands(getOperands());
+ canonicalizeSetAndOperands(&set, &operands);
- PatternMatchResult matchAndRewrite(AffineIfOp ifOp,
- PatternRewriter &rewriter) const override {
- auto set = ifOp.getIntegerSet();
- SmallVector<Value *, 4> operands(ifOp.getOperands());
-
- canonicalizeSetAndOperands(&set, &operands);
-
- // Any canonicalization change always leads to either a reduction in the
- // number of operands or a change in the number of symbolic operands
- // (promotion of dims to symbols).
- if (operands.size() < ifOp.getIntegerSet().getNumInputs() ||
- set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) {
- ifOp.setConditional(set, operands);
- rewriter.updatedRootInPlace(ifOp);
- return matchSuccess();
- }
-
- return matchFailure();
+ // Any canonicalization change always leads to either a reduction in the
+ // number of operands or a change in the number of symbolic operands
+ // (promotion of dims to symbols).
+ if (operands.size() < getIntegerSet().getNumInputs() ||
+ set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
+ setConditional(set, operands);
+ return success();
}
-};
-} // end anonymous namespace
-void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<AffineIfOpCanonicalizer>(context);
+ return failure();
}
//===----------------------------------------------------------------------===//
@@ -1866,11 +1827,16 @@ LogicalResult AffineLoadOp::verify() {
void AffineLoadOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- /// load(memrefcast) -> load
- results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
}
+OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
+ /// load(memrefcast) -> load
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return OpFoldResult();
+}
+
//===----------------------------------------------------------------------===//
// AffineStoreOp
//===----------------------------------------------------------------------===//
@@ -1959,11 +1925,15 @@ LogicalResult AffineStoreOp::verify() {
void AffineStoreOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- /// load(memrefcast) -> load
- results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
}
+LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// store(memrefcast) -> store
+ return foldMemRefCast(*this);
+}
+
//===----------------------------------------------------------------------===//
// AffineMinOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
index b618ac07f17..51f19940dcb 100644
--- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
@@ -32,38 +32,6 @@ using namespace mlir;
using namespace mlir::quant;
using namespace mlir::quant::detail;
-#define GET_OP_CLASSES
-#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
-
-namespace {
-
-/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
-/// value of x if the casts invert each other.
-class RemoveRedundantStorageCastsRewrite
- : public OpRewritePattern<StorageCastOp> {
-public:
- using OpRewritePattern<StorageCastOp>::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(StorageCastOp op,
- PatternRewriter &rewriter) const override {
- if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
- return matchFailure();
- auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
- if (srcScastOp.arg()->getType() != op.getType())
- return matchFailure();
-
- rewriter.replaceOp(op, srcScastOp.arg());
- return matchSuccess();
- }
-};
-
-} // end anonymous namespace
-
-void StorageCastOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
-}
-
QuantizationDialect::QuantizationDialect(MLIRContext *context)
: Dialect(/*name=*/"quant", context) {
addTypes<AnyQuantizedType, UniformQuantizedType,
@@ -73,3 +41,15 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
>();
}
+
+OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
+ /// Matches x -> [scast -> scast] -> y, replacing the second scast with the
+ /// value of x if the casts invert each other.
+ auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg()->getDefiningOp());
+ if (!srcScastOp || srcScastOp.arg()->getType() != getType())
+ return OpFoldResult();
+ return srcScastOp.arg();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 713546fc40d..3189e42d061 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -212,32 +212,20 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
// Common canonicalization pattern support logic
//===----------------------------------------------------------------------===//
-namespace {
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
/// into the root operation directly.
-struct MemRefCastFolder : public RewritePattern {
- /// The rootOpName is the name of the root operation to match against.
- MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
- : RewritePattern(rootOpName, 1, context) {}
-
- PatternMatchResult match(Operation *op) const override {
- for (auto *operand : op->getOperands())
- if (matchPattern(operand, m_Op<MemRefCastOp>()))
- return matchSuccess();
-
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
- if (auto *memref = op->getOperand(i)->getDefiningOp())
- if (auto cast = dyn_cast<MemRefCastOp>(memref))
- op->setOperand(i, cast.getOperand());
- rewriter.updatedRootInPlace(op);
+static LogicalResult foldMemRefCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
+ if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+ operand.set(cast.getOperand());
+ folded = true;
+ }
}
-};
-} // end anonymous namespace.
+ return success(folded);
+}
//===----------------------------------------------------------------------===//
// AddFOp
@@ -1232,11 +1220,15 @@ static LogicalResult verify(DeallocOp op) {
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- /// dealloc(memrefcast) -> dealloc
- results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyDeadDealloc>(context);
}
+LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// dealloc(memrefcast) -> dealloc
+ return foldMemRefCast(*this);
+}
+
//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
@@ -1304,7 +1296,6 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {};
// The size at getIndex() is now a dynamic size of a memref.
-
auto memref = memrefOrTensor()->getDefiningOp();
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
return *(alloc.getDynamicSizes().begin() +
@@ -1321,13 +1312,11 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return *(sizes.begin() + getIndex());
}
- return {};
-}
-
-void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
/// dim(memrefcast) -> dim
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+
+ return {};
}
//===----------------------------------------------------------------------===//
@@ -1507,10 +1496,10 @@ LogicalResult DmaStartOp::verify() {
return success();
}
-void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
// ---------------------------------------------------------------------------
@@ -1565,10 +1554,10 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
@@ -1688,10 +1677,11 @@ static LogicalResult verify(LoadOp op) {
return success();
}
-void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
/// load(memrefcast) -> load
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return OpFoldResult();
}
//===----------------------------------------------------------------------===//
@@ -2092,10 +2082,10 @@ static LogicalResult verify(StoreOp op) {
return success();
}
-void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud