summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp44
-rw-r--r--llvm/test/CodeGen/X86/avx512-vec-cmp.ll6
2 files changed, 41 insertions, 9 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 77a7489a2b5..5741b80d1a7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -43364,9 +43364,22 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ // With vector masks we only demand the upper bit of the mask.
+ SDValue Mask = cast<X86MaskedGatherScatterSDNode>(N)->getMask();
+ if (Mask.getScalarValueSizeInBits() != 1) {
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
+ if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI))
+ return SDValue(N, 0);
+ }
+
+ return SDValue();
+}
+
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
- const X86Subtarget &Subtarget) {
+ TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
if (DCI.isBeforeLegalizeOps()) {
@@ -43426,7 +43439,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
}
// With vector masks we only demand the upper bit of the mask.
- SDValue Mask = N->getOperand(2);
+ SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
@@ -44465,6 +44478,27 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
HADDBuilder);
}
+ // If vectors of i1 are legal, turn (add (zext (vXi1 X)), Y) into
+ // (sub Y, (sext (vXi1 X))).
+ // FIXME: We have the (sub Y, (zext (vXi1 X))) -> (add (sext (vXi1 X)), Y) in
+ // generic DAG combine without a legal type check, but adding this there
+ // caused regressions.
+ if (Subtarget.hasAVX512() && VT.isVector()) {
+ if (Op0.getOpcode() == ISD::ZERO_EXTEND &&
+ Op0.getOperand(0).getValueType().getVectorElementType() == MVT::i1) {
+ SDLoc DL(N);
+ SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op0.getOperand(0));
+ return DAG.getNode(ISD::SUB, DL, VT, Op1, SExt);
+ }
+
+ if (Op1.getOpcode() == ISD::ZERO_EXTEND &&
+ Op1.getOperand(0).getValueType().getVectorElementType() == MVT::i1) {
+ SDLoc DL(N);
+ SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op1.getOperand(0));
+ return DAG.getNode(ISD::SUB, DL, VT, Op0, SExt);
+ }
+ }
+
return combineAddOrSubToADCOrSBB(N, DAG);
}
@@ -45355,9 +45389,9 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::FMSUBADD: return combineFMADDSUB(N, DAG, Subtarget);
case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget);
case X86ISD::MGATHER:
- case X86ISD::MSCATTER:
+ case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI);
case ISD::MGATHER:
- case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI, Subtarget);
+ case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
case X86ISD::PCMPEQ:
case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget);
case X86ISD::PMULDQ:
diff --git a/llvm/test/CodeGen/X86/avx512-vec-cmp.ll b/llvm/test/CodeGen/X86/avx512-vec-cmp.ll
index b5fcc750855..88910fa1749 100644
--- a/llvm/test/CodeGen/X86/avx512-vec-cmp.ll
+++ b/llvm/test/CodeGen/X86/avx512-vec-cmp.ll
@@ -1414,8 +1414,7 @@ define <4 x i32> @zext_bool_logic(<4 x i64> %cond1, <4 x i64> %cond2, <4 x i32>
; AVX512-NEXT: vptestnmq %zmm1, %zmm1, %k1 ## encoding: [0x62,0xf2,0xf6,0x48,0x27,0xc9]
; AVX512-NEXT: korw %k1, %k0, %k1 ## encoding: [0xc5,0xfc,0x45,0xc9]
; AVX512-NEXT: vpternlogd $255, %zmm0, %zmm0, %zmm0 {%k1} {z} ## encoding: [0x62,0xf3,0x7d,0xc9,0x25,0xc0,0xff]
-; AVX512-NEXT: vpsrld $31, %xmm0, %xmm0 ## encoding: [0xc5,0xf9,0x72,0xd0,0x1f]
-; AVX512-NEXT: vpaddd %xmm2, %xmm0, %xmm0 ## encoding: [0xc5,0xf9,0xfe,0xc2]
+; AVX512-NEXT: vpsubd %xmm0, %xmm2, %xmm0 ## encoding: [0xc5,0xe9,0xfa,0xc0]
; AVX512-NEXT: vzeroupper ## encoding: [0xc5,0xf8,0x77]
; AVX512-NEXT: retq ## encoding: [0xc3]
;
@@ -1425,8 +1424,7 @@ define <4 x i32> @zext_bool_logic(<4 x i64> %cond1, <4 x i64> %cond2, <4 x i32>
; SKX-NEXT: vptestnmq %ymm1, %ymm1, %k1 ## encoding: [0x62,0xf2,0xf6,0x28,0x27,0xc9]
; SKX-NEXT: korw %k1, %k0, %k0 ## encoding: [0xc5,0xfc,0x45,0xc1]
; SKX-NEXT: vpmovm2d %k0, %xmm0 ## encoding: [0x62,0xf2,0x7e,0x08,0x38,0xc0]
-; SKX-NEXT: vpsrld $31, %xmm0, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xf9,0x72,0xd0,0x1f]
-; SKX-NEXT: vpaddd %xmm2, %xmm0, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xf9,0xfe,0xc2]
+; SKX-NEXT: vpsubd %xmm0, %xmm2, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xe9,0xfa,0xc0]
; SKX-NEXT: vzeroupper ## encoding: [0xc5,0xf8,0x77]
; SKX-NEXT: retq ## encoding: [0xc3]
%a = icmp eq <4 x i64> %cond1, zeroinitializer
OpenPOWER on IntegriCloud