summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorFeng Liu <fengliuai@google.com>2019-09-10 10:50:16 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-09-10 10:50:57 -0700
commitc68d5467d604d2b1e06a704133370f51a99df11d (patch)
tree4082d544febf58405f56fc89c694beee303d905e
parent277b6136ee78e621a1737e35956d1a9317ff096d (diff)
downloadbcm5719-llvm-c68d5467d604d2b1e06a704133370f51a99df11d.tar.gz
bcm5719-llvm-c68d5467d604d2b1e06a704133370f51a99df11d.zip
Convert ConstFakeQuantPerAxis to qcast and dcast pair
This is also to add the test to the fakeQuantAttrsToType for per-channel fake quant. PiperOrigin-RevId: 268260032
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp92
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp5
-rw-r--r--mlir/test/Dialect/QuantOps/convert-fakequant.mlir19
3 files changed, 89 insertions, 27 deletions
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
index 4f6eb8cb985..1000b1fabbf 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
@@ -37,54 +37,53 @@ public:
} // end anonymous namespace
-/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
-class ConstFakeQuantRewrite : public RewritePattern {
+/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
+template <typename ConcretRewriteClass, typename FakeQuantOp>
+class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
public:
- bool *hadFailure;
+ using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
- ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
- : RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
- hadFailure(hadFailure) {}
+ FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
+ : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(FakeQuantOp op,
PatternRewriter &rewriter) const override {
// TODO: If this pattern comes up more frequently, consider adding core
// support for failable rewrites.
if (failableRewrite(op, rewriter)) {
*hadFailure = true;
- return matchFailure();
+ return Pattern::matchFailure();
}
- return matchSuccess();
+ return Pattern::matchSuccess();
}
- bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
- auto fqOp = cast<ConstFakeQuant>(op);
+private:
+ bool *hadFailure;
- auto converter =
- ExpressedToQuantizedConverter::forInputType(fqOp.getType());
+ bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
+ auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
if (!converter) {
- return (op->emitError("unsupported quantized type conversion"), true);
+ return (op.emitError("unsupported quantized type conversion"), true);
}
- UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
- fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
- fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
- fqOp.narrow_range(), converter.expressedType, fqOp.is_signed());
+ QuantizedType elementType =
+ static_cast<const ConcretRewriteClass *>(this)
+ ->convertFakeQuantAttrsToType(op, converter.expressedType);
- if (!uniformElementType) {
+ if (!elementType) {
// Note that the fakeQuantAttrsToType will have emitted the error.
return true;
}
- Type quantizedType = converter.convert(uniformElementType);
+ Type quantizedType = converter.convert(elementType);
assert(quantizedType &&
"Converter accepted a type that it did not convert");
// TODO: Map to a qbarrier with an attribute like [Forced] to signal that
// this is a forced/hard-coded constraint.
- auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
- fqOp.inputs());
+ auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
+ op.inputs());
rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
qbarrier.getResult());
@@ -92,12 +91,57 @@ public:
}
};
+class ConstFakeQuantRewrite
+ : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
+public:
+ using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
+
+ ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
+ : BaseRewrite(ctx, hadFailure) {}
+
+ QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
+ Type expressedType) const {
+ return fakeQuantAttrsToType(
+ fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
+ fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
+ fqOp.narrow_range(), expressedType, fqOp.is_signed());
+ }
+};
+
+class ConstFakeQuantPerAxisRewrite
+ : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
+ ConstFakeQuantPerAxis> {
+public:
+ using BaseRewrite =
+ FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
+
+ ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
+ : BaseRewrite(ctx, hadFailure) {}
+
+ QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
+ Type expressedType) const {
+ SmallVector<double, 4> min, max;
+ min.reserve(fqOp.min().size());
+ max.reserve(fqOp.max().size());
+ for (auto m : fqOp.min())
+ min.push_back(m.cast<FloatAttr>().getValueAsDouble());
+ for (auto m : fqOp.max())
+ max.push_back(m.cast<FloatAttr>().getValueAsDouble());
+
+ return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
+ fqOp.axis().getSExtValue(), min, max,
+ fqOp.narrow_range(), expressedType,
+ fqOp.is_signed());
+ }
+};
+
void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false;
OwningRewritePatternList patterns;
auto func = getFunction();
- auto *context = &getContext();
- patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
+ auto ctx = func.getContext();
+ patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
+ ctx, &hadFailure);
applyPatternsGreedily(func, patterns);
if (hadFailure)
signalPassFailure();
diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
index 02f803ac839..5d4561be81b 100644
--- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
@@ -136,7 +136,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
loc);
}
-// TODO(fengliuai): test this method once the quantizeAttr method is fixed.
UniformQuantizedPerAxisType
fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
ArrayRef<double> rmins, ArrayRef<double> rmaxs,
@@ -180,8 +179,8 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
return UniformQuantizedPerAxisType::getChecked(
- flags, storageType, expressedType, scales, zeroPoints, qmin, qmax,
- quantizedDimension, loc);
+ flags, storageType, expressedType, scales, zeroPoints, quantizedDimension,
+ qmin, qmax, loc);
}
} // namespace quant
diff --git a/mlir/test/Dialect/QuantOps/convert-fakequant.mlir b/mlir/test/Dialect/QuantOps/convert-fakequant.mlir
index 15de088f39c..316702cc528 100644
--- a/mlir/test/Dialect/QuantOps/convert-fakequant.mlir
+++ b/mlir/test/Dialect/QuantOps/convert-fakequant.mlir
@@ -180,3 +180,22 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
+
+// -----
+// Verifies a qint8 per axis
+// CHECK_LABEL: fakeQuantPerAxis
+func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+
+ // CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
+ // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>
+ // CHECK: %[[d:.*]] = "quant.dcast"(%[[q]])
+ // CHECK-SAME: (tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>)
+
+ %0 = "quant.const_fake_quant_per_axis"(%arg0) {
+ min = [-1.0 : f32, 0.0 : f32, 0.0 : f32],
+ max = [0.9921875 : f32, 0.0: f32, 1.0 : f32],
+ num_bits = 8, narrow_range = false, is_signed = true, axis = 2
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ return %0 : tensor<8x4x3xf32>
+}
OpenPOWER on IntegriCloud