summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/Utils
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils')
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp138
1 files changed, 102 insertions, 36 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
index 637f6a04988..02f803ac839 100644
--- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
@@ -18,71 +18,48 @@
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h"
-using namespace mlir;
-using namespace mlir::quant;
-
-UniformQuantizedType
-mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
- double rmax, bool narrowRange,
- Type expressedType, bool isSigned) {
- MLIRContext *ctx = expressedType.getContext();
- Type storageType;
- unsigned flags;
- int64_t qmin;
- int64_t qmax;
-
+namespace mlir {
+namespace quant {
+namespace {
+bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned,
+ MLIRContext *ctx, Type &storageType, int64_t &qmin,
+ int64_t &qmax) {
// Hard-coded type mapping from TFLite.
if (numBits <= 8) {
storageType = IntegerType::get(8, ctx);
if (isSigned) {
- flags = QuantizationFlags::Signed;
qmin = -128;
qmax = 127;
} else {
- flags = 0;
qmin = 0;
qmax = 255;
}
} else if (numBits <= 16) {
storageType = IntegerType::get(16, ctx);
if (isSigned) {
- flags = QuantizationFlags::Signed;
qmin = -32768;
qmax = 32767;
} else {
- flags = 0;
qmin = 0;
qmax = 65535;
}
} else {
- emitError(loc, "unsupported FakeQuant number of bits: ") << numBits;
- return nullptr;
+ return true;
}
// Handle narrowRange.
if (narrowRange) {
qmin += 1;
}
+ return false;
+}
- // Range must straddle zero.
- if (rmin > 0.0 || rmax < 0.0) {
- return (emitError(loc, "FakeQuant range must straddle zero: [")
- << rmin << "," << rmax << "]",
- nullptr);
- }
-
- // Special case where min/max is close enough. The tensor contents are all
- // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
- // points and dequantized to 0.0.
- if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
- return UniformQuantizedType::getChecked(flags, storageType, expressedType,
- 1.0, qmin, qmin, qmax, loc);
- }
-
+void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax,
+ double &scale, int64_t &nudgedZeroPoint) {
// Determine the scale.
const double qminDouble = qmin;
const double qmaxDouble = qmax;
- const double scale = (rmax - rmin) / (qmaxDouble - qminDouble);
+ scale = (rmax - rmin) / (qmaxDouble - qminDouble);
// Zero point computation.
// In float, solve the affine equation for any known pair
@@ -103,7 +80,7 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
: zeroPointFromMax;
// Now nudge the zero point to be an integer.
- int64_t nudgedZeroPoint = 0;
+ nudgedZeroPoint = 0;
if (zeroPointDouble < qminDouble) {
nudgedZeroPoint = qmin;
} else if (zeroPointDouble > qmaxDouble) {
@@ -115,8 +92,97 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
// By construction, the nudged zero point should always be in range.
assert(nudgedZeroPoint >= qmin);
assert(nudgedZeroPoint <= qmax);
+}
+
+} // end namespace
+
+UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
+ double rmin, double rmax,
+ bool narrowRange, Type expressedType,
+ bool isSigned) {
+ // Range must straddle zero.
+ // TODO(b/140641593): remove this constraint.
+ if (rmin > 0.0 || rmax < 0.0) {
+ return (emitError(loc, "FakeQuant range must straddle zero: [")
+ << rmin << "," << rmax << "]",
+ nullptr);
+ }
+
+ MLIRContext *ctx = expressedType.getContext();
+ unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
+ Type storageType;
+ int64_t qmin;
+ int64_t qmax;
+ if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
+ qmin, qmax)) {
+ return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
+ nullptr);
+ }
+
+ // Special case where min/max is close enough. The tensor contents are all
+ // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
+ // points and dequantized to 0.0.
+ if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
+ return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+ 1.0, qmin, qmin, qmax, loc);
+ }
+
+ double scale;
+ int64_t nudgedZeroPoint;
+ getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
scale, nudgedZeroPoint, qmin, qmax,
loc);
}
+
+// TODO(fengliuai): test this method once the quantizeAttr method is fixed.
+UniformQuantizedPerAxisType
+fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
+ ArrayRef<double> rmins, ArrayRef<double> rmaxs,
+ bool narrowRange, Type expressedType, bool isSigned) {
+ size_t axis_size = rmins.size();
+ if (axis_size != rmaxs.size()) {
+ return (emitError(loc, "mismatched per-axis min and max size: ")
+ << axis_size << " vs. " << rmaxs.size(),
+ nullptr);
+ }
+
+ MLIRContext *ctx = expressedType.getContext();
+ Type storageType;
+ int64_t qmin;
+ int64_t qmax;
+ if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
+ qmin, qmax)) {
+ return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
+ nullptr);
+ }
+
+ SmallVector<double, 4> scales;
+ SmallVector<int64_t, 4> zeroPoints;
+ scales.reserve(axis_size);
+ zeroPoints.reserve(axis_size);
+ for (size_t axis = 0; axis != axis_size; ++axis) {
+ double rmin = rmins[axis];
+ double rmax = rmaxs[axis];
+ if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
+ scales.push_back(1.0);
+ zeroPoints.push_back(qmin);
+ continue;
+ }
+
+ double scale;
+ int64_t nudgedZeroPoint;
+ getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
+ scales.push_back(scale);
+ zeroPoints.push_back(nudgedZeroPoint);
+ }
+
+ unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
+ return UniformQuantizedPerAxisType::getChecked(
+ flags, storageType, expressedType, scales, zeroPoints, qmin, qmax,
+ quantizedDimension, loc);
+}
+
+} // namespace quant
+} // namespace mlir
OpenPOWER on IntegriCloud