diff options
| author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2019-04-12 14:22:57 +0000 |
|---|---|---|
| committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2019-04-12 14:22:57 +0000 |
| commit | 6c8f4ada360d8d289dbba7c80eb8a1fae991d7d0 (patch) | |
| tree | 17c76318831db3bc81e2e00602fff698336ea0ca /llvm/lib | |
| parent | 1e39fc1faa5d8e27877eff45b2b9839c94e5d12e (diff) | |
| download | bcm5719-llvm-6c8f4ada360d8d289dbba7c80eb8a1fae991d7d0.tar.gz bcm5719-llvm-6c8f4ada360d8d289dbba7c80eb8a1fae991d7d0.zip | |
[X86][SSE] Recognise vXi1 boolean anyof/allof reduction patterns
Currently combineHorizontalPredicateResult only handles anyof/allof reduction patterns of legal types, which can be tricky to match as type legalization of bools can introduce bitcasts/truncs/extensions.
This patch extends combineHorizontalPredicateResult to recognise vXi1 bool reductions as well and uses the existing combineBitcastvxi1 helper to create the MOVMSK necessary to then compare the signmask result.
This ensures the accuracy of the reduction costs added in D60403 which assume the MOVMSK generation.
Differential Revision: https://reviews.llvm.org/D60610
llvm-svn: 358286
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 89 |
1 files changed, 56 insertions, 33 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c967329d0dd..c63ea39292f 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -34261,7 +34261,7 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, EVT ExtractVT = Extract->getValueType(0); unsigned BitWidth = ExtractVT.getSizeInBits(); if (ExtractVT != MVT::i64 && ExtractVT != MVT::i32 && ExtractVT != MVT::i16 && - ExtractVT != MVT::i8) + ExtractVT != MVT::i8 && ExtractVT != MVT::i1) return SDValue(); // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. @@ -34275,36 +34275,63 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, if (Match.getScalarValueSizeInBits() != BitWidth) return SDValue(); - unsigned MatchSizeInBits = Match.getValueSizeInBits(); - if (!(MatchSizeInBits == 128 || (MatchSizeInBits == 256 && Subtarget.hasAVX()))) - return SDValue(); + SDValue Movmsk; + SDLoc DL(Extract); + unsigned NumElts = Match.getValueType().getVectorNumElements(); - // Make sure this isn't a vector of 1 element. The perf win from using MOVMSK - // diminishes with less elements in the reduction, but it is generally better - // to get the comparison over to the GPRs as soon as possible to reduce the - // number of vector ops. - if (Match.getValueType().getVectorNumElements() < 2) - return SDValue(); + if (ExtractVT == MVT::i1) { + // Special case for (pre-legalization) vXi1 reductions. + // Use combineBitcastvxi1 to create the MOVMSK. + if (NumElts > 32) + return SDValue(); + if (NumElts == 32 && !Subtarget.hasInt256()) { + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVector(Match, DL); + Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi); + NumElts = 16; + } + EVT MovmskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts); + Movmsk = combineBitcastvxi1(DAG, MovmskVT, Match, DL, Subtarget); + if (!Movmsk) + return SDValue(); + Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, MVT::i32); + } else { + unsigned MatchSizeInBits = Match.getValueSizeInBits(); + if (!(MatchSizeInBits == 128 || + (MatchSizeInBits == 256 && Subtarget.hasAVX()))) + return SDValue(); - // Check that we are extracting a reduction of all sign bits. - if (DAG.ComputeNumSignBits(Match) != BitWidth) - return SDValue(); + // Make sure this isn't a vector of 1 element. The perf win from using + // MOVMSK diminishes with less elements in the reduction, but it is + // generally better to get the comparison over to the GPRs as soon as + // possible to reduce the number of vector ops. + if (Match.getValueType().getVectorNumElements() < 2) + return SDValue(); - SDLoc DL(Extract); - if (MatchSizeInBits == 256 && BitWidth < 32 && !Subtarget.hasInt256()) { - SDValue Lo, Hi; - std::tie(Lo, Hi) = DAG.SplitVector(Match, DL); - Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi); - MatchSizeInBits = Match.getValueSizeInBits(); - } + // Check that we are extracting a reduction of all sign bits. + if (DAG.ComputeNumSignBits(Match) != BitWidth) + return SDValue(); - // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. - MVT MaskSrcVT; - if (64 == BitWidth || 32 == BitWidth) - MaskSrcVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), - MatchSizeInBits / BitWidth); - else - MaskSrcVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); + if (MatchSizeInBits == 256 && BitWidth < 32 && !Subtarget.hasInt256()) { + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVector(Match, DL); + Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi); + MatchSizeInBits = Match.getValueSizeInBits(); + } + + // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. + MVT MaskSrcVT; + if (64 == BitWidth || 32 == BitWidth) + MaskSrcVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), + MatchSizeInBits / BitWidth); + else + MaskSrcVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); + + SDValue BitcastLogicOp = DAG.getBitcast(MaskSrcVT, Match); + Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget); + NumElts = MaskSrcVT.getVectorNumElements(); + } + assert(NumElts <= 32 && "Not expecting more than 32 elements"); SDValue CmpC; ISD::CondCode CondCode; @@ -34314,8 +34341,6 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, CondCode = ISD::CondCode::SETNE; } else { // all_of -> MOVMSK == ((1 << NumElts) - 1) - uint64_t NumElts = MaskSrcVT.getVectorNumElements(); - assert(NumElts <= 32 && "Not expecting more than 32 elements"); CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, MVT::i32); CondCode = ISD::CondCode::SETEQ; } @@ -34323,10 +34348,8 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, // The setcc produces an i8 of 0/1, so extend that to the result width and // negate to get the final 0/-1 mask value. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - SDValue BitcastLogicOp = DAG.getBitcast(MaskSrcVT, Match); - SDValue Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget); - EVT SetccVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), - MVT::i32); + EVT SetccVT = + TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i32); SDValue Setcc = DAG.getSetCC(DL, SetccVT, Movmsk, CmpC, CondCode); SDValue Zext = DAG.getZExtOrTrunc(Setcc, DL, ExtractVT); SDValue Zero = DAG.getConstant(0, DL, ExtractVT); |

