diff options
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps')
-rw-r--r-- | mlir/lib/Dialect/FxpMathOps/CMakeLists.txt | 15 | ||||
-rw-r--r-- | mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp | 24 | ||||
-rw-r--r-- | mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp | 38 | ||||
-rw-r--r-- | mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp | 410 | ||||
-rw-r--r-- | mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h | 233 |
5 files changed, 720 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt b/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt new file mode 100644 index 00000000000..9eddc5545f5 --- /dev/null +++ b/mlir/lib/Dialect/FxpMathOps/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRFxpMathOps + IR/FxpMathOps.cpp + IR/DialectRegistration.cpp + Transforms/LowerUniformRealMath.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/FxpMathOps + ) +add_dependencies(MLIRFxpMathOps + MLIRFxpMathOpsIncGen + MLIRQuantOps + MLIRIR + MLIRPass + MLIRSupport + MLIRStandardOps) diff --git a/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp b/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp new file mode 100644 index 00000000000..aa6782e1464 --- /dev/null +++ b/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp @@ -0,0 +1,24 @@ +//===- DialectRegistration.cpp - Register FxpMathOps dialect --------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/FxpMathOps/FxpMathOps.h" + +using namespace mlir; +using namespace mlir::fxpmath; + +// Static initialization for the fxpmath ops dialect registration. +static mlir::DialectRegistration<FxpMathOpsDialect> FxpMathOps; diff --git a/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp b/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp new file mode 100644 index 00000000000..18c07b07117 --- /dev/null +++ b/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp @@ -0,0 +1,38 @@ +//===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/FxpMathOps/FxpMathOps.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace mlir::fxpmath; + +#define GET_OP_CLASSES +#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc" + +FxpMathOpsDialect::FxpMathOpsDialect(MLIRContext *context) + : Dialect(/*name=*/"fxpmath", context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/FxpMathOps/FxpMathOps.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp new file mode 100644 index 00000000000..32d8de3c25d --- /dev/null +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -0,0 +1,410 @@ +//===- LowerUniformRealMath.cpp ------------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "UniformKernelUtils.h" + +#include "mlir/Dialect/FxpMathOps/FxpMathOps.h" +#include "mlir/Dialect/FxpMathOps/Passes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/StandardOps/Ops.h" + +using namespace mlir; +using namespace mlir::fxpmath; +using namespace mlir::fxpmath::detail; +using namespace mlir::quant; + +namespace { + +struct LowerUniformRealMathPass + : public FunctionPass<LowerUniformRealMathPass> { + void runOnFunction() override; +}; + +struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> { + void runOnFunction() override; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Dequantize +//===----------------------------------------------------------------------===// + +static Value *emitUniformPerLayerDequantize(Location loc, Value *input, + UniformQuantizedType elementType, + PatternRewriter &rewriter) { + // Pre-conditions. + if (!elementType.isSigned()) { + // TODO: Support unsigned storage type. + rewriter.getContext()->emitWarning( + loc, "unimplemented: dequantize signed uniform"); + return nullptr; + } + + Type storageType = elementType.castToStorageType(input->getType()); + Type realType = elementType.castToExpressedType(input->getType()); + Type intermediateType = + castElementType(storageType, IntegerType::get(32, rewriter.getContext())); + assert(storageType && "cannot cast to storage type"); + assert(realType && "cannot cast to expressed type"); + + // Cast to storage type. + input = rewriter.create<StorageCastOp>(loc, storageType, input); + + // Promote to intermediate type. + input = rewriter.create<ConvertISOp>(loc, intermediateType, input); + + // Apply zero-point offset. + if (elementType.getZeroPoint() != 0) { + Value *negZeroPointConst = rewriter.create<ConstantOp>( + loc, broadcastScalarConstIntValue(intermediateType, + -elementType.getZeroPoint())); + input = rewriter.create<AddIOp>(loc, input, negZeroPointConst); + } + + // Convert to float. + input = rewriter.create<ConvertISToFOp>(loc, realType, input); + + // Mul by scale. + Value *scaleConst = rewriter.create<ConstantOp>( + loc, broadcastScalarConstFloatValue(realType, + APFloat(elementType.getScale()))); + return rewriter.create<MulFOp>(loc, input, scaleConst); +} + +static Value * +emitUniformPerAxisDequantize(Location loc, Value *input, + UniformQuantizedPerAxisType elementType, + PatternRewriter &rewriter) { + // TODO: Support per-axis dequantize. + rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning) + << "unimplemented: per-axis uniform dequantization"; + return nullptr; +} + +static Value *emitDequantize(Location loc, Value *input, + PatternRewriter &rewriter) { + Type inputType = input->getType(); + QuantizedType qElementType = + QuantizedType::getQuantizedElementType(inputType); + if (auto uperLayerElementType = + qElementType.dyn_cast_or_null<UniformQuantizedType>()) { + return emitUniformPerLayerDequantize(loc, input, uperLayerElementType, + rewriter); + } else if (auto uperAxisElementType = + qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) { + return emitUniformPerAxisDequantize(loc, input, uperAxisElementType, + rewriter); + } else { + return nullptr; + } +} + +namespace { + +struct UniformDequantizePattern : public RewritePattern { + UniformDequantizePattern(MLIRContext *context) + : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto dcastOp = cast<DequantizeCastOp>(op); + Type inputType = dcastOp.arg()->getType(); + Type outputType = dcastOp.getResult()->getType(); + + QuantizedType inputElementType = + QuantizedType::getQuantizedElementType(inputType); + Type expressedOutputType = inputElementType.castToExpressedType(inputType); + if (expressedOutputType != outputType) { + // Not a valid uniform cast. + return matchFailure(); + } + + Value *dequantizedValue = + emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter); + if (!dequantizedValue) { + return matchFailure(); + } + + rewriter.replaceOp(op, dequantizedValue); + return matchSuccess(); + } +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Elementwise add +//===----------------------------------------------------------------------===// + +static LogicalResult +tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, + PatternRewriter &rewriter) { + if (!info.resultType.isSigned() || info.lhsType != info.resultType || + info.rhsType != info.resultType) { + return failure(); + } + + // Choose a byte aligned intermediate width big enough to perform the + // calculation without overflow. + // TODO: This should probably be made just big enough to avoid overflow and + // leave the downstream tooling to decide how to align that to machine + // word sizes. + unsigned intermediateWidth = + info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32; + IntegerType intermediateElementType = + IntegerType::get(intermediateWidth, rewriter.getContext()); + Type intermediateType = + castElementType(info.resultStorageType, intermediateElementType); + + // Cast operands to storage type. + Value *lhsValue = rewriter + .create<StorageCastOp>(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + Value *rhsValue = rewriter + .create<StorageCastOp>(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); + + // Cast to the intermediate sized type. + lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, + lhsValue); + rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, + rhsValue); + + // Add. + Value *resultValue = + rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue); + + // Zero point offset adjustment. + // result = (lhs - zp) + (rhs - zp) + zp + // zpOffset = -zp + int zpOffset = -1 * info.resultType.getZeroPoint(); + if (zpOffset != 0) { + Value *zpOffsetConst = rewriter.create<ConstantOp>( + info.op->getLoc(), + broadcastScalarConstIntValue(intermediateType, zpOffset)); + resultValue = + rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst); + } + + // Clamp. + auto clampMinMax = info.getClampMinMax(intermediateElementType); + resultValue = rewriter.create<ClampISOp>( + info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second); + + // Convert back to original type. + resultValue = rewriter.create<ConvertISOp>( + info.op->getLoc(), info.resultStorageType, resultValue); + + // Cast back for new result. + rewriter.replaceOpWithNewOp<StorageCastOp>( + info.op, info.getQuantizedResultType(), resultValue); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Elementwise mul +//===----------------------------------------------------------------------===// + +static LogicalResult +tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, + PatternRewriter &rewriter) { + if (!info.resultType.isSigned()) { + return failure(); + } + + double outputMultiplierReal = info.lhsType.getScale() * + info.rhsType.getScale() / + info.resultType.getScale(); + if (outputMultiplierReal > 1.0) { + info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0"); + return failure(); + } + + // TODO: Choose an appropriate intermediate width for muls > 8 bits to + // avoid overflow. + unsigned intermediateWidth = 32; + IntegerType intermediateElementType = + IntegerType::get(intermediateWidth, rewriter.getContext()); + Type intermediateType = + castElementType(info.resultStorageType, intermediateElementType); + + // Cast operands to storage type. + Value *lhsValue = rewriter + .create<StorageCastOp>(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + Value *rhsValue = rewriter + .create<StorageCastOp>(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); + + // Cast to the intermediate sized type. + lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, + lhsValue); + rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType, + rhsValue); + + // Apply argument zeroPoints. + if (info.lhsType.getZeroPoint() != 0) { + Value *zpOffsetConst = rewriter.create<ConstantOp>( + info.op->getLoc(), broadcastScalarConstIntValue( + intermediateType, -info.lhsType.getZeroPoint())); + lhsValue = + rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst); + } + + if (info.rhsType.getZeroPoint() != 0) { + Value *zpOffsetConst = rewriter.create<ConstantOp>( + info.op->getLoc(), broadcastScalarConstIntValue( + intermediateType, -info.rhsType.getZeroPoint())); + rhsValue = + rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst); + } + + // Mul. + Value *resultValue = + rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue); + + // Scale output. + QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal); + resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>( + info.op->getLoc(), resultValue, + IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier)); + resultValue = rewriter.create<RoundingDivideByPotISOp>( + info.op->getLoc(), resultValue, + IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent)); + + // Zero point offset adjustment. + if (info.resultType.getZeroPoint() != 0) { + Value *zpOffsetConst = rewriter.create<ConstantOp>( + info.op->getLoc(), + broadcastScalarConstIntValue(intermediateType, + info.resultType.getZeroPoint())); + resultValue = + rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst); + } + + // Clamp. + auto clampMinMax = info.getClampMinMax(intermediateElementType); + resultValue = rewriter.create<ClampISOp>( + info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second); + + // Convert back to original type. + resultValue = rewriter.create<ConvertISOp>( + info.op->getLoc(), info.resultStorageType, resultValue); + + // Cast back for new result. + rewriter.replaceOpWithNewOp<StorageCastOp>( + info.op, info.getQuantizedResultType(), resultValue); + + return success(); +} + +namespace { + +struct UniformRealAddEwPattern : public RewritePattern { + UniformRealAddEwPattern(MLIRContext *context) + : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto addOp = cast<RealAddEwOp>(op); + const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(), + addOp.clamp_min(), addOp.clamp_max()); + if (!info.isValid()) { + return matchFailure(); + } + + // Try all of the permutations we support. + if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) { + return matchSuccess(); + } + + return matchFailure(); + } +}; + +struct UniformRealMulEwPattern : public RewritePattern { + UniformRealMulEwPattern(MLIRContext *context) + : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto mulOp = cast<RealMulEwOp>(op); + const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(), + mulOp.clamp_min(), mulOp.clamp_max()); + if (!info.isValid()) { + return matchFailure(); + } + + // Try all of the permutations we support. + if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) { + return matchSuccess(); + } + + return matchFailure(); + } +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// LowerUniformRealMath pass +//===----------------------------------------------------------------------===// + +void LowerUniformRealMathPass::runOnFunction() { + auto &fn = getFunction(); + OwningRewritePatternList patterns; + auto *context = &getContext(); + patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context)); + patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context)); + applyPatternsGreedily(fn, std::move(patterns)); +} + +FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() { + return new LowerUniformRealMathPass(); +} + +static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass( + "fxpmath-lower-uniform-real-math", + "Lowers uniform-quantized real math ops to integer arithmetic."); + +//===----------------------------------------------------------------------===// +// LowerUniformCasts pass +//===----------------------------------------------------------------------===// + +void LowerUniformCastsPass::runOnFunction() { + auto &fn = getFunction(); + OwningRewritePatternList patterns; + auto *context = &getContext(); + patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context)); + applyPatternsGreedily(fn, std::move(patterns)); +} + +FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() { + return new LowerUniformCastsPass(); +} + +static PassRegistration<LowerUniformCastsPass> + lowerUniformCastsPass("fxpmath-lower-uniform-casts", + "Lowers uniform-quantized casts."); diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h new file mode 100644 index 00000000000..325b0ca93ca --- /dev/null +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -0,0 +1,233 @@ +//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ +#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ + +#include "mlir/Dialect/QuantOps/QuantOps.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/Dialect/QuantOps/UniformSupport.h" +#include "mlir/IR/Operation.h" + +#include <cmath> + +namespace mlir { +namespace fxpmath { +namespace detail { + +inline quant::UniformQuantizedType getUniformElementType(Type t) { + return quant::QuantizedType::getQuantizedElementType(t) + .dyn_cast_or_null<quant::UniformQuantizedType>(); +} + +inline bool hasStorageBitWidth(quant::QuantizedType t, + llvm::ArrayRef<unsigned> checkWidths) { + unsigned w = t.getStorageType().getIntOrFloatBitWidth(); + for (unsigned checkWidth : checkWidths) { + if (w == checkWidth) + return true; + } + return false; +} + +/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can +/// be considered an exact integral value. +template <typename F> bool integralLog2(F x, int &log2Result) { + const F xLog2 = std::log(x) * (1.0 / std::log(2.0)); + const F xLog2Rounded = std::round(xLog2); + const F xLog2Frac = xLog2 - xLog2Rounded; + log2Result = static_cast<int>(xLog2Rounded); + // Allow small comparison slop below the level that would make a difference + // for 2^16 levels. + return std::abs(xLog2Frac) < 1e-6; +} + +/// Helper class for operating on binary operations where all operands +/// and the result are a UniformQuantizedType. +struct UniformBinaryOpInfo { + UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs, + Optional<APFloat> clampMin, Optional<APFloat> clampMax) + : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax), + lhsType(getUniformElementType(lhs->getType())), + rhsType(getUniformElementType(rhs->getType())), + resultType(getUniformElementType(*op->result_type_begin())), + lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())), + rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())), + resultStorageType( + quant::QuantizedType::castToStorageType(*op->result_type_begin())) { + } + + /// Returns whether this info is valid (all types defined, etc). + bool isValid() const { + return lhsType && rhsType && resultType && lhsStorageType && + rhsStorageType && resultStorageType; + } + + /// Gets the final quantized result type of the result. + Type getQuantizedResultType() const { return *op->result_type_begin(); } + + /// Returns whether the storage type of all operands is identical. + bool isSameStorageType() const { + return lhsType.getStorageType() == rhsType.getStorageType() && + lhsType.getStorageType() == resultType.getStorageType(); + } + + /// Returns whether all operands and result are considered fixedpoint power + /// of two, setting the lhs, rhs, and result log2 scale references. + bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale, + int &resultLog2Scale) const { + if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() || + !resultType.isFixedPoint()) { + return false; + } + + if (!integralLog2(lhsType.getScale(), lhsLog2Scale) || + !integralLog2(rhsType.getScale(), rhsLog2Scale) || + !integralLog2(resultType.getScale(), resultLog2Scale)) { + return false; + } + + return true; + } + + /// Gets the result integer clamp range given the result quantized type + // and any explicit clamp provided as attributes. + std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const { + int64_t typeMin = resultType.getStorageTypeMin(); + int64_t typeMax = resultType.getStorageTypeMax(); + + if (clampMin || clampMax) { + quant::UniformQuantizedValueConverter conv(resultType); + if (clampMin) { + typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin)); + } + if (clampMax) { + typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax)); + } + } + + // The quantized, integral ops expect clamps as 32bit ints. + return { + IntegerAttr::get(ty, typeMin), + IntegerAttr::get(ty, typeMax), + }; + } + + Operation *op; + Value *lhs; + Value *rhs; + Optional<APFloat> clampMin; + Optional<APFloat> clampMax; + + // Element UniformQuantizedType for operands/result. + quant::UniformQuantizedType lhsType; + quant::UniformQuantizedType rhsType; + quant::UniformQuantizedType resultType; + + // Full storage-based types. + Type lhsStorageType; + Type rhsStorageType; + Type resultStorageType; +}; + +/// Derives a quantized multiplier and shift from a real valued multiplier +/// less than 1. +struct QuantizedMultiplierSmallerThanOneExp { + QuantizedMultiplierSmallerThanOneExp(double realMultiplier) { + assert(realMultiplier < 1.0); + assert(realMultiplier > 0.0); + + const double q = std::frexp(realMultiplier, &exponent); + auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31))); + assert(qFixed <= (1ll << 31)); + if (qFixed == (1ll << 31)) { + qFixed /= 2; + ++exponent; + } + assert(qFixed <= std::numeric_limits<int32_t>::max()); + multiplier = static_cast<int32_t>(qFixed); + } + + int32_t multiplier; + int exponent; +}; + +/// Casts an integer or floating point based type to a new element type. +inline Type castElementType(Type t, Type newElementType) { + if (auto vt = t.dyn_cast<VectorOrTensorType>()) { + switch (vt.getKind()) { + case StandardTypes::Kind::Vector: + return VectorType::get(vt.getShape(), newElementType); + case StandardTypes::Kind::RankedTensor: + return RankedTensorType::get(vt.getShape(), newElementType); + case StandardTypes::Kind::UnrankedTensor: + return UnrankedTensorType::get(newElementType); + } + } + assert(t.isIntOrFloat()); + return newElementType; +} + +/// Creates an IntegerAttr with a type that matches the shape of 't' (which can +/// be a primitive/vector/tensor). +inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) { + if (auto vt = t.dyn_cast<VectorOrTensorType>()) { + assert(vt.getElementType().isa<IntegerType>()); + return SplatElementsAttr::get(vt, + IntegerAttr::get(vt.getElementType(), value)); + } + + auto integerType = t.cast<IntegerType>(); + assert(t.isa<IntegerType>() && "integer broadcast must be of integer type"); + return IntegerAttr::get(integerType, value); +} + +/// Given an APFloat, converts it to the float semantics that matches the +/// given FloatType, silently ignoring inexact conversions. +inline APFloat convertFloatToType(FloatType ft, APFloat value) { + bool losesInfo; + auto status = value.convert(ft.getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + (void)status; // unused in opt mode + assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 && + "could not convert to float const"); + return value; +} + +/// Creates an IntegerAttr with a type that matches the shape of 't' (which can +/// be a primitive/vector/tensor). +inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) { + if (auto vt = t.dyn_cast<VectorOrTensorType>()) { + FloatType floatElementType = vt.getElementType().dyn_cast<FloatType>(); + assert(floatElementType && + "float broadcast element type must be float like"); + APFloat apValue = convertFloatToType(floatElementType, value); + return SplatElementsAttr::get(vt, + FloatAttr::get(vt.getElementType(), apValue)); + } else { + auto floatType = t.dyn_cast<FloatType>(); + assert(floatType && "float broadcast must be of float type"); + APFloat apValue = convertFloatToType(floatType, value); + return FloatAttr::get(floatType, apValue); + } +} + +} // namespace detail +} // namespace fxpmath +} // namespace mlir + +#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ |