diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 74 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86InstrAVX512.td | 14 | ||||
-rw-r--r-- | llvm/test/CodeGen/X86/masked_gather_scatter.ll | 39 |
3 files changed, 72 insertions, 55 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 16394f0edc6..ff7ea5d38f5 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1390,6 +1390,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::ROTR, VT, Custom); } + // Custom legalize 2x32 to get a little better code. + setOperationAction(ISD::MSCATTER, MVT::v2f32, Custom); + setOperationAction(ISD::MSCATTER, MVT::v2i32, Custom); + for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) setOperationAction(ISD::MSCATTER, VT, Custom); @@ -24322,33 +24326,55 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, SDValue Mask = N->getMask(); SDValue Chain = N->getChain(); SDValue BasePtr = N->getBasePtr(); - MVT MemVT = N->getMemoryVT().getSimpleVT(); + + if (VT == MVT::v2f32) { + assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); + // If the index is v2i64 and we have VLX we can use xmm for data and index. + if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) { + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, + DAG.getUNDEF(MVT::v2f32)); + SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( + VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); + DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); + return SDValue(NewScatter.getNode(), 1); + } + return SDValue(); + } + + if (VT == MVT::v2i32) { + assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src, + DAG.getUNDEF(MVT::v2i32)); + // If the index is v2i64 and we have VLX we can use xmm for data and index. + if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) { + SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( + VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); + DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); + return SDValue(NewScatter.getNode(), 1); + } + // Custom widen all the operands to avoid promotion. + EVT NewIndexVT = EVT::getVectorVT( + *DAG.getContext(), Index.getValueType().getVectorElementType(), 4); + Index = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewIndexVT, Index, + DAG.getUNDEF(Index.getValueType())); + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, + DAG.getConstant(0, dl, MVT::v2i1)); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(), dl, + Ops, N->getMemOperand()); + } + MVT IndexVT = Index.getSimpleValueType(); MVT MaskVT = Mask.getSimpleValueType(); - if (MemVT.getScalarSizeInBits() < VT.getScalarSizeInBits()) { - // The v2i32 value was promoted to v2i64. - // Now we "redo" the type legalizer's work and widen the original - // v2i32 value to v4i32. The original v2i32 is retrieved from v2i64 - // with a shuffle. - assert((MemVT == MVT::v2i32 && VT == MVT::v2i64) && - "Unexpected memory type"); - int ShuffleMask[] = {0, 2, -1, -1}; - Src = DAG.getVectorShuffle(MVT::v4i32, dl, DAG.getBitcast(MVT::v4i32, Src), - DAG.getUNDEF(MVT::v4i32), ShuffleMask); - // Now we have 4 elements instead of 2. - // Expand the index. - MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), 4); - Index = ExtendToType(Index, NewIndexVT, DAG); - - // Expand the mask with zeroes - // Mask may be <2 x i64> or <2 x i1> at this moment - assert((MaskVT == MVT::v2i1 || MaskVT == MVT::v2i64) && - "Unexpected mask type"); - MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), 4); - Mask = ExtendToType(Mask, ExtMaskVT, DAG, true); - VT = MVT::v4i32; - } + // If the index is v2i32, we're being called by type legalization and we + // should just let the default handling take care of it. + if (IndexVT == MVT::v2i32) + return SDValue(); unsigned NumElts = VT.getVectorNumElements(); if (!Subtarget.hasVLX() && !VT.is512BitVector() && diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index 058defdd1e2..e5fab871ac7 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -8610,16 +8610,17 @@ defm VPGATHER : avx512_gather_q_pd<0x90, 0x91, avx512vl_i64_info, "vpgather", "Q avx512_gather_d_ps<0x90, 0x91, avx512vl_i32_info, "vpgather", "D">; multiclass avx512_scatter<bits<8> opc, string OpcodeStr, X86VectorVTInfo _, - X86MemOperand memop, PatFrag ScatterNode> { + X86MemOperand memop, PatFrag ScatterNode, + RegisterClass MaskRC = _.KRCWM> { let mayStore = 1, Constraints = "$mask = $mask_wb", ExeDomain = _.ExeDomain in - def mr : AVX5128I<opc, MRMDestMem, (outs _.KRCWM:$mask_wb), - (ins memop:$dst, _.KRCWM:$mask, _.RC:$src), + def mr : AVX5128I<opc, MRMDestMem, (outs MaskRC:$mask_wb), + (ins memop:$dst, MaskRC:$mask, _.RC:$src), !strconcat(OpcodeStr#_.Suffix, "\t{$src, ${dst} {${mask}}|${dst} {${mask}}, $src}"), - [(set _.KRCWM:$mask_wb, (ScatterNode (_.VT _.RC:$src), - _.KRCWM:$mask, vectoraddr:$dst))]>, + [(set MaskRC:$mask_wb, (ScatterNode (_.VT _.RC:$src), + MaskRC:$mask, vectoraddr:$dst))]>, EVEX, EVEX_K, EVEX_CD8<_.EltSize, CD8VT1>, Sched<[WriteStore]>; } @@ -8656,7 +8657,8 @@ let Predicates = [HasVLX] in { defm NAME##D##SUFF##Z128: avx512_scatter<dopc, OpcodeStr##"d", _.info128, vx128xmem, mscatterv4i32>, EVEX_V128; defm NAME##Q##SUFF##Z128: avx512_scatter<qopc, OpcodeStr##"q", _.info128, - vx64xmem, mscatterv2i64>, EVEX_V128; + vx64xmem, mscatterv2i64, VK2WM>, + EVEX_V128; } } diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll index 2558efebbea..723b6e95741 100644 --- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -1094,11 +1094,9 @@ define void @test20(<2 x float>%a1, <2 x float*> %ptr, <2 x i1> %mask) { ; ; SKX-LABEL: test20: ; SKX: # %bb.0: -; SKX-NEXT: # kill: def %xmm1 killed %xmm1 def %ymm1 ; SKX-NEXT: vpsllq $63, %xmm2, %xmm2 ; SKX-NEXT: vptestmq %xmm2, %xmm2, %k1 -; SKX-NEXT: vscatterqps %xmm0, (,%ymm1) {%k1} -; SKX-NEXT: vzeroupper +; SKX-NEXT: vscatterqps %xmm0, (,%xmm1) {%k1} ; SKX-NEXT: retq ; ; SKX_32-LABEL: test20: @@ -1119,45 +1117,41 @@ define void @test21(<2 x i32>%a1, <2 x i32*> %ptr, <2 x i1>%mask) { ; KNL_64-NEXT: # kill: def %xmm1 killed %xmm1 def %zmm1 ; KNL_64-NEXT: vpsllq $63, %xmm2, %xmm2 ; KNL_64-NEXT: vptestmq %zmm2, %zmm2, %k0 +; KNL_64-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] ; KNL_64-NEXT: kshiftlw $14, %k0, %k0 ; KNL_64-NEXT: kshiftrw $14, %k0, %k1 -; KNL_64-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] ; KNL_64-NEXT: vpscatterqd %ymm0, (,%zmm1) {%k1} ; KNL_64-NEXT: vzeroupper ; KNL_64-NEXT: retq ; ; KNL_32-LABEL: test21: ; KNL_32: # %bb.0: -; KNL_32-NEXT: vpsllq $32, %xmm1, %xmm1 -; KNL_32-NEXT: vpsraq $32, %zmm1, %zmm1 ; KNL_32-NEXT: vpsllq $63, %xmm2, %xmm2 ; KNL_32-NEXT: vptestmq %zmm2, %zmm2, %k0 +; KNL_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; KNL_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; KNL_32-NEXT: vpmovsxdq %ymm1, %zmm1 ; KNL_32-NEXT: kshiftlw $14, %k0, %k0 ; KNL_32-NEXT: kshiftrw $14, %k0, %k1 -; KNL_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] ; KNL_32-NEXT: vpscatterqd %ymm0, (,%zmm1) {%k1} ; KNL_32-NEXT: vzeroupper ; KNL_32-NEXT: retl ; ; SKX-LABEL: test21: ; SKX: # %bb.0: -; SKX-NEXT: # kill: def %xmm1 killed %xmm1 def %ymm1 ; SKX-NEXT: vpsllq $63, %xmm2, %xmm2 ; SKX-NEXT: vptestmq %xmm2, %xmm2, %k1 ; SKX-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; SKX-NEXT: vpscatterqd %xmm0, (,%ymm1) {%k1} -; SKX-NEXT: vzeroupper +; SKX-NEXT: vpscatterqd %xmm0, (,%xmm1) {%k1} ; SKX-NEXT: retq ; ; SKX_32-LABEL: test21: ; SKX_32: # %bb.0: -; SKX_32-NEXT: vpsllq $32, %xmm1, %xmm1 -; SKX_32-NEXT: vpsraq $32, %xmm1, %xmm1 ; SKX_32-NEXT: vpsllq $63, %xmm2, %xmm2 ; SKX_32-NEXT: vptestmq %xmm2, %xmm2, %k1 ; SKX_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; SKX_32-NEXT: vpscatterqd %xmm0, (,%ymm1) {%k1} -; SKX_32-NEXT: vzeroupper +; SKX_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; SKX_32-NEXT: vpscatterdd %xmm0, (,%xmm1) {%k1} ; SKX_32-NEXT: retl call void @llvm.masked.scatter.v2i32.v2p0i32(<2 x i32> %a1, <2 x i32*> %ptr, i32 4, <2 x i1> %mask) ret void @@ -1594,9 +1588,9 @@ define void @test28(<2 x i32>%a1, <2 x i32*> %ptr) { ; ; KNL_32-LABEL: test28: ; KNL_32: # %bb.0: -; KNL_32-NEXT: vpsllq $32, %xmm1, %xmm1 -; KNL_32-NEXT: vpsraq $32, %zmm1, %zmm1 ; KNL_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; KNL_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; KNL_32-NEXT: vpmovsxdq %ymm1, %zmm1 ; KNL_32-NEXT: movb $3, %al ; KNL_32-NEXT: kmovw %eax, %k1 ; KNL_32-NEXT: vpscatterqd %ymm0, (,%zmm1) {%k1} @@ -1605,23 +1599,18 @@ define void @test28(<2 x i32>%a1, <2 x i32*> %ptr) { ; ; SKX-LABEL: test28: ; SKX: # %bb.0: -; SKX-NEXT: # kill: def %xmm1 killed %xmm1 def %ymm1 -; SKX-NEXT: movb $3, %al -; SKX-NEXT: kmovw %eax, %k1 ; SKX-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; SKX-NEXT: vpscatterqd %xmm0, (,%ymm1) {%k1} -; SKX-NEXT: vzeroupper +; SKX-NEXT: kxnorw %k0, %k0, %k1 +; SKX-NEXT: vpscatterqd %xmm0, (,%xmm1) {%k1} ; SKX-NEXT: retq ; ; SKX_32-LABEL: test28: ; SKX_32: # %bb.0: -; SKX_32-NEXT: vpsllq $32, %xmm1, %xmm1 -; SKX_32-NEXT: vpsraq $32, %xmm1, %xmm1 ; SKX_32-NEXT: movb $3, %al ; SKX_32-NEXT: kmovw %eax, %k1 ; SKX_32-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; SKX_32-NEXT: vpscatterqd %xmm0, (,%ymm1) {%k1} -; SKX_32-NEXT: vzeroupper +; SKX_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; SKX_32-NEXT: vpscatterdd %xmm0, (,%xmm1) {%k1} ; SKX_32-NEXT: retl call void @llvm.masked.scatter.v2i32.v2p0i32(<2 x i32> %a1, <2 x i32*> %ptr, i32 4, <2 x i1> <i1 true, i1 true>) ret void |