diff options
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils')
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 109 | ||||
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp | 182 | ||||
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp | 73 |
3 files changed, 364 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp new file mode 100644 index 00000000000..5562e45bb4a --- /dev/null +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -0,0 +1,109 @@ +//===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" + +using namespace mlir; +using namespace mlir::quant; + +UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc, + unsigned numBits, + double rmin, double rmax, + bool narrowRange, + Type expressedType) { + MLIRContext *ctx = expressedType.getContext(); + Type storageType; + unsigned flags; + int64_t qmin; + int64_t qmax; + + // Hard-coded type mapping from TFLite. + if (numBits <= 8) { + storageType = IntegerType::get(8, ctx); + flags = 0; + qmin = 0; + qmax = 255; + } else if (numBits <= 16) { + storageType = IntegerType::get(16, ctx); + flags = QuantizationFlags::Signed; + qmin = -32768; + qmax = 32767; + } else { + ctx->emitError(loc, "unsupported FakeQuant number of bits: ") << numBits; + return nullptr; + } + + // Handle narrowRange. + if (narrowRange) { + qmin += 1; + } + + // Range must straddle zero. + if (rmin > 0.0 || rmax < 0.0) { + return (ctx->emitError(loc, "FakeQuant range must straddle zero: [") + << rmin << "," << rmax << "]", + nullptr); + } + + // Special case where min/max is a point. Must be 0. + if (rmin == rmax) { + return UniformQuantizedType::getChecked(flags, storageType, expressedType, + 0.0, 0, qmin, qmax, loc); + } + + // Determine the scale. + const double qminDouble = qmin; + const double qmaxDouble = qmax; + const double scale = (rmax - rmin) / (qmaxDouble - qminDouble); + + // Zero point computation. + // In float, solve the affine equation for any known pair + // (real value, corresponding quantized value), of which, two such pairs + // are known: (rmin, qmin), (rmax, qmax). + // The arithmetic error on the zero point computed from either pair will be + // roughly machine_epsilon * (sum of absolute values of terms). + // Use the variant that adds the smaller error. + const double zeroPointFromMin = qminDouble - rmin / scale; + const double zeroPointFromMinError = + std::abs(qminDouble) + std::abs(rmin / scale); + const double zeroPointFromMax = qmaxDouble - rmax / scale; + const double zeroPointFromMaxError = + std::abs(qmaxDouble) + std::abs(rmax / scale); + + const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError) + ? zeroPointFromMin + : zeroPointFromMax; + + // Now nudge the zero point to be an integer. + int64_t nudgedZeroPoint = 0; + if (zeroPointDouble < qminDouble) { + nudgedZeroPoint = qmin; + } else if (zeroPointDouble > qmaxDouble) { + nudgedZeroPoint = qmax; + } else { + nudgedZeroPoint = round(zeroPointDouble); + } + + // By construction, the nudged zero point should always be in range. + assert(nudgedZeroPoint >= qmin); + assert(nudgedZeroPoint <= qmax); + + return UniformQuantizedType::getChecked(flags, storageType, expressedType, + scale, nudgedZeroPoint, qmin, qmax, + loc); +} diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp new file mode 100644 index 00000000000..3685a65f2d8 --- /dev/null +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -0,0 +1,182 @@ +//===- QuantizeUtils.cpp - Support utilities for quantization -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/QuantOps/QuantizeUtils.h" +#include "mlir/Dialect/QuantOps/UniformSupport.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace quant { +/// Converts a possible primitive, real expressed value attribute to a +/// corresponding storage attribute (typically FloatAttr -> IntegerAttr). +/// quantizedElementType is the QuantizedType that describes the expressed +/// origValue. +/// Returns a converter Attribute or nullptr if conversion is not possible. +static Attribute convertPrimitiveValueAttr( + Attribute origRealValue, QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, Type &outConvertedType) { + if (origRealValue.isa<FloatAttr>()) { + FloatAttr floatAttr = origRealValue.cast<FloatAttr>(); + outConvertedType = quantizedElementType.getStorageType(); + return IntegerAttr::get(quantizedElementType.getStorageType(), + converter.quantizeFloatToInt(floatAttr.getValue())); + } + + return nullptr; +} + +/// Converts a real expressed DenseFPElementsAttr to a corresponding +/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized +/// storage values assuming the given quantizedElementType and converter. +static DenseElementsAttr +convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + // Convert to corresponding quantized value attributes. + SmallVector<APInt, 8> quantValues; + quantValues.reserve(realFPElementsAttr.size()); + for (APFloat realVal : realFPElementsAttr) { + quantValues.push_back(converter.quantizeFloatToInt(realVal)); + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newDenseType = + quantizedElementType + .castExpressedToStorageType(realFPElementsAttr.getType()) + .dyn_cast_or_null<VectorOrTensorType>(); + if (!newDenseType) { + return nullptr; + } + return DenseIntElementsAttr::get(newDenseType, quantValues); +} + +/// Converts a real expressed SplatElementsAttr to a corresponding +/// SplatElementsAttr containing quantized storage values assuming the given +/// quantizedElementType and converter. +static SplatElementsAttr +convertSplatElementsAttr(SplatElementsAttr realSplatAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + // Since the splat just references a single primitive value, use the + // function for converting primitives. + // NOTE: When implementing per-channel, we will need to promote the + // splat to a dense and handle channels individually. + Type unusedPrimitiveType; + auto elementAttr = + convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType, + converter, unusedPrimitiveType); + if (!elementAttr) { + return nullptr; + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newSplatType = + quantizedElementType.castExpressedToStorageType(realSplatAttr.getType()) + .dyn_cast_or_null<VectorOrTensorType>(); + if (!newSplatType) { + return nullptr; + } + return SplatElementsAttr::get(newSplatType, elementAttr); +} + +/// Converts a real expressed SplatElementsAttr to a corresponding +/// SplatElementsAttr containing quantized storage values assuming the given +/// quantizedElementType and converter. +static SparseElementsAttr +convertSparseElementsAttr(SparseElementsAttr realSparseAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); + if (!realDenseAttr.isa<DenseFPElementsAttr>()) { + return nullptr; + } + DenseElementsAttr quantDenseAttr = + convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(), + quantizedElementType, converter); + if (!quantDenseAttr) { + return nullptr; + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newSparseType = + quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) + .dyn_cast_or_null<VectorOrTensorType>(); + if (!newSparseType) { + return nullptr; + } + return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(), + quantDenseAttr); +} + +/// Converts a real expressed Attribute to a corresponding Attribute containing +/// quantized storage values assuming the given uniform quantizedElementType and +/// converter. +Attribute quantizeAttrUniform(Attribute realValue, + UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, + Type &outConvertedType) { + // Fork to handle different variants of constants supported. + if (realValue.isa<SplatElementsAttr>()) { + // Splatted tensor or vector constant. + auto converted = convertSplatElementsAttr( + realValue.cast<SplatElementsAttr>(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else if (realValue.isa<DenseFPElementsAttr>()) { + // Dense tensor or vector constant. + auto converted = convertDenseFPElementsAttr( + realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else if (realValue.isa<SparseElementsAttr>()) { + // Sparse tensor or vector constant. + auto converted = convertSparseElementsAttr( + realValue.cast<SparseElementsAttr>(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else { + // Nothing else matched: try to convert a primitive. + return convertPrimitiveValueAttr(realValue, quantizedElementType, converter, + outConvertedType); + } +} + +/// Convert an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(). +/// Returns nullptr if the conversion is not supported. +/// On success, stores the converted type in outConvertedType. +Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, + Type &outConvertedType) { + // Hard-coded to just support UniformQuantizedType. This will need to + // be generalized when there is more than one. + auto uniformQuantizedType = + quantizedElementType.dyn_cast<UniformQuantizedType>(); + if (!uniformQuantizedType) { + return nullptr; + } + UniformQuantizedValueConverter converter(uniformQuantizedType); + return quantizeAttrUniform(realValue, uniformQuantizedType, converter, + outConvertedType); +} + +} // namespace quant +} // namespace mlir diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp new file mode 100644 index 00000000000..d791075f5db --- /dev/null +++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp @@ -0,0 +1,73 @@ +//===- UniformSupport.cpp - Support utilities for uniform quant -----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/QuantOps/UniformSupport.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; +using namespace mlir::quant; + +static bool isQuantizablePrimitiveType(Type inputType) { + return inputType.isa<FloatType>(); +} + +const ExpressedToUniformQuantizedConverter +ExpressedToUniformQuantizedConverter::forInputType(Type inputType) { + switch (inputType.getKind()) { + default: + if (isQuantizablePrimitiveType(inputType)) { + // Supported primitive type (which just is the expressed type). + return ExpressedToUniformQuantizedConverter{inputType, inputType}; + } + // Unsupported. + return ExpressedToUniformQuantizedConverter{inputType, nullptr}; + case StandardTypes::RankedTensor: + case StandardTypes::UnrankedTensor: + case StandardTypes::Vector: { + Type elementType = inputType.cast<VectorOrTensorType>().getElementType(); + if (!isQuantizablePrimitiveType(elementType)) { + // Unsupported. + return ExpressedToUniformQuantizedConverter{inputType, nullptr}; + } + return ExpressedToUniformQuantizedConverter{ + inputType, inputType.cast<VectorOrTensorType>().getElementType()}; + } + } +} + +Type ExpressedToUniformQuantizedConverter::convert( + UniformQuantizedType elementalType) const { + assert(expressedType && "convert() on unsupported conversion"); + + switch (inputType.getKind()) { + default: + if (isQuantizablePrimitiveType(elementalType)) { + // For primitives, just use the new elemental type. + return elementalType; + } + // Unsupported. + return nullptr; + case StandardTypes::RankedTensor: + return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(), + elementalType); + case StandardTypes::UnrankedTensor: + return UnrankedTensorType::get(elementalType); + case StandardTypes::Vector: + return VectorType::get(inputType.cast<VectorType>().getShape(), + elementalType); + } +} |