summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp')
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp114
1 files changed, 114 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
new file mode 100644
index 00000000000..ea8095b791c
--- /dev/null
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
@@ -0,0 +1,114 @@
+//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
+#include "mlir/Dialect/QuantOps/Passes.h"
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/UniformSupport.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class ConvertSimulatedQuantPass
+ : public FunctionPass<ConvertSimulatedQuantPass> {
+public:
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
+class ConstFakeQuantRewrite : public RewritePattern {
+public:
+ bool *hadFailure;
+
+ ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
+ : RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
+ hadFailure(hadFailure) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // TODO: If this pattern comes up more frequently, consider adding core
+ // support for failable rewrites.
+ if (failableRewrite(op, rewriter)) {
+ *hadFailure = true;
+ return matchFailure();
+ }
+
+ return matchSuccess();
+ }
+
+ bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
+ auto fqOp = cast<ConstFakeQuant>(op);
+
+ auto converter =
+ ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
+ if (!converter) {
+ return (op->emitError("unsupported quantized type conversion"), true);
+ }
+
+ UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
+ fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
+ fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
+ fqOp.narrow_range(), converter.expressedType);
+
+ if (!uniformElementType) {
+ // Note that the fakeQuantAttrsToType will have emitted the error.
+ return true;
+ }
+
+ Type quantizedType = converter.convert(uniformElementType);
+ assert(quantizedType &&
+ "Converter accepted a type that it did not convert");
+
+ // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
+ // this is a forced/hard-coded constraint.
+ auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
+ fqOp.inputs());
+ rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
+ qbarrier.getResult());
+
+ return false;
+ }
+};
+
+void ConvertSimulatedQuantPass::runOnFunction() {
+ bool hadFailure = false;
+ OwningRewritePatternList patterns;
+ auto &func = getFunction();
+ auto *context = &getContext();
+ patterns.push_back(
+ llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
+ applyPatternsGreedily(func, std::move(patterns));
+ if (hadFailure)
+ signalPassFailure();
+}
+
+FunctionPassBase *mlir::quant::createConvertSimulatedQuantPass() {
+ return new ConvertSimulatedQuantPass();
+}
+
+static PassRegistration<ConvertSimulatedQuantPass>
+ pass("quant-convert-simulated-quantization",
+ "Converts training-time simulated quantization ops to corresponding "
+ "quantize/dequantize casts.");
OpenPOWER on IntegriCloud