summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp')
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp25
1 files changed, 13 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
index 5d4561be81b..2e1bd958b79 100644
--- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
@@ -54,8 +54,17 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned,
return false;
}
-void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax,
- double &scale, int64_t &nudgedZeroPoint) {
+// This is a specific implementation of nudging:
+// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
+// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
+// point is derived from the shifted range, and the scale isn't changed. As
+// a consequence some values, which are supposeed in the original [rmin, rmax]
+// range will be outside the shifted range and be clamped during quantization.
+// TODO(fengliuai): we should nudge the scale as well, but that requires the
+// fake quant op used in the training to use the nudged scale as well.
+void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
+ double rmax, double &scale,
+ int64_t &nudgedZeroPoint) {
// Determine the scale.
const double qminDouble = qmin;
const double qmaxDouble = qmax;
@@ -100,14 +109,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
double rmin, double rmax,
bool narrowRange, Type expressedType,
bool isSigned) {
- // Range must straddle zero.
- // TODO(b/140641593): remove this constraint.
- if (rmin > 0.0 || rmax < 0.0) {
- return (emitError(loc, "FakeQuant range must straddle zero: [")
- << rmin << "," << rmax << "]",
- nullptr);
- }
-
MLIRContext *ctx = expressedType.getContext();
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
Type storageType;
@@ -129,7 +130,7 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
double scale;
int64_t nudgedZeroPoint;
- getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
+ getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
scale, nudgedZeroPoint, qmin, qmax,
@@ -172,7 +173,7 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
double scale;
int64_t nudgedZeroPoint;
- getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
+ getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
scales.push_back(scale);
zeroPoints.push_back(nudgedZeroPoint);
}
OpenPOWER on IntegriCloud