diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64Subtarget.h | 3 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 106 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 3 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 10 |
4 files changed, 112 insertions, 10 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h index 5b9bee6e41b..2c66221886c 100644 --- a/llvm/lib/Target/AArch64/AArch64Subtarget.h +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -106,6 +106,7 @@ protected: unsigned PrefFunctionAlignment = 0; unsigned PrefLoopAlignment = 0; unsigned MaxJumpTableSize = 0; + unsigned WideningBaseCost = 0; // ReserveX18 - X18 is not available as a general purpose register. bool ReserveX18; @@ -228,6 +229,8 @@ public: unsigned getMaximumJumpTableSize() const { return MaxJumpTableSize; } + unsigned getWideningBaseCost() const { return WideningBaseCost; } + /// CPU has TBI (top byte of addresses is ignored during HW address /// translation) and OS enables it. bool supportsAddressTopByteIgnored() const; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 4d59da0c646..7c6f55c06bc 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -176,11 +176,95 @@ AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) { return TTI::PSK_Software; } +bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, + ArrayRef<const Value *> Args) { + + // A helper that returns a vector type from the given type. The number of + // elements in type Ty determine the vector width. + auto toVectorTy = [&](Type *ArgTy) { + return VectorType::get(ArgTy->getScalarType(), + DstTy->getVectorNumElements()); + }; + + // Exit early if DstTy is not a vector type whose elements are at least + // 16-bits wide. + if (!DstTy->isVectorTy() || DstTy->getScalarSizeInBits() < 16) + return false; + + // Determine if the operation has a widening variant. We consider both the + // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the + // instructions. + // + // TODO: Add additional widening operations (e.g., mul, shl, etc.) once we + // verify that their extending operands are eliminated during code + // generation. + switch (Opcode) { + case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2). + case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2). + break; + default: + return false; + } + + // To be a widening instruction (either the "wide" or "long" versions), the + // second operand must be a sign- or zero extend having a single user. We + // only consider extends having a single user because they may otherwise not + // be eliminated. + if (Args.size() != 2 || + (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])) || + !Args[1]->hasOneUse()) + return false; + auto *Extend = cast<CastInst>(Args[1]); + + // Legalize the destination type and ensure it can be used in a widening + // operation. + auto DstTyL = TLI->getTypeLegalizationCost(DL, DstTy); + unsigned DstElTySize = DstTyL.second.getScalarSizeInBits(); + if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits()) + return false; + + // Legalize the source type and ensure it can be used in a widening + // operation. + Type *SrcTy = toVectorTy(Extend->getSrcTy()); + auto SrcTyL = TLI->getTypeLegalizationCost(DL, SrcTy); + unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits(); + if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits()) + return false; + + // Get the total number of vector elements in the legalized types. + unsigned NumDstEls = DstTyL.first * DstTyL.second.getVectorNumElements(); + unsigned NumSrcEls = SrcTyL.first * SrcTyL.second.getVectorNumElements(); + + // Return true if the legalized types have the same number of vector elements + // and the destination element type size is twice that of the source type. + return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize; +} + int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, const Instruction *I) { int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + // If the cast is observable, and it is used by a widening instruction (e.g., + // uaddl, saddw, etc.), it may be free. + if (I && I->hasOneUse()) { + auto *SingleUser = cast<Instruction>(*I->user_begin()); + SmallVector<const Value *, 4> Operands(SingleUser->operand_values()); + if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) { + // If the cast is the second operand, it is free. We will generate either + // a "wide" or "long" version of the widening instruction. + if (I == SingleUser->getOperand(1)) + return 0; + // If the cast is not the second operand, it will be free if it looks the + // same as the second operand. In this case, we will generate a "long" + // version of the widening instruction. + if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1))) + if (I->getOpcode() == Cast->getOpcode() && + cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy()) + return 0; + } + } + EVT SrcTy = TLI->getValueType(DL, Src); EVT DstTy = TLI->getValueType(DL, Dst); @@ -379,6 +463,16 @@ int AArch64TTIImpl::getArithmeticInstrCost( // Legalize the type. std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty); + // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.), + // add in the widening overhead specified by the sub-target. Since the + // extends feeding widening instructions are performed automatically, they + // aren't present in the generated code and have a zero cost. By adding a + // widening overhead here, we attach the total cost of the combined operation + // to the widening instruction. + int Cost = 0; + if (isWideningInstruction(Ty, Opcode, Args)) + Cost += ST->getWideningBaseCost(); + int ISD = TLI->InstructionOpcodeToISD(Opcode); if (ISD == ISD::SDIV && @@ -388,9 +482,9 @@ int AArch64TTIImpl::getArithmeticInstrCost( // normally expanded to the sequence ADD + CMP + SELECT + SRA. // The OperandValue properties many not be same as that of previous // operation; conservatively assume OP_None. - int Cost = getArithmeticInstrCost(Instruction::Add, Ty, Opd1Info, Opd2Info, - TargetTransformInfo::OP_None, - TargetTransformInfo::OP_None); + Cost += getArithmeticInstrCost(Instruction::Add, Ty, Opd1Info, Opd2Info, + TargetTransformInfo::OP_None, + TargetTransformInfo::OP_None); Cost += getArithmeticInstrCost(Instruction::Sub, Ty, Opd1Info, Opd2Info, TargetTransformInfo::OP_None, TargetTransformInfo::OP_None); @@ -405,8 +499,8 @@ int AArch64TTIImpl::getArithmeticInstrCost( switch (ISD) { default: - return BaseT::getArithmeticInstrCost(Opcode, Ty, Opd1Info, Opd2Info, - Opd1PropInfo, Opd2PropInfo); + return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, Opd1Info, Opd2Info, + Opd1PropInfo, Opd2PropInfo); case ISD::ADD: case ISD::MUL: case ISD::XOR: @@ -414,7 +508,7 @@ int AArch64TTIImpl::getArithmeticInstrCost( case ISD::AND: // These nodes are marked as 'custom' for combining purposes only. // We know that they are legal. See LowerAdd in ISelLowering. - return 1 * LT.first; + return (Cost + 1) * LT.first; } } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index e37c003e064..a8cf2ccf106 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -43,6 +43,9 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> { VECTOR_LDST_FOUR_ELEMENTS }; + bool isWideningInstruction(Type *Ty, unsigned Opcode, + ArrayRef<const Value *> Args); + public: explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F) : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl(F)), diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 99084444bdd..960019f41df 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1819,11 +1819,13 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { CInt->getValue().isPowerOf2()) Op2VP = TargetTransformInfo::OP_PowerOf2; - int ScalarCost = VecTy->getNumElements() * - TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, - Op2VK, Op1VP, Op2VP); + SmallVector<const Value *, 4> Operands(VL0->operand_values()); + int ScalarCost = + VecTy->getNumElements() * + TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, Op2VK, Op1VP, + Op2VP, Operands); int VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy, Op1VK, Op2VK, - Op1VP, Op2VP); + Op1VP, Op2VP, Operands); return VecCost - ScalarCost; } case Instruction::GetElementPtr: { |

