diff options
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp')
-rw-r--r-- | mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp | 39 |
1 files changed, 16 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 32d8de3c25d..2a752c2c865 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -118,15 +118,13 @@ static Value *emitDequantize(Location loc, Value *input, namespace { -struct UniformDequantizePattern : public RewritePattern { - UniformDequantizePattern(MLIRContext *context) - : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {} +struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> { + using OpRewritePattern<DequantizeCastOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(DequantizeCastOp op, PatternRewriter &rewriter) const { - auto dcastOp = cast<DequantizeCastOp>(op); - Type inputType = dcastOp.arg()->getType(); - Type outputType = dcastOp.getResult()->getType(); + Type inputType = op.arg()->getType(); + Type outputType = op.getResult()->getType(); QuantizedType inputElementType = QuantizedType::getQuantizedElementType(inputType); @@ -136,8 +134,7 @@ struct UniformDequantizePattern : public RewritePattern { return matchFailure(); } - Value *dequantizedValue = - emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter); + Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); if (!dequantizedValue) { return matchFailure(); } @@ -322,15 +319,13 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, namespace { -struct UniformRealAddEwPattern : public RewritePattern { - UniformRealAddEwPattern(MLIRContext *context) - : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {} +struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> { + using OpRewritePattern<RealAddEwOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(RealAddEwOp op, PatternRewriter &rewriter) const { - auto addOp = cast<RealAddEwOp>(op); - const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(), - addOp.clamp_min(), addOp.clamp_max()); + const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), + op.clamp_max()); if (!info.isValid()) { return matchFailure(); } @@ -344,15 +339,13 @@ struct UniformRealAddEwPattern : public RewritePattern { } }; -struct UniformRealMulEwPattern : public RewritePattern { - UniformRealMulEwPattern(MLIRContext *context) - : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {} +struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> { + using OpRewritePattern<RealMulEwOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(RealMulEwOp op, PatternRewriter &rewriter) const { - auto mulOp = cast<RealMulEwOp>(op); - const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(), - mulOp.clamp_min(), mulOp.clamp_max()); + const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), + op.clamp_max()); if (!info.isValid()) { return matchFailure(); } |