diff options
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp')
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 138 |
1 files changed, 102 insertions, 36 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 637f6a04988..02f803ac839 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -18,71 +18,48 @@ #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" -using namespace mlir; -using namespace mlir::quant; - -UniformQuantizedType -mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, - double rmax, bool narrowRange, - Type expressedType, bool isSigned) { - MLIRContext *ctx = expressedType.getContext(); - Type storageType; - unsigned flags; - int64_t qmin; - int64_t qmax; - +namespace mlir { +namespace quant { +namespace { +bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, + MLIRContext *ctx, Type &storageType, int64_t &qmin, + int64_t &qmax) { // Hard-coded type mapping from TFLite. if (numBits <= 8) { storageType = IntegerType::get(8, ctx); if (isSigned) { - flags = QuantizationFlags::Signed; qmin = -128; qmax = 127; } else { - flags = 0; qmin = 0; qmax = 255; } } else if (numBits <= 16) { storageType = IntegerType::get(16, ctx); if (isSigned) { - flags = QuantizationFlags::Signed; qmin = -32768; qmax = 32767; } else { - flags = 0; qmin = 0; qmax = 65535; } } else { - emitError(loc, "unsupported FakeQuant number of bits: ") << numBits; - return nullptr; + return true; } // Handle narrowRange. if (narrowRange) { qmin += 1; } + return false; +} - // Range must straddle zero. - if (rmin > 0.0 || rmax < 0.0) { - return (emitError(loc, "FakeQuant range must straddle zero: [") - << rmin << "," << rmax << "]", - nullptr); - } - - // Special case where min/max is close enough. The tensor contents are all - // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero - // points and dequantized to 0.0. - if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { - return UniformQuantizedType::getChecked(flags, storageType, expressedType, - 1.0, qmin, qmin, qmax, loc); - } - +void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax, + double &scale, int64_t &nudgedZeroPoint) { // Determine the scale. const double qminDouble = qmin; const double qmaxDouble = qmax; - const double scale = (rmax - rmin) / (qmaxDouble - qminDouble); + scale = (rmax - rmin) / (qmaxDouble - qminDouble); // Zero point computation. // In float, solve the affine equation for any known pair @@ -103,7 +80,7 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, : zeroPointFromMax; // Now nudge the zero point to be an integer. - int64_t nudgedZeroPoint = 0; + nudgedZeroPoint = 0; if (zeroPointDouble < qminDouble) { nudgedZeroPoint = qmin; } else if (zeroPointDouble > qmaxDouble) { @@ -115,8 +92,97 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, // By construction, the nudged zero point should always be in range. assert(nudgedZeroPoint >= qmin); assert(nudgedZeroPoint <= qmax); +} + +} // end namespace + +UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, + double rmin, double rmax, + bool narrowRange, Type expressedType, + bool isSigned) { + // Range must straddle zero. + // TODO(b/140641593): remove this constraint. + if (rmin > 0.0 || rmax < 0.0) { + return (emitError(loc, "FakeQuant range must straddle zero: [") + << rmin << "," << rmax << "]", + nullptr); + } + + MLIRContext *ctx = expressedType.getContext(); + unsigned flags = isSigned ? QuantizationFlags::Signed : 0; + Type storageType; + int64_t qmin; + int64_t qmax; + if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, + qmin, qmax)) { + return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, + nullptr); + } + + // Special case where min/max is close enough. The tensor contents are all + // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero + // points and dequantized to 0.0. + if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { + return UniformQuantizedType::getChecked(flags, storageType, expressedType, + 1.0, qmin, qmin, qmax, loc); + } + + double scale; + int64_t nudgedZeroPoint; + getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); return UniformQuantizedType::getChecked(flags, storageType, expressedType, scale, nudgedZeroPoint, qmin, qmax, loc); } + +// TODO(fengliuai): test this method once the quantizeAttr method is fixed. +UniformQuantizedPerAxisType +fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, + ArrayRef<double> rmins, ArrayRef<double> rmaxs, + bool narrowRange, Type expressedType, bool isSigned) { + size_t axis_size = rmins.size(); + if (axis_size != rmaxs.size()) { + return (emitError(loc, "mismatched per-axis min and max size: ") + << axis_size << " vs. " << rmaxs.size(), + nullptr); + } + + MLIRContext *ctx = expressedType.getContext(); + Type storageType; + int64_t qmin; + int64_t qmax; + if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, + qmin, qmax)) { + return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, + nullptr); + } + + SmallVector<double, 4> scales; + SmallVector<int64_t, 4> zeroPoints; + scales.reserve(axis_size); + zeroPoints.reserve(axis_size); + for (size_t axis = 0; axis != axis_size; ++axis) { + double rmin = rmins[axis]; + double rmax = rmaxs[axis]; + if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { + scales.push_back(1.0); + zeroPoints.push_back(qmin); + continue; + } + + double scale; + int64_t nudgedZeroPoint; + getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); + scales.push_back(scale); + zeroPoints.push_back(nudgedZeroPoint); + } + + unsigned flags = isSigned ? QuantizationFlags::Signed : 0; + return UniformQuantizedPerAxisType::getChecked( + flags, storageType, expressedType, scales, zeroPoints, qmin, qmax, + quantizedDimension, loc); +} + +} // namespace quant +} // namespace mlir |