diff options
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp')
-rw-r--r-- | mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp | 64 |
1 files changed, 32 insertions, 32 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 725751eb6c1..df6015de1b9 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -37,9 +37,9 @@ struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> { // Dequantize //===----------------------------------------------------------------------===// -static ValuePtr emitUniformPerLayerDequantize(Location loc, ValuePtr input, - UniformQuantizedType elementType, - PatternRewriter &rewriter) { +static Value emitUniformPerLayerDequantize(Location loc, Value input, + UniformQuantizedType elementType, + PatternRewriter &rewriter) { // Pre-conditions. if (!elementType.isSigned()) { // TODO: Support unsigned storage type. @@ -62,7 +62,7 @@ static ValuePtr emitUniformPerLayerDequantize(Location loc, ValuePtr input, // Apply zero-point offset. if (elementType.getZeroPoint() != 0) { - ValuePtr negZeroPointConst = rewriter.create<ConstantOp>( + Value negZeroPointConst = rewriter.create<ConstantOp>( loc, broadcastScalarConstIntValue(intermediateType, -elementType.getZeroPoint())); input = rewriter.create<AddIOp>(loc, input, negZeroPointConst); @@ -72,14 +72,14 @@ static ValuePtr emitUniformPerLayerDequantize(Location loc, ValuePtr input, input = rewriter.create<ConvertISToFOp>(loc, realType, input); // Mul by scale. - ValuePtr scaleConst = rewriter.create<ConstantOp>( + Value scaleConst = rewriter.create<ConstantOp>( loc, broadcastScalarConstFloatValue(realType, APFloat(elementType.getScale()))); return rewriter.create<MulFOp>(loc, input, scaleConst); } -static ValuePtr -emitUniformPerAxisDequantize(Location loc, ValuePtr input, +static Value +emitUniformPerAxisDequantize(Location loc, Value input, UniformQuantizedPerAxisType elementType, PatternRewriter &rewriter) { // TODO: Support per-axis dequantize. @@ -88,8 +88,8 @@ emitUniformPerAxisDequantize(Location loc, ValuePtr input, return nullptr; } -static ValuePtr emitDequantize(Location loc, ValuePtr input, - PatternRewriter &rewriter) { +static Value emitDequantize(Location loc, Value input, + PatternRewriter &rewriter) { Type inputType = input->getType(); QuantizedType qElementType = QuantizedType::getQuantizedElementType(inputType); @@ -124,7 +124,7 @@ struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> { return matchFailure(); } - ValuePtr dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); + Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); if (!dequantizedValue) { return matchFailure(); } @@ -161,14 +161,14 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, castElementType(info.resultStorageType, intermediateElementType); // Cast operands to storage type. - ValuePtr lhsValue = rewriter - .create<StorageCastOp>(info.op->getLoc(), - info.lhsStorageType, info.lhs) - .getResult(); - ValuePtr rhsValue = rewriter - .create<StorageCastOp>(info.op->getLoc(), - info.rhsStorageType, info.rhs) - .getResult(); + Value lhsValue = rewriter + .create<StorageCastOp>(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + Value rhsValue = rewriter + .create<StorageCastOp>(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); // Cast to the intermediate sized type. lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, @@ -177,7 +177,7 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, rhsValue); // Add. - ValuePtr resultValue = + Value resultValue = rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue); // Zero point offset adjustment. @@ -185,7 +185,7 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, // zpOffset = -zp int zpOffset = -1 * info.resultType.getZeroPoint(); if (zpOffset != 0) { - ValuePtr zpOffsetConst = rewriter.create<ConstantOp>( + Value zpOffsetConst = rewriter.create<ConstantOp>( info.op->getLoc(), broadcastScalarConstIntValue(intermediateType, zpOffset)); resultValue = @@ -237,14 +237,14 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, castElementType(info.resultStorageType, intermediateElementType); // Cast operands to storage type. - ValuePtr lhsValue = rewriter - .create<StorageCastOp>(info.op->getLoc(), - info.lhsStorageType, info.lhs) - .getResult(); - ValuePtr rhsValue = rewriter - .create<StorageCastOp>(info.op->getLoc(), - info.rhsStorageType, info.rhs) - .getResult(); + Value lhsValue = rewriter + .create<StorageCastOp>(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + Value rhsValue = rewriter + .create<StorageCastOp>(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); // Cast to the intermediate sized type. lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, @@ -254,7 +254,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, // Apply argument zeroPoints. if (info.lhsType.getZeroPoint() != 0) { - ValuePtr zpOffsetConst = rewriter.create<ConstantOp>( + Value zpOffsetConst = rewriter.create<ConstantOp>( info.op->getLoc(), broadcastScalarConstIntValue( intermediateType, -info.lhsType.getZeroPoint())); lhsValue = @@ -262,7 +262,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, } if (info.rhsType.getZeroPoint() != 0) { - ValuePtr zpOffsetConst = rewriter.create<ConstantOp>( + Value zpOffsetConst = rewriter.create<ConstantOp>( info.op->getLoc(), broadcastScalarConstIntValue( intermediateType, -info.rhsType.getZeroPoint())); rhsValue = @@ -270,7 +270,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, } // Mul. - ValuePtr resultValue = + Value resultValue = rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue); // Scale output. @@ -284,7 +284,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, // Zero point offset adjustment. if (info.resultType.getZeroPoint() != 0) { - ValuePtr zpOffsetConst = rewriter.create<ConstantOp>( + Value zpOffsetConst = rewriter.create<ConstantOp>( info.op->getLoc(), broadcastScalarConstIntValue(intermediateType, info.resultType.getZeroPoint())); |