summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp25
1 files changed, 14 insertions, 11 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 552c91703a7..f5c9971a3c5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -35859,11 +35859,12 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
SDLoc DL(Extract);
EVT MatchVT = Match.getValueType();
unsigned NumElts = MatchVT.getVectorNumElements();
+ unsigned MaxElts = Subtarget.hasInt256() ? 32 : 16;
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (ExtractVT == MVT::i1) {
// Special case for (pre-legalization) vXi1 reductions.
- if (NumElts > 32)
+ if (NumElts > 64 || !isPowerOf2_32(NumElts))
return SDValue();
if (TLI.isTypeLegal(MatchVT)) {
// If this is a legal AVX512 predicate type then we can just bitcast.
@@ -35871,18 +35872,18 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
Movmsk = DAG.getBitcast(MovmskVT, Match);
} else {
// Use combineBitcastvxi1 to create the MOVMSK.
- if (NumElts == 32 && !Subtarget.hasInt256()) {
+ while (NumElts > MaxElts) {
SDValue Lo, Hi;
std::tie(Lo, Hi) = DAG.SplitVector(Match, DL);
Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi);
- NumElts = 16;
+ NumElts /= 2;
}
EVT MovmskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts);
Movmsk = combineBitcastvxi1(DAG, MovmskVT, Match, DL, Subtarget);
}
if (!Movmsk)
return SDValue();
- Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, MVT::i32);
+ Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, NumElts > 32 ? MVT::i64 : MVT::i32);
} else {
// Bail with AVX512VL (which uses predicate registers).
if (Subtarget.hasVLX())
@@ -35923,13 +35924,15 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget);
NumElts = MaskSrcVT.getVectorNumElements();
}
- assert(NumElts <= 32 && "Not expecting more than 32 elements");
+ assert((NumElts <= 32 || NumElts == 64) &&
+ "Not expecting more than 64 elements");
+ MVT CmpVT = NumElts == 64 ? MVT::i64 : MVT::i32;
if (BinOp == ISD::XOR) {
// parity -> (AND (CTPOP(MOVMSK X)), 1)
- SDValue Mask = DAG.getConstant(1, DL, MVT::i32);
- SDValue Result = DAG.getNode(ISD::CTPOP, DL, MVT::i32, Movmsk);
- Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result, Mask);
+ SDValue Mask = DAG.getConstant(1, DL, CmpVT);
+ SDValue Result = DAG.getNode(ISD::CTPOP, DL, CmpVT, Movmsk);
+ Result = DAG.getNode(ISD::AND, DL, CmpVT, Result, Mask);
return DAG.getZExtOrTrunc(Result, DL, ExtractVT);
}
@@ -35937,18 +35940,18 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
ISD::CondCode CondCode;
if (BinOp == ISD::OR) {
// any_of -> MOVMSK != 0
- CmpC = DAG.getConstant(0, DL, MVT::i32);
+ CmpC = DAG.getConstant(0, DL, CmpVT);
CondCode = ISD::CondCode::SETNE;
} else {
// all_of -> MOVMSK == ((1 << NumElts) - 1)
- CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, MVT::i32);
+ CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, CmpVT);
CondCode = ISD::CondCode::SETEQ;
}
// The setcc produces an i8 of 0/1, so extend that to the result width and
// negate to get the final 0/-1 mask value.
EVT SetccVT =
- TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i32);
+ TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), CmpVT);
SDValue Setcc = DAG.getSetCC(DL, SetccVT, Movmsk, CmpC, CondCode);
SDValue Zext = DAG.getZExtOrTrunc(Setcc, DL, ExtractVT);
SDValue Zero = DAG.getConstant(0, DL, ExtractVT);
OpenPOWER on IntegriCloud