summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/TargetTransformInfo.cpp8
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp55
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h3
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp7
4 files changed, 70 insertions, 3 deletions
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index c8f116a1d33..c5793add30b 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -267,6 +267,14 @@ int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst,
return Cost;
}
+int TargetTransformInfo::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
+ VectorType *VecTy,
+ unsigned Index) const {
+ int Cost = TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const {
int Cost = TTIImpl->getCFInstrCost(Opcode);
assert(Cost >= 0 && "TTI should not produce negative costs!");
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 87f96f80040..8e832ffb6a3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -291,6 +291,61 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) {
return BaseT::getCastInstrCost(Opcode, Dst, Src);
}
+int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
+ VectorType *VecTy,
+ unsigned Index) {
+
+ // Make sure we were given a valid extend opcode.
+ assert(Opcode == Instruction::SExt ||
+ Opcode == Instruction::ZExt && "Invalid opcode");
+
+ // We are extending an element we extract from a vector, so the source type
+ // of the extend is the element type of the vector.
+ auto *Src = VecTy->getElementType();
+
+ // Sign- and zero-extends are for integer types only.
+ assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
+
+ // Get the cost for the extract. We compute the cost (if any) for the extend
+ // below.
+ auto Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
+
+ // Legalize the types.
+ auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy);
+ auto DstVT = TLI->getValueType(DL, Dst);
+ auto SrcVT = TLI->getValueType(DL, Src);
+
+ // If the resulting type is still a vector and the destination type is legal,
+ // we may get the extension for free. If not, get the default cost for the
+ // extend.
+ if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
+ return Cost + getCastInstrCost(Opcode, Dst, Src);
+
+ // The destination type should be larger than the element type. If not, get
+ // the default cost for the extend.
+ if (DstVT.getSizeInBits() < SrcVT.getSizeInBits())
+ return Cost + getCastInstrCost(Opcode, Dst, Src);
+
+ switch (Opcode) {
+ default:
+ llvm_unreachable("Opcode should be either SExt or ZExt");
+
+ // For sign-extends, we only need a smov, which performs the extension
+ // automatically.
+ case Instruction::SExt:
+ return Cost;
+
+ // For zero-extends, the extend is performed automatically by a umov unless
+ // the destination type is i64 and the element type is i8 or i16.
+ case Instruction::ZExt:
+ if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
+ return Cost;
+ }
+
+ // If we are unable to perform the extend for free, get the default cost.
+ return Cost + getCastInstrCost(Opcode, Dst, Src);
+}
+
int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
unsigned Index) {
assert(Val->isVectorTy() && "This must be a vector type");
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 93a84b7a992..4f2e8310d76 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -99,6 +99,9 @@ public:
int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src);
+ int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy,
+ unsigned Index);
+
int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index);
int getArithmeticInstrCost(
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f1bfbc2ba84..b92a97556aa 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1830,11 +1830,12 @@ int BoUpSLP::getTreeCost() {
if (MinBWs.count(ScalarRoot)) {
auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]);
VecTy = VectorType::get(MinTy, BundleWidth);
+ ExtractCost += TTI->getExtractWithExtendCost(
+ Instruction::SExt, EU.Scalar->getType(), VecTy, EU.Lane);
+ } else {
ExtractCost +=
- TTI->getCastInstrCost(Instruction::SExt, EU.Scalar->getType(), MinTy);
+ TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane);
}
- ExtractCost +=
- TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane);
}
int SpillCost = getSpillCost();
OpenPOWER on IntegriCloud