summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp')
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp39
1 files changed, 16 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
index 32d8de3c25d..2a752c2c865 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -118,15 +118,13 @@ static Value *emitDequantize(Location loc, Value *input,
namespace {
-struct UniformDequantizePattern : public RewritePattern {
- UniformDequantizePattern(MLIRContext *context)
- : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
+struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
+ using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const {
- auto dcastOp = cast<DequantizeCastOp>(op);
- Type inputType = dcastOp.arg()->getType();
- Type outputType = dcastOp.getResult()->getType();
+ Type inputType = op.arg()->getType();
+ Type outputType = op.getResult()->getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
@@ -136,8 +134,7 @@ struct UniformDequantizePattern : public RewritePattern {
return matchFailure();
}
- Value *dequantizedValue =
- emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
+ Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
}
@@ -322,15 +319,13 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
namespace {
-struct UniformRealAddEwPattern : public RewritePattern {
- UniformRealAddEwPattern(MLIRContext *context)
- : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
+ using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const {
- auto addOp = cast<RealAddEwOp>(op);
- const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
- addOp.clamp_min(), addOp.clamp_max());
+ const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+ op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
@@ -344,15 +339,13 @@ struct UniformRealAddEwPattern : public RewritePattern {
}
};
-struct UniformRealMulEwPattern : public RewritePattern {
- UniformRealMulEwPattern(MLIRContext *context)
- : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
+struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
+ using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const {
- auto mulOp = cast<RealMulEwOp>(op);
- const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
- mulOp.clamp_min(), mulOp.clamp_max());
+ const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+ op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
OpenPOWER on IntegriCloud