summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
diff options
context:
space:
mode:
authorStella Laurenzo <laurenzo@google.com>2019-05-14 11:03:55 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-05-20 13:41:55 -0700
commitd4dcf7de9e6f5f00177c534d765c5b24d9db8ed8 (patch)
tree7f819d67376f8d08681b81740082c193b6aa7ad9 /mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
parent6264fccd3a4af9edc37f9b6d0f37763e61800ba5 (diff)
downloadbcm5719-llvm-d4dcf7de9e6f5f00177c534d765c5b24d9db8ed8.tar.gz
bcm5719-llvm-d4dcf7de9e6f5f00177c534d765c5b24d9db8ed8.zip
Move Quantization -> Dialect/QuantOps, FxpMathOps -> Dialect/FxpMathOps.
Adding the additional layer of directory was discussed offline and matches the Target/ tree. The names match the defacto convention we seem to be following where the C++ namespace is ^(.+)Ops/$ matched against the directory name. This is in preparation for patching the Quantizer into this tree, which would have been confusing without moving the Quantization dialect to its more proper home. It is left to others to move other dialects if desired. Tested: ninja check-mlir -- PiperOrigin-RevId: 248171982
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp')
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp410
1 files changed, 410 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
new file mode 100644
index 00000000000..32d8de3c25d
--- /dev/null
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -0,0 +1,410 @@
+//===- LowerUniformRealMath.cpp ------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "UniformKernelUtils.h"
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/FxpMathOps/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::fxpmath;
+using namespace mlir::fxpmath::detail;
+using namespace mlir::quant;
+
+namespace {
+
+struct LowerUniformRealMathPass
+ : public FunctionPass<LowerUniformRealMathPass> {
+ void runOnFunction() override;
+};
+
+struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Dequantize
+//===----------------------------------------------------------------------===//
+
+static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
+ UniformQuantizedType elementType,
+ PatternRewriter &rewriter) {
+ // Pre-conditions.
+ if (!elementType.isSigned()) {
+ // TODO: Support unsigned storage type.
+ rewriter.getContext()->emitWarning(
+ loc, "unimplemented: dequantize signed uniform");
+ return nullptr;
+ }
+
+ Type storageType = elementType.castToStorageType(input->getType());
+ Type realType = elementType.castToExpressedType(input->getType());
+ Type intermediateType =
+ castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
+ assert(storageType && "cannot cast to storage type");
+ assert(realType && "cannot cast to expressed type");
+
+ // Cast to storage type.
+ input = rewriter.create<StorageCastOp>(loc, storageType, input);
+
+ // Promote to intermediate type.
+ input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
+
+ // Apply zero-point offset.
+ if (elementType.getZeroPoint() != 0) {
+ Value *negZeroPointConst = rewriter.create<ConstantOp>(
+ loc, broadcastScalarConstIntValue(intermediateType,
+ -elementType.getZeroPoint()));
+ input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
+ }
+
+ // Convert to float.
+ input = rewriter.create<ConvertISToFOp>(loc, realType, input);
+
+ // Mul by scale.
+ Value *scaleConst = rewriter.create<ConstantOp>(
+ loc, broadcastScalarConstFloatValue(realType,
+ APFloat(elementType.getScale())));
+ return rewriter.create<MulFOp>(loc, input, scaleConst);
+}
+
+static Value *
+emitUniformPerAxisDequantize(Location loc, Value *input,
+ UniformQuantizedPerAxisType elementType,
+ PatternRewriter &rewriter) {
+ // TODO: Support per-axis dequantize.
+ rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
+ << "unimplemented: per-axis uniform dequantization";
+ return nullptr;
+}
+
+static Value *emitDequantize(Location loc, Value *input,
+ PatternRewriter &rewriter) {
+ Type inputType = input->getType();
+ QuantizedType qElementType =
+ QuantizedType::getQuantizedElementType(inputType);
+ if (auto uperLayerElementType =
+ qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
+ return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
+ rewriter);
+ } else if (auto uperAxisElementType =
+ qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
+ return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
+ rewriter);
+ } else {
+ return nullptr;
+ }
+}
+
+namespace {
+
+struct UniformDequantizePattern : public RewritePattern {
+ UniformDequantizePattern(MLIRContext *context)
+ : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto dcastOp = cast<DequantizeCastOp>(op);
+ Type inputType = dcastOp.arg()->getType();
+ Type outputType = dcastOp.getResult()->getType();
+
+ QuantizedType inputElementType =
+ QuantizedType::getQuantizedElementType(inputType);
+ Type expressedOutputType = inputElementType.castToExpressedType(inputType);
+ if (expressedOutputType != outputType) {
+ // Not a valid uniform cast.
+ return matchFailure();
+ }
+
+ Value *dequantizedValue =
+ emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
+ if (!dequantizedValue) {
+ return matchFailure();
+ }
+
+ rewriter.replaceOp(op, dequantizedValue);
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Elementwise add
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
+ PatternRewriter &rewriter) {
+ if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
+ info.rhsType != info.resultType) {
+ return failure();
+ }
+
+ // Choose a byte aligned intermediate width big enough to perform the
+ // calculation without overflow.
+ // TODO: This should probably be made just big enough to avoid overflow and
+ // leave the downstream tooling to decide how to align that to machine
+ // word sizes.
+ unsigned intermediateWidth =
+ info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
+ IntegerType intermediateElementType =
+ IntegerType::get(intermediateWidth, rewriter.getContext());
+ Type intermediateType =
+ castElementType(info.resultStorageType, intermediateElementType);
+
+ // Cast operands to storage type.
+ Value *lhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.lhsStorageType, info.lhs)
+ .getResult();
+ Value *rhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.rhsStorageType, info.rhs)
+ .getResult();
+
+ // Cast to the intermediate sized type.
+ lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ lhsValue);
+ rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ rhsValue);
+
+ // Add.
+ Value *resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
+
+ // Zero point offset adjustment.
+ // result = (lhs - zp) + (rhs - zp) + zp
+ // zpOffset = -zp
+ int zpOffset = -1 * info.resultType.getZeroPoint();
+ if (zpOffset != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(),
+ broadcastScalarConstIntValue(intermediateType, zpOffset));
+ resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
+ }
+
+ // Clamp.
+ auto clampMinMax = info.getClampMinMax(intermediateElementType);
+ resultValue = rewriter.create<ClampISOp>(
+ info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
+
+ // Convert back to original type.
+ resultValue = rewriter.create<ConvertISOp>(
+ info.op->getLoc(), info.resultStorageType, resultValue);
+
+ // Cast back for new result.
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ info.op, info.getQuantizedResultType(), resultValue);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise mul
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
+ PatternRewriter &rewriter) {
+ if (!info.resultType.isSigned()) {
+ return failure();
+ }
+
+ double outputMultiplierReal = info.lhsType.getScale() *
+ info.rhsType.getScale() /
+ info.resultType.getScale();
+ if (outputMultiplierReal > 1.0) {
+ info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
+ return failure();
+ }
+
+ // TODO: Choose an appropriate intermediate width for muls > 8 bits to
+ // avoid overflow.
+ unsigned intermediateWidth = 32;
+ IntegerType intermediateElementType =
+ IntegerType::get(intermediateWidth, rewriter.getContext());
+ Type intermediateType =
+ castElementType(info.resultStorageType, intermediateElementType);
+
+ // Cast operands to storage type.
+ Value *lhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.lhsStorageType, info.lhs)
+ .getResult();
+ Value *rhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.rhsStorageType, info.rhs)
+ .getResult();
+
+ // Cast to the intermediate sized type.
+ lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ lhsValue);
+ rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ rhsValue);
+
+ // Apply argument zeroPoints.
+ if (info.lhsType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(), broadcastScalarConstIntValue(
+ intermediateType, -info.lhsType.getZeroPoint()));
+ lhsValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
+ }
+
+ if (info.rhsType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(), broadcastScalarConstIntValue(
+ intermediateType, -info.rhsType.getZeroPoint()));
+ rhsValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
+ }
+
+ // Mul.
+ Value *resultValue =
+ rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
+
+ // Scale output.
+ QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
+ resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
+ info.op->getLoc(), resultValue,
+ IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
+ resultValue = rewriter.create<RoundingDivideByPotISOp>(
+ info.op->getLoc(), resultValue,
+ IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
+
+ // Zero point offset adjustment.
+ if (info.resultType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(),
+ broadcastScalarConstIntValue(intermediateType,
+ info.resultType.getZeroPoint()));
+ resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
+ }
+
+ // Clamp.
+ auto clampMinMax = info.getClampMinMax(intermediateElementType);
+ resultValue = rewriter.create<ClampISOp>(
+ info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
+
+ // Convert back to original type.
+ resultValue = rewriter.create<ConvertISOp>(
+ info.op->getLoc(), info.resultStorageType, resultValue);
+
+ // Cast back for new result.
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ info.op, info.getQuantizedResultType(), resultValue);
+
+ return success();
+}
+
+namespace {
+
+struct UniformRealAddEwPattern : public RewritePattern {
+ UniformRealAddEwPattern(MLIRContext *context)
+ : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto addOp = cast<RealAddEwOp>(op);
+ const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
+ addOp.clamp_min(), addOp.clamp_max());
+ if (!info.isValid()) {
+ return matchFailure();
+ }
+
+ // Try all of the permutations we support.
+ if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+};
+
+struct UniformRealMulEwPattern : public RewritePattern {
+ UniformRealMulEwPattern(MLIRContext *context)
+ : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto mulOp = cast<RealMulEwOp>(op);
+ const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
+ mulOp.clamp_min(), mulOp.clamp_max());
+ if (!info.isValid()) {
+ return matchFailure();
+ }
+
+ // Try all of the permutations we support.
+ if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// LowerUniformRealMath pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformRealMathPass::runOnFunction() {
+ auto &fn = getFunction();
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
+ patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
+ applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() {
+ return new LowerUniformRealMathPass();
+}
+
+static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
+ "fxpmath-lower-uniform-real-math",
+ "Lowers uniform-quantized real math ops to integer arithmetic.");
+
+//===----------------------------------------------------------------------===//
+// LowerUniformCasts pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformCastsPass::runOnFunction() {
+ auto &fn = getFunction();
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
+ applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() {
+ return new LowerUniformCastsPass();
+}
+
+static PassRegistration<LowerUniformCastsPass>
+ lowerUniformCastsPass("fxpmath-lower-uniform-casts",
+ "Lowers uniform-quantized casts.");
OpenPOWER on IntegriCloud