diff options
author | Feng Liu <fengliuai@google.com> | 2019-09-10 10:50:16 -0700 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-09-10 10:50:57 -0700 |
commit | c68d5467d604d2b1e06a704133370f51a99df11d (patch) | |
tree | 4082d544febf58405f56fc89c694beee303d905e | |
parent | 277b6136ee78e621a1737e35956d1a9317ff096d (diff) | |
download | bcm5719-llvm-c68d5467d604d2b1e06a704133370f51a99df11d.tar.gz bcm5719-llvm-c68d5467d604d2b1e06a704133370f51a99df11d.zip |
Convert ConstFakeQuantPerAxis to qcast and dcast pair
This is also to add the test to the fakeQuantAttrsToType for per-channel fake quant.
PiperOrigin-RevId: 268260032
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp | 92 | ||||
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 5 | ||||
-rw-r--r-- | mlir/test/Dialect/QuantOps/convert-fakequant.mlir | 19 |
3 files changed, 89 insertions, 27 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 4f6eb8cb985..1000b1fabbf 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -37,54 +37,53 @@ public: } // end anonymous namespace -/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair. -class ConstFakeQuantRewrite : public RewritePattern { +/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. +template <typename ConcretRewriteClass, typename FakeQuantOp> +class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> { public: - bool *hadFailure; + using OpRewritePattern<FakeQuantOp>::OpRewritePattern; - ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure) - : RewritePattern(ConstFakeQuant::getOperationName(), 1, context), - hadFailure(hadFailure) {} + FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {} - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(FakeQuantOp op, PatternRewriter &rewriter) const override { // TODO: If this pattern comes up more frequently, consider adding core // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; - return matchFailure(); + return Pattern::matchFailure(); } - return matchSuccess(); + return Pattern::matchSuccess(); } - bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { - auto fqOp = cast<ConstFakeQuant>(op); +private: + bool *hadFailure; - auto converter = - ExpressedToQuantizedConverter::forInputType(fqOp.getType()); + bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { + auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); if (!converter) { - return (op->emitError("unsupported quantized type conversion"), true); + return (op.emitError("unsupported quantized type conversion"), true); } - UniformQuantizedType uniformElementType = fakeQuantAttrsToType( - fqOp.getLoc(), fqOp.num_bits().getSExtValue(), - fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), - fqOp.narrow_range(), converter.expressedType, fqOp.is_signed()); + QuantizedType elementType = + static_cast<const ConcretRewriteClass *>(this) + ->convertFakeQuantAttrsToType(op, converter.expressedType); - if (!uniformElementType) { + if (!elementType) { // Note that the fakeQuantAttrsToType will have emitted the error. return true; } - Type quantizedType = converter.convert(uniformElementType); + Type quantizedType = converter.convert(elementType); assert(quantizedType && "Converter accepted a type that it did not convert"); // TODO: Map to a qbarrier with an attribute like [Forced] to signal that // this is a forced/hard-coded constraint. - auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType, - fqOp.inputs()); + auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType, + op.inputs()); rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, qbarrier.getResult()); @@ -92,12 +91,57 @@ public: } }; +class ConstFakeQuantRewrite + : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> { +public: + using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>; + + ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, + Type expressedType) const { + return fakeQuantAttrsToType( + fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), + fqOp.narrow_range(), expressedType, fqOp.is_signed()); + } +}; + +class ConstFakeQuantPerAxisRewrite + : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, + ConstFakeQuantPerAxis> { +public: + using BaseRewrite = + FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>; + + ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, + Type expressedType) const { + SmallVector<double, 4> min, max; + min.reserve(fqOp.min().size()); + max.reserve(fqOp.max().size()); + for (auto m : fqOp.min()) + min.push_back(m.cast<FloatAttr>().getValueAsDouble()); + for (auto m : fqOp.max()) + max.push_back(m.cast<FloatAttr>().getValueAsDouble()); + + return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.axis().getSExtValue(), min, max, + fqOp.narrow_range(), expressedType, + fqOp.is_signed()); + } +}; + void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; auto func = getFunction(); - auto *context = &getContext(); - patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure); + auto ctx = func.getContext(); + patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( + ctx, &hadFailure); applyPatternsGreedily(func, patterns); if (hadFailure) signalPassFailure(); diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 02f803ac839..5d4561be81b 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -136,7 +136,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, loc); } -// TODO(fengliuai): test this method once the quantizeAttr method is fixed. UniformQuantizedPerAxisType fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, ArrayRef<double> rmins, ArrayRef<double> rmaxs, @@ -180,8 +179,8 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, unsigned flags = isSigned ? QuantizationFlags::Signed : 0; return UniformQuantizedPerAxisType::getChecked( - flags, storageType, expressedType, scales, zeroPoints, qmin, qmax, - quantizedDimension, loc); + flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, + qmin, qmax, loc); } } // namespace quant diff --git a/mlir/test/Dialect/QuantOps/convert-fakequant.mlir b/mlir/test/Dialect/QuantOps/convert-fakequant.mlir index 15de088f39c..316702cc528 100644 --- a/mlir/test/Dialect/QuantOps/convert-fakequant.mlir +++ b/mlir/test/Dialect/QuantOps/convert-fakequant.mlir @@ -180,3 +180,22 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> { } : (tensor<f32>) -> tensor<f32> return %0 : tensor<f32> } + +// ----- +// Verifies a qint8 per axis +// CHECK_LABEL: fakeQuantPerAxis +func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + + // CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>> + // CHECK: %[[d:.*]] = "quant.dcast"(%[[q]]) + // CHECK-SAME: (tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>) + + %0 = "quant.const_fake_quant_per_axis"(%arg0) { + min = [-1.0 : f32, 0.0 : f32, 0.0 : f32], + max = [0.9921875 : f32, 0.0: f32, 1.0 : f32], + num_bits = 8, narrow_range = false, is_signed = true, axis = 2 + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} |