diff options
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp')
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp | 92 |
1 files changed, 68 insertions, 24 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 4f6eb8cb985..1000b1fabbf 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -37,54 +37,53 @@ public: } // end anonymous namespace -/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair. -class ConstFakeQuantRewrite : public RewritePattern { +/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. +template <typename ConcretRewriteClass, typename FakeQuantOp> +class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { public: - bool *hadFailure; + using OpRewritePattern<FakeQuantOp>::OpRewritePattern; - ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure) - : RewritePattern(ConstFakeQuant::getOperationName(), 1, context), - hadFailure(hadFailure) {} + FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(FakeQuantOp op, PatternRewriter &rewriter) const override { // TODO: If this pattern comes up more frequently, consider adding core // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; - return matchFailure(); + return Pattern::matchFailure(); } - return matchSuccess(); + return Pattern::matchSuccess(); } - bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { - auto fqOp = cast<ConstFakeQuant>(op); +private: + bool *hadFailure; - auto converter = - ExpressedToQuantizedConverter::forInputType(fqOp.getType()); + bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { + auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); if (!converter) { - return (op->emitError("unsupported quantized type conversion"), true); + return (op.emitError("unsupported quantized type conversion"), true); } - UniformQuantizedType uniformElementType = fakeQuantAttrsToType( - fqOp.getLoc(), fqOp.num_bits().getSExtValue(), - fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), - fqOp.narrow_range(), converter.expressedType, fqOp.is_signed()); + QuantizedType elementType = + static_cast<const ConcretRewriteClass *>(this) + ->convertFakeQuantAttrsToType(op, converter.expressedType); - if (!uniformElementType) { + if (!elementType) { // Note that the fakeQuantAttrsToType will have emitted the error. return true; } - Type quantizedType = converter.convert(uniformElementType); + Type quantizedType = converter.convert(elementType); assert(quantizedType && "Converter accepted a type that it did not convert"); // TODO: Map to a qbarrier with an attribute like [Forced] to signal that // this is a forced/hard-coded constraint. - auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType, - fqOp.inputs()); + auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, + op.inputs()); rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, qbarrier.getResult()); @@ -92,12 +91,57 @@ public: } }; +class ConstFakeQuantRewrite + : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { +public: + using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; + + ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, + Type expressedType) const { + return fakeQuantAttrsToType( + fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), + fqOp.narrow_range(), expressedType, fqOp.is_signed()); + } +}; + +class ConstFakeQuantPerAxisRewrite + : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, + ConstFakeQuantPerAxis> { +public: + using BaseRewrite = + FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; + + ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, + Type expressedType) const { + SmallVector<double, 4> min, max; + min.reserve(fqOp.min().size()); + max.reserve(fqOp.max().size()); + for (auto m : fqOp.min()) + min.push_back(m.cast<FloatAttr>().getValueAsDouble()); + for (auto m : fqOp.max()) + max.push_back(m.cast<FloatAttr>().getValueAsDouble()); + + return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.axis().getSExtValue(), min, max, + fqOp.narrow_range(), expressedType, + fqOp.is_signed()); + } +}; + void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; auto func = getFunction(); - auto *context = &getContext(); - patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure); + auto ctx = func.getContext(); + patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( + ctx, &hadFailure); applyPatternsGreedily(func, patterns); if (hadFailure) signalPassFailure(); |