summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/Transforms
diff options
context:
space:
mode:
authorStella Laurenzo <laurenzo@google.com>2019-05-14 11:03:55 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-05-20 13:41:55 -0700
commitd4dcf7de9e6f5f00177c534d765c5b24d9db8ed8 (patch)
tree7f819d67376f8d08681b81740082c193b6aa7ad9 /mlir/lib/Dialect/QuantOps/Transforms
parent6264fccd3a4af9edc37f9b6d0f37763e61800ba5 (diff)
downloadbcm5719-llvm-d4dcf7de9e6f5f00177c534d765c5b24d9db8ed8.tar.gz
bcm5719-llvm-d4dcf7de9e6f5f00177c534d765c5b24d9db8ed8.zip
Move Quantization -> Dialect/QuantOps, FxpMathOps -> Dialect/FxpMathOps.
Adding the additional layer of directory was discussed offline and matches the Target/ tree. The names match the defacto convention we seem to be following where the C++ namespace is ^(.+)Ops/$ matched against the directory name. This is in preparation for patching the Quantizer into this tree, which would have been confusing without moving the Quantization dialect to its more proper home. It is left to others to move other dialects if desired. Tested: ninja check-mlir -- PiperOrigin-RevId: 248171982
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Transforms')
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp135
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp114
2 files changed, 249 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
new file mode 100644
index 00000000000..228e16d752e
--- /dev/null
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -0,0 +1,135 @@
+//===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/Passes.h"
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantizeUtils.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class ConvertConstPass : public FunctionPass<ConvertConstPass> {
+public:
+ void runOnFunction() override;
+};
+
+class QuantizedConstRewrite : public RewritePattern {
+public:
+ struct State : PatternState {
+ QuantizedType quantizedElementType;
+ Attribute value;
+ };
+
+ QuantizedConstRewrite(MLIRContext *context)
+ : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult match(Operation *op) const override;
+ void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
+ PatternRewriter &rewriter) const override;
+};
+
+} // end anonymous namespace
+
+/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
+/// quantized and the operand type is quantizable.
+PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
+ State state;
+
+ // Is the operand a constant?
+ auto qbarrier = cast<QuantizeCastOp>(op);
+ if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
+ return matchFailure();
+ }
+ // Does the qbarrier convert to a quantized type. This will not be true
+ // if a quantized type has not yet been chosen or if the cast to an equivalent
+ // storage type is not supported.
+ Type qbarrierResultType = qbarrier.getResult()->getType();
+ state.quantizedElementType =
+ QuantizedType::getQuantizedElementType(qbarrierResultType);
+ if (!state.quantizedElementType) {
+ return matchFailure();
+ }
+ if (!QuantizedType::castToStorageType(qbarrierResultType)) {
+ return matchFailure();
+ }
+
+ // Is the operand type compatible with the expressed type of the quantized
+ // type? This will not be true if the qbarrier is superfluous (converts
+ // from and to a quantized type).
+ if (!state.quantizedElementType.isCompatibleExpressedType(
+ qbarrier.arg()->getType())) {
+ return matchFailure();
+ }
+
+ // Is the constant value a type expressed in a way that we support?
+ if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
+ !state.value.isa<DenseElementsAttr>() &&
+ !state.value.isa<SparseElementsAttr>()) {
+ return matchFailure();
+ }
+
+ return matchSuccess(llvm::make_unique<State>(std::move(state)));
+}
+
+void QuantizedConstRewrite::rewrite(Operation *op,
+ std::unique_ptr<PatternState> baseState,
+ PatternRewriter &rewriter) const {
+ auto state = static_cast<State *>(baseState.get());
+
+ Type newConstValueType;
+ Attribute newConstValue = quantizeAttr(
+ state->value, state->quantizedElementType, newConstValueType);
+ if (!newConstValue) {
+ return;
+ }
+
+ auto *origConstOp = op->getOperand(0);
+ // When creating the new const op, use a fused location that combines the
+ // original const and the qbarrier that led to the quantization.
+ auto fusedLoc =
+ FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
+ rewriter.getContext());
+ auto newConstOp =
+ rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ op, {origConstOp}, *op->result_type_begin(), newConstOp);
+}
+
+void ConvertConstPass::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto &func = getFunction();
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
+ applyPatternsGreedily(func, std::move(patterns));
+}
+
+FunctionPassBase *mlir::quant::createConvertConstPass() {
+ return new ConvertConstPass();
+}
+
+static PassRegistration<ConvertConstPass>
+ pass("quant-convert-const",
+ "Converts constants followed by qbarrier to actual quantized values");
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
new file mode 100644
index 00000000000..ea8095b791c
--- /dev/null
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
@@ -0,0 +1,114 @@
+//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
+#include "mlir/Dialect/QuantOps/Passes.h"
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class ConvertSimulatedQuantPass
+ : public FunctionPass<ConvertSimulatedQuantPass> {
+public:
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
+class ConstFakeQuantRewrite : public RewritePattern {
+public:
+ bool *hadFailure;
+
+ ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
+ : RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
+ hadFailure(hadFailure) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // TODO: If this pattern comes up more frequently, consider adding core
+ // support for failable rewrites.
+ if (failableRewrite(op, rewriter)) {
+ *hadFailure = true;
+ return matchFailure();
+ }
+
+ return matchSuccess();
+ }
+
+ bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
+ auto fqOp = cast<ConstFakeQuant>(op);
+
+ auto converter =
+ ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
+ if (!converter) {
+ return (op->emitError("unsupported quantized type conversion"), true);
+ }
+
+ UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
+ fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
+ fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
+ fqOp.narrow_range(), converter.expressedType);
+
+ if (!uniformElementType) {
+ // Note that the fakeQuantAttrsToType will have emitted the error.
+ return true;
+ }
+
+ Type quantizedType = converter.convert(uniformElementType);
+ assert(quantizedType &&
+ "Converter accepted a type that it did not convert");
+
+ // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
+ // this is a forced/hard-coded constraint.
+ auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
+ fqOp.inputs());
+ rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
+ qbarrier.getResult());
+
+ return false;
+ }
+};
+
+void ConvertSimulatedQuantPass::runOnFunction() {
+ bool hadFailure = false;
+ OwningRewritePatternList patterns;
+ auto &func = getFunction();
+ auto *context = &getContext();
+ patterns.push_back(
+ llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
+ applyPatternsGreedily(func, std::move(patterns));
+ if (hadFailure)
+ signalPassFailure();
+}
+
+FunctionPassBase *mlir::quant::createConvertSimulatedQuantPass() {
+ return new ConvertSimulatedQuantPass();
+}
+
+static PassRegistration<ConvertSimulatedQuantPass>
+ pass("quant-convert-simulated-quantization",
+ "Converts training-time simulated quantization ops to corresponding "
+ "quantize/dequantize casts.");
OpenPOWER on IntegriCloud