diff options
| author | Geoffrey Martin-Noble <gcmn@google.com> | 2019-05-16 00:12:45 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2019-05-20 13:43:58 -0700 |
| commit | 090662c5f3572c573cf249844748ecbf11d10dbe (patch) | |
| tree | c3b040c726fb56de37f6fa80fcd40ff0ef16f2e3 /mlir/lib/Dialect/QuantOps | |
| parent | b3888fa9cc46dc3607af2dc5cf840ef33d0fcd8d (diff) | |
| download | bcm5719-llvm-090662c5f3572c573cf249844748ecbf11d10dbe.tar.gz bcm5719-llvm-090662c5f3572c573cf249844748ecbf11d10dbe.zip | |
Rename VectorOrTensorType to ShapedType
This is in preparation for making it also support/be a parent class of MemRefType. MemRefs have similar shape/rank/element semantics and it would be useful to be able to use these same utilities for them.
This CL should not change any semantics and only change variables, types, string literals, and comments. In follow-up CLs I will prepare all callers to handle MemRef types or remove their dependence on ShapedType.
Discussion/Rationale in https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/cHLoyfGu8y8
--
PiperOrigin-RevId: 248476449
Diffstat (limited to 'mlir/lib/Dialect/QuantOps')
| -rw-r--r-- | mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp | 43 | ||||
| -rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp | 12 | ||||
| -rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp | 4 |
3 files changed, 29 insertions, 30 deletions
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp index 6ca3b92d064..1b63b8f4f55 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp @@ -98,8 +98,8 @@ Type QuantizedType::getExpressedType() const { } bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { - if (candidateExpressedType.isa<VectorOrTensorType>()) { - return candidateExpressedType.cast<VectorOrTensorType>().getElementType() == + if (candidateExpressedType.isa<ShapedType>()) { + return candidateExpressedType.cast<ShapedType>().getElementType() == getExpressedType(); } return candidateExpressedType == getExpressedType(); @@ -107,9 +107,9 @@ bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { QuantizedType QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { - if (primitiveOrContainerType.isa<VectorOrTensorType>()) { + if (primitiveOrContainerType.isa<ShapedType>()) { Type elementType = - primitiveOrContainerType.cast<VectorOrTensorType>().getElementType(); + primitiveOrContainerType.cast<ShapedType>().getElementType(); return elementType.dyn_cast<QuantizedType>(); } return primitiveOrContainerType.dyn_cast<QuantizedType>(); @@ -139,20 +139,20 @@ Type QuantizedType::castToStorageType(Type quantizedType) { if (quantizedType.isa<QuantizedType>()) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 return quantizedType.cast<QuantizedType>().getStorageType(); - } else if (quantizedType.isa<VectorOrTensorType>()) { + } else if (quantizedType.isa<ShapedType>()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>(); - if (!vtType.getElementType().isa<QuantizedType>()) { + ShapedType sType = quantizedType.cast<ShapedType>(); + if (!sType.getElementType().isa<QuantizedType>()) { return nullptr; } Type storageType = - vtType.getElementType().cast<QuantizedType>().getStorageType(); + sType.getElementType().cast<QuantizedType>().getStorageType(); if (quantizedType.isa<RankedTensorType>()) { - return RankedTensorType::get(vtType.getShape(), storageType); + return RankedTensorType::get(sType.getShape(), storageType); } else if (quantizedType.isa<UnrankedTensorType>()) { return UnrankedTensorType::get(storageType); } else if (quantizedType.isa<VectorType>()) { - return VectorType::get(vtType.getShape(), storageType); + return VectorType::get(sType.getShape(), storageType); } } @@ -163,22 +163,21 @@ Type QuantizedType::castFromExpressedType(Type candidateType) { if (candidateType == getExpressedType()) { // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> return *this; - } else if (candidateType.isa<VectorOrTensorType>()) { - VectorOrTensorType candidateVtType = - candidateType.cast<VectorOrTensorType>(); - if (candidateVtType.getElementType() != getExpressedType()) { + } else if (candidateType.isa<ShapedType>()) { + ShapedType candidateShapedType = candidateType.cast<ShapedType>(); + if (candidateShapedType.getElementType() != getExpressedType()) { return nullptr; } if (candidateType.isa<RankedTensorType>()) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return RankedTensorType::get(candidateVtType.getShape(), *this); + return RankedTensorType::get(candidateShapedType.getShape(), *this); } else if (candidateType.isa<UnrankedTensorType>()) { // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">> return UnrankedTensorType::get(*this); } else if (candidateType.isa<VectorType>()) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return VectorType::get(candidateVtType.getShape(), *this); + return VectorType::get(candidateShapedType.getShape(), *this); } } @@ -189,20 +188,20 @@ Type QuantizedType::castToExpressedType(Type quantizedType) { if (quantizedType.isa<QuantizedType>()) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 return quantizedType.cast<QuantizedType>().getExpressedType(); - } else if (quantizedType.isa<VectorOrTensorType>()) { + } else if (quantizedType.isa<ShapedType>()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>(); - if (!vtType.getElementType().isa<QuantizedType>()) { + ShapedType sType = quantizedType.cast<ShapedType>(); + if (!sType.getElementType().isa<QuantizedType>()) { return nullptr; } Type expressedType = - vtType.getElementType().cast<QuantizedType>().getExpressedType(); + sType.getElementType().cast<QuantizedType>().getExpressedType(); if (quantizedType.isa<RankedTensorType>()) { - return RankedTensorType::get(vtType.getShape(), expressedType); + return RankedTensorType::get(sType.getShape(), expressedType); } else if (quantizedType.isa<UnrankedTensorType>()) { return UnrankedTensorType::get(expressedType); } else if (quantizedType.isa<VectorType>()) { - return VectorType::get(vtType.getShape(), expressedType); + return VectorType::get(sType.getShape(), expressedType); } } diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp index 3685a65f2d8..c50b3075b69 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -56,10 +56,10 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newDenseType = + ShapedType newDenseType = quantizedElementType .castExpressedToStorageType(realFPElementsAttr.getType()) - .dyn_cast_or_null<VectorOrTensorType>(); + .dyn_cast_or_null<ShapedType>(); if (!newDenseType) { return nullptr; } @@ -87,9 +87,9 @@ convertSplatElementsAttr(SplatElementsAttr realSplatAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newSplatType = + ShapedType newSplatType = quantizedElementType.castExpressedToStorageType(realSplatAttr.getType()) - .dyn_cast_or_null<VectorOrTensorType>(); + .dyn_cast_or_null<ShapedType>(); if (!newSplatType) { return nullptr; } @@ -116,9 +116,9 @@ convertSparseElementsAttr(SparseElementsAttr realSparseAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newSparseType = + ShapedType newSparseType = quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) - .dyn_cast_or_null<VectorOrTensorType>(); + .dyn_cast_or_null<ShapedType>(); if (!newSparseType) { return nullptr; } diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp index d791075f5db..db8a5848981 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp @@ -38,13 +38,13 @@ ExpressedToUniformQuantizedConverter::forInputType(Type inputType) { case StandardTypes::RankedTensor: case StandardTypes::UnrankedTensor: case StandardTypes::Vector: { - Type elementType = inputType.cast<VectorOrTensorType>().getElementType(); + Type elementType = inputType.cast<ShapedType>().getElementType(); if (!isQuantizablePrimitiveType(elementType)) { // Unsupported. return ExpressedToUniformQuantizedConverter{inputType, nullptr}; } return ExpressedToUniformQuantizedConverter{ - inputType, inputType.cast<VectorOrTensorType>().getElementType()}; + inputType, inputType.cast<ShapedType>().getElementType()}; } } } |

