diff options
| author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2019-09-08 11:46:21 +0000 |
|---|---|---|
| committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2019-09-08 11:46:21 +0000 |
| commit | 3262084384c74c6dd8bf5908a2014081ec003e1d (patch) | |
| tree | c82d7e6b334b19fba863c2cdd2b72ec9bfc2f2e3 /llvm/lib | |
| parent | acf81f4210cdc769c25ea41c8acf77666190767e (diff) | |
| download | bcm5719-llvm-3262084384c74c6dd8bf5908a2014081ec003e1d.tar.gz bcm5719-llvm-3262084384c74c6dd8bf5908a2014081ec003e1d.zip | |
[X86][SSE] Add support for <64 x i1> bool reduction
This generalizes the existing <32 x i1> pre-AVX2 split code to support reductions from <64 x i1> as well, we can probably generalize to any larger pow2 case in the future if the (unlikely) need ever arises.
We still need to tweak combineBitcastvxi1 to improve AVX512F codegen as its assumes vXi1 types should be handled on the mask registers even when they aren't legal.
Differential Revision: https://reviews.llvm.org/D67070
llvm-svn: 371328
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 552c91703a7..f5c9971a3c5 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35859,11 +35859,12 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, SDLoc DL(Extract); EVT MatchVT = Match.getValueType(); unsigned NumElts = MatchVT.getVectorNumElements(); + unsigned MaxElts = Subtarget.hasInt256() ? 32 : 16; const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (ExtractVT == MVT::i1) { // Special case for (pre-legalization) vXi1 reductions. - if (NumElts > 32) + if (NumElts > 64 || !isPowerOf2_32(NumElts)) return SDValue(); if (TLI.isTypeLegal(MatchVT)) { // If this is a legal AVX512 predicate type then we can just bitcast. @@ -35871,18 +35872,18 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, Movmsk = DAG.getBitcast(MovmskVT, Match); } else { // Use combineBitcastvxi1 to create the MOVMSK. - if (NumElts == 32 && !Subtarget.hasInt256()) { + while (NumElts > MaxElts) { SDValue Lo, Hi; std::tie(Lo, Hi) = DAG.SplitVector(Match, DL); Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi); - NumElts = 16; + NumElts /= 2; } EVT MovmskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts); Movmsk = combineBitcastvxi1(DAG, MovmskVT, Match, DL, Subtarget); } if (!Movmsk) return SDValue(); - Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, MVT::i32); + Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, NumElts > 32 ? MVT::i64 : MVT::i32); } else { // Bail with AVX512VL (which uses predicate registers). if (Subtarget.hasVLX()) @@ -35923,13 +35924,15 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget); NumElts = MaskSrcVT.getVectorNumElements(); } - assert(NumElts <= 32 && "Not expecting more than 32 elements"); + assert((NumElts <= 32 || NumElts == 64) && + "Not expecting more than 64 elements"); + MVT CmpVT = NumElts == 64 ? MVT::i64 : MVT::i32; if (BinOp == ISD::XOR) { // parity -> (AND (CTPOP(MOVMSK X)), 1) - SDValue Mask = DAG.getConstant(1, DL, MVT::i32); - SDValue Result = DAG.getNode(ISD::CTPOP, DL, MVT::i32, Movmsk); - Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result, Mask); + SDValue Mask = DAG.getConstant(1, DL, CmpVT); + SDValue Result = DAG.getNode(ISD::CTPOP, DL, CmpVT, Movmsk); + Result = DAG.getNode(ISD::AND, DL, CmpVT, Result, Mask); return DAG.getZExtOrTrunc(Result, DL, ExtractVT); } @@ -35937,18 +35940,18 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, ISD::CondCode CondCode; if (BinOp == ISD::OR) { // any_of -> MOVMSK != 0 - CmpC = DAG.getConstant(0, DL, MVT::i32); + CmpC = DAG.getConstant(0, DL, CmpVT); CondCode = ISD::CondCode::SETNE; } else { // all_of -> MOVMSK == ((1 << NumElts) - 1) - CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, MVT::i32); + CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, CmpVT); CondCode = ISD::CondCode::SETEQ; } // The setcc produces an i8 of 0/1, so extend that to the result width and // negate to get the final 0/-1 mask value. EVT SetccVT = - TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i32); + TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), CmpVT); SDValue Setcc = DAG.getSetCC(DL, SetccVT, Movmsk, CmpC, CondCode); SDValue Zext = DAG.getZExtOrTrunc(Setcc, DL, ExtractVT); SDValue Zero = DAG.getConstant(0, DL, ExtractVT); |

