diff options
-rw-r--r-- | llvm/include/llvm/CodeGen/SelectionDAG.h | 7 | ||||
-rw-r--r-- | llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 9 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 12 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 23 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86InstrAVX512.td | 21 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86InstrFragmentsSIMD.td | 32 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86InstrSSE.td | 6 | ||||
-rw-r--r-- | llvm/test/CodeGen/X86/avx512vl-intrinsics.ll | 31 |
9 files changed, 97 insertions, 46 deletions
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index d934ddb8366..9f45cc82089 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -965,11 +965,12 @@ public: SDValue getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Mask, SDValue Src0, EVT MemVT, - MachineMemOperand *MMO, ISD::LoadExtType); + MachineMemOperand *MMO, ISD::LoadExtType, + bool IsExpanding = false); SDValue getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, - MachineMemOperand *MMO, bool IsTrunc, - bool isCompressing = false); + MachineMemOperand *MMO, bool IsTruncating = false, + bool IsCompressing = false); SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef<SDValue> Ops, MachineMemOperand *MMO); SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index fb5a01fade8..1d14d1228ce 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -444,6 +444,7 @@ protected: uint16_t : NumLSBaseSDNodeBits; uint16_t ExtTy : 2; // enum ISD::LoadExtType + uint16_t IsExpanding : 1; }; class StoreSDNodeBitfields { @@ -473,7 +474,7 @@ protected: static_assert(sizeof(ConstantSDNodeBitfields) <= 2, "field too wide"); static_assert(sizeof(MemSDNodeBitfields) <= 2, "field too wide"); static_assert(sizeof(LSBaseSDNodeBitfields) <= 2, "field too wide"); - static_assert(sizeof(LoadSDNodeBitfields) <= 2, "field too wide"); + static_assert(sizeof(LoadSDNodeBitfields) <= 4, "field too wide"); static_assert(sizeof(StoreSDNodeBitfields) <= 2, "field too wide"); private: @@ -1939,9 +1940,11 @@ class MaskedLoadSDNode : public MaskedLoadStoreSDNode { public: friend class SelectionDAG; MaskedLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, - ISD::LoadExtType ETy, EVT MemVT, MachineMemOperand *MMO) + ISD::LoadExtType ETy, bool IsExpanding, EVT MemVT, + MachineMemOperand *MMO) : MaskedLoadStoreSDNode(ISD::MLOAD, Order, dl, VTs, MemVT, MMO) { LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = IsExpanding; } ISD::LoadExtType getExtensionType() const { @@ -1952,6 +1955,8 @@ public: static bool classof(const SDNode *N) { return N->getOpcode() == ISD::MLOAD; } + + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } }; /// This class is used to represent an MSTORE node diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index a915fe161a3..3671422be2c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5347,7 +5347,7 @@ SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Mask, SDValue Src0, EVT MemVT, MachineMemOperand *MMO, - ISD::LoadExtType ExtTy) { + ISD::LoadExtType ExtTy, bool isExpanding) { SDVTList VTs = getVTList(VT, MVT::Other); SDValue Ops[] = { Chain, Ptr, Mask, Src0 }; @@ -5355,7 +5355,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, AddNodeIDNode(ID, ISD::MLOAD, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData<MaskedLoadSDNode>( - dl.getIROrder(), VTs, ExtTy, MemVT, MMO)); + dl.getIROrder(), VTs, ExtTy, isExpanding, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { @@ -5363,7 +5363,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, return SDValue(E, 0); } auto *N = newSDNode<MaskedLoadSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, - ExtTy, MemVT, MMO); + ExtTy, isExpanding, MemVT, MMO); createOperands(N, Ops); CSEMap.InsertNode(N, IP); @@ -5374,7 +5374,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, MachineMemOperand *MMO, - bool isTrunc, bool isCompress) { + bool IsTruncating, bool IsCompressing) { assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); EVT VT = Val.getValueType(); @@ -5384,7 +5384,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData<MaskedStoreSDNode>( - dl.getIROrder(), VTs, isTrunc, isCompress, MemVT, MMO)); + dl.getIROrder(), VTs, IsTruncating, IsCompressing, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { @@ -5392,7 +5392,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, return SDValue(E, 0); } auto *N = newSDNode<MaskedStoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, - isTrunc, isCompress, MemVT, MMO); + IsTruncating, IsCompressing, MemVT, MMO); createOperands(N, Ops); CSEMap.InsertNode(N, IP); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 4ae94e0befa..2aaab4b0d87 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3821,7 +3821,7 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) { Alignment, AAInfo, Ranges); SDValue Load = DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Mask, Src0, VT, MMO, - ISD::NON_EXTLOAD); + ISD::NON_EXTLOAD, false); if (AddToChain) { SDValue OutChain = Load.getValue(1); DAG.setRoot(OutChain); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 6fbd9dcfd32..32c4ffe585a 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -18854,7 +18854,8 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); return DAG.getMaskedStore(Chain, dl, DataToCompress, Addr, VMask, VT, - MemIntr->getMemOperand(), false, true); + MemIntr->getMemOperand(), + false /* truncating */, true /* compressing */); } case TRUNCATE_TO_MEM_VI8: case TRUNCATE_TO_MEM_VI16: @@ -18877,7 +18878,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); return DAG.getMaskedStore(Chain, dl, DataToTruncate, Addr, VMask, VT, - MemIntr->getMemOperand(), true); + MemIntr->getMemOperand(), true /* truncating */); } case EXPAND_FROM_MEM: { SDValue Mask = Op.getOperand(4); @@ -18889,16 +18890,16 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op); assert(MemIntr && "Expected MemIntrinsicSDNode!"); - SDValue DataToExpand = DAG.getLoad(VT, dl, Chain, Addr, - MemIntr->getMemOperand()); + if (isAllOnesConstant(Mask)) // Return a regular (unmasked) vector load. + return DAG.getLoad(VT, dl, Chain, Addr, MemIntr->getMemOperand()); + if (X86::isZeroNode(Mask)) + return DAG.getUNDEF(VT); - if (isAllOnesConstant(Mask)) // return just a load - return DataToExpand; - - SDValue Results[] = { - getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, DataToExpand), - Mask, PassThru, Subtarget, DAG), Chain}; - return DAG.getMergeValues(Results, dl); + MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); + SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); + return DAG.getMaskedLoad(VT, dl, Chain, Addr, VMask, PassThru, VT, + MemIntr->getMemOperand(), ISD::NON_EXTLOAD, + true /* expanding */); } } } diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index d0ef44c92ba..ea1c5b09bc8 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -7552,13 +7552,28 @@ multiclass expand_by_vec_width<bits<8> opc, X86VectorVTInfo _, AVX5128IBase, EVEX_CD8<_.EltSize, CD8VT1>; } +multiclass expand_by_vec_width_lowering<X86VectorVTInfo _ > { + + def : Pat<(_.VT (X86mExpandingLoad addr:$src, _.KRCWM:$mask, undef)), + (!cast<Instruction>(NAME#_.ZSuffix##rmkz) + _.KRCWM:$mask, addr:$src)>; + + def : Pat<(_.VT (X86mExpandingLoad addr:$src, _.KRCWM:$mask, + (_.VT _.RC:$src0))), + (!cast<Instruction>(NAME#_.ZSuffix##rmk) + _.RC:$src0, _.KRCWM:$mask, addr:$src)>; +} + multiclass expand_by_elt_width<bits<8> opc, string OpcodeStr, AVX512VLVectorVTInfo VTInfo> { - defm Z : expand_by_vec_width<opc, VTInfo.info512, OpcodeStr>, EVEX_V512; + defm Z : expand_by_vec_width<opc, VTInfo.info512, OpcodeStr>, + expand_by_vec_width_lowering<VTInfo.info512>, EVEX_V512; let Predicates = [HasVLX] in { - defm Z256 : expand_by_vec_width<opc, VTInfo.info256, OpcodeStr>, EVEX_V256; - defm Z128 : expand_by_vec_width<opc, VTInfo.info128, OpcodeStr>, EVEX_V128; + defm Z256 : expand_by_vec_width<opc, VTInfo.info256, OpcodeStr>, + expand_by_vec_width_lowering<VTInfo.info256>, EVEX_V256; + defm Z128 : expand_by_vec_width<opc, VTInfo.info128, OpcodeStr>, + expand_by_vec_width_lowering<VTInfo.info128>, EVEX_V128; } } diff --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td index 6eb5c9a45f7..f1b9475600f 100644 --- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td +++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td @@ -919,30 +919,36 @@ def vinsert256_insert : PatFrag<(ops node:$bigvec, node:$smallvec, return X86::isVINSERT256Index(N); }], INSERT_get_vinsert256_imm>; -def masked_load_aligned128 : PatFrag<(ops node:$src1, node:$src2, node:$src3), +def X86mload : PatFrag<(ops node:$src1, node:$src2, node:$src3), (masked_load node:$src1, node:$src2, node:$src3), [{ - if (auto *Load = dyn_cast<MaskedLoadSDNode>(N)) - return Load->getAlignment() >= 16; - return false; + return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() && + cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; + +def masked_load_aligned128 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (X86mload node:$src1, node:$src2, node:$src3), [{ + return cast<MaskedLoadSDNode>(N)->getAlignment() >= 16; }]>; def masked_load_aligned256 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_load node:$src1, node:$src2, node:$src3), [{ - if (auto *Load = dyn_cast<MaskedLoadSDNode>(N)) - return Load->getAlignment() >= 32; - return false; + (X86mload node:$src1, node:$src2, node:$src3), [{ + return cast<MaskedLoadSDNode>(N)->getAlignment() >= 32; }]>; def masked_load_aligned512 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_load node:$src1, node:$src2, node:$src3), [{ - if (auto *Load = dyn_cast<MaskedLoadSDNode>(N)) - return Load->getAlignment() >= 64; - return false; + (X86mload node:$src1, node:$src2, node:$src3), [{ + return cast<MaskedLoadSDNode>(N)->getAlignment() >= 64; }]>; def masked_load_unaligned : PatFrag<(ops node:$src1, node:$src2, node:$src3), (masked_load node:$src1, node:$src2, node:$src3), [{ - return isa<MaskedLoadSDNode>(N); + return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() && + cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; + +def X86mExpandingLoad : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_load node:$src1, node:$src2, node:$src3), [{ + return cast<MaskedLoadSDNode>(N)->isExpandingLoad(); }]>; // Masked store fragments. diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td index 8db144af9b0..b298a2b8812 100644 --- a/llvm/lib/Target/X86/X86InstrSSE.td +++ b/llvm/lib/Target/X86/X86InstrSSE.td @@ -8622,12 +8622,12 @@ multiclass maskmov_lowering<string InstrStr, RegisterClass RC, ValueType VT, def: Pat<(X86mstore addr:$ptr, (MaskVT RC:$mask), (VT RC:$src)), (!cast<Instruction>(InstrStr#"mr") addr:$ptr, RC:$mask, RC:$src)>; // masked load - def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), undef)), + def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), undef)), (!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr)>; - def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), + def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), (VT (bitconvert (ZeroVT immAllZerosV))))), (!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr)>; - def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))), + def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))), (!cast<Instruction>(BlendStr#"rr") RC:$src0, (!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr), diff --git a/llvm/test/CodeGen/X86/avx512vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512vl-intrinsics.ll index 0ee213d8ba2..5e4e8fd529b 100644 --- a/llvm/test/CodeGen/X86/avx512vl-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512vl-intrinsics.ll @@ -1042,6 +1042,29 @@ define <4 x i32> @expand10(<4 x i32> %data, i8 %mask) { declare <4 x i32> @llvm.x86.avx512.mask.expand.d.128(<4 x i32> %data, <4 x i32> %src0, i8 %mask) +define <8 x i64> @expand11(i8* %addr) { +; CHECK-LABEL: expand11: +; CHECK: ## BB#0: +; CHECK-NEXT: vmovups (%rdi), %zmm0 ## encoding: [0x62,0xf1,0x7c,0x48,0x10,0x07] +; CHECK-NEXT: retq ## encoding: [0xc3] + %res = call <8 x i64> @llvm.x86.avx512.mask.expand.load.q.512(i8* %addr, <8 x i64> undef, i8 -1) + ret <8 x i64> %res +} + +define <8 x i64> @expand12(i8* %addr, i8 %mask) { +; CHECK-LABEL: expand12: +; CHECK: ## BB#0: +; CHECK-NEXT: kmovw %esi, %k1 ## encoding: [0xc5,0xf8,0x92,0xce] +; CHECK-NEXT: vpexpandq (%rdi), %zmm0 {%k1} {z} ## encoding: [0x62,0xf2,0xfd,0xc9,0x89,0x07] +; CHECK-NEXT: retq ## encoding: [0xc3] + %laddr = bitcast i8* %addr to <8 x i64>* + %data = load <8 x i64>, <8 x i64>* %laddr, align 1 + %res = call <8 x i64> @llvm.x86.avx512.mask.expand.q.512(<8 x i64> %data, <8 x i64>zeroinitializer, i8 %mask) + ret <8 x i64> %res +} + +declare <8 x i64> @llvm.x86.avx512.mask.expand.q.512(<8 x i64> , <8 x i64>, i8) + define < 2 x i64> @test_mask_mul_epi32_rr_128(< 4 x i32> %a, < 4 x i32> %b) { ; CHECK-LABEL: test_mask_mul_epi32_rr_128: ; CHECK: ## BB#0: @@ -5250,9 +5273,9 @@ define <8 x i32>@test_int_x86_avx512_mask_psrav8_si_const() { ; CHECK: ## BB#0: ; CHECK-NEXT: vmovdqa32 {{.*#+}} ymm0 = [2,9,4294967284,23,4294967270,37,4294967256,51] ; CHECK-NEXT: ## encoding: [0x62,0xf1,0x7d,0x28,0x6f,0x05,A,A,A,A] -; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI309_0-4, kind: reloc_riprel_4byte +; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_0-4, kind: reloc_riprel_4byte ; CHECK-NEXT: vpsravd {{.*}}(%rip), %ymm0, %ymm0 ## encoding: [0x62,0xf2,0x7d,0x28,0x46,0x05,A,A,A,A] -; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI309_1-4, kind: reloc_riprel_4byte +; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_1-4, kind: reloc_riprel_4byte ; CHECK-NEXT: retq ## encoding: [0xc3] %res = call <8 x i32> @llvm.x86.avx512.mask.psrav8.si(<8 x i32> <i32 2, i32 9, i32 -12, i32 23, i32 -26, i32 37, i32 -40, i32 51>, <8 x i32> <i32 1, i32 18, i32 35, i32 52, i32 69, i32 15, i32 32, i32 49>, <8 x i32> zeroinitializer, i8 -1) ret <8 x i32> %res @@ -5283,9 +5306,9 @@ define <2 x i64>@test_int_x86_avx512_mask_psrav_q_128_const(i8 %x3) { ; CHECK: ## BB#0: ; CHECK-NEXT: vmovdqa64 {{.*#+}} xmm0 = [2,18446744073709551607] ; CHECK-NEXT: ## encoding: [0x62,0xf1,0xfd,0x08,0x6f,0x05,A,A,A,A] -; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_0-4, kind: reloc_riprel_4byte +; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI313_0-4, kind: reloc_riprel_4byte ; CHECK-NEXT: vpsravq {{.*}}(%rip), %xmm0, %xmm0 ## encoding: [0x62,0xf2,0xfd,0x08,0x46,0x05,A,A,A,A] -; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_1-4, kind: reloc_riprel_4byte +; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI313_1-4, kind: reloc_riprel_4byte ; CHECK-NEXT: retq ## encoding: [0xc3] %res = call <2 x i64> @llvm.x86.avx512.mask.psrav.q.128(<2 x i64> <i64 2, i64 -9>, <2 x i64> <i64 1, i64 90>, <2 x i64> zeroinitializer, i8 -1) ret <2 x i64> %res |