summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-05-25 17:22:27 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-06-01 20:03:22 -0700
commit9e21ab8f522265d37159372dbce96f66488c4e34 (patch)
tree0ec933521a5894a21543e7ef1684e437b4978380 /mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
parent2f50b6c401fd4d6ff63718ef3b889a79ba32a640 (diff)
downloadbcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.tar.gz
bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.zip
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
-- PiperOrigin-RevId: 250003405
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp')
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp39
1 files changed, 16 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
index 32d8de3c25d..2a752c2c865 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -118,15 +118,13 @@ static Value *emitDequantize(Location loc, Value *input,
namespace {
-struct UniformDequantizePattern : public RewritePattern {
- UniformDequantizePattern(MLIRContext *context)
- : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
+struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
+ using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const {
- auto dcastOp = cast<DequantizeCastOp>(op);
- Type inputType = dcastOp.arg()->getType();
- Type outputType = dcastOp.getResult()->getType();
+ Type inputType = op.arg()->getType();
+ Type outputType = op.getResult()->getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
@@ -136,8 +134,7 @@ struct UniformDequantizePattern : public RewritePattern {
return matchFailure();
}
- Value *dequantizedValue =
- emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
+ Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
}
@@ -322,15 +319,13 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
namespace {
-struct UniformRealAddEwPattern : public RewritePattern {
- UniformRealAddEwPattern(MLIRContext *context)
- : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
+ using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const {
- auto addOp = cast<RealAddEwOp>(op);
- const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
- addOp.clamp_min(), addOp.clamp_max());
+ const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+ op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
@@ -344,15 +339,13 @@ struct UniformRealAddEwPattern : public RewritePattern {
}
};
-struct UniformRealMulEwPattern : public RewritePattern {
- UniformRealMulEwPattern(MLIRContext *context)
- : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
+struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
+ using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const {
- auto mulOp = cast<RealMulEwOp>(op);
- const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
- mulOp.clamp_min(), mulOp.clamp_max());
+ const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+ op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
OpenPOWER on IntegriCloud