diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-13 14:52:39 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-13 17:19:02 -0800 |
| commit | 7ac42fa26e5ac2c3554eb38b7456c6bd81e69cec (patch) | |
| tree | ceff109190beaf4333847d0b2391491b29275237 /mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp | |
| parent | 27ae92516b925e5b8e416032117ef8922fca4d37 (diff) | |
| download | bcm5719-llvm-7ac42fa26e5ac2c3554eb38b7456c6bd81e69cec.tar.gz bcm5719-llvm-7ac42fa26e5ac2c3554eb38b7456c6bd81e69cec.zip | |
Refactor various canonicalization patterns as in-place folds.
This is more efficient, and allows for these to fire in more situations: e.g. createOrFold, DialectConversion, etc.
PiperOrigin-RevId: 285476837
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp | 44 |
1 files changed, 12 insertions, 32 deletions
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index b618ac07f17..51f19940dcb 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -32,38 +32,6 @@ using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; -#define GET_OP_CLASSES -#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" - -namespace { - -/// Matches x -> [scast -> scast] -> y, replacing the second scast with the -/// value of x if the casts invert each other. -class RemoveRedundantStorageCastsRewrite - : public OpRewritePattern<StorageCastOp> { -public: - using OpRewritePattern<StorageCastOp>::OpRewritePattern; - - PatternMatchResult matchAndRewrite(StorageCastOp op, - PatternRewriter &rewriter) const override { - if (!matchPattern(op.arg(), m_Op<StorageCastOp>())) - return matchFailure(); - auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp()); - if (srcScastOp.arg()->getType() != op.getType()) - return matchFailure(); - - rewriter.replaceOp(op, srcScastOp.arg()); - return matchSuccess(); - } -}; - -} // end anonymous namespace - -void StorageCastOp::getCanonicalizationPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert<RemoveRedundantStorageCastsRewrite>(context); -} - QuantizationDialect::QuantizationDialect(MLIRContext *context) : Dialect(/*name=*/"quant", context) { addTypes<AnyQuantizedType, UniformQuantizedType, @@ -73,3 +41,15 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context) #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" >(); } + +OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) { + /// Matches x -> [scast -> scast] -> y, replacing the second scast with the + /// value of x if the casts invert each other. + auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg()->getDefiningOp()); + if (!srcScastOp || srcScastOp.arg()->getType() != getType()) + return OpFoldResult(); + return srcScastOp.arg(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" |

