summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Quantizer/Configurations
diff options
context:
space:
mode:
authorStella Laurenzo <laurenzo@google.com>2019-05-17 17:43:50 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-05-20 13:46:43 -0700
commit8e5bfb85c44916ef581675ea5a68e062d85624fd (patch)
treeaeec4f46106cb44c5d8bb6483c7a86da4e884932 /mlir/lib/Quantizer/Configurations
parent1a100849c46e1a1c2cf0cba04aaad64e689d06d1 (diff)
downloadbcm5719-llvm-8e5bfb85c44916ef581675ea5a68e062d85624fd.tar.gz
bcm5719-llvm-8e5bfb85c44916ef581675ea5a68e062d85624fd.zip
Upstream the Quantizer tool (part 3).
This upstreams the config and constraints for a reference quantization scheme based on the FxpMathOps dialect. There are probably two more CLs to get the rest: one with the passes/tests, and one with the tool main() itself. -- PiperOrigin-RevId: 248817505
Diffstat (limited to 'mlir/lib/Quantizer/Configurations')
-rw-r--r--mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp289
1 files changed, 289 insertions, 0 deletions
diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
new file mode 100644
index 00000000000..8623df92c29
--- /dev/null
+++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
@@ -0,0 +1,289 @@
+//===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a TargetConfiguration for reference fixed-point math
+// quantization scheme based on the FxpMathOps (plus a small category of
+// extension ops that can be added from other dialects).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+#include "mlir/Quantizer/Support/Statistics.h"
+#include "mlir/Quantizer/Support/UniformConstraints.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::fxpmath;
+using namespace mlir::quant;
+using namespace std::placeholders;
+
+namespace {
+
+struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
+ FxpMathTargetConfigImpl(SolverContext &context)
+ : FxpMathTargetConfig(context) {
+ Builder b(&context.getMlirContext());
+ IntegerType i8Type = b.getIntegerType(8);
+ IntegerType i16Type = b.getIntegerType(16);
+ IntegerType i32Type = b.getIntegerType(32);
+
+ q8 = addCandidateType(
+ AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr,
+ std::numeric_limits<int8_t>::min(),
+ std::numeric_limits<int8_t>::max()),
+ CandidateQuantizedType::Scheme::UniformPerLayer);
+ q16 = addCandidateType(
+ AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr,
+ std::numeric_limits<int16_t>::min(),
+ std::numeric_limits<int16_t>::max()),
+ CandidateQuantizedType::Scheme::UniformPerLayer);
+ q32ExplicitFixedPoint = addCandidateType(
+ AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr,
+ std::numeric_limits<int32_t>::min(),
+ std::numeric_limits<int32_t>::max()),
+ CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale);
+
+ // Op handlers.
+ addOpHandler<ConstantOp>(
+ std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
+ addOpHandler<ReturnOp>(
+ std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
+ addOpHandler<quant::StatisticsOp>(
+ std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
+
+ // FxpMathOps.
+ addOpHandler<RealAddEwOp>(
+ std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2));
+ addOpHandler<RealMulEwOp>(
+ std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2));
+ addOpHandler<RealMatMulOp>(
+ std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2));
+ addOpHandler<RealMatMulBiasOp>(
+ std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2));
+
+ // Require stats ops.
+ addRequireStatsOp<RealAddEwOp>();
+ addRequireStatsOp<RealSubEwOp>();
+ addRequireStatsOp<RealDivEwOp>();
+ addRequireStatsOp<RealMulEwOp>();
+ addRequireStatsOp<RealMatMulOp>();
+ addRequireStatsOp<RealMatMulBiasOp>();
+ }
+
+ bool isHandledType(Type t) const final {
+ if (t.isa<FloatType>())
+ return true;
+ auto shapedType = t.dyn_cast<ShapedType>();
+ return (shapedType && shapedType.getElementType().isa<FloatType>() &&
+ (t.isa<VectorType>() || t.isa<TensorType>()));
+ }
+
+ void finalizeAnchors(CAGSlice &cag) const override {
+ cag.enumerateImpliedConnections(
+ [&](CAGAnchorNode *from, CAGAnchorNode *to) {
+ UniformConstraintsBuilder(cag).coupleAnchors(from, to);
+ });
+ }
+
+ void addValueIdentityOpByName(StringRef opName) override {
+ addOpHandlerByName(
+ opName,
+ std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2));
+ }
+
+ void handleValueIdentity(Operation *op, CAGSlice &cag) const {
+ assert(op->getNumResults() == 1);
+ if (!isHandledType(op->getResult(0)->getType()))
+ return;
+
+ auto resultNode = cag.getResultAnchor(op, 0);
+ resultNode->setTypeTransformRule(
+ CAGAnchorNode::TypeTransformRule::DirectStorage);
+
+ for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) {
+ if (!isHandledType(op->getOperand(opIdx)->getType()))
+ continue;
+ auto operandNode = cag.getOperandAnchor(op, opIdx);
+ operandNode->setTypeTransformRule(
+ CAGAnchorNode::TypeTransformRule::DirectStorage);
+ UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode);
+ }
+ }
+
+ void handleConstant(Operation *op, CAGSlice &cag) const {
+ if (!isHandledType(op->getResult(0)->getType()))
+ return;
+
+ auto resultNode = cag.getResultAnchor(op, 0);
+ resultNode->setTypeTransformRule(
+ CAGAnchorNode::TypeTransformRule::ExpressedOnly);
+ Attribute valueAttr;
+ if (!matchPattern(op, m_Constant(&valueAttr))) {
+ return;
+ }
+
+ AttributeTensorStatistics stats(valueAttr);
+ TensorAxisStatistics layerStats;
+ if (!stats.get(layerStats)) {
+ op->emitOpError("could not compute statistics");
+ return;
+ }
+
+ UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
+ }
+
+ void handleTerminal(Operation *op, CAGSlice &cag) const {
+ if (!isHandledType(op->getOperand(0)->getType()))
+ return;
+ auto operandNode = cag.getOperandAnchor(op, 0);
+ operandNode->setTypeTransformRule(
+ CAGAnchorNode::TypeTransformRule::ExpressedOnly);
+ }
+
+ void handleStats(Operation *op, CAGSlice &cag) const {
+ if (!isHandledType(op->getResult(0)->getType()))
+ return;
+
+ auto argNode = cag.getOperandAnchor(op, 0);
+ auto resultNode = cag.getResultAnchor(op, 0);
+ UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode);
+
+ TensorAxisStatistics layerStats;
+ auto statsOp = cast<quant::StatisticsOp>(op);
+ auto layerStatsAttr = statsOp.layerStats();
+ layerStats.minValue =
+ layerStatsAttr.getValue({0}).cast<FloatAttr>().getValueAsDouble();
+ layerStats.maxValue =
+ layerStatsAttr.getValue({1}).cast<FloatAttr>().getValueAsDouble();
+ UniformConstraintsBuilder(cag).applyStats(resultNode,
+ std::move(layerStats));
+ }
+
+ void handleAdd(Operation *op, CAGSlice &cag) const {
+ if (!isHandledType(op->getResult(0)->getType()))
+ return;
+
+ auto lhs = cag.getOperandAnchor(op, 0);
+ auto rhs = cag.getOperandAnchor(op, 1);
+ auto resultNode = cag.getResultAnchor(op, 0);
+ // Add supports 8/16 bit math.
+ llvm::SmallBitVector disableMask =
+ getCandidateTypeDisabledExceptMask({q8, q16});
+ lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+ rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+ resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+ // NOTE: We couple the add such that the scale/zeroPoint match between
+ // both args and the result. This is overly constrained in that it is
+ // possible to write efficient add kernels with a bit more freedom (i.e.
+ // zeroPoints can vary, scales can differ by a power of two, etc).
+ // However, fully coupled yields the simples solutions on the fast path.
+ // Further efficiency can be had by constraining the zeroPoint to 0, but
+ // there isn't a constraint for this yet (and there are tradeoffs).
+ UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode);
+ UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode);
+ addRealMathOptionalConstraints(op, resultNode, cag);
+ }
+
+ void handleMul(Operation *op, CAGSlice &cag) const {
+ if (!isHandledType(op->getResult(0)->getType()))
+ return;
+
+ auto lhs = cag.getOperandAnchor(op, 0);
+ auto rhs = cag.getOperandAnchor(op, 1);
+ auto resultNode = cag.getResultAnchor(op, 0);
+ // Mul supports 8/16 bit math.
+ llvm::SmallBitVector disableMask =
+ getCandidateTypeDisabledExceptMask({q8, q16});
+ lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+ rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+ resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+ addRealMathOptionalConstraints(op, resultNode, cag);
+ }
+
+ void handleMatMul(Operation *op, CAGSlice &cag) const {
+ if (!isHandledType(op->getResult(0)->getType()))
+ return;
+
+ auto lhs = cag.getOperandAnchor(op, 0);
+ auto rhs = cag.getOperandAnchor(op, 1);
+ auto resultNode = cag.getResultAnchor(op, 0);
+ // Mul supports 8/16 bit math.
+ llvm::SmallBitVector disableMask =
+ getCandidateTypeDisabledExceptMask({q8, q16});
+ lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+ rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+ resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+ addRealMathOptionalConstraints(op, resultNode, cag);
+ }
+
+ void handleMatMulBias(Operation *op, CAGSlice &cag) const {
+ if (!isHandledType(op->getResult(0)->getType()))
+ return;
+
+ auto lhs = cag.getOperandAnchor(op, 0);
+ auto rhs = cag.getOperandAnchor(op, 1);
+ auto bias = cag.getOperandAnchor(op, 2);
+ bias->getUniformMetadata().disabledCandidateTypes =
+ getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint});
+
+ auto resultNode = cag.getResultAnchor(op, 0);
+ UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias);
+
+ // Mul supports 8/16 bit math.
+ llvm::SmallBitVector disableMask =
+ getCandidateTypeDisabledExceptMask({q8, q16});
+ lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+ rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+ resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+ addRealMathOptionalConstraints(op, resultNode, cag);
+ }
+
+ void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor,
+ CAGSlice &cag) const {
+ // TODO: It would be nice if these all extended some base trait instead
+ // of requiring name lookup.
+ auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min");
+ auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max");
+
+ if (clampMinAttr || clampMaxAttr) {
+ auto nan = APFloat::getQNaN(APFloat::IEEEdouble());
+ auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan;
+ auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan;
+ UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax);
+ }
+ }
+
+ unsigned q8;
+ unsigned q16;
+ unsigned q32ExplicitFixedPoint;
+};
+
+} // anonymous namespace
+
+std::unique_ptr<FxpMathTargetConfig>
+FxpMathTargetConfig::create(SolverContext &context) {
+ return llvm::make_unique<FxpMathTargetConfigImpl>(context);
+}
OpenPOWER on IntegriCloud