summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/FxpMathOps/Transforms
diff options
context:
space:
mode:
authorGeoffrey Martin-Noble <gcmn@google.com>2019-05-16 00:12:45 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-05-20 13:43:58 -0700
commit090662c5f3572c573cf249844748ecbf11d10dbe (patch)
treec3b040c726fb56de37f6fa80fcd40ff0ef16f2e3 /mlir/lib/Dialect/FxpMathOps/Transforms
parentb3888fa9cc46dc3607af2dc5cf840ef33d0fcd8d (diff)
downloadbcm5719-llvm-090662c5f3572c573cf249844748ecbf11d10dbe.tar.gz
bcm5719-llvm-090662c5f3572c573cf249844748ecbf11d10dbe.zip
Rename VectorOrTensorType to ShapedType
This is in preparation for making it also support/be a parent class of MemRefType. MemRefs have similar shape/rank/element semantics and it would be useful to be able to use these same utilities for them. This CL should not change any semantics and only change variables, types, string literals, and comments. In follow-up CLs I will prepare all callers to handle MemRef types or remove their dependence on ShapedType. Discussion/Rationale in https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/cHLoyfGu8y8 -- PiperOrigin-RevId: 248476449
Diffstat (limited to 'mlir/lib/Dialect/FxpMathOps/Transforms')
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h24
1 files changed, 12 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
index 325b0ca93ca..2fdac8ef26a 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
@@ -168,12 +168,12 @@ struct QuantizedMultiplierSmallerThanOneExp {
/// Casts an integer or floating point based type to a new element type.
inline Type castElementType(Type t, Type newElementType) {
- if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
- switch (vt.getKind()) {
+ if (auto st = t.dyn_cast<ShapedType>()) {
+ switch (st.getKind()) {
case StandardTypes::Kind::Vector:
- return VectorType::get(vt.getShape(), newElementType);
+ return VectorType::get(st.getShape(), newElementType);
case StandardTypes::Kind::RankedTensor:
- return RankedTensorType::get(vt.getShape(), newElementType);
+ return RankedTensorType::get(st.getShape(), newElementType);
case StandardTypes::Kind::UnrankedTensor:
return UnrankedTensorType::get(newElementType);
}
@@ -185,10 +185,10 @@ inline Type castElementType(Type t, Type newElementType) {
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
/// be a primitive/vector/tensor).
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
- if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
- assert(vt.getElementType().isa<IntegerType>());
- return SplatElementsAttr::get(vt,
- IntegerAttr::get(vt.getElementType(), value));
+ if (auto st = t.dyn_cast<ShapedType>()) {
+ assert(st.getElementType().isa<IntegerType>());
+ return SplatElementsAttr::get(st,
+ IntegerAttr::get(st.getElementType(), value));
}
auto integerType = t.cast<IntegerType>();
@@ -211,13 +211,13 @@ inline APFloat convertFloatToType(FloatType ft, APFloat value) {
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
/// be a primitive/vector/tensor).
inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
- if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
- FloatType floatElementType = vt.getElementType().dyn_cast<FloatType>();
+ if (auto st = t.dyn_cast<ShapedType>()) {
+ FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
assert(floatElementType &&
"float broadcast element type must be float like");
APFloat apValue = convertFloatToType(floatElementType, value);
- return SplatElementsAttr::get(vt,
- FloatAttr::get(vt.getElementType(), apValue));
+ return SplatElementsAttr::get(st,
+ FloatAttr::get(st.getElementType(), apValue));
} else {
auto floatType = t.dyn_cast<FloatType>();
assert(floatType && "float broadcast must be of float type");
OpenPOWER on IntegriCloud