summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-05-25 17:22:27 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-06-01 20:03:22 -0700
commit9e21ab8f522265d37159372dbce96f66488c4e34 (patch)
tree0ec933521a5894a21543e7ef1684e437b4978380 /mlir/lib/Dialect/QuantOps
parent2f50b6c401fd4d6ff63718ef3b889a79ba32a640 (diff)
downloadbcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.tar.gz
bcm5719-llvm-9e21ab8f522265d37159372dbce96f66488c4e34.zip
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
-- PiperOrigin-RevId: 250003405
Diffstat (limited to 'mlir/lib/Dialect/QuantOps')
-rw-r--r--mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp27
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp64
2 files changed, 36 insertions, 55 deletions
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
index fb5b2e1b0f7..e237e8b6eb2 100644
--- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
@@ -38,26 +38,21 @@ namespace {
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
/// value of x if the casts invert each other.
-class RemoveRedundantStorageCastsRewrite : public RewritePattern {
+class RemoveRedundantStorageCastsRewrite
+ : public OpRewritePattern<StorageCastOp> {
public:
- RemoveRedundantStorageCastsRewrite(MLIRContext *context)
- : RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
+ using OpRewritePattern<StorageCastOp>::OpRewritePattern;
- PatternMatchResult match(Operation *op) const override {
- auto scastOp = cast<StorageCastOp>(op);
- if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
- auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
- if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
- return matchSuccess();
- }
- }
- return matchFailure();
- }
+ PatternMatchResult matchAndRewrite(StorageCastOp op,
+ PatternRewriter &rewriter) const override {
+ if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
+ return matchFailure();
+ auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
+ if (srcScastOp.arg()->getType() != op.getType())
+ return matchFailure();
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto scastOp = cast<StorageCastOp>(op);
- auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
rewriter.replaceOp(op, srcScastOp.arg());
+ return matchSuccess();
}
};
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 44b1156ec28..0c8ba3171aa 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -36,40 +36,35 @@ public:
void runOnFunction() override;
};
-class QuantizedConstRewrite : public RewritePattern {
-public:
- struct State : PatternState {
- QuantizedType quantizedElementType;
- Attribute value;
- };
-
- QuantizedConstRewrite(MLIRContext *context)
- : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {}
+struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
+ using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
- PatternMatchResult match(Operation *op) const override;
- void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
- PatternRewriter &rewriter) const override;
+ PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
+ PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
/// quantized and the operand type is quantizable.
-PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
- State state;
+
+PatternMatchResult
+QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
+ PatternRewriter &rewriter) const {
+ Attribute value;
// Is the operand a constant?
- auto qbarrier = cast<QuantizeCastOp>(op);
- if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
+ if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
return matchFailure();
}
+
// Does the qbarrier convert to a quantized type. This will not be true
// if a quantized type has not yet been chosen or if the cast to an equivalent
// storage type is not supported.
Type qbarrierResultType = qbarrier.getResult()->getType();
- state.quantizedElementType =
+ QuantizedType quantizedElementType =
QuantizedType::getQuantizedElementType(qbarrierResultType);
- if (!state.quantizedElementType) {
+ if (!quantizedElementType) {
return matchFailure();
}
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
@@ -79,43 +74,34 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
// Is the operand type compatible with the expressed type of the quantized
// type? This will not be true if the qbarrier is superfluous (converts
// from and to a quantized type).
- if (!state.quantizedElementType.isCompatibleExpressedType(
+ if (!quantizedElementType.isCompatibleExpressedType(
qbarrier.arg()->getType())) {
return matchFailure();
}
// Is the constant value a type expressed in a way that we support?
- if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
- !state.value.isa<DenseElementsAttr>() &&
- !state.value.isa<SparseElementsAttr>()) {
+ if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
+ !value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
return matchFailure();
}
- return matchSuccess(llvm::make_unique<State>(std::move(state)));
-}
-
-void QuantizedConstRewrite::rewrite(Operation *op,
- std::unique_ptr<PatternState> baseState,
- PatternRewriter &rewriter) const {
- auto state = static_cast<State *>(baseState.get());
-
Type newConstValueType;
- Attribute newConstValue = quantizeAttr(
- state->value, state->quantizedElementType, newConstValueType);
+ auto newConstValue =
+ quantizeAttr(value, quantizedElementType, newConstValueType);
if (!newConstValue) {
- return;
+ return matchFailure();
}
- auto *origConstOp = op->getOperand(0);
// When creating the new const op, use a fused location that combines the
// original const and the qbarrier that led to the quantization.
- auto fusedLoc =
- FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
- rewriter.getContext());
+ auto fusedLoc = FusedLoc::get(
+ {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
+ rewriter.getContext());
auto newConstOp =
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
- rewriter.replaceOpWithNewOp<StorageCastOp>(
- {origConstOp}, op, *op->result_type_begin(), newConstOp);
+ rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
+ qbarrier.getType(), newConstOp);
+ return matchSuccess();
}
void ConvertConstPass::runOnFunction() {
OpenPOWER on IntegriCloud