diff options
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp')
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp new file mode 100644 index 00000000000..5562e45bb4a --- /dev/null +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -0,0 +1,109 @@ +//===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" + +using namespace mlir; +using namespace mlir::quant; + +UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc, + unsigned numBits, + double rmin, double rmax, + bool narrowRange, + Type expressedType) { + MLIRContext *ctx = expressedType.getContext(); + Type storageType; + unsigned flags; + int64_t qmin; + int64_t qmax; + + // Hard-coded type mapping from TFLite. + if (numBits <= 8) { + storageType = IntegerType::get(8, ctx); + flags = 0; + qmin = 0; + qmax = 255; + } else if (numBits <= 16) { + storageType = IntegerType::get(16, ctx); + flags = QuantizationFlags::Signed; + qmin = -32768; + qmax = 32767; + } else { + ctx->emitError(loc, "unsupported FakeQuant number of bits: ") << numBits; + return nullptr; + } + + // Handle narrowRange. + if (narrowRange) { + qmin += 1; + } + + // Range must straddle zero. + if (rmin > 0.0 || rmax < 0.0) { + return (ctx->emitError(loc, "FakeQuant range must straddle zero: [") + << rmin << "," << rmax << "]", + nullptr); + } + + // Special case where min/max is a point. Must be 0. + if (rmin == rmax) { + return UniformQuantizedType::getChecked(flags, storageType, expressedType, + 0.0, 0, qmin, qmax, loc); + } + + // Determine the scale. + const double qminDouble = qmin; + const double qmaxDouble = qmax; + const double scale = (rmax - rmin) / (qmaxDouble - qminDouble); + + // Zero point computation. + // In float, solve the affine equation for any known pair + // (real value, corresponding quantized value), of which, two such pairs + // are known: (rmin, qmin), (rmax, qmax). + // The arithmetic error on the zero point computed from either pair will be + // roughly machine_epsilon * (sum of absolute values of terms). + // Use the variant that adds the smaller error. + const double zeroPointFromMin = qminDouble - rmin / scale; + const double zeroPointFromMinError = + std::abs(qminDouble) + std::abs(rmin / scale); + const double zeroPointFromMax = qmaxDouble - rmax / scale; + const double zeroPointFromMaxError = + std::abs(qmaxDouble) + std::abs(rmax / scale); + + const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError) + ? zeroPointFromMin + : zeroPointFromMax; + + // Now nudge the zero point to be an integer. + int64_t nudgedZeroPoint = 0; + if (zeroPointDouble < qminDouble) { + nudgedZeroPoint = qmin; + } else if (zeroPointDouble > qmaxDouble) { + nudgedZeroPoint = qmax; + } else { + nudgedZeroPoint = round(zeroPointDouble); + } + + // By construction, the nudged zero point should always be in range. + assert(nudgedZeroPoint >= qmin); + assert(nudgedZeroPoint <= qmax); + + return UniformQuantizedType::getChecked(flags, storageType, expressedType, + scale, nudgedZeroPoint, qmin, qmax, + loc); +} |