diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 552c91703a7..f5c9971a3c5 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35859,11 +35859,12 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, SDLoc DL(Extract); EVT MatchVT = Match.getValueType(); unsigned NumElts = MatchVT.getVectorNumElements(); + unsigned MaxElts = Subtarget.hasInt256() ? 32 : 16; const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (ExtractVT == MVT::i1) { // Special case for (pre-legalization) vXi1 reductions. - if (NumElts > 32) + if (NumElts > 64 || !isPowerOf2_32(NumElts)) return SDValue(); if (TLI.isTypeLegal(MatchVT)) { // If this is a legal AVX512 predicate type then we can just bitcast. @@ -35871,18 +35872,18 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, Movmsk = DAG.getBitcast(MovmskVT, Match); } else { // Use combineBitcastvxi1 to create the MOVMSK. - if (NumElts == 32 && !Subtarget.hasInt256()) { + while (NumElts > MaxElts) { SDValue Lo, Hi; std::tie(Lo, Hi) = DAG.SplitVector(Match, DL); Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi); - NumElts = 16; + NumElts /= 2; } EVT MovmskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts); Movmsk = combineBitcastvxi1(DAG, MovmskVT, Match, DL, Subtarget); } if (!Movmsk) return SDValue(); - Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, MVT::i32); + Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, NumElts > 32 ? MVT::i64 : MVT::i32); } else { // Bail with AVX512VL (which uses predicate registers). if (Subtarget.hasVLX()) @@ -35923,13 +35924,15 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget); NumElts = MaskSrcVT.getVectorNumElements(); } - assert(NumElts <= 32 && "Not expecting more than 32 elements"); + assert((NumElts <= 32 || NumElts == 64) && + "Not expecting more than 64 elements"); + MVT CmpVT = NumElts == 64 ? MVT::i64 : MVT::i32; if (BinOp == ISD::XOR) { // parity -> (AND (CTPOP(MOVMSK X)), 1) - SDValue Mask = DAG.getConstant(1, DL, MVT::i32); - SDValue Result = DAG.getNode(ISD::CTPOP, DL, MVT::i32, Movmsk); - Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result, Mask); + SDValue Mask = DAG.getConstant(1, DL, CmpVT); + SDValue Result = DAG.getNode(ISD::CTPOP, DL, CmpVT, Movmsk); + Result = DAG.getNode(ISD::AND, DL, CmpVT, Result, Mask); return DAG.getZExtOrTrunc(Result, DL, ExtractVT); } @@ -35937,18 +35940,18 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, ISD::CondCode CondCode; if (BinOp == ISD::OR) { // any_of -> MOVMSK != 0 - CmpC = DAG.getConstant(0, DL, MVT::i32); + CmpC = DAG.getConstant(0, DL, CmpVT); CondCode = ISD::CondCode::SETNE; } else { // all_of -> MOVMSK == ((1 << NumElts) - 1) - CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, MVT::i32); + CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, CmpVT); CondCode = ISD::CondCode::SETEQ; } // The setcc produces an i8 of 0/1, so extend that to the result width and // negate to get the final 0/-1 mask value. EVT SetccVT = - TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i32); + TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), CmpVT); SDValue Setcc = DAG.getSetCC(DL, SetccVT, Movmsk, CmpC, CondCode); SDValue Zext = DAG.getZExtOrTrunc(Setcc, DL, ExtractVT); SDValue Zero = DAG.getConstant(0, DL, ExtractVT); |

