From 7ac42fa26e5ac2c3554eb38b7456c6bd81e69cec Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 13 Dec 2019 14:52:39 -0800 Subject: Refactor various canonicalization patterns as in-place folds. This is more efficient, and allows for these to fire in more situations: e.g. createOrFold, DialectConversion, etc. PiperOrigin-RevId: 285476837 --- mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp | 44 +++++++++---------------------- 1 file changed, 12 insertions(+), 32 deletions(-) (limited to 'mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp') diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index b618ac07f17..51f19940dcb 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -32,38 +32,6 @@ using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; -#define GET_OP_CLASSES -#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" - -namespace { - -/// Matches x -> [scast -> scast] -> y, replacing the second scast with the -/// value of x if the casts invert each other. -class RemoveRedundantStorageCastsRewrite - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(StorageCastOp op, - PatternRewriter &rewriter) const override { - if (!matchPattern(op.arg(), m_Op())) - return matchFailure(); - auto srcScastOp = cast(op.arg()->getDefiningOp()); - if (srcScastOp.arg()->getType() != op.getType()) - return matchFailure(); - - rewriter.replaceOp(op, srcScastOp.arg()); - return matchSuccess(); - } -}; - -} // end anonymous namespace - -void StorageCastOp::getCanonicalizationPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); -} - QuantizationDialect::QuantizationDialect(MLIRContext *context) : Dialect(/*name=*/"quant", context) { addTypes(); } + +OpFoldResult StorageCastOp::fold(ArrayRef operands) { + /// Matches x -> [scast -> scast] -> y, replacing the second scast with the + /// value of x if the casts invert each other. + auto srcScastOp = dyn_cast_or_null(arg()->getDefiningOp()); + if (!srcScastOp || srcScastOp.arg()->getType() != getType()) + return OpFoldResult(); + return srcScastOp.arg(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" -- cgit v1.2.3