summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp')
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp410
1 files changed, 410 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
new file mode 100644
index 00000000000..32d8de3c25d
--- /dev/null
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -0,0 +1,410 @@
+//===- LowerUniformRealMath.cpp ------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "UniformKernelUtils.h"
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/FxpMathOps/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::fxpmath;
+using namespace mlir::fxpmath::detail;
+using namespace mlir::quant;
+
+namespace {
+
+struct LowerUniformRealMathPass
+ : public FunctionPass<LowerUniformRealMathPass> {
+ void runOnFunction() override;
+};
+
+struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Dequantize
+//===----------------------------------------------------------------------===//
+
+static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
+ UniformQuantizedType elementType,
+ PatternRewriter &rewriter) {
+ // Pre-conditions.
+ if (!elementType.isSigned()) {
+ // TODO: Support unsigned storage type.
+ rewriter.getContext()->emitWarning(
+ loc, "unimplemented: dequantize signed uniform");
+ return nullptr;
+ }
+
+ Type storageType = elementType.castToStorageType(input->getType());
+ Type realType = elementType.castToExpressedType(input->getType());
+ Type intermediateType =
+ castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
+ assert(storageType && "cannot cast to storage type");
+ assert(realType && "cannot cast to expressed type");
+
+ // Cast to storage type.
+ input = rewriter.create<StorageCastOp>(loc, storageType, input);
+
+ // Promote to intermediate type.
+ input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
+
+ // Apply zero-point offset.
+ if (elementType.getZeroPoint() != 0) {
+ Value *negZeroPointConst = rewriter.create<ConstantOp>(
+ loc, broadcastScalarConstIntValue(intermediateType,
+ -elementType.getZeroPoint()));
+ input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
+ }
+
+ // Convert to float.
+ input = rewriter.create<ConvertISToFOp>(loc, realType, input);
+
+ // Mul by scale.
+ Value *scaleConst = rewriter.create<ConstantOp>(
+ loc, broadcastScalarConstFloatValue(realType,
+ APFloat(elementType.getScale())));
+ return rewriter.create<MulFOp>(loc, input, scaleConst);
+}
+
+static Value *
+emitUniformPerAxisDequantize(Location loc, Value *input,
+ UniformQuantizedPerAxisType elementType,
+ PatternRewriter &rewriter) {
+ // TODO: Support per-axis dequantize.
+ rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
+ << "unimplemented: per-axis uniform dequantization";
+ return nullptr;
+}
+
+static Value *emitDequantize(Location loc, Value *input,
+ PatternRewriter &rewriter) {
+ Type inputType = input->getType();
+ QuantizedType qElementType =
+ QuantizedType::getQuantizedElementType(inputType);
+ if (auto uperLayerElementType =
+ qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
+ return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
+ rewriter);
+ } else if (auto uperAxisElementType =
+ qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
+ return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
+ rewriter);
+ } else {
+ return nullptr;
+ }
+}
+
+namespace {
+
+struct UniformDequantizePattern : public RewritePattern {
+ UniformDequantizePattern(MLIRContext *context)
+ : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto dcastOp = cast<DequantizeCastOp>(op);
+ Type inputType = dcastOp.arg()->getType();
+ Type outputType = dcastOp.getResult()->getType();
+
+ QuantizedType inputElementType =
+ QuantizedType::getQuantizedElementType(inputType);
+ Type expressedOutputType = inputElementType.castToExpressedType(inputType);
+ if (expressedOutputType != outputType) {
+ // Not a valid uniform cast.
+ return matchFailure();
+ }
+
+ Value *dequantizedValue =
+ emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
+ if (!dequantizedValue) {
+ return matchFailure();
+ }
+
+ rewriter.replaceOp(op, dequantizedValue);
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Elementwise add
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
+ PatternRewriter &rewriter) {
+ if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
+ info.rhsType != info.resultType) {
+ return failure();
+ }
+
+ // Choose a byte aligned intermediate width big enough to perform the
+ // calculation without overflow.
+ // TODO: This should probably be made just big enough to avoid overflow and
+ // leave the downstream tooling to decide how to align that to machine
+ // word sizes.
+ unsigned intermediateWidth =
+ info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
+ IntegerType intermediateElementType =
+ IntegerType::get(intermediateWidth, rewriter.getContext());
+ Type intermediateType =
+ castElementType(info.resultStorageType, intermediateElementType);
+
+ // Cast operands to storage type.
+ Value *lhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.lhsStorageType, info.lhs)
+ .getResult();
+ Value *rhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.rhsStorageType, info.rhs)
+ .getResult();
+
+ // Cast to the intermediate sized type.
+ lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ lhsValue);
+ rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ rhsValue);
+
+ // Add.
+ Value *resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
+
+ // Zero point offset adjustment.
+ // result = (lhs - zp) + (rhs - zp) + zp
+ // zpOffset = -zp
+ int zpOffset = -1 * info.resultType.getZeroPoint();
+ if (zpOffset != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(),
+ broadcastScalarConstIntValue(intermediateType, zpOffset));
+ resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
+ }
+
+ // Clamp.
+ auto clampMinMax = info.getClampMinMax(intermediateElementType);
+ resultValue = rewriter.create<ClampISOp>(
+ info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
+
+ // Convert back to original type.
+ resultValue = rewriter.create<ConvertISOp>(
+ info.op->getLoc(), info.resultStorageType, resultValue);
+
+ // Cast back for new result.
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ info.op, info.getQuantizedResultType(), resultValue);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise mul
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
+ PatternRewriter &rewriter) {
+ if (!info.resultType.isSigned()) {
+ return failure();
+ }
+
+ double outputMultiplierReal = info.lhsType.getScale() *
+ info.rhsType.getScale() /
+ info.resultType.getScale();
+ if (outputMultiplierReal > 1.0) {
+ info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
+ return failure();
+ }
+
+ // TODO: Choose an appropriate intermediate width for muls > 8 bits to
+ // avoid overflow.
+ unsigned intermediateWidth = 32;
+ IntegerType intermediateElementType =
+ IntegerType::get(intermediateWidth, rewriter.getContext());
+ Type intermediateType =
+ castElementType(info.resultStorageType, intermediateElementType);
+
+ // Cast operands to storage type.
+ Value *lhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.lhsStorageType, info.lhs)
+ .getResult();
+ Value *rhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.rhsStorageType, info.rhs)
+ .getResult();
+
+ // Cast to the intermediate sized type.
+ lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ lhsValue);
+ rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
+ rhsValue);
+
+ // Apply argument zeroPoints.
+ if (info.lhsType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(), broadcastScalarConstIntValue(
+ intermediateType, -info.lhsType.getZeroPoint()));
+ lhsValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
+ }
+
+ if (info.rhsType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(), broadcastScalarConstIntValue(
+ intermediateType, -info.rhsType.getZeroPoint()));
+ rhsValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
+ }
+
+ // Mul.
+ Value *resultValue =
+ rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
+
+ // Scale output.
+ QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
+ resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
+ info.op->getLoc(), resultValue,
+ IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
+ resultValue = rewriter.create<RoundingDivideByPotISOp>(
+ info.op->getLoc(), resultValue,
+ IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
+
+ // Zero point offset adjustment.
+ if (info.resultType.getZeroPoint() != 0) {
+ Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ info.op->getLoc(),
+ broadcastScalarConstIntValue(intermediateType,
+ info.resultType.getZeroPoint()));
+ resultValue =
+ rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
+ }
+
+ // Clamp.
+ auto clampMinMax = info.getClampMinMax(intermediateElementType);
+ resultValue = rewriter.create<ClampISOp>(
+ info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
+
+ // Convert back to original type.
+ resultValue = rewriter.create<ConvertISOp>(
+ info.op->getLoc(), info.resultStorageType, resultValue);
+
+ // Cast back for new result.
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ info.op, info.getQuantizedResultType(), resultValue);
+
+ return success();
+}
+
+namespace {
+
+struct UniformRealAddEwPattern : public RewritePattern {
+ UniformRealAddEwPattern(MLIRContext *context)
+ : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto addOp = cast<RealAddEwOp>(op);
+ const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
+ addOp.clamp_min(), addOp.clamp_max());
+ if (!info.isValid()) {
+ return matchFailure();
+ }
+
+ // Try all of the permutations we support.
+ if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+};
+
+struct UniformRealMulEwPattern : public RewritePattern {
+ UniformRealMulEwPattern(MLIRContext *context)
+ : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto mulOp = cast<RealMulEwOp>(op);
+ const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
+ mulOp.clamp_min(), mulOp.clamp_max());
+ if (!info.isValid()) {
+ return matchFailure();
+ }
+
+ // Try all of the permutations we support.
+ if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// LowerUniformRealMath pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformRealMathPass::runOnFunction() {
+ auto &fn = getFunction();
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
+ patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
+ applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() {
+ return new LowerUniformRealMathPass();
+}
+
+static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
+ "fxpmath-lower-uniform-real-math",
+ "Lowers uniform-quantized real math ops to integer arithmetic.");
+
+//===----------------------------------------------------------------------===//
+// LowerUniformCasts pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformCastsPass::runOnFunction() {
+ auto &fn = getFunction();
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
+ applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() {
+ return new LowerUniformCastsPass();
+}
+
+static PassRegistration<LowerUniformCastsPass>
+ lowerUniformCastsPass("fxpmath-lower-uniform-casts",
+ "Lowers uniform-quantized casts.");
OpenPOWER on IntegriCloud