diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index a9c12047662..bc4c2a842a2 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -28627,6 +28627,85 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1); } +// Attempt to replace an all_of/any_of style horizontal reduction with a MOVMSK. +static SDValue combineHorizontalPredicateResult(SDNode *Extract, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // Bail without SSE2 or with AVX512VL (which uses predicate registers). + if (!Subtarget.hasSSE2() || Subtarget.hasVLX()) + return SDValue(); + + EVT ExtractVT = Extract->getValueType(0); + unsigned BitWidth = ExtractVT.getSizeInBits(); + if (ExtractVT != MVT::i64 && ExtractVT != MVT::i32 && ExtractVT != MVT::i16 && + ExtractVT != MVT::i8) + return SDValue(); + + // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. + for (ISD::NodeType Op : {ISD::OR, ISD::AND}) { + SDValue Match = matchBinOpReduction(Extract, Op); + if (!Match) + continue; + + // EXTRACT_VECTOR_ELT can require implicit extension of the vector element + // which we can't support here for now. + if (Match.getScalarValueSizeInBits() != BitWidth) + continue; + + // We require AVX2 for PMOVMSKB for v16i16/v32i8; + unsigned MatchSizeInBits = Match.getValueSizeInBits(); + if (!(MatchSizeInBits == 128 || + (MatchSizeInBits == 256 && + ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2())))) + return SDValue(); + + // Don't bother performing this for 2-element vectors. + if (Match.getValueType().getVectorNumElements() <= 2) + return SDValue(); + + // Check that we are extracting a reduction of all sign bits. + if (DAG.ComputeNumSignBits(Match) != BitWidth) + return SDValue(); + + // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. + MVT MaskVT; + if (64 == BitWidth || 32 == BitWidth) + MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), + MatchSizeInBits / BitWidth); + else + MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); + + APInt CompareBits; + ISD::CondCode CondCode; + if (Op == ISD::OR) { + // any_of -> MOVMSK != 0 + CompareBits = APInt::getNullValue(32); + CondCode = ISD::CondCode::SETNE; + } else { + // all_of -> MOVMSK == ((1 << NumElts) - 1) + CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements()); + CondCode = ISD::CondCode::SETEQ; + } + + // Perform the select as i32/i64 and then truncate to avoid partial register + // stalls. + unsigned ResWidth = std::max(BitWidth, 32u); + APInt ResOnes = APInt::getAllOnesValue(ResWidth); + APInt ResZero = APInt::getNullValue(ResWidth); + EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth); + + SDLoc DL(Extract); + SDValue Res = DAG.getBitcast(MaskVT, Match); + Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res); + Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), + DAG.getConstant(ResOnes, DL, ResVT), + DAG.getConstant(ResZero, DL, ResVT), CondCode); + return DAG.getSExtOrTrunc(Res, DL, ExtractVT); + } + + return SDValue(); +} + static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, const X86Subtarget &Subtarget) { // PSADBW is only supported on SSE2 and up. @@ -28738,6 +28817,10 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) return SAD; + // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK. + if (SDValue Cmp = combineHorizontalPredicateResult(N, DAG, Subtarget)) + return Cmp; + // Only operate on vectors of 4 elements, where the alternative shuffling // gets to be more expensive. if (InputVector.getValueType() != MVT::v4i32) |