diff options
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp | 44 |
1 files changed, 12 insertions, 32 deletions
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index b618ac07f17..51f19940dcb 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -32,38 +32,6 @@ using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; -#define GET_OP_CLASSES -#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" - -namespace { - -/// Matches x -> [scast -> scast] -> y, replacing the second scast with the -/// value of x if the casts invert each other. -class RemoveRedundantStorageCastsRewrite - : public OpRewritePattern<StorageCastOp> { -public: - using OpRewritePattern<StorageCastOp>::OpRewritePattern; - - PatternMatchResult matchAndRewrite(StorageCastOp op, - PatternRewriter &rewriter) const override { - if (!matchPattern(op.arg(), m_Op<StorageCastOp>())) - return matchFailure(); - auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp()); - if (srcScastOp.arg()->getType() != op.getType()) - return matchFailure(); - - rewriter.replaceOp(op, srcScastOp.arg()); - return matchSuccess(); - } -}; - -} // end anonymous namespace - -void StorageCastOp::getCanonicalizationPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert<RemoveRedundantStorageCastsRewrite>(context); -} - QuantizationDialect::QuantizationDialect(MLIRContext *context) : Dialect(/*name=*/"quant", context) { addTypes<AnyQuantizedType, UniformQuantizedType, @@ -73,3 +41,15 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context) #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" >(); } + +OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) { + /// Matches x -> [scast -> scast] -> y, replacing the second scast with the + /// value of x if the casts invert each other. + auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg()->getDefiningOp()); + if (!srcScastOp || srcScastOp.arg()->getType() != getType()) + return OpFoldResult(); + return srcScastOp.arg(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" |

