summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp')
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp182
1 files changed, 182 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
new file mode 100644
index 00000000000..3685a65f2d8
--- /dev/null
+++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
@@ -0,0 +1,182 @@
+//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/QuantizeUtils.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace quant {
+/// Converts a possible primitive, real expressed value attribute to a
+/// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
+/// quantizedElementType is the QuantizedType that describes the expressed
+/// origValue.
+/// Returns a converter Attribute or nullptr if conversion is not possible.
+static Attribute convertPrimitiveValueAttr(
+ Attribute origRealValue, QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
+ if (origRealValue.isa<FloatAttr>()) {
+ FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
+ outConvertedType = quantizedElementType.getStorageType();
+ return IntegerAttr::get(quantizedElementType.getStorageType(),
+ converter.quantizeFloatToInt(floatAttr.getValue()));
+ }
+
+ return nullptr;
+}
+
+/// Converts a real expressed DenseFPElementsAttr to a corresponding
+/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
+/// storage values assuming the given quantizedElementType and converter.
+static DenseElementsAttr
+convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ // Convert to corresponding quantized value attributes.
+ SmallVector<APInt, 8> quantValues;
+ quantValues.reserve(realFPElementsAttr.size());
+ for (APFloat realVal : realFPElementsAttr) {
+ quantValues.push_back(converter.quantizeFloatToInt(realVal));
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newDenseType =
+ quantizedElementType
+ .castExpressedToStorageType(realFPElementsAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newDenseType) {
+ return nullptr;
+ }
+ return DenseIntElementsAttr::get(newDenseType, quantValues);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SplatElementsAttr
+convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ // Since the splat just references a single primitive value, use the
+ // function for converting primitives.
+ // NOTE: When implementing per-channel, we will need to promote the
+ // splat to a dense and handle channels individually.
+ Type unusedPrimitiveType;
+ auto elementAttr =
+ convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType,
+ converter, unusedPrimitiveType);
+ if (!elementAttr) {
+ return nullptr;
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newSplatType =
+ quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newSplatType) {
+ return nullptr;
+ }
+ return SplatElementsAttr::get(newSplatType, elementAttr);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SparseElementsAttr
+convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
+ if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
+ return nullptr;
+ }
+ DenseElementsAttr quantDenseAttr =
+ convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
+ quantizedElementType, converter);
+ if (!quantDenseAttr) {
+ return nullptr;
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newSparseType =
+ quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newSparseType) {
+ return nullptr;
+ }
+ return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
+ quantDenseAttr);
+}
+
+/// Converts a real expressed Attribute to a corresponding Attribute containing
+/// quantized storage values assuming the given uniform quantizedElementType and
+/// converter.
+Attribute quantizeAttrUniform(Attribute realValue,
+ UniformQuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter,
+ Type &outConvertedType) {
+ // Fork to handle different variants of constants supported.
+ if (realValue.isa<SplatElementsAttr>()) {
+ // Splatted tensor or vector constant.
+ auto converted = convertSplatElementsAttr(
+ realValue.cast<SplatElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else if (realValue.isa<DenseFPElementsAttr>()) {
+ // Dense tensor or vector constant.
+ auto converted = convertDenseFPElementsAttr(
+ realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else if (realValue.isa<SparseElementsAttr>()) {
+ // Sparse tensor or vector constant.
+ auto converted = convertSparseElementsAttr(
+ realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else {
+ // Nothing else matched: try to convert a primitive.
+ return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
+ outConvertedType);
+ }
+}
+
+/// Convert an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType().
+/// Returns nullptr if the conversion is not supported.
+/// On success, stores the converted type in outConvertedType.
+Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
+ Type &outConvertedType) {
+ // Hard-coded to just support UniformQuantizedType. This will need to
+ // be generalized when there is more than one.
+ auto uniformQuantizedType =
+ quantizedElementType.dyn_cast<UniformQuantizedType>();
+ if (!uniformQuantizedType) {
+ return nullptr;
+ }
+ UniformQuantizedValueConverter converter(uniformQuantizedType);
+ return quantizeAttrUniform(realValue, uniformQuantizedType, converter,
+ outConvertedType);
+}
+
+} // namespace quant
+} // namespace mlir
OpenPOWER on IntegriCloud