summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp')
-rw-r--r--mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp44
1 files changed, 12 insertions, 32 deletions
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
index b618ac07f17..51f19940dcb 100644
--- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
@@ -32,38 +32,6 @@ using namespace mlir;
using namespace mlir::quant;
using namespace mlir::quant::detail;
-#define GET_OP_CLASSES
-#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
-
-namespace {
-
-/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
-/// value of x if the casts invert each other.
-class RemoveRedundantStorageCastsRewrite
- : public OpRewritePattern<StorageCastOp> {
-public:
- using OpRewritePattern<StorageCastOp>::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(StorageCastOp op,
- PatternRewriter &rewriter) const override {
- if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
- return matchFailure();
- auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
- if (srcScastOp.arg()->getType() != op.getType())
- return matchFailure();
-
- rewriter.replaceOp(op, srcScastOp.arg());
- return matchSuccess();
- }
-};
-
-} // end anonymous namespace
-
-void StorageCastOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
-}
-
QuantizationDialect::QuantizationDialect(MLIRContext *context)
: Dialect(/*name=*/"quant", context) {
addTypes<AnyQuantizedType, UniformQuantizedType,
@@ -73,3 +41,15 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
>();
}
+
+OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
+ /// Matches x -> [scast -> scast] -> y, replacing the second scast with the
+ /// value of x if the casts invert each other.
+ auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg()->getDefiningOp());
+ if (!srcScastOp || srcScastOp.arg()->getType() != getType())
+ return OpFoldResult();
+ return srcScastOp.arg();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
OpenPOWER on IntegriCloud