diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 87f96f80040..8e832ffb6a3 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -291,6 +291,61 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { return BaseT::getCastInstrCost(Opcode, Dst, Src); } +int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, + unsigned Index) { + + // Make sure we were given a valid extend opcode. + assert(Opcode == Instruction::SExt || + Opcode == Instruction::ZExt && "Invalid opcode"); + + // We are extending an element we extract from a vector, so the source type + // of the extend is the element type of the vector. + auto *Src = VecTy->getElementType(); + + // Sign- and zero-extends are for integer types only. + assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type"); + + // Get the cost for the extract. We compute the cost (if any) for the extend + // below. + auto Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, Index); + + // Legalize the types. + auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy); + auto DstVT = TLI->getValueType(DL, Dst); + auto SrcVT = TLI->getValueType(DL, Src); + + // If the resulting type is still a vector and the destination type is legal, + // we may get the extension for free. If not, get the default cost for the + // extend. + if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) + return Cost + getCastInstrCost(Opcode, Dst, Src); + + // The destination type should be larger than the element type. If not, get + // the default cost for the extend. + if (DstVT.getSizeInBits() < SrcVT.getSizeInBits()) + return Cost + getCastInstrCost(Opcode, Dst, Src); + + switch (Opcode) { + default: + llvm_unreachable("Opcode should be either SExt or ZExt"); + + // For sign-extends, we only need a smov, which performs the extension + // automatically. + case Instruction::SExt: + return Cost; + + // For zero-extends, the extend is performed automatically by a umov unless + // the destination type is i64 and the element type is i8 or i16. + case Instruction::ZExt: + if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u) + return Cost; + } + + // If we are unable to perform the extend for free, get the default cost. + return Cost + getCastInstrCost(Opcode, Dst, Src); +} + int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) { assert(Val->isVectorTy() && "This must be a vector type"); |