diff options
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp')
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp new file mode 100644 index 00000000000..3685a65f2d8 --- /dev/null +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -0,0 +1,182 @@ +//===- QuantizeUtils.cpp - Support utilities for quantization -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/QuantOps/QuantizeUtils.h" +#include "mlir/Dialect/QuantOps/UniformSupport.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace quant { +/// Converts a possible primitive, real expressed value attribute to a +/// corresponding storage attribute (typically FloatAttr -> IntegerAttr). +/// quantizedElementType is the QuantizedType that describes the expressed +/// origValue. +/// Returns a converter Attribute or nullptr if conversion is not possible. +static Attribute convertPrimitiveValueAttr( + Attribute origRealValue, QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, Type &outConvertedType) { + if (origRealValue.isa<FloatAttr>()) { + FloatAttr floatAttr = origRealValue.cast<FloatAttr>(); + outConvertedType = quantizedElementType.getStorageType(); + return IntegerAttr::get(quantizedElementType.getStorageType(), + converter.quantizeFloatToInt(floatAttr.getValue())); + } + + return nullptr; +} + +/// Converts a real expressed DenseFPElementsAttr to a corresponding +/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized +/// storage values assuming the given quantizedElementType and converter. +static DenseElementsAttr +convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + // Convert to corresponding quantized value attributes. + SmallVector<APInt, 8> quantValues; + quantValues.reserve(realFPElementsAttr.size()); + for (APFloat realVal : realFPElementsAttr) { + quantValues.push_back(converter.quantizeFloatToInt(realVal)); + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newDenseType = + quantizedElementType + .castExpressedToStorageType(realFPElementsAttr.getType()) + .dyn_cast_or_null<VectorOrTensorType>(); + if (!newDenseType) { + return nullptr; + } + return DenseIntElementsAttr::get(newDenseType, quantValues); +} + +/// Converts a real expressed SplatElementsAttr to a corresponding +/// SplatElementsAttr containing quantized storage values assuming the given +/// quantizedElementType and converter. +static SplatElementsAttr +convertSplatElementsAttr(SplatElementsAttr realSplatAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + // Since the splat just references a single primitive value, use the + // function for converting primitives. + // NOTE: When implementing per-channel, we will need to promote the + // splat to a dense and handle channels individually. + Type unusedPrimitiveType; + auto elementAttr = + convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType, + converter, unusedPrimitiveType); + if (!elementAttr) { + return nullptr; + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newSplatType = + quantizedElementType.castExpressedToStorageType(realSplatAttr.getType()) + .dyn_cast_or_null<VectorOrTensorType>(); + if (!newSplatType) { + return nullptr; + } + return SplatElementsAttr::get(newSplatType, elementAttr); +} + +/// Converts a real expressed SplatElementsAttr to a corresponding +/// SplatElementsAttr containing quantized storage values assuming the given +/// quantizedElementType and converter. +static SparseElementsAttr +convertSparseElementsAttr(SparseElementsAttr realSparseAttr, + QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); + if (!realDenseAttr.isa<DenseFPElementsAttr>()) { + return nullptr; + } + DenseElementsAttr quantDenseAttr = + convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(), + quantizedElementType, converter); + if (!quantDenseAttr) { + return nullptr; + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). + VectorOrTensorType newSparseType = + quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) + .dyn_cast_or_null<VectorOrTensorType>(); + if (!newSparseType) { + return nullptr; + } + return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(), + quantDenseAttr); +} + +/// Converts a real expressed Attribute to a corresponding Attribute containing +/// quantized storage values assuming the given uniform quantizedElementType and +/// converter. +Attribute quantizeAttrUniform(Attribute realValue, + UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, + Type &outConvertedType) { + // Fork to handle different variants of constants supported. + if (realValue.isa<SplatElementsAttr>()) { + // Splatted tensor or vector constant. + auto converted = convertSplatElementsAttr( + realValue.cast<SplatElementsAttr>(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else if (realValue.isa<DenseFPElementsAttr>()) { + // Dense tensor or vector constant. + auto converted = convertDenseFPElementsAttr( + realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else if (realValue.isa<SparseElementsAttr>()) { + // Sparse tensor or vector constant. + auto converted = convertSparseElementsAttr( + realValue.cast<SparseElementsAttr>(), quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } else { + // Nothing else matched: try to convert a primitive. + return convertPrimitiveValueAttr(realValue, quantizedElementType, converter, + outConvertedType); + } +} + +/// Convert an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(). +/// Returns nullptr if the conversion is not supported. +/// On success, stores the converted type in outConvertedType. +Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, + Type &outConvertedType) { + // Hard-coded to just support UniformQuantizedType. This will need to + // be generalized when there is more than one. + auto uniformQuantizedType = + quantizedElementType.dyn_cast<UniformQuantizedType>(); + if (!uniformQuantizedType) { + return nullptr; + } + UniformQuantizedValueConverter converter(uniformQuantizedType); + return quantizeAttrUniform(realValue, uniformQuantizedType, converter, + outConvertedType); +} + +} // namespace quant +} // namespace mlir |