summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-05-25 17:22:27 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-06-01 20:03:22 -0700
commit9e21ab8f522265d37159372dbce96f66488c4e34 (patch)
tree0ec933521a5894a21543e7ef1684e437b4978380 /mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
parent2f50b6c401fd4d6ff63718ef3b889a79ba32a640 (diff)
downloadbcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.tar.gz
bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.zip
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
-- PiperOrigin-RevId: 250003405
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp')
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp64
1 files changed, 25 insertions, 39 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 44b1156ec28..0c8ba3171aa 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -36,40 +36,35 @@ public:
void runOnFunction() override;
};
-class QuantizedConstRewrite : public RewritePattern {
-public:
- struct State : PatternState {
- QuantizedType quantizedElementType;
- Attribute value;
- };
-
- QuantizedConstRewrite(MLIRContext *context)
- : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {}
+struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
+ using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
- PatternMatchResult match(Operation *op) const override;
- void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
- PatternRewriter &rewriter) const override;
+ PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
+ PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
/// quantized and the operand type is quantizable.
-PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
- State state;
+
+PatternMatchResult
+QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
+ PatternRewriter &rewriter) const {
+ Attribute value;
// Is the operand a constant?
- auto qbarrier = cast<QuantizeCastOp>(op);
- if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
+ if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
return matchFailure();
}
+
// Does the qbarrier convert to a quantized type. This will not be true
// if a quantized type has not yet been chosen or if the cast to an equivalent
// storage type is not supported.
Type qbarrierResultType = qbarrier.getResult()->getType();
- state.quantizedElementType =
+ QuantizedType quantizedElementType =
QuantizedType::getQuantizedElementType(qbarrierResultType);
- if (!state.quantizedElementType) {
+ if (!quantizedElementType) {
return matchFailure();
}
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
@@ -79,43 +74,34 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
// Is the operand type compatible with the expressed type of the quantized
// type? This will not be true if the qbarrier is superfluous (converts
// from and to a quantized type).
- if (!state.quantizedElementType.isCompatibleExpressedType(
+ if (!quantizedElementType.isCompatibleExpressedType(
qbarrier.arg()->getType())) {
return matchFailure();
}
// Is the constant value a type expressed in a way that we support?
- if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
- !state.value.isa<DenseElementsAttr>() &&
- !state.value.isa<SparseElementsAttr>()) {
+ if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
+ !value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
return matchFailure();
}
- return matchSuccess(llvm::make_unique<State>(std::move(state)));
-}
-
-void QuantizedConstRewrite::rewrite(Operation *op,
- std::unique_ptr<PatternState> baseState,
- PatternRewriter &rewriter) const {
- auto state = static_cast<State *>(baseState.get());
-
Type newConstValueType;
- Attribute newConstValue = quantizeAttr(
- state->value, state->quantizedElementType, newConstValueType);
+ auto newConstValue =
+ quantizeAttr(value, quantizedElementType, newConstValueType);
if (!newConstValue) {
- return;
+ return matchFailure();
}
- auto *origConstOp = op->getOperand(0);
// When creating the new const op, use a fused location that combines the
// original const and the qbarrier that led to the quantization.
- auto fusedLoc =
- FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
- rewriter.getContext());
+ auto fusedLoc = FusedLoc::get(
+ {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
+ rewriter.getContext());
auto newConstOp =
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
- rewriter.replaceOpWithNewOp<StorageCastOp>(
- {origConstOp}, op, *op->result_type_begin(), newConstOp);
+ rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
+ qbarrier.getType(), newConstOp);
+ return matchSuccess();
}
void ConvertConstPass::runOnFunction() {
OpenPOWER on IntegriCloud