diff options
author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2017-02-02 11:52:33 +0000 |
---|---|---|
committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2017-02-02 11:52:33 +0000 |
commit | 20ab6b875ad6cdc831351818775f73a91e56db09 (patch) | |
tree | 7f703e9244c575eda3f6b23c77542c251a597e5d /llvm/lib/Target/X86/X86ISelLowering.cpp | |
parent | 28912c09b20909ae5ffde3ab3cdad2e3c65ae657 (diff) | |
download | bcm5719-llvm-20ab6b875ad6cdc831351818775f73a91e56db09.tar.gz bcm5719-llvm-20ab6b875ad6cdc831351818775f73a91e56db09.zip |
[X86][SSE] Use MOVMSK for all_of/any_of reduction patterns
This is a first attempt at using the MOVMSK instructions to replace all_of/any_of reduction patterns (i.e. an and/or + shuffle chain).
So far this only matches patterns where we are reducing an all/none bits source vector (i.e. a comparison result) but we should be able to expand on this in conjunction with improvements to 'bool vector' handling both in the x86 backend as well as the vectorizers etc.
Differential Revision: https://reviews.llvm.org/D28810
llvm-svn: 293880
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index a9c12047662..bc4c2a842a2 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -28627,6 +28627,85 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1); } +// Attempt to replace an all_of/any_of style horizontal reduction with a MOVMSK. +static SDValue combineHorizontalPredicateResult(SDNode *Extract, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // Bail without SSE2 or with AVX512VL (which uses predicate registers). + if (!Subtarget.hasSSE2() || Subtarget.hasVLX()) + return SDValue(); + + EVT ExtractVT = Extract->getValueType(0); + unsigned BitWidth = ExtractVT.getSizeInBits(); + if (ExtractVT != MVT::i64 && ExtractVT != MVT::i32 && ExtractVT != MVT::i16 && + ExtractVT != MVT::i8) + return SDValue(); + + // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. + for (ISD::NodeType Op : {ISD::OR, ISD::AND}) { + SDValue Match = matchBinOpReduction(Extract, Op); + if (!Match) + continue; + + // EXTRACT_VECTOR_ELT can require implicit extension of the vector element + // which we can't support here for now. + if (Match.getScalarValueSizeInBits() != BitWidth) + continue; + + // We require AVX2 for PMOVMSKB for v16i16/v32i8; + unsigned MatchSizeInBits = Match.getValueSizeInBits(); + if (!(MatchSizeInBits == 128 || + (MatchSizeInBits == 256 && + ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2())))) + return SDValue(); + + // Don't bother performing this for 2-element vectors. + if (Match.getValueType().getVectorNumElements() <= 2) + return SDValue(); + + // Check that we are extracting a reduction of all sign bits. + if (DAG.ComputeNumSignBits(Match) != BitWidth) + return SDValue(); + + // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. + MVT MaskVT; + if (64 == BitWidth || 32 == BitWidth) + MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), + MatchSizeInBits / BitWidth); + else + MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); + + APInt CompareBits; + ISD::CondCode CondCode; + if (Op == ISD::OR) { + // any_of -> MOVMSK != 0 + CompareBits = APInt::getNullValue(32); + CondCode = ISD::CondCode::SETNE; + } else { + // all_of -> MOVMSK == ((1 << NumElts) - 1) + CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements()); + CondCode = ISD::CondCode::SETEQ; + } + + // Perform the select as i32/i64 and then truncate to avoid partial register + // stalls. + unsigned ResWidth = std::max(BitWidth, 32u); + APInt ResOnes = APInt::getAllOnesValue(ResWidth); + APInt ResZero = APInt::getNullValue(ResWidth); + EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth); + + SDLoc DL(Extract); + SDValue Res = DAG.getBitcast(MaskVT, Match); + Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res); + Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), + DAG.getConstant(ResOnes, DL, ResVT), + DAG.getConstant(ResZero, DL, ResVT), CondCode); + return DAG.getSExtOrTrunc(Res, DL, ExtractVT); + } + + return SDValue(); +} + static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, const X86Subtarget &Subtarget) { // PSADBW is only supported on SSE2 and up. @@ -28738,6 +28817,10 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) return SAD; + // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK. + if (SDValue Cmp = combineHorizontalPredicateResult(N, DAG, Subtarget)) + return Cmp; + // Only operate on vectors of 4 elements, where the alternative shuffling // gets to be more expensive. if (InputVector.getValueType() != MVT::v4i32) |