diff options
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp')
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 5d4561be81b..2e1bd958b79 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -54,8 +54,17 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, return false; } -void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax, - double &scale, int64_t &nudgedZeroPoint) { +// This is a specific implementation of nudging: +// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted +// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero +// point is derived from the shifted range, and the scale isn't changed. As +// a consequence some values, which are supposeed in the original [rmin, rmax] +// range will be outside the shifted range and be clamped during quantization. +// TODO(fengliuai): we should nudge the scale as well, but that requires the +// fake quant op used in the training to use the nudged scale as well. +void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, + double rmax, double &scale, + int64_t &nudgedZeroPoint) { // Determine the scale. const double qminDouble = qmin; const double qmaxDouble = qmax; @@ -100,14 +109,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange, Type expressedType, bool isSigned) { - // Range must straddle zero. - // TODO(b/140641593): remove this constraint. - if (rmin > 0.0 || rmax < 0.0) { - return (emitError(loc, "FakeQuant range must straddle zero: [") - << rmin << "," << rmax << "]", - nullptr); - } - MLIRContext *ctx = expressedType.getContext(); unsigned flags = isSigned ? QuantizationFlags::Signed : 0; Type storageType; @@ -129,7 +130,7 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double scale; int64_t nudgedZeroPoint; - getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); + getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); return UniformQuantizedType::getChecked(flags, storageType, expressedType, scale, nudgedZeroPoint, qmin, qmax, @@ -172,7 +173,7 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, double scale; int64_t nudgedZeroPoint; - getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); + getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); scales.push_back(scale); zeroPoints.push_back(nudgedZeroPoint); } |