diff options
Diffstat (limited to 'clang/lib/CodeGen/TargetInfo.cpp')
| -rw-r--r-- | clang/lib/CodeGen/TargetInfo.cpp | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/clang/lib/CodeGen/TargetInfo.cpp b/clang/lib/CodeGen/TargetInfo.cpp index 4ef57786f44..707425ea813 100644 --- a/clang/lib/CodeGen/TargetInfo.cpp +++ b/clang/lib/CodeGen/TargetInfo.cpp @@ -5607,6 +5607,7 @@ private: uint64_t Members) const; ABIArgInfo coerceIllegalVector(QualType Ty) const; bool isIllegalVectorType(QualType Ty) const; + bool containsAnyFP16Vectors(QualType Ty) const; bool isHomogeneousAggregateBaseType(QualType Ty) const override; bool isHomogeneousAggregateSmallEnough(const Type *Ty, @@ -5806,9 +5807,7 @@ ABIArgInfo ARMABIInfo::classifyHomogeneousAggregate(QualType Ty, // Base can be a floating-point or a vector. if (const VectorType *VT = Base->getAs<VectorType>()) { // FP16 vectors should be converted to integer vectors - if (!getTarget().hasLegalHalfType() && - (VT->getElementType()->isFloat16Type() || - VT->getElementType()->isHalfType())) { + if (!getTarget().hasLegalHalfType() && containsAnyFP16Vectors(Ty)) { uint64_t Size = getContext().getTypeSize(VT); llvm::Type *NewVecTy = llvm::VectorType::get( llvm::Type::getInt32Ty(getVMContext()), Size / 32); @@ -6169,6 +6168,37 @@ bool ARMABIInfo::isIllegalVectorType(QualType Ty) const { return false; } +/// Return true if a type contains any 16-bit floating point vectors +bool ARMABIInfo::containsAnyFP16Vectors(QualType Ty) const { + if (const ConstantArrayType *AT = getContext().getAsConstantArrayType(Ty)) { + uint64_t NElements = AT->getSize().getZExtValue(); + if (NElements == 0) + return false; + return containsAnyFP16Vectors(AT->getElementType()); + } else if (const RecordType *RT = Ty->getAs<RecordType>()) { + const RecordDecl *RD = RT->getDecl(); + + // If this is a C++ record, check the bases first. + if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) + if (llvm::any_of(CXXRD->bases(), [this](const CXXBaseSpecifier &B) { + return containsAnyFP16Vectors(B.getType()); + })) + return true; + + if (llvm::any_of(RD->fields(), [this](FieldDecl *FD) { + return FD && containsAnyFP16Vectors(FD->getType()); + })) + return true; + + return false; + } else { + if (const VectorType *VT = Ty->getAs<VectorType>()) + return (VT->getElementType()->isFloat16Type() || + VT->getElementType()->isHalfType()); + return false; + } +} + bool ARMABIInfo::isLegalVectorTypeForSwift(CharUnits vectorSize, llvm::Type *eltTy, unsigned numElts) const { |

