diff options
author | River Riddle <riverriddle@google.com> | 2019-05-25 17:22:27 -0700 |
---|---|---|
committer | Mehdi Amini <joker.eph@gmail.com> | 2019-06-01 20:03:22 -0700 |
commit | 9e21ab8f522265d37159372dbce96f66488c4e34 (patch) | |
tree | 0ec933521a5894a21543e7ef1684e437b4978380 /mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp | |
parent | 2f50b6c401fd4d6ff63718ef3b889a79ba32a640 (diff) | |
download | bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.tar.gz bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.zip |
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
--
PiperOrigin-RevId: 250003405
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp')
-rw-r--r-- | mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp | 39 |
1 files changed, 16 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 32d8de3c25d..2a752c2c865 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -118,15 +118,13 @@ static Value *emitDequantize(Location loc, Value *input, namespace { -struct UniformDequantizePattern : public RewritePattern { - UniformDequantizePattern(MLIRContext *context) - : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {} +struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> { + using OpRewritePattern<DequantizeCastOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(DequantizeCastOp op, PatternRewriter &rewriter) const { - auto dcastOp = cast<DequantizeCastOp>(op); - Type inputType = dcastOp.arg()->getType(); - Type outputType = dcastOp.getResult()->getType(); + Type inputType = op.arg()->getType(); + Type outputType = op.getResult()->getType(); QuantizedType inputElementType = QuantizedType::getQuantizedElementType(inputType); @@ -136,8 +134,7 @@ struct UniformDequantizePattern : public RewritePattern { return matchFailure(); } - Value *dequantizedValue = - emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter); + Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); if (!dequantizedValue) { return matchFailure(); } @@ -322,15 +319,13 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, namespace { -struct UniformRealAddEwPattern : public RewritePattern { - UniformRealAddEwPattern(MLIRContext *context) - : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {} +struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> { + using OpRewritePattern<RealAddEwOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(RealAddEwOp op, PatternRewriter &rewriter) const { - auto addOp = cast<RealAddEwOp>(op); - const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(), - addOp.clamp_min(), addOp.clamp_max()); + const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), + op.clamp_max()); if (!info.isValid()) { return matchFailure(); } @@ -344,15 +339,13 @@ struct UniformRealAddEwPattern : public RewritePattern { } }; -struct UniformRealMulEwPattern : public RewritePattern { - UniformRealMulEwPattern(MLIRContext *context) - : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {} +struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> { + using OpRewritePattern<RealMulEwOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(RealMulEwOp op, PatternRewriter &rewriter) const { - auto mulOp = cast<RealMulEwOp>(op); - const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(), - mulOp.clamp_min(), mulOp.clamp_max()); + const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), + op.clamp_max()); if (!info.isValid()) { return matchFailure(); } |