summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps
diff options
context:
space:
mode:
authorGeoffrey Martin-Noble <gcmn@google.com>2019-05-16 00:12:45 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-05-20 13:43:58 -0700
commit090662c5f3572c573cf249844748ecbf11d10dbe (patch)
treec3b040c726fb56de37f6fa80fcd40ff0ef16f2e3 /mlir/lib/Dialect/QuantOps
parentb3888fa9cc46dc3607af2dc5cf840ef33d0fcd8d (diff)
downloadbcm5719-llvm-090662c5f3572c573cf249844748ecbf11d10dbe.tar.gz
bcm5719-llvm-090662c5f3572c573cf249844748ecbf11d10dbe.zip
Rename VectorOrTensorType to ShapedType
This is in preparation for making it also support/be a parent class of MemRefType. MemRefs have similar shape/rank/element semantics and it would be useful to be able to use these same utilities for them. This CL should not change any semantics and only change variables, types, string literals, and comments. In follow-up CLs I will prepare all callers to handle MemRef types or remove their dependence on ShapedType. Discussion/Rationale in https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/cHLoyfGu8y8 -- PiperOrigin-RevId: 248476449
Diffstat (limited to 'mlir/lib/Dialect/QuantOps')
-rw-r--r--mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp43
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp12
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp4
3 files changed, 29 insertions, 30 deletions
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
index 6ca3b92d064..1b63b8f4f55 100644
--- a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
@@ -98,8 +98,8 @@ Type QuantizedType::getExpressedType() const {
}
bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
- if (candidateExpressedType.isa<VectorOrTensorType>()) {
- return candidateExpressedType.cast<VectorOrTensorType>().getElementType() ==
+ if (candidateExpressedType.isa<ShapedType>()) {
+ return candidateExpressedType.cast<ShapedType>().getElementType() ==
getExpressedType();
}
return candidateExpressedType == getExpressedType();
@@ -107,9 +107,9 @@ bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
QuantizedType
QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
- if (primitiveOrContainerType.isa<VectorOrTensorType>()) {
+ if (primitiveOrContainerType.isa<ShapedType>()) {
Type elementType =
- primitiveOrContainerType.cast<VectorOrTensorType>().getElementType();
+ primitiveOrContainerType.cast<ShapedType>().getElementType();
return elementType.dyn_cast<QuantizedType>();
}
return primitiveOrContainerType.dyn_cast<QuantizedType>();
@@ -139,20 +139,20 @@ Type QuantizedType::castToStorageType(Type quantizedType) {
if (quantizedType.isa<QuantizedType>()) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
return quantizedType.cast<QuantizedType>().getStorageType();
- } else if (quantizedType.isa<VectorOrTensorType>()) {
+ } else if (quantizedType.isa<ShapedType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
- if (!vtType.getElementType().isa<QuantizedType>()) {
+ ShapedType sType = quantizedType.cast<ShapedType>();
+ if (!sType.getElementType().isa<QuantizedType>()) {
return nullptr;
}
Type storageType =
- vtType.getElementType().cast<QuantizedType>().getStorageType();
+ sType.getElementType().cast<QuantizedType>().getStorageType();
if (quantizedType.isa<RankedTensorType>()) {
- return RankedTensorType::get(vtType.getShape(), storageType);
+ return RankedTensorType::get(sType.getShape(), storageType);
} else if (quantizedType.isa<UnrankedTensorType>()) {
return UnrankedTensorType::get(storageType);
} else if (quantizedType.isa<VectorType>()) {
- return VectorType::get(vtType.getShape(), storageType);
+ return VectorType::get(sType.getShape(), storageType);
}
}
@@ -163,22 +163,21 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
if (candidateType == getExpressedType()) {
// i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
return *this;
- } else if (candidateType.isa<VectorOrTensorType>()) {
- VectorOrTensorType candidateVtType =
- candidateType.cast<VectorOrTensorType>();
- if (candidateVtType.getElementType() != getExpressedType()) {
+ } else if (candidateType.isa<ShapedType>()) {
+ ShapedType candidateShapedType = candidateType.cast<ShapedType>();
+ if (candidateShapedType.getElementType() != getExpressedType()) {
return nullptr;
}
if (candidateType.isa<RankedTensorType>()) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- return RankedTensorType::get(candidateVtType.getShape(), *this);
+ return RankedTensorType::get(candidateShapedType.getShape(), *this);
} else if (candidateType.isa<UnrankedTensorType>()) {
// i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(*this);
} else if (candidateType.isa<VectorType>()) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- return VectorType::get(candidateVtType.getShape(), *this);
+ return VectorType::get(candidateShapedType.getShape(), *this);
}
}
@@ -189,20 +188,20 @@ Type QuantizedType::castToExpressedType(Type quantizedType) {
if (quantizedType.isa<QuantizedType>()) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
return quantizedType.cast<QuantizedType>().getExpressedType();
- } else if (quantizedType.isa<VectorOrTensorType>()) {
+ } else if (quantizedType.isa<ShapedType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
- if (!vtType.getElementType().isa<QuantizedType>()) {
+ ShapedType sType = quantizedType.cast<ShapedType>();
+ if (!sType.getElementType().isa<QuantizedType>()) {
return nullptr;
}
Type expressedType =
- vtType.getElementType().cast<QuantizedType>().getExpressedType();
+ sType.getElementType().cast<QuantizedType>().getExpressedType();
if (quantizedType.isa<RankedTensorType>()) {
- return RankedTensorType::get(vtType.getShape(), expressedType);
+ return RankedTensorType::get(sType.getShape(), expressedType);
} else if (quantizedType.isa<UnrankedTensorType>()) {
return UnrankedTensorType::get(expressedType);
} else if (quantizedType.isa<VectorType>()) {
- return VectorType::get(vtType.getShape(), expressedType);
+ return VectorType::get(sType.getShape(), expressedType);
}
}
diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
index 3685a65f2d8..c50b3075b69 100644
--- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
+++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
@@ -56,10 +56,10 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
- VectorOrTensorType newDenseType =
+ ShapedType newDenseType =
quantizedElementType
.castExpressedToStorageType(realFPElementsAttr.getType())
- .dyn_cast_or_null<VectorOrTensorType>();
+ .dyn_cast_or_null<ShapedType>();
if (!newDenseType) {
return nullptr;
}
@@ -87,9 +87,9 @@ convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
- VectorOrTensorType newSplatType =
+ ShapedType newSplatType =
quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
- .dyn_cast_or_null<VectorOrTensorType>();
+ .dyn_cast_or_null<ShapedType>();
if (!newSplatType) {
return nullptr;
}
@@ -116,9 +116,9 @@ convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
- VectorOrTensorType newSparseType =
+ ShapedType newSparseType =
quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
- .dyn_cast_or_null<VectorOrTensorType>();
+ .dyn_cast_or_null<ShapedType>();
if (!newSparseType) {
return nullptr;
}
diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
index d791075f5db..db8a5848981 100644
--- a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
@@ -38,13 +38,13 @@ ExpressedToUniformQuantizedConverter::forInputType(Type inputType) {
case StandardTypes::RankedTensor:
case StandardTypes::UnrankedTensor:
case StandardTypes::Vector: {
- Type elementType = inputType.cast<VectorOrTensorType>().getElementType();
+ Type elementType = inputType.cast<ShapedType>().getElementType();
if (!isQuantizablePrimitiveType(elementType)) {
// Unsupported.
return ExpressedToUniformQuantizedConverter{inputType, nullptr};
}
return ExpressedToUniformQuantizedConverter{
- inputType, inputType.cast<VectorOrTensorType>().getElementType()};
+ inputType, inputType.cast<ShapedType>().getElementType()};
}
}
}
OpenPOWER on IntegriCloud