diff options
author | River Riddle <riverriddle@google.com> | 2019-05-25 17:22:27 -0700 |
---|---|---|
committer | Mehdi Amini <joker.eph@gmail.com> | 2019-06-01 20:03:22 -0700 |
commit | 9e21ab8f522265d37159372dbce96f66488c4e34 (patch) | |
tree | 0ec933521a5894a21543e7ef1684e437b4978380 /mlir/lib/Dialect/QuantOps/Transforms | |
parent | 2f50b6c401fd4d6ff63718ef3b889a79ba32a640 (diff) | |
download | bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.tar.gz bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.zip |
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
--
PiperOrigin-RevId: 250003405
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Transforms')
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp | 64 |
1 files changed, 25 insertions, 39 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 44b1156ec28..0c8ba3171aa 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -36,40 +36,35 @@ public: void runOnFunction() override; }; -class QuantizedConstRewrite : public RewritePattern { -public: - struct State : PatternState { - QuantizedType quantizedElementType; - Attribute value; - }; - - QuantizedConstRewrite(MLIRContext *context) - : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {} +struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> { + using OpRewritePattern<QuantizeCastOp>::OpRewritePattern; - PatternMatchResult match(Operation *op) const override; - void rewrite(Operation *op, std::unique_ptr<PatternState> baseState, - PatternRewriter &rewriter) const override; + PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier, + PatternRewriter &rewriter) const override; }; } // end anonymous namespace /// Matches a [constant] -> [qbarrier] where the qbarrier results type is /// quantized and the operand type is quantizable. -PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { - State state; + +PatternMatchResult +QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, + PatternRewriter &rewriter) const { + Attribute value; // Is the operand a constant? - auto qbarrier = cast<QuantizeCastOp>(op); - if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) { + if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { return matchFailure(); } + // Does the qbarrier convert to a quantized type. This will not be true // if a quantized type has not yet been chosen or if the cast to an equivalent // storage type is not supported. Type qbarrierResultType = qbarrier.getResult()->getType(); - state.quantizedElementType = + QuantizedType quantizedElementType = QuantizedType::getQuantizedElementType(qbarrierResultType); - if (!state.quantizedElementType) { + if (!quantizedElementType) { return matchFailure(); } if (!QuantizedType::castToStorageType(qbarrierResultType)) { @@ -79,43 +74,34 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { // Is the operand type compatible with the expressed type of the quantized // type? This will not be true if the qbarrier is superfluous (converts // from and to a quantized type). - if (!state.quantizedElementType.isCompatibleExpressedType( + if (!quantizedElementType.isCompatibleExpressedType( qbarrier.arg()->getType())) { return matchFailure(); } // Is the constant value a type expressed in a way that we support? - if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() && - !state.value.isa<DenseElementsAttr>() && - !state.value.isa<SparseElementsAttr>()) { + if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() && + !value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) { return matchFailure(); } - return matchSuccess(llvm::make_unique<State>(std::move(state))); -} - -void QuantizedConstRewrite::rewrite(Operation *op, - std::unique_ptr<PatternState> baseState, - PatternRewriter &rewriter) const { - auto state = static_cast<State *>(baseState.get()); - Type newConstValueType; - Attribute newConstValue = quantizeAttr( - state->value, state->quantizedElementType, newConstValueType); + auto newConstValue = + quantizeAttr(value, quantizedElementType, newConstValueType); if (!newConstValue) { - return; + return matchFailure(); } - auto *origConstOp = op->getOperand(0); // When creating the new const op, use a fused location that combines the // original const and the qbarrier that led to the quantization. - auto fusedLoc = - FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()}, - rewriter.getContext()); + auto fusedLoc = FusedLoc::get( + {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()}, + rewriter.getContext()); auto newConstOp = rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); - rewriter.replaceOpWithNewOp<StorageCastOp>( - {origConstOp}, op, *op->result_type_begin(), newConstOp); + rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier, + qbarrier.getType(), newConstOp); + return matchSuccess(); } void ConvertConstPass::runOnFunction() { |