diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/TargetTransformInfo.cpp | 84 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 27 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 2 |
3 files changed, 103 insertions, 10 deletions
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 9843eed4c5a..6ac4fbe2dc2 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -683,6 +683,66 @@ static bool isAlternateVectorMask(ArrayRef<int> Mask) { return isAlternate; } +static bool isTransposeVectorMask(ArrayRef<int> Mask) { + // Transpose vector masks transpose a 2xn matrix. They read corresponding + // even- or odd-numbered vector elements from two n-dimensional source + // vectors and write each result into consecutive elements of an + // n-dimensional destination vector. Two shuffles are necessary to complete + // the transpose, one for the even elements and another for the odd elements. + // This description closely follows how the TRN1 and TRN2 AArch64 + // instructions operate. + // + // For example, a simple 2x2 matrix can be transposed with: + // + // ; Original matrix + // m0 = <a, b> + // m1 = <c, d> + // + // ; Transposed matrix + // t0 = <a, c> = shufflevector m0, m1, <0, 2> + // t1 = <b, d> = shufflevector m0, m1, <1, 3> + // + // For matrices having greater than n columns, the resulting nx2 transposed + // matrix is stored in two result vectors such that one vector contains + // interleaved elements from all the even-numbered rows and the other vector + // contains interleaved elements from all the odd-numbered rows. For example, + // a 2x4 matrix can be transposed with: + // + // ; Original matrix + // m0 = <a, b, c, d> + // m1 = <e, f, g, h> + // + // ; Transposed matrix + // t0 = <a, e, c, g> = shufflevector m0, m1 <0, 4, 2, 6> + // t1 = <b, f, d, h> = shufflevector m0, m1 <1, 5, 3, 7> + // + // The above explanation places limitations on what valid transpose masks can + // look like. These limitations are defined by the checks below. + // + // 1. The number of elements in the mask must be a power of two. + if (!isPowerOf2_32(Mask.size())) + return false; + + // 2. The first element of the mask must be either a zero (for the + // even-numbered vector elements) or a one (for the odd-numbered vector + // elements). + if (Mask[0] != 0 && Mask[0] != 1) + return false; + + // 3. The difference between the first two elements must be equal to the + // number of elements in the mask. + if (Mask[1] - Mask[0] != (int)Mask.size()) + return false; + + // 4. The difference between consecutive even-numbered and odd-numbered + // elements must be equal to two. + for (int I = 2; I < (int)Mask.size(); ++I) + if (Mask[I] - Mask[I - 2] != 2) + return false; + + return true; +} + static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) { TargetTransformInfo::OperandValueKind OpInfo = TargetTransformInfo::OK_AnyValue; @@ -1139,22 +1199,26 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { if (NumVecElems == Mask.size()) { if (isReverseVectorMask(Mask)) - return getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0, - 0, nullptr); + return TTIImpl->getShuffleCost(TargetTransformInfo::SK_Reverse, + VecTypOp0, 0, nullptr); if (isAlternateVectorMask(Mask)) - return getShuffleCost(TargetTransformInfo::SK_Alternate, - VecTypOp0, 0, nullptr); + return TTIImpl->getShuffleCost(TargetTransformInfo::SK_Alternate, + VecTypOp0, 0, nullptr); + + if (isTransposeVectorMask(Mask)) + return TTIImpl->getShuffleCost(TargetTransformInfo::SK_Transpose, + VecTypOp0, 0, nullptr); if (isZeroEltBroadcastVectorMask(Mask)) - return getShuffleCost(TargetTransformInfo::SK_Broadcast, - VecTypOp0, 0, nullptr); + return TTIImpl->getShuffleCost(TargetTransformInfo::SK_Broadcast, + VecTypOp0, 0, nullptr); if (isSingleSourceVectorMask(Mask)) - return getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, - VecTypOp0, 0, nullptr); + return TTIImpl->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, + VecTypOp0, 0, nullptr); - return getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, - VecTypOp0, 0, nullptr); + return TTIImpl->getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, + VecTypOp0, 0, nullptr); } return -1; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index a626323635c..337db546658 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -912,3 +912,30 @@ int AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy, return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm); } + +int AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, + Type *SubTp) { + + // Transpose shuffle kinds can be performed with 'trn1/trn2' and 'zip1/zip2' + // instructions. + if (Kind == TTI::SK_Transpose) { + static const CostTblEntry TransposeTbl[] = { + {ISD::VECTOR_SHUFFLE, MVT::v8i8, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v16i8, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v4i16, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v8i16, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v2i32, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v4i32, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v2i64, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v2f32, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v4f32, 1}, + {ISD::VECTOR_SHUFFLE, MVT::v2f64, 1}, + }; + std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp); + if (const auto *Entry = + CostTableLookup(TransposeTbl, ISD::VECTOR_SHUFFLE, LT.second)) + return LT.first * Entry->Cost; + } + + return BaseT::getShuffleCost(Kind, Tp, Index, SubTp); +} diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index e71eb515cc2..c056a7d2428 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -169,6 +169,8 @@ public: int getArithmeticReductionCost(unsigned Opcode, Type *Ty, bool IsPairwiseForm); + + int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp); /// @} }; |