summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp')
-rw-r--r--mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp68
1 files changed, 68 insertions, 0 deletions
diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
new file mode 100644
index 00000000000..da5bd12ea1c
--- /dev/null
+++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
@@ -0,0 +1,68 @@
+//===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a pass to remove any instrumentation ops. It is often one
+// of the final steps when performing quantization and is run after any
+// decisions requiring instrumentation have been made.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+class RemoveInstrumentationPass
+ : public FunctionPass<RemoveInstrumentationPass> {
+ void runOnFunction() override;
+};
+
+template <typename OpTy>
+class RemoveIdentityOpRewrite : public RewritePattern {
+public:
+ RemoveIdentityOpRewrite(MLIRContext *context)
+ : RewritePattern(OpTy::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ assert(op->getNumOperands() == 1);
+ assert(op->getNumResults() == 1);
+
+ rewriter.replaceOp(op, op->getOperand(0));
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+void RemoveInstrumentationPass::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto func = getFunction();
+ auto *context = &getContext();
+ patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
+ RemoveIdentityOpRewrite<StatisticsRefOp>,
+ RemoveIdentityOpRewrite<CoupledRefOp>>(context);
+ applyPatternsGreedily(func, patterns);
+}
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::quantizer::createRemoveInstrumentationPass() {
+ return std::make_unique<RemoveInstrumentationPass>();
+}
+
+static PassRegistration<RemoveInstrumentationPass>
+ pass("quantizer-remove-instrumentation",
+ "Removes instrumentation and hints which have no effect on final "
+ "execution");
OpenPOWER on IntegriCloud