summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/FxpMathOps
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps')
-rw-r--r--mlir/lib/Dialect/FxpMathOps/CMakeLists.txt15
-rw-r--r--mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp24
-rw-r--r--mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp38
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp410
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h233
5 files changed, 720 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt b/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt
new file mode 100644
index 00000000000..9eddc5545f5
--- /dev/null
+++ b/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRFxpMathOps
+ IR/FxpMathOps.cpp
+ IR/DialectRegistration.cpp
+ Transforms/LowerUniformRealMath.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/FxpMathOps
+ )
+add_dependencies(MLIRFxpMathOps
+ MLIRFxpMathOpsIncGen
+ MLIRQuantOps
+ MLIRIR
+ MLIRPass
+ MLIRSupport
+ MLIRStandardOps)
diff --git a/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp b/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp
new file mode 100644
index 00000000000..aa6782e1464
--- /dev/null
+++ b/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp
@@ -0,0 +1,24 @@
+//===- DialectRegistration.cpp - Register FxpMathOps dialect --------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+
+using namespace mlir;
+using namespace mlir::fxpmath;
+
+// Static initialization for the fxpmath ops dialect registration.
+static mlir::DialectRegistration<FxpMathOpsDialect> FxpMathOps;
diff --git a/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp b/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp
new file mode 100644
index 00000000000..18c07b07117
--- /dev/null
+++ b/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp
@@ -0,0 +1,38 @@
+//===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::fxpmath;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc"
+
+FxpMathOpsDialect::FxpMathOpsDialect(MLIRContext *context)
+ : Dialect(/*name=*/"fxpmath", context) {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc"
+ >();
+}
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
new file mode 100644
index 00000000000..32d8de3c25d
--- /dev/null
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -0,0 +1,410 @@
+//===- LowerUniformRealMath.cpp ------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "UniformKernelUtils.h"
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/FxpMathOps/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::fxpmath;
+using namespace mlir::fxpmath::detail;
+using namespace mlir::quant;
+
+namespace {
+
+struct LowerUniformRealMathPass
+ : public FunctionPass<LowerUniformRealMathPass> {
+ void runOnFunction() override;
+};
+
+struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Dequantize
+//===----------------------------------------------------------------------===//
+
+static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
+ UniformQuantizedType elementType,
+ PatternRewriter &rewriter) {
+ // Pre-conditions.
+ if (!elementType.isSigned()) {
+ // TODO: Support unsigned storage type.
+ rewriter.getContext()->emitWarning(
+ loc, "unimplemented: dequantize signed uniform");
+ return nullptr;
+ }
+
+ Type storageType = elementType.castToStorageType(input->getType());
+ Type realType = elementType.castToExpressedType(input->getType());
+ Type intermediateType =
+ castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
+ assert(storageType && "cannot cast to storage type");
+ assert(realType && "cannot cast to expressed type");
+
+ // Cast to storage type.
+ input = rewriter.create<StorageCastOp>(loc, storageType, input);
+
+ // Promote to intermediate type.
+ input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
+
+ // Apply zero-point offset.
+ if (elementType.getZeroPoint() != 0) {
+ Value *negZeroPointConst = rewriter.create<ConstantOp>(
+ loc, broadcastScalarConstIntValue(intermediateType,
+ -elementType.getZeroPoint()));
+ input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
+ }
+
+ // Convert to float.
+ input = rewriter.create<ConvertISToFOp>(loc, realType, input);
+
+ // Mul by scale.
+ Value *scaleConst = rewriter.create<ConstantOp>(
+ loc, broadcastScalarConstFloatValue(realType,
+ APFloat(elementType.getScale())));
+ return rewriter.create<MulFOp>(loc, input, scaleConst);
+}
+
+static Value *
+emitUniformPerAxisDequantize(Location loc, Value *input,
+ UniformQuantizedPerAxisType elementType,
+ PatternRewriter &rewriter) {
+ // TODO: Support per-axis dequantize.
+ rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
+ << "unimplemented: per-axis uniform dequantization";
+ return nullptr;
+}
+
+static Value *emitDequantize(Location loc, Value *input,
+ PatternRewriter &rewriter) {
+ Type inputType = input->getType();
+ QuantizedType qElementType =
+ QuantizedType::getQuantizedElementType(inputType);
+ if (auto uperLayerElementType =
+ qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
+ return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
+ rewriter);
+ } else if (auto uperAxisElementType =
+ qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
+ return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
+ rewriter);
+ } else {
+ return nullptr;
+ }
+}
+
+namespace {
+
+struct UniformDequantizePattern : public RewritePattern {
+ UniformDequantizePattern(MLIRContext *context)
+ : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto dcastOp = cast<DequantizeCastOp>(op);
+ Type inputType = dcastOp.arg()->getType();
+ Type outputType = dcastOp.getResult()->getType();
+
+ QuantizedType inputElementType =
+ QuantizedType::getQuantizedElementType(inputType);
+ Type expressedOutputType = inputElementType.castToExpressedType(inputType);
+ if (expressedOutputType != outputType) {
+ // Not a valid uniform cast.
+ return matchFailure();
+ }
+
+ Value *dequantizedValue =
+ emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
+ if (!dequantizedValue) {
+ return matchFailure();
+ }
+
+ rewriter.replaceOp(op, dequantizedValue);
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Elementwise add
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
+ PatternRewriter &rewriter) {
+ if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
+ info.rhsType != info.resultType) {
+ return failure();
+ }
+
+ // Choose a byte aligned intermediate width big enough to perform the
+ // calculation without overflow.
+ // TODO: This should probably be made just big enough to avoid overflow and
+ // leave the downstream tooling to decide how to align that to machine
+ // word sizes.
+ unsigned intermediateWidth =
+ info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
+ IntegerType intermediateElementType =
+ IntegerType::get(intermediateWidth, rewriter.getContext());
+ Type intermediateType =
+ castElementType(info.resultStorageType, intermediateElementType);
+
+ // Cast operands to storage type.
+ Value *lhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.lhsStorageType, info.lhs)
+ .getResult();
+ Value *rhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.rhsStorageType, info.rhs)
+ .getResult();
+
+ // Cast to the intermediate sized type.
+ lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ lhsValue);
+ rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ rhsValue);
+
+ // Add.
+ Value *resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
+
+ // Zero point offset adjustment.
+ // result = (lhs - zp) + (rhs - zp) + zp
+ // zpOffset = -zp
+ int zpOffset = -1 * info.resultType.getZeroPoint();
+ if (zpOffset != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(),
+ broadcastScalarConstIntValue(intermediateType, zpOffset));
+ resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
+ }
+
+ // Clamp.
+ auto clampMinMax = info.getClampMinMax(intermediateElementType);
+ resultValue = rewriter.create<ClampISOp>(
+ info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
+
+ // Convert back to original type.
+ resultValue = rewriter.create<ConvertISOp>(
+ info.op->getLoc(), info.resultStorageType, resultValue);
+
+ // Cast back for new result.
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ info.op, info.getQuantizedResultType(), resultValue);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise mul
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
+ PatternRewriter &rewriter) {
+ if (!info.resultType.isSigned()) {
+ return failure();
+ }
+
+ double outputMultiplierReal = info.lhsType.getScale() *
+ info.rhsType.getScale() /
+ info.resultType.getScale();
+ if (outputMultiplierReal > 1.0) {
+ info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
+ return failure();
+ }
+
+ // TODO: Choose an appropriate intermediate width for muls > 8 bits to
+ // avoid overflow.
+ unsigned intermediateWidth = 32;
+ IntegerType intermediateElementType =
+ IntegerType::get(intermediateWidth, rewriter.getContext());
+ Type intermediateType =
+ castElementType(info.resultStorageType, intermediateElementType);
+
+ // Cast operands to storage type.
+ Value *lhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.lhsStorageType, info.lhs)
+ .getResult();
+ Value *rhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.rhsStorageType, info.rhs)
+ .getResult();
+
+ // Cast to the intermediate sized type.
+ lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ lhsValue);
+ rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ rhsValue);
+
+ // Apply argument zeroPoints.
+ if (info.lhsType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(), broadcastScalarConstIntValue(
+ intermediateType, -info.lhsType.getZeroPoint()));
+ lhsValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
+ }
+
+ if (info.rhsType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(), broadcastScalarConstIntValue(
+ intermediateType, -info.rhsType.getZeroPoint()));
+ rhsValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
+ }
+
+ // Mul.
+ Value *resultValue =
+ rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
+
+ // Scale output.
+ QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
+ resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
+ info.op->getLoc(), resultValue,
+ IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
+ resultValue = rewriter.create<RoundingDivideByPotISOp>(
+ info.op->getLoc(), resultValue,
+ IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
+
+ // Zero point offset adjustment.
+ if (info.resultType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(),
+ broadcastScalarConstIntValue(intermediateType,
+ info.resultType.getZeroPoint()));
+ resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
+ }
+
+ // Clamp.
+ auto clampMinMax = info.getClampMinMax(intermediateElementType);
+ resultValue = rewriter.create<ClampISOp>(
+ info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
+
+ // Convert back to original type.
+ resultValue = rewriter.create<ConvertISOp>(
+ info.op->getLoc(), info.resultStorageType, resultValue);
+
+ // Cast back for new result.
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ info.op, info.getQuantizedResultType(), resultValue);
+
+ return success();
+}
+
+namespace {
+
+struct UniformRealAddEwPattern : public RewritePattern {
+ UniformRealAddEwPattern(MLIRContext *context)
+ : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto addOp = cast<RealAddEwOp>(op);
+ const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
+ addOp.clamp_min(), addOp.clamp_max());
+ if (!info.isValid()) {
+ return matchFailure();
+ }
+
+ // Try all of the permutations we support.
+ if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+};
+
+struct UniformRealMulEwPattern : public RewritePattern {
+ UniformRealMulEwPattern(MLIRContext *context)
+ : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto mulOp = cast<RealMulEwOp>(op);
+ const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
+ mulOp.clamp_min(), mulOp.clamp_max());
+ if (!info.isValid()) {
+ return matchFailure();
+ }
+
+ // Try all of the permutations we support.
+ if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// LowerUniformRealMath pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformRealMathPass::runOnFunction() {
+ auto &fn = getFunction();
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
+ patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
+ applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() {
+ return new LowerUniformRealMathPass();
+}
+
+static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
+ "fxpmath-lower-uniform-real-math",
+ "Lowers uniform-quantized real math ops to integer arithmetic.");
+
+//===----------------------------------------------------------------------===//
+// LowerUniformCasts pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformCastsPass::runOnFunction() {
+ auto &fn = getFunction();
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
+ applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() {
+ return new LowerUniformCastsPass();
+}
+
+static PassRegistration<LowerUniformCastsPass>
+ lowerUniformCastsPass("fxpmath-lower-uniform-casts",
+ "Lowers uniform-quantized casts.");
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
new file mode 100644
index 00000000000..325b0ca93ca
--- /dev/null
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
@@ -0,0 +1,233 @@
+//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
+#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Operation.h"
+
+#include <cmath>
+
+namespace mlir {
+namespace fxpmath {
+namespace detail {
+
+inline quant::UniformQuantizedType getUniformElementType(Type t) {
+ return quant::QuantizedType::getQuantizedElementType(t)
+ .dyn_cast_or_null<quant::UniformQuantizedType>();
+}
+
+inline bool hasStorageBitWidth(quant::QuantizedType t,
+ llvm::ArrayRef<unsigned> checkWidths) {
+ unsigned w = t.getStorageType().getIntOrFloatBitWidth();
+ for (unsigned checkWidth : checkWidths) {
+ if (w == checkWidth)
+ return true;
+ }
+ return false;
+}
+
+/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
+/// be considered an exact integral value.
+template <typename F> bool integralLog2(F x, int &log2Result) {
+ const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
+ const F xLog2Rounded = std::round(xLog2);
+ const F xLog2Frac = xLog2 - xLog2Rounded;
+ log2Result = static_cast<int>(xLog2Rounded);
+ // Allow small comparison slop below the level that would make a difference
+ // for 2^16 levels.
+ return std::abs(xLog2Frac) < 1e-6;
+}
+
+/// Helper class for operating on binary operations where all operands
+/// and the result are a UniformQuantizedType.
+struct UniformBinaryOpInfo {
+ UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
+ Optional<APFloat> clampMin, Optional<APFloat> clampMax)
+ : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
+ lhsType(getUniformElementType(lhs->getType())),
+ rhsType(getUniformElementType(rhs->getType())),
+ resultType(getUniformElementType(*op->result_type_begin())),
+ lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
+ rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
+ resultStorageType(
+ quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
+ }
+
+ /// Returns whether this info is valid (all types defined, etc).
+ bool isValid() const {
+ return lhsType && rhsType && resultType && lhsStorageType &&
+ rhsStorageType && resultStorageType;
+ }
+
+ /// Gets the final quantized result type of the result.
+ Type getQuantizedResultType() const { return *op->result_type_begin(); }
+
+ /// Returns whether the storage type of all operands is identical.
+ bool isSameStorageType() const {
+ return lhsType.getStorageType() == rhsType.getStorageType() &&
+ lhsType.getStorageType() == resultType.getStorageType();
+ }
+
+ /// Returns whether all operands and result are considered fixedpoint power
+ /// of two, setting the lhs, rhs, and result log2 scale references.
+ bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
+ int &resultLog2Scale) const {
+ if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
+ !resultType.isFixedPoint()) {
+ return false;
+ }
+
+ if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
+ !integralLog2(rhsType.getScale(), rhsLog2Scale) ||
+ !integralLog2(resultType.getScale(), resultLog2Scale)) {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// Gets the result integer clamp range given the result quantized type
+ // and any explicit clamp provided as attributes.
+ std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
+ int64_t typeMin = resultType.getStorageTypeMin();
+ int64_t typeMax = resultType.getStorageTypeMax();
+
+ if (clampMin || clampMax) {
+ quant::UniformQuantizedValueConverter conv(resultType);
+ if (clampMin) {
+ typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
+ }
+ if (clampMax) {
+ typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
+ }
+ }
+
+ // The quantized, integral ops expect clamps as 32bit ints.
+ return {
+ IntegerAttr::get(ty, typeMin),
+ IntegerAttr::get(ty, typeMax),
+ };
+ }
+
+ Operation *op;
+ Value *lhs;
+ Value *rhs;
+ Optional<APFloat> clampMin;
+ Optional<APFloat> clampMax;
+
+ // Element UniformQuantizedType for operands/result.
+ quant::UniformQuantizedType lhsType;
+ quant::UniformQuantizedType rhsType;
+ quant::UniformQuantizedType resultType;
+
+ // Full storage-based types.
+ Type lhsStorageType;
+ Type rhsStorageType;
+ Type resultStorageType;
+};
+
+/// Derives a quantized multiplier and shift from a real valued multiplier
+/// less than 1.
+struct QuantizedMultiplierSmallerThanOneExp {
+ QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
+ assert(realMultiplier < 1.0);
+ assert(realMultiplier > 0.0);
+
+ const double q = std::frexp(realMultiplier, &exponent);
+ auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
+ assert(qFixed <= (1ll << 31));
+ if (qFixed == (1ll << 31)) {
+ qFixed /= 2;
+ ++exponent;
+ }
+ assert(qFixed <= std::numeric_limits<int32_t>::max());
+ multiplier = static_cast<int32_t>(qFixed);
+ }
+
+ int32_t multiplier;
+ int exponent;
+};
+
+/// Casts an integer or floating point based type to a new element type.
+inline Type castElementType(Type t, Type newElementType) {
+ if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
+ switch (vt.getKind()) {
+ case StandardTypes::Kind::Vector:
+ return VectorType::get(vt.getShape(), newElementType);
+ case StandardTypes::Kind::RankedTensor:
+ return RankedTensorType::get(vt.getShape(), newElementType);
+ case StandardTypes::Kind::UnrankedTensor:
+ return UnrankedTensorType::get(newElementType);
+ }
+ }
+ assert(t.isIntOrFloat());
+ return newElementType;
+}
+
+/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
+/// be a primitive/vector/tensor).
+inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
+ if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
+ assert(vt.getElementType().isa<IntegerType>());
+ return SplatElementsAttr::get(vt,
+ IntegerAttr::get(vt.getElementType(), value));
+ }
+
+ auto integerType = t.cast<IntegerType>();
+ assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
+ return IntegerAttr::get(integerType, value);
+}
+
+/// Given an APFloat, converts it to the float semantics that matches the
+/// given FloatType, silently ignoring inexact conversions.
+inline APFloat convertFloatToType(FloatType ft, APFloat value) {
+ bool losesInfo;
+ auto status = value.convert(ft.getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &losesInfo);
+ (void)status; // unused in opt mode
+ assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
+ "could not convert to float const");
+ return value;
+}
+
+/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
+/// be a primitive/vector/tensor).
+inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
+ if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
+ FloatType floatElementType = vt.getElementType().dyn_cast<FloatType>();
+ assert(floatElementType &&
+ "float broadcast element type must be float like");
+ APFloat apValue = convertFloatToType(floatElementType, value);
+ return SplatElementsAttr::get(vt,
+ FloatAttr::get(vt.getElementType(), apValue));
+ } else {
+ auto floatType = t.dyn_cast<FloatType>();
+ assert(floatType && "float broadcast must be of float type");
+ APFloat apValue = convertFloatToType(floatType, value);
+ return FloatAttr::get(floatType, apValue);
+ }
+}
+
+} // namespace detail
+} // namespace fxpmath
+} // namespace mlir
+
+#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
OpenPOWER on IntegriCloud