summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp74
-rw-r--r--llvm/lib/Target/X86/X86InstrAVX512.td14
-rw-r--r--llvm/test/CodeGen/X86/masked_gather_scatter.ll39
3 files changed, 72 insertions, 55 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 16394f0edc6..ff7ea5d38f5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1390,6 +1390,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::ROTR, VT, Custom);
}
+ // Custom legalize 2x32 to get a little better code.
+ setOperationAction(ISD::MSCATTER, MVT::v2f32, Custom);
+ setOperationAction(ISD::MSCATTER, MVT::v2i32, Custom);
+
for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 })
setOperationAction(ISD::MSCATTER, VT, Custom);
@@ -24322,33 +24326,55 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
SDValue Mask = N->getMask();
SDValue Chain = N->getChain();
SDValue BasePtr = N->getBasePtr();
- MVT MemVT = N->getMemoryVT().getSimpleVT();
+
+ if (VT == MVT::v2f32) {
+ assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type");
+ // If the index is v2i64 and we have VLX we can use xmm for data and index.
+ if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) {
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src,
+ DAG.getUNDEF(MVT::v2f32));
+ SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other);
+ SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
+ SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
+ VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
+ DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
+ return SDValue(NewScatter.getNode(), 1);
+ }
+ return SDValue();
+ }
+
+ if (VT == MVT::v2i32) {
+ assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type");
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src,
+ DAG.getUNDEF(MVT::v2i32));
+ // If the index is v2i64 and we have VLX we can use xmm for data and index.
+ if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) {
+ SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other);
+ SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
+ SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
+ VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
+ DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
+ return SDValue(NewScatter.getNode(), 1);
+ }
+ // Custom widen all the operands to avoid promotion.
+ EVT NewIndexVT = EVT::getVectorVT(
+ *DAG.getContext(), Index.getValueType().getVectorElementType(), 4);
+ Index = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewIndexVT, Index,
+ DAG.getUNDEF(Index.getValueType()));
+ Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask,
+ DAG.getConstant(0, dl, MVT::v2i1));
+ SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
+ return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(), dl,
+ Ops, N->getMemOperand());
+ }
+
MVT IndexVT = Index.getSimpleValueType();
MVT MaskVT = Mask.getSimpleValueType();
- if (MemVT.getScalarSizeInBits() < VT.getScalarSizeInBits()) {
- // The v2i32 value was promoted to v2i64.
- // Now we "redo" the type legalizer's work and widen the original
- // v2i32 value to v4i32. The original v2i32 is retrieved from v2i64
- // with a shuffle.
- assert((MemVT == MVT::v2i32 && VT == MVT::v2i64) &&
- "Unexpected memory type");
- int ShuffleMask[] = {0, 2, -1, -1};
- Src = DAG.getVectorShuffle(MVT::v4i32, dl, DAG.getBitcast(MVT::v4i32, Src),
- DAG.getUNDEF(MVT::v4i32), ShuffleMask);
- // Now we have 4 elements instead of 2.
- // Expand the index.
- MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), 4);
- Index = ExtendToType(Index, NewIndexVT, DAG);
-
- // Expand the mask with zeroes
- // Mask may be <2 x i64> or <2 x i1> at this moment
- assert((MaskVT == MVT::v2i1 || MaskVT == MVT::v2i64) &&
- "Unexpected mask type");
- MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), 4);
- Mask = ExtendToType(Mask, ExtMaskVT, DAG, true);
- VT = MVT::v4i32;
- }
+ // If the index is v2i32, we're being called by type legalization and we
+ // should just let the default handling take care of it.
+ if (IndexVT == MVT::v2i32)
+ return SDValue();
unsigned NumElts = VT.getVectorNumElements();
if (!Subtarget.hasVLX() && !VT.is512BitVector() &&
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 058defdd1e2..e5fab871ac7 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -8610,16 +8610,17 @@ defm VPGATHER : avx512_gather_q_pd<0x90, 0x91, avx512vl_i64_info, "vpgather", "Q
avx512_gather_d_ps<0x90, 0x91, avx512vl_i32_info, "vpgather", "D">;
multiclass avx512_scatter<bits<8> opc, string OpcodeStr, X86VectorVTInfo _,
- X86MemOperand memop, PatFrag ScatterNode> {
+ X86MemOperand memop, PatFrag ScatterNode,
+ RegisterClass MaskRC = _.KRCWM> {
let mayStore = 1, Constraints = "$mask = $mask_wb", ExeDomain = _.ExeDomain in
- def mr : AVX5128I<opc, MRMDestMem, (outs _.KRCWM:$mask_wb),
- (ins memop:$dst, _.KRCWM:$mask, _.RC:$src),
+ def mr : AVX5128I<opc, MRMDestMem, (outs MaskRC:$mask_wb),
+ (ins memop:$dst, MaskRC:$mask, _.RC:$src),
!strconcat(OpcodeStr#_.Suffix,
"\t{$src, ${dst} {${mask}}|${dst} {${mask}}, $src}"),
- [(set _.KRCWM:$mask_wb, (ScatterNode (_.VT _.RC:$src),
- _.KRCWM:$mask, vectoraddr:$dst))]>,
+ [(set MaskRC:$mask_wb, (ScatterNode (_.VT _.RC:$src),
+ MaskRC:$mask, vectoraddr:$dst))]>,
EVEX, EVEX_K, EVEX_CD8<_.EltSize, CD8VT1>,
Sched<[WriteStore]>;
}
@@ -8656,7 +8657,8 @@ let Predicates = [HasVLX] in {
defm NAME##D##SUFF##Z128: avx512_scatter<dopc, OpcodeStr##"d", _.info128,
vx128xmem, mscatterv4i32>, EVEX_V128;
defm NAME##Q##SUFF##Z128: avx512_scatter<qopc, OpcodeStr##"q", _.info128,
- vx64xmem, mscatterv2i64>, EVEX_V128;
+ vx64xmem, mscatterv2i64, VK2WM>,
+ EVEX_V128;
}
}
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 2558efebbea..723b6e95741 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -1094,11 +1094,9 @@ define void @test20(<2 x float>%a1, <2 x float*> %ptr, <2 x i1> %mask) {
;
; SKX-LABEL: test20:
; SKX: # %bb.0:
-; SKX-NEXT: # kill: def %xmm1 killed %xmm1 def %ymm1
; SKX-NEXT: vpsllq $63, %xmm2, %xmm2
; SKX-NEXT: vptestmq %xmm2, %xmm2, %k1
-; SKX-NEXT: vscatterqps %xmm0, (,%ymm1) {%k1}
-; SKX-NEXT: vzeroupper
+; SKX-NEXT: vscatterqps %xmm0, (,%xmm1) {%k1}
; SKX-NEXT: retq
;
; SKX_32-LABEL: test20:
@@ -1119,45 +1117,41 @@ define void @test21(<2 x i32>%a1, <2 x i32*> %ptr, <2 x i1>%mask) {
; KNL_64-NEXT: # kill: def %xmm1 killed %xmm1 def %zmm1
; KNL_64-NEXT: vpsllq $63, %xmm2, %xmm2
; KNL_64-NEXT: vptestmq %zmm2, %zmm2, %k0
+; KNL_64-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; KNL_64-NEXT: kshiftlw $14, %k0, %k0
; KNL_64-NEXT: kshiftrw $14, %k0, %k1
-; KNL_64-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; KNL_64-NEXT: vpscatterqd %ymm0, (,%zmm1) {%k1}
; KNL_64-NEXT: vzeroupper
; KNL_64-NEXT: retq
;
; KNL_32-LABEL: test21:
; KNL_32: # %bb.0:
-; KNL_32-NEXT: vpsllq $32, %xmm1, %xmm1
-; KNL_32-NEXT: vpsraq $32, %zmm1, %zmm1
; KNL_32-NEXT: vpsllq $63, %xmm2, %xmm2
; KNL_32-NEXT: vptestmq %zmm2, %zmm2, %k0
+; KNL_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
+; KNL_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
+; KNL_32-NEXT: vpmovsxdq %ymm1, %zmm1
; KNL_32-NEXT: kshiftlw $14, %k0, %k0
; KNL_32-NEXT: kshiftrw $14, %k0, %k1
-; KNL_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; KNL_32-NEXT: vpscatterqd %ymm0, (,%zmm1) {%k1}
; KNL_32-NEXT: vzeroupper
; KNL_32-NEXT: retl
;
; SKX-LABEL: test21:
; SKX: # %bb.0:
-; SKX-NEXT: # kill: def %xmm1 killed %xmm1 def %ymm1
; SKX-NEXT: vpsllq $63, %xmm2, %xmm2
; SKX-NEXT: vptestmq %xmm2, %xmm2, %k1
; SKX-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; SKX-NEXT: vpscatterqd %xmm0, (,%ymm1) {%k1}
-; SKX-NEXT: vzeroupper
+; SKX-NEXT: vpscatterqd %xmm0, (,%xmm1) {%k1}
; SKX-NEXT: retq
;
; SKX_32-LABEL: test21:
; SKX_32: # %bb.0:
-; SKX_32-NEXT: vpsllq $32, %xmm1, %xmm1
-; SKX_32-NEXT: vpsraq $32, %xmm1, %xmm1
; SKX_32-NEXT: vpsllq $63, %xmm2, %xmm2
; SKX_32-NEXT: vptestmq %xmm2, %xmm2, %k1
; SKX_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; SKX_32-NEXT: vpscatterqd %xmm0, (,%ymm1) {%k1}
-; SKX_32-NEXT: vzeroupper
+; SKX_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
+; SKX_32-NEXT: vpscatterdd %xmm0, (,%xmm1) {%k1}
; SKX_32-NEXT: retl
call void @llvm.masked.scatter.v2i32.v2p0i32(<2 x i32> %a1, <2 x i32*> %ptr, i32 4, <2 x i1> %mask)
ret void
@@ -1594,9 +1588,9 @@ define void @test28(<2 x i32>%a1, <2 x i32*> %ptr) {
;
; KNL_32-LABEL: test28:
; KNL_32: # %bb.0:
-; KNL_32-NEXT: vpsllq $32, %xmm1, %xmm1
-; KNL_32-NEXT: vpsraq $32, %zmm1, %zmm1
; KNL_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
+; KNL_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
+; KNL_32-NEXT: vpmovsxdq %ymm1, %zmm1
; KNL_32-NEXT: movb $3, %al
; KNL_32-NEXT: kmovw %eax, %k1
; KNL_32-NEXT: vpscatterqd %ymm0, (,%zmm1) {%k1}
@@ -1605,23 +1599,18 @@ define void @test28(<2 x i32>%a1, <2 x i32*> %ptr) {
;
; SKX-LABEL: test28:
; SKX: # %bb.0:
-; SKX-NEXT: # kill: def %xmm1 killed %xmm1 def %ymm1
-; SKX-NEXT: movb $3, %al
-; SKX-NEXT: kmovw %eax, %k1
; SKX-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; SKX-NEXT: vpscatterqd %xmm0, (,%ymm1) {%k1}
-; SKX-NEXT: vzeroupper
+; SKX-NEXT: kxnorw %k0, %k0, %k1
+; SKX-NEXT: vpscatterqd %xmm0, (,%xmm1) {%k1}
; SKX-NEXT: retq
;
; SKX_32-LABEL: test28:
; SKX_32: # %bb.0:
-; SKX_32-NEXT: vpsllq $32, %xmm1, %xmm1
-; SKX_32-NEXT: vpsraq $32, %xmm1, %xmm1
; SKX_32-NEXT: movb $3, %al
; SKX_32-NEXT: kmovw %eax, %k1
; SKX_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; SKX_32-NEXT: vpscatterqd %xmm0, (,%ymm1) {%k1}
-; SKX_32-NEXT: vzeroupper
+; SKX_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
+; SKX_32-NEXT: vpscatterdd %xmm0, (,%xmm1) {%k1}
; SKX_32-NEXT: retl
call void @llvm.masked.scatter.v2i32.v2p0i32(<2 x i32> %a1, <2 x i32*> %ptr, i32 4, <2 x i1> <i1 true, i1 true>)
ret void
OpenPOWER on IntegriCloud