diff options
| author | River Riddle <riverriddle@google.com> | 2019-05-25 17:22:27 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2019-06-01 20:03:22 -0700 |
| commit | 9e21ab8f522265d37159372dbce96f66488c4e34 (patch) | |
| tree | 0ec933521a5894a21543e7ef1684e437b4978380 /mlir/lib/Dialect/QuantOps | |
| parent | 2f50b6c401fd4d6ff63718ef3b889a79ba32a640 (diff) | |
| download | bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.tar.gz bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.zip | |
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
--
PiperOrigin-RevId: 250003405
Diffstat (limited to 'mlir/lib/Dialect/QuantOps')
| -rw-r--r-- | mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp | 27 | ||||
| -rw-r--r-- | mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp | 64 |
2 files changed, 36 insertions, 55 deletions
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index fb5b2e1b0f7..e237e8b6eb2 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -38,26 +38,21 @@ namespace { /// Matches x -> [scast -> scast] -> y, replacing the second scast with the /// value of x if the casts invert each other. -class RemoveRedundantStorageCastsRewrite : public RewritePattern { +class RemoveRedundantStorageCastsRewrite + : public OpRewritePattern<StorageCastOp> { public: - RemoveRedundantStorageCastsRewrite(MLIRContext *context) - : RewritePattern(StorageCastOp::getOperationName(), 1, context) {} + using OpRewritePattern<StorageCastOp>::OpRewritePattern; - PatternMatchResult match(Operation *op) const override { - auto scastOp = cast<StorageCastOp>(op); - if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) { - auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp()); - if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) { - return matchSuccess(); - } - } - return matchFailure(); - } + PatternMatchResult matchAndRewrite(StorageCastOp op, + PatternRewriter &rewriter) const override { + if (!matchPattern(op.arg(), m_Op<StorageCastOp>())) + return matchFailure(); + auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp()); + if (srcScastOp.arg()->getType() != op.getType()) + return matchFailure(); - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - auto scastOp = cast<StorageCastOp>(op); - auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp()); rewriter.replaceOp(op, srcScastOp.arg()); + return matchSuccess(); } }; diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 44b1156ec28..0c8ba3171aa 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -36,40 +36,35 @@ public: void runOnFunction() override; }; -class QuantizedConstRewrite : public RewritePattern { -public: - struct State : PatternState { - QuantizedType quantizedElementType; - Attribute value; - }; - - QuantizedConstRewrite(MLIRContext *context) - : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {} +struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> { + using OpRewritePattern<QuantizeCastOp>::OpRewritePattern; - PatternMatchResult match(Operation *op) const override; - void rewrite(Operation *op, std::unique_ptr<PatternState> baseState, - PatternRewriter &rewriter) const override; + PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier, + PatternRewriter &rewriter) const override; }; } // end anonymous namespace /// Matches a [constant] -> [qbarrier] where the qbarrier results type is /// quantized and the operand type is quantizable. -PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { - State state; + +PatternMatchResult +QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, + PatternRewriter &rewriter) const { + Attribute value; // Is the operand a constant? - auto qbarrier = cast<QuantizeCastOp>(op); - if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) { + if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { return matchFailure(); } + // Does the qbarrier convert to a quantized type. This will not be true // if a quantized type has not yet been chosen or if the cast to an equivalent // storage type is not supported. Type qbarrierResultType = qbarrier.getResult()->getType(); - state.quantizedElementType = + QuantizedType quantizedElementType = QuantizedType::getQuantizedElementType(qbarrierResultType); - if (!state.quantizedElementType) { + if (!quantizedElementType) { return matchFailure(); } if (!QuantizedType::castToStorageType(qbarrierResultType)) { @@ -79,43 +74,34 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { // Is the operand type compatible with the expressed type of the quantized // type? This will not be true if the qbarrier is superfluous (converts // from and to a quantized type). - if (!state.quantizedElementType.isCompatibleExpressedType( + if (!quantizedElementType.isCompatibleExpressedType( qbarrier.arg()->getType())) { return matchFailure(); } // Is the constant value a type expressed in a way that we support? - if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() && - !state.value.isa<DenseElementsAttr>() && - !state.value.isa<SparseElementsAttr>()) { + if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() && + !value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) { return matchFailure(); } - return matchSuccess(llvm::make_unique<State>(std::move(state))); -} - -void QuantizedConstRewrite::rewrite(Operation *op, - std::unique_ptr<PatternState> baseState, - PatternRewriter &rewriter) const { - auto state = static_cast<State *>(baseState.get()); - Type newConstValueType; - Attribute newConstValue = quantizeAttr( - state->value, state->quantizedElementType, newConstValueType); + auto newConstValue = + quantizeAttr(value, quantizedElementType, newConstValueType); if (!newConstValue) { - return; + return matchFailure(); } - auto *origConstOp = op->getOperand(0); // When creating the new const op, use a fused location that combines the // original const and the qbarrier that led to the quantization. - auto fusedLoc = - FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()}, - rewriter.getContext()); + auto fusedLoc = FusedLoc::get( + {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()}, + rewriter.getContext()); auto newConstOp = rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); - rewriter.replaceOpWithNewOp<StorageCastOp>( - {origConstOp}, op, *op->result_type_begin(), newConstOp); + rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier, + qbarrier.getType(), newConstOp); + return matchSuccess(); } void ConvertConstPass::runOnFunction() { |

