summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/Utils
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Utils')
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp109
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp182
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp73
3 files changed, 364 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
new file mode 100644
index 00000000000..5562e45bb4a
--- /dev/null
+++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
@@ -0,0 +1,109 @@
+//===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc,
+ unsigned numBits,
+ double rmin, double rmax,
+ bool narrowRange,
+ Type expressedType) {
+ MLIRContext *ctx = expressedType.getContext();
+ Type storageType;
+ unsigned flags;
+ int64_t qmin;
+ int64_t qmax;
+
+ // Hard-coded type mapping from TFLite.
+ if (numBits <= 8) {
+ storageType = IntegerType::get(8, ctx);
+ flags = 0;
+ qmin = 0;
+ qmax = 255;
+ } else if (numBits <= 16) {
+ storageType = IntegerType::get(16, ctx);
+ flags = QuantizationFlags::Signed;
+ qmin = -32768;
+ qmax = 32767;
+ } else {
+ ctx->emitError(loc, "unsupported FakeQuant number of bits: ") << numBits;
+ return nullptr;
+ }
+
+ // Handle narrowRange.
+ if (narrowRange) {
+ qmin += 1;
+ }
+
+ // Range must straddle zero.
+ if (rmin > 0.0 || rmax < 0.0) {
+ return (ctx->emitError(loc, "FakeQuant range must straddle zero: [")
+ << rmin << "," << rmax << "]",
+ nullptr);
+ }
+
+ // Special case where min/max is a point. Must be 0.
+ if (rmin == rmax) {
+ return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+ 0.0, 0, qmin, qmax, loc);
+ }
+
+ // Determine the scale.
+ const double qminDouble = qmin;
+ const double qmaxDouble = qmax;
+ const double scale = (rmax - rmin) / (qmaxDouble - qminDouble);
+
+ // Zero point computation.
+ // In float, solve the affine equation for any known pair
+ // (real value, corresponding quantized value), of which, two such pairs
+ // are known: (rmin, qmin), (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair will be
+ // roughly machine_epsilon * (sum of absolute values of terms).
+ // Use the variant that adds the smaller error.
+ const double zeroPointFromMin = qminDouble - rmin / scale;
+ const double zeroPointFromMinError =
+ std::abs(qminDouble) + std::abs(rmin / scale);
+ const double zeroPointFromMax = qmaxDouble - rmax / scale;
+ const double zeroPointFromMaxError =
+ std::abs(qmaxDouble) + std::abs(rmax / scale);
+
+ const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
+ ? zeroPointFromMin
+ : zeroPointFromMax;
+
+ // Now nudge the zero point to be an integer.
+ int64_t nudgedZeroPoint = 0;
+ if (zeroPointDouble < qminDouble) {
+ nudgedZeroPoint = qmin;
+ } else if (zeroPointDouble > qmaxDouble) {
+ nudgedZeroPoint = qmax;
+ } else {
+ nudgedZeroPoint = round(zeroPointDouble);
+ }
+
+ // By construction, the nudged zero point should always be in range.
+ assert(nudgedZeroPoint >= qmin);
+ assert(nudgedZeroPoint <= qmax);
+
+ return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+ scale, nudgedZeroPoint, qmin, qmax,
+ loc);
+}
diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
new file mode 100644
index 00000000000..3685a65f2d8
--- /dev/null
+++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
@@ -0,0 +1,182 @@
+//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/QuantizeUtils.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace quant {
+/// Converts a possible primitive, real expressed value attribute to a
+/// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
+/// quantizedElementType is the QuantizedType that describes the expressed
+/// origValue.
+/// Returns a converter Attribute or nullptr if conversion is not possible.
+static Attribute convertPrimitiveValueAttr(
+ Attribute origRealValue, QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
+ if (origRealValue.isa<FloatAttr>()) {
+ FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
+ outConvertedType = quantizedElementType.getStorageType();
+ return IntegerAttr::get(quantizedElementType.getStorageType(),
+ converter.quantizeFloatToInt(floatAttr.getValue()));
+ }
+
+ return nullptr;
+}
+
+/// Converts a real expressed DenseFPElementsAttr to a corresponding
+/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
+/// storage values assuming the given quantizedElementType and converter.
+static DenseElementsAttr
+convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ // Convert to corresponding quantized value attributes.
+ SmallVector<APInt, 8> quantValues;
+ quantValues.reserve(realFPElementsAttr.size());
+ for (APFloat realVal : realFPElementsAttr) {
+ quantValues.push_back(converter.quantizeFloatToInt(realVal));
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newDenseType =
+ quantizedElementType
+ .castExpressedToStorageType(realFPElementsAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newDenseType) {
+ return nullptr;
+ }
+ return DenseIntElementsAttr::get(newDenseType, quantValues);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SplatElementsAttr
+convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ // Since the splat just references a single primitive value, use the
+ // function for converting primitives.
+ // NOTE: When implementing per-channel, we will need to promote the
+ // splat to a dense and handle channels individually.
+ Type unusedPrimitiveType;
+ auto elementAttr =
+ convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType,
+ converter, unusedPrimitiveType);
+ if (!elementAttr) {
+ return nullptr;
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newSplatType =
+ quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newSplatType) {
+ return nullptr;
+ }
+ return SplatElementsAttr::get(newSplatType, elementAttr);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SparseElementsAttr
+convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
+ if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
+ return nullptr;
+ }
+ DenseElementsAttr quantDenseAttr =
+ convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
+ quantizedElementType, converter);
+ if (!quantDenseAttr) {
+ return nullptr;
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newSparseType =
+ quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newSparseType) {
+ return nullptr;
+ }
+ return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
+ quantDenseAttr);
+}
+
+/// Converts a real expressed Attribute to a corresponding Attribute containing
+/// quantized storage values assuming the given uniform quantizedElementType and
+/// converter.
+Attribute quantizeAttrUniform(Attribute realValue,
+ UniformQuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter,
+ Type &outConvertedType) {
+ // Fork to handle different variants of constants supported.
+ if (realValue.isa<SplatElementsAttr>()) {
+ // Splatted tensor or vector constant.
+ auto converted = convertSplatElementsAttr(
+ realValue.cast<SplatElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else if (realValue.isa<DenseFPElementsAttr>()) {
+ // Dense tensor or vector constant.
+ auto converted = convertDenseFPElementsAttr(
+ realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else if (realValue.isa<SparseElementsAttr>()) {
+ // Sparse tensor or vector constant.
+ auto converted = convertSparseElementsAttr(
+ realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else {
+ // Nothing else matched: try to convert a primitive.
+ return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
+ outConvertedType);
+ }
+}
+
+/// Convert an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType().
+/// Returns nullptr if the conversion is not supported.
+/// On success, stores the converted type in outConvertedType.
+Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
+ Type &outConvertedType) {
+ // Hard-coded to just support UniformQuantizedType. This will need to
+ // be generalized when there is more than one.
+ auto uniformQuantizedType =
+ quantizedElementType.dyn_cast<UniformQuantizedType>();
+ if (!uniformQuantizedType) {
+ return nullptr;
+ }
+ UniformQuantizedValueConverter converter(uniformQuantizedType);
+ return quantizeAttrUniform(realValue, uniformQuantizedType, converter,
+ outConvertedType);
+}
+
+} // namespace quant
+} // namespace mlir
diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
new file mode 100644
index 00000000000..d791075f5db
--- /dev/null
+++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
@@ -0,0 +1,73 @@
+//===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+static bool isQuantizablePrimitiveType(Type inputType) {
+ return inputType.isa<FloatType>();
+}
+
+const ExpressedToUniformQuantizedConverter
+ExpressedToUniformQuantizedConverter::forInputType(Type inputType) {
+ switch (inputType.getKind()) {
+ default:
+ if (isQuantizablePrimitiveType(inputType)) {
+ // Supported primitive type (which just is the expressed type).
+ return ExpressedToUniformQuantizedConverter{inputType, inputType};
+ }
+ // Unsupported.
+ return ExpressedToUniformQuantizedConverter{inputType, nullptr};
+ case StandardTypes::RankedTensor:
+ case StandardTypes::UnrankedTensor:
+ case StandardTypes::Vector: {
+ Type elementType = inputType.cast<VectorOrTensorType>().getElementType();
+ if (!isQuantizablePrimitiveType(elementType)) {
+ // Unsupported.
+ return ExpressedToUniformQuantizedConverter{inputType, nullptr};
+ }
+ return ExpressedToUniformQuantizedConverter{
+ inputType, inputType.cast<VectorOrTensorType>().getElementType()};
+ }
+ }
+}
+
+Type ExpressedToUniformQuantizedConverter::convert(
+ UniformQuantizedType elementalType) const {
+ assert(expressedType && "convert() on unsupported conversion");
+
+ switch (inputType.getKind()) {
+ default:
+ if (isQuantizablePrimitiveType(elementalType)) {
+ // For primitives, just use the new elemental type.
+ return elementalType;
+ }
+ // Unsupported.
+ return nullptr;
+ case StandardTypes::RankedTensor:
+ return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
+ elementalType);
+ case StandardTypes::UnrankedTensor:
+ return UnrankedTensorType::get(elementalType);
+ case StandardTypes::Vector:
+ return VectorType::get(inputType.cast<VectorType>().getShape(),
+ elementalType);
+ }
+}
OpenPOWER on IntegriCloud