diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 1c26f5789c2..03c1c5dd453 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -26791,9 +26791,51 @@ static SDValue foldXorTruncShiftIntoCmp(SDNode *N, SelectionDAG &DAG) { return Cond; } +/// Turn vector tests of the signbit in the form of: +/// xor (sra X, elt_size(X)-1), -1 +/// into: +/// pcmpgt X, -1 +/// +/// This should be called before type legalization because the pattern may not +/// persist after that. +static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + EVT VT = N->getValueType(0); + // TODO: AVX2 can handle 256-bit integer vectors. + if (!((Subtarget.hasSSE2() && + (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32)) || + (Subtarget.hasSSE42() && VT == MVT::v2i64))) + return SDValue(); + + // There must be a shift right algebraic before the xor, and the xor must be a + // 'not' operation. + SDValue Shift = N->getOperand(0); + SDValue Ones = N->getOperand(1); + if (Shift.getOpcode() != ISD::SRA || !Shift.hasOneUse() || + !ISD::isBuildVectorAllOnes(Ones.getNode())) + return SDValue(); + + // The shift should be smearing the sign bit across each vector element. + auto *ShiftBV = dyn_cast<BuildVectorSDNode>(Shift.getOperand(1)); + if (!ShiftBV) + return SDValue(); + + EVT ShiftEltTy = Shift.getValueType().getVectorElementType(); + auto *ShiftAmt = ShiftBV->getConstantSplatNode(); + if (!ShiftAmt || ShiftAmt->getZExtValue() != ShiftEltTy.getSizeInBits() - 1) + return SDValue(); + + // Create a greater-than comparison against -1. We don't use the more obvious + // greater-than-or-equal-to-zero because SSE/AVX don't have that instruction. + return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones); +} + static SDValue PerformXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) + return Cmp; + if (DCI.isBeforeLegalizeOps()) return SDValue(); |