diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 44 | ||||
-rw-r--r-- | llvm/test/CodeGen/X86/avx512-vec-cmp.ll | 6 |
2 files changed, 41 insertions, 9 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 77a7489a2b5..5741b80d1a7 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43364,9 +43364,22 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { + // With vector masks we only demand the upper bit of the mask. + SDValue Mask = cast<X86MaskedGatherScatterSDNode>(N)->getMask(); + if (Mask.getScalarValueSizeInBits() != 1) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits())); + if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) + return SDValue(N, 0); + } + + return SDValue(); +} + static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { + TargetLowering::DAGCombinerInfo &DCI) { SDLoc DL(N); if (DCI.isBeforeLegalizeOps()) { @@ -43426,7 +43439,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, } // With vector masks we only demand the upper bit of the mask. - SDValue Mask = N->getOperand(2); + SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask(); if (Mask.getScalarValueSizeInBits() != 1) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits())); @@ -44465,6 +44478,27 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, HADDBuilder); } + // If vectors of i1 are legal, turn (add (zext (vXi1 X)), Y) into + // (sub Y, (sext (vXi1 X))). + // FIXME: We have the (sub Y, (zext (vXi1 X))) -> (add (sext (vXi1 X)), Y) in + // generic DAG combine without a legal type check, but adding this there + // caused regressions. + if (Subtarget.hasAVX512() && VT.isVector()) { + if (Op0.getOpcode() == ISD::ZERO_EXTEND && + Op0.getOperand(0).getValueType().getVectorElementType() == MVT::i1) { + SDLoc DL(N); + SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op0.getOperand(0)); + return DAG.getNode(ISD::SUB, DL, VT, Op1, SExt); + } + + if (Op1.getOpcode() == ISD::ZERO_EXTEND && + Op1.getOperand(0).getValueType().getVectorElementType() == MVT::i1) { + SDLoc DL(N); + SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op1.getOperand(0)); + return DAG.getNode(ISD::SUB, DL, VT, Op0, SExt); + } + } + return combineAddOrSubToADCOrSBB(N, DAG); } @@ -45355,9 +45389,9 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::FMSUBADD: return combineFMADDSUB(N, DAG, Subtarget); case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget); case X86ISD::MGATHER: - case X86ISD::MSCATTER: + case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI); case ISD::MGATHER: - case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI, Subtarget); + case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI); case X86ISD::PCMPEQ: case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget); case X86ISD::PMULDQ: diff --git a/llvm/test/CodeGen/X86/avx512-vec-cmp.ll b/llvm/test/CodeGen/X86/avx512-vec-cmp.ll index b5fcc750855..88910fa1749 100644 --- a/llvm/test/CodeGen/X86/avx512-vec-cmp.ll +++ b/llvm/test/CodeGen/X86/avx512-vec-cmp.ll @@ -1414,8 +1414,7 @@ define <4 x i32> @zext_bool_logic(<4 x i64> %cond1, <4 x i64> %cond2, <4 x i32> ; AVX512-NEXT: vptestnmq %zmm1, %zmm1, %k1 ## encoding: [0x62,0xf2,0xf6,0x48,0x27,0xc9] ; AVX512-NEXT: korw %k1, %k0, %k1 ## encoding: [0xc5,0xfc,0x45,0xc9] ; AVX512-NEXT: vpternlogd $255, %zmm0, %zmm0, %zmm0 {%k1} {z} ## encoding: [0x62,0xf3,0x7d,0xc9,0x25,0xc0,0xff] -; AVX512-NEXT: vpsrld $31, %xmm0, %xmm0 ## encoding: [0xc5,0xf9,0x72,0xd0,0x1f] -; AVX512-NEXT: vpaddd %xmm2, %xmm0, %xmm0 ## encoding: [0xc5,0xf9,0xfe,0xc2] +; AVX512-NEXT: vpsubd %xmm0, %xmm2, %xmm0 ## encoding: [0xc5,0xe9,0xfa,0xc0] ; AVX512-NEXT: vzeroupper ## encoding: [0xc5,0xf8,0x77] ; AVX512-NEXT: retq ## encoding: [0xc3] ; @@ -1425,8 +1424,7 @@ define <4 x i32> @zext_bool_logic(<4 x i64> %cond1, <4 x i64> %cond2, <4 x i32> ; SKX-NEXT: vptestnmq %ymm1, %ymm1, %k1 ## encoding: [0x62,0xf2,0xf6,0x28,0x27,0xc9] ; SKX-NEXT: korw %k1, %k0, %k0 ## encoding: [0xc5,0xfc,0x45,0xc1] ; SKX-NEXT: vpmovm2d %k0, %xmm0 ## encoding: [0x62,0xf2,0x7e,0x08,0x38,0xc0] -; SKX-NEXT: vpsrld $31, %xmm0, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xf9,0x72,0xd0,0x1f] -; SKX-NEXT: vpaddd %xmm2, %xmm0, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xf9,0xfe,0xc2] +; SKX-NEXT: vpsubd %xmm0, %xmm2, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xe9,0xfa,0xc0] ; SKX-NEXT: vzeroupper ## encoding: [0xc5,0xf8,0x77] ; SKX-NEXT: retq ## encoding: [0xc3] %a = icmp eq <4 x i64> %cond1, zeroinitializer |