summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
authorSimon Pilgrim <llvm-dev@redking.me.uk>2019-09-08 11:46:21 +0000
committerSimon Pilgrim <llvm-dev@redking.me.uk>2019-09-08 11:46:21 +0000
commit3262084384c74c6dd8bf5908a2014081ec003e1d (patch)
treec82d7e6b334b19fba863c2cdd2b72ec9bfc2f2e3 /llvm/lib
parentacf81f4210cdc769c25ea41c8acf77666190767e (diff)
downloadbcm5719-llvm-3262084384c74c6dd8bf5908a2014081ec003e1d.tar.gz
bcm5719-llvm-3262084384c74c6dd8bf5908a2014081ec003e1d.zip
[X86][SSE] Add support for <64 x i1> bool reduction
This generalizes the existing <32 x i1> pre-AVX2 split code to support reductions from <64 x i1> as well, we can probably generalize to any larger pow2 case in the future if the (unlikely) need ever arises. We still need to tweak combineBitcastvxi1 to improve AVX512F codegen as its assumes vXi1 types should be handled on the mask registers even when they aren't legal. Differential Revision: https://reviews.llvm.org/D67070 llvm-svn: 371328
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp25
1 files changed, 14 insertions, 11 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 552c91703a7..f5c9971a3c5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -35859,11 +35859,12 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
SDLoc DL(Extract);
EVT MatchVT = Match.getValueType();
unsigned NumElts = MatchVT.getVectorNumElements();
+ unsigned MaxElts = Subtarget.hasInt256() ? 32 : 16;
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (ExtractVT == MVT::i1) {
// Special case for (pre-legalization) vXi1 reductions.
- if (NumElts > 32)
+ if (NumElts > 64 || !isPowerOf2_32(NumElts))
return SDValue();
if (TLI.isTypeLegal(MatchVT)) {
// If this is a legal AVX512 predicate type then we can just bitcast.
@@ -35871,18 +35872,18 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
Movmsk = DAG.getBitcast(MovmskVT, Match);
} else {
// Use combineBitcastvxi1 to create the MOVMSK.
- if (NumElts == 32 && !Subtarget.hasInt256()) {
+ while (NumElts > MaxElts) {
SDValue Lo, Hi;
std::tie(Lo, Hi) = DAG.SplitVector(Match, DL);
Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi);
- NumElts = 16;
+ NumElts /= 2;
}
EVT MovmskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts);
Movmsk = combineBitcastvxi1(DAG, MovmskVT, Match, DL, Subtarget);
}
if (!Movmsk)
return SDValue();
- Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, MVT::i32);
+ Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, NumElts > 32 ? MVT::i64 : MVT::i32);
} else {
// Bail with AVX512VL (which uses predicate registers).
if (Subtarget.hasVLX())
@@ -35923,13 +35924,15 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget);
NumElts = MaskSrcVT.getVectorNumElements();
}
- assert(NumElts <= 32 && "Not expecting more than 32 elements");
+ assert((NumElts <= 32 || NumElts == 64) &&
+ "Not expecting more than 64 elements");
+ MVT CmpVT = NumElts == 64 ? MVT::i64 : MVT::i32;
if (BinOp == ISD::XOR) {
// parity -> (AND (CTPOP(MOVMSK X)), 1)
- SDValue Mask = DAG.getConstant(1, DL, MVT::i32);
- SDValue Result = DAG.getNode(ISD::CTPOP, DL, MVT::i32, Movmsk);
- Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result, Mask);
+ SDValue Mask = DAG.getConstant(1, DL, CmpVT);
+ SDValue Result = DAG.getNode(ISD::CTPOP, DL, CmpVT, Movmsk);
+ Result = DAG.getNode(ISD::AND, DL, CmpVT, Result, Mask);
return DAG.getZExtOrTrunc(Result, DL, ExtractVT);
}
@@ -35937,18 +35940,18 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
ISD::CondCode CondCode;
if (BinOp == ISD::OR) {
// any_of -> MOVMSK != 0
- CmpC = DAG.getConstant(0, DL, MVT::i32);
+ CmpC = DAG.getConstant(0, DL, CmpVT);
CondCode = ISD::CondCode::SETNE;
} else {
// all_of -> MOVMSK == ((1 << NumElts) - 1)
- CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, MVT::i32);
+ CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, CmpVT);
CondCode = ISD::CondCode::SETEQ;
}
// The setcc produces an i8 of 0/1, so extend that to the result width and
// negate to get the final 0/-1 mask value.
EVT SetccVT =
- TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i32);
+ TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), CmpVT);
SDValue Setcc = DAG.getSetCC(DL, SetccVT, Movmsk, CmpC, CondCode);
SDValue Zext = DAG.getZExtOrTrunc(Setcc, DL, ExtractVT);
SDValue Zero = DAG.getConstant(0, DL, ExtractVT);
OpenPOWER on IntegriCloud