summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/CodeGen/MachineValueType.h134
-rw-r--r--llvm/include/llvm/CodeGen/ValueTypes.h59
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp4
-rw-r--r--llvm/unittests/CodeGen/CMakeLists.txt1
-rw-r--r--llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp88
5 files changed, 274 insertions, 12 deletions
diff --git a/llvm/include/llvm/CodeGen/MachineValueType.h b/llvm/include/llvm/CodeGen/MachineValueType.h
index 743fe1c1015..d5de22255a6 100644
--- a/llvm/include/llvm/CodeGen/MachineValueType.h
+++ b/llvm/include/llvm/CodeGen/MachineValueType.h
@@ -232,6 +232,42 @@ class MVT {
SimpleValueType SimpleTy;
+
+ // A class to represent the number of elements in a vector
+ //
+ // For fixed-length vectors, the total number of elements is equal to 'Min'
+ // For scalable vectors, the total number of elements is a multiple of 'Min'
+ class ElementCount {
+ public:
+ unsigned Min;
+ bool Scalable;
+
+ ElementCount(unsigned Min, bool Scalable)
+ : Min(Min), Scalable(Scalable) {}
+
+ ElementCount operator*(unsigned RHS) {
+ return { Min * RHS, Scalable };
+ }
+
+ ElementCount& operator*=(unsigned RHS) {
+ Min *= RHS;
+ return *this;
+ }
+
+ ElementCount operator/(unsigned RHS) {
+ return { Min / RHS, Scalable };
+ }
+
+ ElementCount& operator/=(unsigned RHS) {
+ Min /= RHS;
+ return *this;
+ }
+
+ bool operator==(const ElementCount& RHS) {
+ return Min == RHS.Min && Scalable == RHS.Scalable;
+ }
+ };
+
constexpr MVT() : SimpleTy(INVALID_SIMPLE_VALUE_TYPE) {}
constexpr MVT(SimpleValueType SVT) : SimpleTy(SVT) {}
@@ -276,6 +312,15 @@ class MVT {
SimpleTy <= MVT::LAST_VECTOR_VALUETYPE);
}
+ /// Return true if this is a vector value type where the
+ /// runtime length is machine dependent
+ bool isScalableVector() const {
+ return ((SimpleTy >= MVT::FIRST_INTEGER_SCALABLE_VALUETYPE &&
+ SimpleTy <= MVT::LAST_INTEGER_SCALABLE_VALUETYPE) ||
+ (SimpleTy >= MVT::FIRST_FP_SCALABLE_VALUETYPE &&
+ SimpleTy <= MVT::LAST_FP_SCALABLE_VALUETYPE));
+ }
+
/// Return true if this is a 16-bit vector type.
bool is16BitVector() const {
return (SimpleTy == MVT::v2i8 || SimpleTy == MVT::v1i16 ||
@@ -560,6 +605,10 @@ class MVT {
}
}
+ MVT::ElementCount getVectorElementCount() const {
+ return { getVectorNumElements(), isScalableVector() };
+ }
+
unsigned getSizeInBits() const {
switch (SimpleTy) {
default:
@@ -837,6 +886,83 @@ class MVT {
return (MVT::SimpleValueType)(MVT::INVALID_SIMPLE_VALUE_TYPE);
}
+ static MVT getScalableVectorVT(MVT VT, unsigned NumElements) {
+ switch(VT.SimpleTy) {
+ default:
+ break;
+ case MVT::i1:
+ if (NumElements == 2) return MVT::nxv2i1;
+ if (NumElements == 4) return MVT::nxv4i1;
+ if (NumElements == 8) return MVT::nxv8i1;
+ if (NumElements == 16) return MVT::nxv16i1;
+ if (NumElements == 32) return MVT::nxv32i1;
+ break;
+ case MVT::i8:
+ if (NumElements == 1) return MVT::nxv1i8;
+ if (NumElements == 2) return MVT::nxv2i8;
+ if (NumElements == 4) return MVT::nxv4i8;
+ if (NumElements == 8) return MVT::nxv8i8;
+ if (NumElements == 16) return MVT::nxv16i8;
+ if (NumElements == 32) return MVT::nxv32i8;
+ break;
+ case MVT::i16:
+ if (NumElements == 1) return MVT::nxv1i16;
+ if (NumElements == 2) return MVT::nxv2i16;
+ if (NumElements == 4) return MVT::nxv4i16;
+ if (NumElements == 8) return MVT::nxv8i16;
+ if (NumElements == 16) return MVT::nxv16i16;
+ if (NumElements == 32) return MVT::nxv32i16;
+ break;
+ case MVT::i32:
+ if (NumElements == 1) return MVT::nxv1i32;
+ if (NumElements == 2) return MVT::nxv2i32;
+ if (NumElements == 4) return MVT::nxv4i32;
+ if (NumElements == 8) return MVT::nxv8i32;
+ if (NumElements == 16) return MVT::nxv16i32;
+ if (NumElements == 32) return MVT::nxv32i32;
+ break;
+ case MVT::i64:
+ if (NumElements == 1) return MVT::nxv1i64;
+ if (NumElements == 2) return MVT::nxv2i64;
+ if (NumElements == 4) return MVT::nxv4i64;
+ if (NumElements == 8) return MVT::nxv8i64;
+ if (NumElements == 16) return MVT::nxv16i64;
+ if (NumElements == 32) return MVT::nxv32i64;
+ break;
+ case MVT::f16:
+ if (NumElements == 2) return MVT::nxv2f16;
+ if (NumElements == 4) return MVT::nxv4f16;
+ if (NumElements == 8) return MVT::nxv8f16;
+ break;
+ case MVT::f32:
+ if (NumElements == 1) return MVT::nxv1f32;
+ if (NumElements == 2) return MVT::nxv2f32;
+ if (NumElements == 4) return MVT::nxv4f32;
+ if (NumElements == 8) return MVT::nxv8f32;
+ if (NumElements == 16) return MVT::nxv16f32;
+ break;
+ case MVT::f64:
+ if (NumElements == 1) return MVT::nxv1f64;
+ if (NumElements == 2) return MVT::nxv2f64;
+ if (NumElements == 4) return MVT::nxv4f64;
+ if (NumElements == 8) return MVT::nxv8f64;
+ break;
+ }
+ return (MVT::SimpleValueType)(MVT::INVALID_SIMPLE_VALUE_TYPE);
+ }
+
+ static MVT getVectorVT(MVT VT, unsigned NumElements, bool IsScalable) {
+ if (IsScalable)
+ return getScalableVectorVT(VT, NumElements);
+ return getVectorVT(VT, NumElements);
+ }
+
+ static MVT getVectorVT(MVT VT, MVT::ElementCount EC) {
+ if (EC.Scalable)
+ return getScalableVectorVT(VT, EC.Min);
+ return getVectorVT(VT, EC.Min);
+ }
+
/// Return the value type corresponding to the specified type. This returns
/// all pointers as iPTR. If HandleUnknown is true, unknown types are
/// returned as Other, otherwise they are invalid.
@@ -887,6 +1013,14 @@ class MVT {
MVT::FIRST_FP_VECTOR_VALUETYPE,
(MVT::SimpleValueType)(MVT::LAST_FP_VECTOR_VALUETYPE + 1));
}
+ static mvt_range integer_scalable_vector_valuetypes() {
+ return mvt_range(MVT::FIRST_INTEGER_SCALABLE_VALUETYPE,
+ (MVT::SimpleValueType)(MVT::LAST_INTEGER_SCALABLE_VALUETYPE + 1));
+ }
+ static mvt_range fp_scalable_vector_valuetypes() {
+ return mvt_range(MVT::FIRST_FP_SCALABLE_VALUETYPE,
+ (MVT::SimpleValueType)(MVT::LAST_FP_SCALABLE_VALUETYPE + 1));
+ }
/// @}
};
diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h
index b204f731288..b404b4ca701 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.h
+++ b/llvm/include/llvm/CodeGen/ValueTypes.h
@@ -67,24 +67,41 @@ namespace llvm {
/// Returns the EVT that represents a vector NumElements in length, where
/// each element is of type VT.
- static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements) {
- MVT M = MVT::getVectorVT(VT.V, NumElements);
+ static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements,
+ bool IsScalable = false) {
+ MVT M = MVT::getVectorVT(VT.V, NumElements, IsScalable);
if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE)
return M;
+
+ assert(!IsScalable && "We don't support extended scalable types yet");
return getExtendedVectorVT(Context, VT, NumElements);
}
+ /// Returns the EVT that represents a vector EC.Min elements in length,
+ /// where each element is of type VT.
+ static EVT getVectorVT(LLVMContext &Context, EVT VT, MVT::ElementCount EC) {
+ MVT M = MVT::getVectorVT(VT.V, EC);
+ if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE)
+ return M;
+ assert (!EC.Scalable && "We don't support extended scalable types yet");
+ return getExtendedVectorVT(Context, VT, EC.Min);
+ }
+
/// Return a vector with the same number of elements as this vector, but
/// with the element type converted to an integer type with the same
/// bitwidth.
EVT changeVectorElementTypeToInteger() const {
- if (!isSimple())
+ if (!isSimple()) {
+ assert (!isScalableVector() &&
+ "We don't support extended scalable types yet");
return changeExtendedVectorElementTypeToInteger();
+ }
MVT EltTy = getSimpleVT().getVectorElementType();
unsigned BitWidth = EltTy.getSizeInBits();
MVT IntTy = MVT::getIntegerVT(BitWidth);
- MVT VecTy = MVT::getVectorVT(IntTy, getVectorNumElements());
- assert(VecTy.SimpleTy >= 0 &&
+ MVT VecTy = MVT::getVectorVT(IntTy, getVectorNumElements(),
+ isScalableVector());
+ assert(VecTy.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE &&
"Simple vector VT not representable by simple integer vector VT!");
return VecTy;
}
@@ -132,6 +149,17 @@ namespace llvm {
return isSimple() ? V.isVector() : isExtendedVector();
}
+ /// Return true if this is a vector type where the runtime
+ /// length is machine dependent
+ bool isScalableVector() const {
+ // FIXME: We don't support extended scalable types yet, because the
+ // matching IR type doesn't exist. Once it has been added, this can
+ // be changed to call isExtendedScalableVector.
+ if (!isSimple())
+ return false;
+ return V.isScalableVector();
+ }
+
/// Return true if this is a 16-bit vector type.
bool is16BitVector() const {
return isSimple() ? V.is16BitVector() : isExtended16BitVector();
@@ -247,6 +275,17 @@ namespace llvm {
return getExtendedVectorNumElements();
}
+ // Given a (possibly scalable) vector type, return the ElementCount
+ MVT::ElementCount getVectorElementCount() const {
+ assert((isVector()) && "Invalid vector type!");
+ if (isSimple())
+ return V.getVectorElementCount();
+
+ assert(!isScalableVector() &&
+ "We don't support extended scalable types yet");
+ return {getExtendedVectorNumElements(), false};
+ }
+
/// Return the size of the specified value type in bits.
unsigned getSizeInBits() const {
if (isSimple())
@@ -301,7 +340,7 @@ namespace llvm {
EVT widenIntegerVectorElementType(LLVMContext &Context) const {
EVT EltVT = getVectorElementType();
EltVT = EVT::getIntegerVT(Context, 2 * EltVT.getSizeInBits());
- return EVT::getVectorVT(Context, EltVT, getVectorNumElements());
+ return EVT::getVectorVT(Context, EltVT, getVectorElementCount());
}
// Return a VT for a vector type with the same element type but
@@ -309,9 +348,8 @@ namespace llvm {
// extended type.
EVT getHalfNumVectorElementsVT(LLVMContext &Context) const {
EVT EltVT = getVectorElementType();
- auto EltCnt = getVectorNumElements();
- assert(!(getVectorNumElements() & 1) &&
- "Splitting vector, but not in half!");
+ auto EltCnt = getVectorElementCount();
+ assert(!(EltCnt.Min & 1) && "Splitting vector, but not in half!");
return EVT::getVectorVT(Context, EltVT, EltCnt / 2);
}
@@ -327,7 +365,8 @@ namespace llvm {
if (!isPow2VectorType()) {
unsigned NElts = getVectorNumElements();
unsigned Pow2NElts = 1 << Log2_32_Ceil(NElts);
- return EVT::getVectorVT(Context, getVectorElementType(), Pow2NElts);
+ return EVT::getVectorVT(Context, getVectorElementType(), Pow2NElts,
+ isScalableVector());
}
else {
return *this;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 0a2b680e1c6..154af46c944 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -925,9 +925,9 @@ SDValue DAGTypeLegalizer::BitConvertVectorToIntegerVector(SDValue Op) {
assert(Op.getValueType().isVector() && "Only applies to vectors!");
unsigned EltWidth = Op.getScalarValueSizeInBits();
EVT EltNVT = EVT::getIntegerVT(*DAG.getContext(), EltWidth);
- unsigned NumElts = Op.getValueType().getVectorNumElements();
+ auto EltCnt = Op.getValueType().getVectorElementCount();
return DAG.getNode(ISD::BITCAST, SDLoc(Op),
- EVT::getVectorVT(*DAG.getContext(), EltNVT, NumElts), Op);
+ EVT::getVectorVT(*DAG.getContext(), EltNVT, EltCnt), Op);
}
SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt
index 240734dc6b1..e944f6c9e3b 100644
--- a/llvm/unittests/CodeGen/CMakeLists.txt
+++ b/llvm/unittests/CodeGen/CMakeLists.txt
@@ -9,6 +9,7 @@ set(CodeGenSources
DIEHashTest.cpp
LowLevelTypeTest.cpp
MachineInstrBundleIteratorTest.cpp
+ ScalableVectorMVTsTest.cpp
)
add_llvm_unittest(CodeGenTests
diff --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
new file mode 100644
index 00000000000..a22c87200ba
--- /dev/null
+++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
@@ -0,0 +1,88 @@
+//===-------- llvm/unittest/CodeGen/ScalableVectorMVTsTest.cpp ------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/MachineValueType.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/LLVMContext.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+TEST(ScalableVectorMVTsTest, IntegerMVTs) {
+ for (auto VecTy : MVT::integer_scalable_vector_valuetypes()) {
+ ASSERT_TRUE(VecTy.isValid());
+ ASSERT_TRUE(VecTy.isInteger());
+ ASSERT_TRUE(VecTy.isVector());
+ ASSERT_TRUE(VecTy.isScalableVector());
+ ASSERT_TRUE(VecTy.getScalarType().isValid());
+
+ ASSERT_FALSE(VecTy.isFloatingPoint());
+ }
+}
+
+TEST(ScalableVectorMVTsTest, FloatMVTs) {
+ for (auto VecTy : MVT::fp_scalable_vector_valuetypes()) {
+ ASSERT_TRUE(VecTy.isValid());
+ ASSERT_TRUE(VecTy.isFloatingPoint());
+ ASSERT_TRUE(VecTy.isVector());
+ ASSERT_TRUE(VecTy.isScalableVector());
+ ASSERT_TRUE(VecTy.getScalarType().isValid());
+
+ ASSERT_FALSE(VecTy.isInteger());
+ }
+}
+
+TEST(ScalableVectorMVTsTest, HelperFuncs) {
+ LLVMContext Ctx;
+
+ // Create with scalable flag
+ EVT Vnx4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4, /*Scalable=*/true);
+ ASSERT_TRUE(Vnx4i32.isScalableVector());
+
+ // Create with separate MVT::ElementCount
+ auto EltCnt = MVT::ElementCount(2, true);
+ EVT Vnx2i32 = EVT::getVectorVT(Ctx, MVT::i32, EltCnt);
+ ASSERT_TRUE(Vnx2i32.isScalableVector());
+
+ // Create with inline MVT::ElementCount
+ EVT Vnx2i64 = EVT::getVectorVT(Ctx, MVT::i64, {2, true});
+ ASSERT_TRUE(Vnx2i64.isScalableVector());
+
+ // Check that changing scalar types/element count works
+ EXPECT_EQ(Vnx2i32.widenIntegerVectorElementType(Ctx), Vnx2i64);
+ EXPECT_EQ(Vnx4i32.getHalfNumVectorElementsVT(Ctx), Vnx2i32);
+
+ // Check that overloaded '*' and '/' operators work
+ EXPECT_EQ(EVT::getVectorVT(Ctx, MVT::i64, EltCnt * 2), MVT::nxv4i64);
+ EXPECT_EQ(EVT::getVectorVT(Ctx, MVT::i64, EltCnt / 2), MVT::nxv1i64);
+
+ // Check that float->int conversion works
+ EVT Vnx2f64 = EVT::getVectorVT(Ctx, MVT::f64, {2, true});
+ EXPECT_EQ(Vnx2f64.changeTypeToInteger(), Vnx2i64);
+
+ // Check fields inside MVT::ElementCount
+ EltCnt = Vnx4i32.getVectorElementCount();
+ EXPECT_EQ(EltCnt.Min, 4);
+ ASSERT_TRUE(EltCnt.Scalable);
+
+ // Check that fixed-length vector types aren't scalable.
+ EVT V8i32 = EVT::getVectorVT(Ctx, MVT::i32, 8);
+ ASSERT_FALSE(V8i32.isScalableVector());
+ EVT V4f64 = EVT::getVectorVT(Ctx, MVT::f64, {4, false});
+ ASSERT_FALSE(V4f64.isScalableVector());
+
+ // Check that MVT::ElementCount works for fixed-length types.
+ EltCnt = V8i32.getVectorElementCount();
+ EXPECT_EQ(EltCnt.Min, 8);
+ ASSERT_FALSE(EltCnt.Scalable);
+}
+
+}
OpenPOWER on IntegriCloud