diff options
| author | Stella Laurenzo <laurenzo@google.com> | 2019-05-17 17:43:50 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2019-05-20 13:46:43 -0700 |
| commit | 8e5bfb85c44916ef581675ea5a68e062d85624fd (patch) | |
| tree | aeec4f46106cb44c5d8bb6483c7a86da4e884932 /mlir/lib/Quantizer/Configurations | |
| parent | 1a100849c46e1a1c2cf0cba04aaad64e689d06d1 (diff) | |
| download | bcm5719-llvm-8e5bfb85c44916ef581675ea5a68e062d85624fd.tar.gz bcm5719-llvm-8e5bfb85c44916ef581675ea5a68e062d85624fd.zip | |
Upstream the Quantizer tool (part 3).
This upstreams the config and constraints for a reference quantization scheme based on the FxpMathOps dialect.
There are probably two more CLs to get the rest: one with the passes/tests, and one with the tool main() itself.
--
PiperOrigin-RevId: 248817505
Diffstat (limited to 'mlir/lib/Quantizer/Configurations')
| -rw-r--r-- | mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp | 289 |
1 files changed, 289 insertions, 0 deletions
diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp new file mode 100644 index 00000000000..8623df92c29 --- /dev/null +++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp @@ -0,0 +1,289 @@ +//===- FxpMathConfig.cpp - Reference fixed point config -------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines a TargetConfiguration for reference fixed-point math +// quantization scheme based on the FxpMathOps (plus a small category of +// extension ops that can be added from other dialects). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Quantizer/Configurations/FxpMathConfig.h" + +#include "mlir/Dialect/FxpMathOps/FxpMathOps.h" +#include "mlir/Dialect/QuantOps/QuantOps.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" +#include "mlir/Quantizer/Support/Metadata.h" +#include "mlir/Quantizer/Support/Statistics.h" +#include "mlir/Quantizer/Support/UniformConstraints.h" +#include "mlir/StandardOps/Ops.h" + +using namespace mlir; +using namespace mlir::quantizer; +using namespace mlir::fxpmath; +using namespace mlir::quant; +using namespace std::placeholders; + +namespace { + +struct FxpMathTargetConfigImpl : public FxpMathTargetConfig { + FxpMathTargetConfigImpl(SolverContext &context) + : FxpMathTargetConfig(context) { + Builder b(&context.getMlirContext()); + IntegerType i8Type = b.getIntegerType(8); + IntegerType i16Type = b.getIntegerType(16); + IntegerType i32Type = b.getIntegerType(32); + + q8 = addCandidateType( + AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr, + std::numeric_limits<int8_t>::min(), + std::numeric_limits<int8_t>::max()), + CandidateQuantizedType::Scheme::UniformPerLayer); + q16 = addCandidateType( + AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr, + std::numeric_limits<int16_t>::min(), + std::numeric_limits<int16_t>::max()), + CandidateQuantizedType::Scheme::UniformPerLayer); + q32ExplicitFixedPoint = addCandidateType( + AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr, + std::numeric_limits<int32_t>::min(), + std::numeric_limits<int32_t>::max()), + CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale); + + // Op handlers. + addOpHandler<ConstantOp>( + std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2)); + addOpHandler<ReturnOp>( + std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2)); + addOpHandler<quant::StatisticsOp>( + std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2)); + + // FxpMathOps. + addOpHandler<RealAddEwOp>( + std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2)); + addOpHandler<RealMulEwOp>( + std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2)); + addOpHandler<RealMatMulOp>( + std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2)); + addOpHandler<RealMatMulBiasOp>( + std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2)); + + // Require stats ops. + addRequireStatsOp<RealAddEwOp>(); + addRequireStatsOp<RealSubEwOp>(); + addRequireStatsOp<RealDivEwOp>(); + addRequireStatsOp<RealMulEwOp>(); + addRequireStatsOp<RealMatMulOp>(); + addRequireStatsOp<RealMatMulBiasOp>(); + } + + bool isHandledType(Type t) const final { + if (t.isa<FloatType>()) + return true; + auto shapedType = t.dyn_cast<ShapedType>(); + return (shapedType && shapedType.getElementType().isa<FloatType>() && + (t.isa<VectorType>() || t.isa<TensorType>())); + } + + void finalizeAnchors(CAGSlice &cag) const override { + cag.enumerateImpliedConnections( + [&](CAGAnchorNode *from, CAGAnchorNode *to) { + UniformConstraintsBuilder(cag).coupleAnchors(from, to); + }); + } + + void addValueIdentityOpByName(StringRef opName) override { + addOpHandlerByName( + opName, + std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2)); + } + + void handleValueIdentity(Operation *op, CAGSlice &cag) const { + assert(op->getNumResults() == 1); + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto resultNode = cag.getResultAnchor(op, 0); + resultNode->setTypeTransformRule( + CAGAnchorNode::TypeTransformRule::DirectStorage); + + for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) { + if (!isHandledType(op->getOperand(opIdx)->getType())) + continue; + auto operandNode = cag.getOperandAnchor(op, opIdx); + operandNode->setTypeTransformRule( + CAGAnchorNode::TypeTransformRule::DirectStorage); + UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode); + } + } + + void handleConstant(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto resultNode = cag.getResultAnchor(op, 0); + resultNode->setTypeTransformRule( + CAGAnchorNode::TypeTransformRule::ExpressedOnly); + Attribute valueAttr; + if (!matchPattern(op, m_Constant(&valueAttr))) { + return; + } + + AttributeTensorStatistics stats(valueAttr); + TensorAxisStatistics layerStats; + if (!stats.get(layerStats)) { + op->emitOpError("could not compute statistics"); + return; + } + + UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats); + } + + void handleTerminal(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getOperand(0)->getType())) + return; + auto operandNode = cag.getOperandAnchor(op, 0); + operandNode->setTypeTransformRule( + CAGAnchorNode::TypeTransformRule::ExpressedOnly); + } + + void handleStats(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto argNode = cag.getOperandAnchor(op, 0); + auto resultNode = cag.getResultAnchor(op, 0); + UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode); + + TensorAxisStatistics layerStats; + auto statsOp = cast<quant::StatisticsOp>(op); + auto layerStatsAttr = statsOp.layerStats(); + layerStats.minValue = + layerStatsAttr.getValue({0}).cast<FloatAttr>().getValueAsDouble(); + layerStats.maxValue = + layerStatsAttr.getValue({1}).cast<FloatAttr>().getValueAsDouble(); + UniformConstraintsBuilder(cag).applyStats(resultNode, + std::move(layerStats)); + } + + void handleAdd(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto lhs = cag.getOperandAnchor(op, 0); + auto rhs = cag.getOperandAnchor(op, 1); + auto resultNode = cag.getResultAnchor(op, 0); + // Add supports 8/16 bit math. + llvm::SmallBitVector disableMask = + getCandidateTypeDisabledExceptMask({q8, q16}); + lhs->getUniformMetadata().disabledCandidateTypes = disableMask; + rhs->getUniformMetadata().disabledCandidateTypes = disableMask; + resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; + // NOTE: We couple the add such that the scale/zeroPoint match between + // both args and the result. This is overly constrained in that it is + // possible to write efficient add kernels with a bit more freedom (i.e. + // zeroPoints can vary, scales can differ by a power of two, etc). + // However, fully coupled yields the simples solutions on the fast path. + // Further efficiency can be had by constraining the zeroPoint to 0, but + // there isn't a constraint for this yet (and there are tradeoffs). + UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode); + UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode); + addRealMathOptionalConstraints(op, resultNode, cag); + } + + void handleMul(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto lhs = cag.getOperandAnchor(op, 0); + auto rhs = cag.getOperandAnchor(op, 1); + auto resultNode = cag.getResultAnchor(op, 0); + // Mul supports 8/16 bit math. + llvm::SmallBitVector disableMask = + getCandidateTypeDisabledExceptMask({q8, q16}); + lhs->getUniformMetadata().disabledCandidateTypes = disableMask; + rhs->getUniformMetadata().disabledCandidateTypes = disableMask; + resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; + addRealMathOptionalConstraints(op, resultNode, cag); + } + + void handleMatMul(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto lhs = cag.getOperandAnchor(op, 0); + auto rhs = cag.getOperandAnchor(op, 1); + auto resultNode = cag.getResultAnchor(op, 0); + // Mul supports 8/16 bit math. + llvm::SmallBitVector disableMask = + getCandidateTypeDisabledExceptMask({q8, q16}); + lhs->getUniformMetadata().disabledCandidateTypes = disableMask; + rhs->getUniformMetadata().disabledCandidateTypes = disableMask; + resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; + addRealMathOptionalConstraints(op, resultNode, cag); + } + + void handleMatMulBias(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto lhs = cag.getOperandAnchor(op, 0); + auto rhs = cag.getOperandAnchor(op, 1); + auto bias = cag.getOperandAnchor(op, 2); + bias->getUniformMetadata().disabledCandidateTypes = + getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint}); + + auto resultNode = cag.getResultAnchor(op, 0); + UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias); + + // Mul supports 8/16 bit math. + llvm::SmallBitVector disableMask = + getCandidateTypeDisabledExceptMask({q8, q16}); + lhs->getUniformMetadata().disabledCandidateTypes = disableMask; + rhs->getUniformMetadata().disabledCandidateTypes = disableMask; + resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; + addRealMathOptionalConstraints(op, resultNode, cag); + } + + void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor, + CAGSlice &cag) const { + // TODO: It would be nice if these all extended some base trait instead + // of requiring name lookup. + auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min"); + auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max"); + + if (clampMinAttr || clampMaxAttr) { + auto nan = APFloat::getQNaN(APFloat::IEEEdouble()); + auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan; + auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan; + UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax); + } + } + + unsigned q8; + unsigned q16; + unsigned q32ExplicitFixedPoint; +}; + +} // anonymous namespace + +std::unique_ptr<FxpMathTargetConfig> +FxpMathTargetConfig::create(SolverContext &context) { + return llvm::make_unique<FxpMathTargetConfigImpl>(context); +} |

