diff options
| author | Geoffrey Martin-Noble <gcmn@google.com> | 2019-05-16 00:12:45 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2019-05-20 13:43:58 -0700 |
| commit | 090662c5f3572c573cf249844748ecbf11d10dbe (patch) | |
| tree | c3b040c726fb56de37f6fa80fcd40ff0ef16f2e3 /mlir/lib/Dialect/FxpMathOps/Transforms | |
| parent | b3888fa9cc46dc3607af2dc5cf840ef33d0fcd8d (diff) | |
| download | bcm5719-llvm-090662c5f3572c573cf249844748ecbf11d10dbe.tar.gz bcm5719-llvm-090662c5f3572c573cf249844748ecbf11d10dbe.zip | |
Rename VectorOrTensorType to ShapedType
This is in preparation for making it also support/be a parent class of MemRefType. MemRefs have similar shape/rank/element semantics and it would be useful to be able to use these same utilities for them.
This CL should not change any semantics and only change variables, types, string literals, and comments. In follow-up CLs I will prepare all callers to handle MemRef types or remove their dependence on ShapedType.
Discussion/Rationale in https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/cHLoyfGu8y8
--
PiperOrigin-RevId: 248476449
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms')
| -rw-r--r-- | mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index 325b0ca93ca..2fdac8ef26a 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -168,12 +168,12 @@ struct QuantizedMultiplierSmallerThanOneExp { /// Casts an integer or floating point based type to a new element type. inline Type castElementType(Type t, Type newElementType) { - if (auto vt = t.dyn_cast<VectorOrTensorType>()) { - switch (vt.getKind()) { + if (auto st = t.dyn_cast<ShapedType>()) { + switch (st.getKind()) { case StandardTypes::Kind::Vector: - return VectorType::get(vt.getShape(), newElementType); + return VectorType::get(st.getShape(), newElementType); case StandardTypes::Kind::RankedTensor: - return RankedTensorType::get(vt.getShape(), newElementType); + return RankedTensorType::get(st.getShape(), newElementType); case StandardTypes::Kind::UnrankedTensor: return UnrankedTensorType::get(newElementType); } @@ -185,10 +185,10 @@ inline Type castElementType(Type t, Type newElementType) { /// Creates an IntegerAttr with a type that matches the shape of 't' (which can /// be a primitive/vector/tensor). inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) { - if (auto vt = t.dyn_cast<VectorOrTensorType>()) { - assert(vt.getElementType().isa<IntegerType>()); - return SplatElementsAttr::get(vt, - IntegerAttr::get(vt.getElementType(), value)); + if (auto st = t.dyn_cast<ShapedType>()) { + assert(st.getElementType().isa<IntegerType>()); + return SplatElementsAttr::get(st, + IntegerAttr::get(st.getElementType(), value)); } auto integerType = t.cast<IntegerType>(); @@ -211,13 +211,13 @@ inline APFloat convertFloatToType(FloatType ft, APFloat value) { /// Creates an IntegerAttr with a type that matches the shape of 't' (which can /// be a primitive/vector/tensor). inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) { - if (auto vt = t.dyn_cast<VectorOrTensorType>()) { - FloatType floatElementType = vt.getElementType().dyn_cast<FloatType>(); + if (auto st = t.dyn_cast<ShapedType>()) { + FloatType floatElementType = st.getElementType().dyn_cast<FloatType>(); assert(floatElementType && "float broadcast element type must be float like"); APFloat apValue = convertFloatToType(floatElementType, value); - return SplatElementsAttr::get(vt, - FloatAttr::get(vt.getElementType(), apValue)); + return SplatElementsAttr::get(st, + FloatAttr::get(st.getElementType(), apValue)); } else { auto floatType = t.dyn_cast<FloatType>(); assert(floatType && "float broadcast must be of float type"); |

