diff options
Diffstat (limited to 'llvm/lib/Target')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelDAGToDAG.cpp | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 125 | ||||
| -rw-r--r-- | llvm/lib/Target/X86/X86InstrAVX512.td | 100 | ||||
| -rw-r--r-- | llvm/lib/Target/X86/X86InstrVecCompiler.td | 25 |
4 files changed, 123 insertions, 129 deletions
diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 660c1eff3c4..775cc79c653 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -460,7 +460,7 @@ static bool isLegalMaskCompare(SDNode *N, const X86Subtarget *Subtarget) { // this happens we will use 512-bit operations and the mask will not be // zero extended. EVT OpVT = N->getOperand(0).getValueType(); - if (OpVT == MVT::v8i32 || OpVT == MVT::v8f32) + if (OpVT.is256BitVector() || OpVT.is128BitVector()) return Subtarget->hasVLX(); return true; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index dab97501b85..56bdb3583a9 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1144,6 +1144,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addRegisterClass(MVT::v8f64, &X86::VR512RegClass); addRegisterClass(MVT::v1i1, &X86::VK1RegClass); + addRegisterClass(MVT::v2i1, &X86::VK2RegClass); + addRegisterClass(MVT::v4i1, &X86::VK4RegClass); addRegisterClass(MVT::v8i1, &X86::VK8RegClass); addRegisterClass(MVT::v16i1, &X86::VK16RegClass); @@ -1171,15 +1173,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FP_TO_UINT, MVT::v2i1, Custom); } - // Extends of v16i1/v8i1 to 128-bit vectors. - setOperationAction(ISD::SIGN_EXTEND, MVT::v16i8, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v16i8, Custom); - setOperationAction(ISD::ANY_EXTEND, MVT::v16i8, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v8i16, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v8i16, Custom); - setOperationAction(ISD::ANY_EXTEND, MVT::v8i16, Custom); + // Extends of v16i1/v8i1/v4i1/v2i1 to 128-bit vectors. + for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { + setOperationAction(ISD::SIGN_EXTEND, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND, VT, Custom); + setOperationAction(ISD::ANY_EXTEND, VT, Custom); + } - for (auto VT : { MVT::v8i1, MVT::v16i1 }) { + for (auto VT : { MVT::v2i1, MVT::v4i1, MVT::v8i1, MVT::v16i1 }) { setOperationAction(ISD::ADD, VT, Custom); setOperationAction(ISD::SUB, VT, Custom); setOperationAction(ISD::MUL, VT, Custom); @@ -1195,9 +1196,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } setOperationAction(ISD::CONCAT_VECTORS, MVT::v16i1, Custom); + setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom); + setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v8i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v16i1, Custom); - for (auto VT : { MVT::v1i1, MVT::v8i1 }) + for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1 }) setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); for (MVT VT : MVT::fp_vector_valuetypes()) @@ -1528,41 +1532,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) { - addRegisterClass(MVT::v4i1, &X86::VK4RegClass); - addRegisterClass(MVT::v2i1, &X86::VK2RegClass); - - for (auto VT : { MVT::v2i1, MVT::v4i1 }) { - setOperationAction(ISD::ADD, VT, Custom); - setOperationAction(ISD::SUB, VT, Custom); - setOperationAction(ISD::MUL, VT, Custom); - setOperationAction(ISD::VSELECT, VT, Expand); - - setOperationAction(ISD::TRUNCATE, VT, Custom); - setOperationAction(ISD::SETCC, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::SELECT, VT, Custom); - setOperationAction(ISD::BUILD_VECTOR, VT, Custom); - setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); - } - - // TODO: v8i1 concat should be legal without VLX to support concats of - // v1i1, but we won't legalize it correctly currently without introducing - // a v4i1 concat in the middle. - setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom); - setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom); - for (auto VT : { MVT::v2i1, MVT::v4i1 }) - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); - - // Extends from v2i1/v4i1 masks to 128-bit vectors. - setOperationAction(ISD::ZERO_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v2i64, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v2i64, Custom); - setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::ANY_EXTEND, MVT::v2i64, Custom); - setTruncStoreAction(MVT::v4i64, MVT::v4i8, Legal); setTruncStoreAction(MVT::v4i64, MVT::v4i16, Legal); setTruncStoreAction(MVT::v4i64, MVT::v4i32, Legal); @@ -4945,8 +4914,6 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget, } else if (VT.getVectorElementType() == MVT::i1) { assert((Subtarget.hasBWI() || VT.getVectorNumElements() <= 16) && "Unexpected vector type"); - assert((Subtarget.hasVLX() || VT.getVectorNumElements() >= 8) && - "Unexpected vector type"); Vec = DAG.getConstant(0, dl, VT); } else { unsigned Num32BitElts = VT.getSizeInBits() / 32; @@ -17779,6 +17746,19 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, assert(EltVT == MVT::f32 || EltVT == MVT::f64); #endif + // Custom widen MVT::v2f32 to prevent the default widening + // from getting a result type of v4i32, extracting it to v2i32 and then + // trying to sign extend that to v2i1. + if (VT == MVT::v2i1 && Op1.getValueType() == MVT::v2f32) { + Op0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Op0, + DAG.getUNDEF(MVT::v2f32)); + Op1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Op1, + DAG.getUNDEF(MVT::v2f32)); + SDValue NewOp = DAG.getNode(ISD::SETCC, dl, MVT::v4i1, Op0, Op1, CC); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i1, NewOp, + DAG.getIntPtrConstant(0, dl)); + } + unsigned Opc; if (Subtarget.hasAVX512() && VT.getVectorElementType() == MVT::i1) { assert(VT.getVectorNumElements() <= 16); @@ -24417,8 +24397,8 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, // Mask // At this point we have promoted mask operand - assert(MaskVT.getScalarSizeInBits() >= 32 && "unexpected mask type"); - MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), NumElts); + assert(MaskVT.getScalarType() == MVT::i1 && "unexpected mask type"); + MVT ExtMaskVT = MVT::getVectorVT(MVT::i1, NumElts); // Use the original mask here, do not modify the mask twice Mask = ExtendToType(N->getMask(), ExtMaskVT, DAG, true); @@ -24427,12 +24407,9 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, Src = ExtendToType(Src, NewVT, DAG); } } - // If the mask is "wide" at this point - truncate it to i1 vector - MVT BitMaskVT = MVT::getVectorVT(MVT::i1, NumElts); - Mask = DAG.getNode(ISD::TRUNCATE, dl, BitMaskVT, Mask); // The mask is killed by scatter, add it to the values - SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other); + SDVTList VTs = DAG.getVTList(Mask.getValueType(), MVT::Other); SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index}; SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); @@ -24455,11 +24432,6 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, assert((!N->isExpandingLoad() || ScalarVT.getSizeInBits() >= 32) && "Expanding masked load is supported for 32 and 64-bit types only!"); - // 4x32, 4x64 and 2x64 vectors of non-expanding loads are legal regardless of - // VLX. These types for exp-loads are handled here. - if (!N->isExpandingLoad() && VT.getVectorNumElements() <= 4) - return Op; - assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && "Cannot lower masked load op."); @@ -24476,16 +24448,12 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, Src0 = ExtendToType(Src0, WideDataVT, DAG); // Mask element has to be i1. - MVT MaskEltTy = Mask.getSimpleValueType().getScalarType(); - assert((MaskEltTy == MVT::i1 || VT.getVectorNumElements() <= 4) && - "We handle 4x32, 4x64 and 2x64 vectors only in this case"); + assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 && + "Unexpected mask type"); - MVT WideMaskVT = MVT::getVectorVT(MaskEltTy, NumEltsInWideVec); + MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec); Mask = ExtendToType(Mask, WideMaskVT, DAG, true); - if (MaskEltTy != MVT::i1) - Mask = DAG.getNode(ISD::TRUNCATE, dl, - MVT::getVectorVT(MVT::i1, NumEltsInWideVec), Mask); SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(), N->getBasePtr(), Mask, Src0, N->getMemoryVT(), N->getMemOperand(), @@ -24514,10 +24482,6 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget, assert((!N->isCompressingStore() || ScalarVT.getSizeInBits() >= 32) && "Expanding masked load is supported for 32 and 64-bit types only!"); - // 4x32 and 2x64 vectors of non-compressing stores are legal regardless to VLX. - if (!N->isCompressingStore() && VT.getVectorNumElements() <= 4) - return Op; - assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && "Cannot lower masked store op."); @@ -24532,17 +24496,13 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget, MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec); // Mask element has to be i1. - MVT MaskEltTy = Mask.getSimpleValueType().getScalarType(); - assert((MaskEltTy == MVT::i1 || VT.getVectorNumElements() <= 4) && - "We handle 4x32, 4x64 and 2x64 vectors only in this case"); + assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 && + "Unexpected mask type"); - MVT WideMaskVT = MVT::getVectorVT(MaskEltTy, NumEltsInWideVec); + MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec); DataToStore = ExtendToType(DataToStore, WideDataVT, DAG); Mask = ExtendToType(Mask, WideMaskVT, DAG, true); - if (MaskEltTy != MVT::i1) - Mask = DAG.getNode(ISD::TRUNCATE, dl, - MVT::getVectorVT(MVT::i1, NumEltsInWideVec), Mask); return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(), Mask, N->getMemoryVT(), N->getMemOperand(), N->isTruncatingStore(), N->isCompressingStore()); @@ -24592,12 +24552,9 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); // Mask - MVT MaskBitVT = MVT::getVectorVT(MVT::i1, NumElts); - // At this point we have promoted mask operand - assert(MaskVT.getScalarSizeInBits() >= 32 && "unexpected mask type"); - MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), NumElts); - Mask = ExtendToType(Mask, ExtMaskVT, DAG, true); - Mask = DAG.getNode(ISD::TRUNCATE, dl, MaskBitVT, Mask); + assert(MaskVT.getScalarType() == MVT::i1 && "unexpected mask type"); + MaskVT = MVT::getVectorVT(MVT::i1, NumElts); + Mask = ExtendToType(Mask, MaskVT, DAG, true); // The pass-through value MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts); @@ -24605,7 +24562,7 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( - DAG.getVTList(NewVT, MaskBitVT, MVT::Other), Ops, dl, N->getMemoryVT(), + DAG.getVTList(NewVT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), N->getMemOperand()); SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, NewGather.getValue(0), @@ -30447,7 +30404,7 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer // type, widen both sides to avoid a trip through memory. if ((VT == MVT::v4i1 || VT == MVT::v2i1) && SrcVT.isScalarInteger() && - Subtarget.hasVLX()) { + Subtarget.hasAVX512()) { SDLoc dl(N); N0 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i8, N0); N0 = DAG.getBitcast(MVT::v8i1, N0); @@ -30458,7 +30415,7 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer // type, widen both sides to avoid a trip through memory. if ((SrcVT == MVT::v4i1 || SrcVT == MVT::v2i1) && VT.isScalarInteger() && - Subtarget.hasVLX()) { + Subtarget.hasAVX512()) { SDLoc dl(N); unsigned NumConcats = 8 / SrcVT.getVectorNumElements(); SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(SrcVT)); diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index a65e033b572..caf5091dac6 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -2962,46 +2962,77 @@ multiclass avx512_mask_shiftop_w<bits<8> opc1, bits<8> opc2, string OpcodeStr, defm KSHIFTL : avx512_mask_shiftop_w<0x32, 0x33, "kshiftl", X86kshiftl, SSE_PSHUF>; defm KSHIFTR : avx512_mask_shiftop_w<0x30, 0x31, "kshiftr", X86kshiftr, SSE_PSHUF>; -multiclass axv512_icmp_packed_no_vlx_lowering<SDNode OpNode, string InstStr> { -def : Pat<(v8i1 (OpNode (v8i32 VR256X:$src1), (v8i32 VR256X:$src2))), - (COPY_TO_REGCLASS (!cast<Instruction>(InstStr##Zrr) - (v16i32 (INSERT_SUBREG (IMPLICIT_DEF), VR256X:$src1, sub_ymm)), - (v16i32 (INSERT_SUBREG (IMPLICIT_DEF), VR256X:$src2, sub_ymm))), VK8)>; - -def : Pat<(v8i1 (and VK8:$mask, - (OpNode (v8i32 VR256X:$src1), (v8i32 VR256X:$src2)))), +multiclass axv512_icmp_packed_no_vlx_lowering<SDNode OpNode, string InstStr, + X86VectorVTInfo Narrow, + X86VectorVTInfo Wide> { +def : Pat<(Narrow.KVT (OpNode (Narrow.VT Narrow.RC:$src1), + (Narrow.VT Narrow.RC:$src2))), + (COPY_TO_REGCLASS + (!cast<Instruction>(InstStr##Zrr) + (Wide.VT (INSERT_SUBREG (IMPLICIT_DEF), Narrow.RC:$src1, Narrow.SubRegIdx)), + (Wide.VT (INSERT_SUBREG (IMPLICIT_DEF), Narrow.RC:$src2, Narrow.SubRegIdx))), + Narrow.KRC)>; + +def : Pat<(Narrow.KVT (and Narrow.KRC:$mask, + (OpNode (Narrow.VT Narrow.RC:$src1), + (Narrow.VT Narrow.RC:$src2)))), (COPY_TO_REGCLASS (!cast<Instruction>(InstStr##Zrrk) - (COPY_TO_REGCLASS VK8:$mask, VK16), - (v16i32 (INSERT_SUBREG (IMPLICIT_DEF), VR256X:$src1, sub_ymm)), - (v16i32 (INSERT_SUBREG (IMPLICIT_DEF), VR256X:$src2, sub_ymm))), - VK8)>; + (COPY_TO_REGCLASS Narrow.KRC:$mask, Wide.KRC), + (Wide.VT (INSERT_SUBREG (IMPLICIT_DEF), Narrow.RC:$src1, Narrow.SubRegIdx)), + (Wide.VT (INSERT_SUBREG (IMPLICIT_DEF), Narrow.RC:$src2, Narrow.SubRegIdx))), + Narrow.KRC)>; } multiclass axv512_icmp_packed_cc_no_vlx_lowering<SDNode OpNode, string InstStr, - AVX512VLVectorVTInfo _> { -def : Pat<(v8i1 (OpNode (_.info256.VT VR256X:$src1), (_.info256.VT VR256X:$src2), imm:$cc)), - (COPY_TO_REGCLASS (!cast<Instruction>(InstStr##Zrri) - (_.info512.VT (INSERT_SUBREG (IMPLICIT_DEF), VR256X:$src1, sub_ymm)), - (_.info512.VT (INSERT_SUBREG (IMPLICIT_DEF), VR256X:$src2, sub_ymm)), - imm:$cc), VK8)>; - -def : Pat<(v8i1 (and VK8:$mask, (OpNode (_.info256.VT VR256X:$src1), - (_.info256.VT VR256X:$src2), imm:$cc))), - (COPY_TO_REGCLASS (!cast<Instruction>(InstStr##Zrrik) - (COPY_TO_REGCLASS VK8:$mask, VK16), - (_.info512.VT (INSERT_SUBREG (IMPLICIT_DEF), VR256X:$src1, sub_ymm)), - (_.info512.VT (INSERT_SUBREG (IMPLICIT_DEF), VR256X:$src2, sub_ymm)), - imm:$cc), VK8)>; + X86VectorVTInfo Narrow, + X86VectorVTInfo Wide> { +def : Pat<(Narrow.KVT (OpNode (Narrow.VT Narrow.RC:$src1), + (Narrow.VT Narrow.RC:$src2), imm:$cc)), + (COPY_TO_REGCLASS + (!cast<Instruction>(InstStr##Zrri) + (Wide.VT (INSERT_SUBREG (IMPLICIT_DEF), Narrow.RC:$src1, Narrow.SubRegIdx)), + (Wide.VT (INSERT_SUBREG (IMPLICIT_DEF), Narrow.RC:$src2, Narrow.SubRegIdx)), + imm:$cc), Narrow.KRC)>; + +def : Pat<(Narrow.KVT (and Narrow.KRC:$mask, + (OpNode (Narrow.VT Narrow.RC:$src1), + (Narrow.VT Narrow.RC:$src2), imm:$cc))), + (COPY_TO_REGCLASS (!cast<Instruction>(InstStr##Zrrik) + (COPY_TO_REGCLASS Narrow.KRC:$mask, Wide.KRC), + (Wide.VT (INSERT_SUBREG (IMPLICIT_DEF), Narrow.RC:$src1, Narrow.SubRegIdx)), + (Wide.VT (INSERT_SUBREG (IMPLICIT_DEF), Narrow.RC:$src2, Narrow.SubRegIdx)), + imm:$cc), Narrow.KRC)>; } let Predicates = [HasAVX512, NoVLX] in { - defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpgtm, "VPCMPGTD">; - defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpeqm, "VPCMPEQD">; + defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpgtm, "VPCMPGTD", v8i32x_info, v16i32_info>; + defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpeqm, "VPCMPEQD", v8i32x_info, v16i32_info>; + + defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpgtm, "VPCMPGTD", v4i32x_info, v16i32_info>; + defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpeqm, "VPCMPEQD", v4i32x_info, v16i32_info>; + + defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpgtm, "VPCMPGTQ", v4i64x_info, v8i64_info>; + defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpeqm, "VPCMPEQQ", v4i64x_info, v8i64_info>; - defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VCMPPS", avx512vl_f32_info>; - defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VPCMPD", avx512vl_i32_info>; - defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpmu, "VPCMPUD", avx512vl_i32_info>; + defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpgtm, "VPCMPGTQ", v2i64x_info, v8i64_info>; + defm : axv512_icmp_packed_no_vlx_lowering<X86pcmpeqm, "VPCMPEQQ", v2i64x_info, v8i64_info>; + + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VCMPPS", v8f32x_info, v16f32_info>; + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VPCMPD", v8i32x_info, v16i32_info>; + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpmu, "VPCMPUD", v8i32x_info, v16i32_info>; + + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VCMPPS", v4f32x_info, v16f32_info>; + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VPCMPD", v4i32x_info, v16i32_info>; + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpmu, "VPCMPUD", v4i32x_info, v16i32_info>; + + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VCMPPD", v4f64x_info, v8f64_info>; + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VPCMPQ", v4i64x_info, v8i64_info>; + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpmu, "VPCMPUQ", v4i64x_info, v8i64_info>; + + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VCMPPD", v2f64x_info, v8f64_info>; + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpm, "VPCMPQ", v2i64x_info, v8i64_info>; + defm : axv512_icmp_packed_cc_no_vlx_lowering<X86cmpmu, "VPCMPUQ", v2i64x_info, v8i64_info>; } // Mask setting all 0s or 1s @@ -3376,8 +3407,15 @@ multiclass mask_move_lowering<string InstrStr, X86VectorVTInfo Narrow, // Patterns for handling v8i1 selects of 256-bit vectors when VLX isn't // available. Use a 512-bit operation and extract. let Predicates = [HasAVX512, NoVLX] in { + defm : mask_move_lowering<"VMOVAPSZ", v4f32x_info, v16f32_info>; + defm : mask_move_lowering<"VMOVDQA32Z", v4i32x_info, v16i32_info>; defm : mask_move_lowering<"VMOVAPSZ", v8f32x_info, v16f32_info>; defm : mask_move_lowering<"VMOVDQA32Z", v8i32x_info, v16i32_info>; + + defm : mask_move_lowering<"VMOVAPDZ", v2f64x_info, v8f64_info>; + defm : mask_move_lowering<"VMOVDQA64Z", v2i64x_info, v8i64_info>; + defm : mask_move_lowering<"VMOVAPDZ", v4f64x_info, v8f64_info>; + defm : mask_move_lowering<"VMOVDQA64Z", v4i64x_info, v8i64_info>; } let Predicates = [HasAVX512] in { diff --git a/llvm/lib/Target/X86/X86InstrVecCompiler.td b/llvm/lib/Target/X86/X86InstrVecCompiler.td index c1cb4dcb16b..ed3e83f7848 100644 --- a/llvm/lib/Target/X86/X86InstrVecCompiler.td +++ b/llvm/lib/Target/X86/X86InstrVecCompiler.td @@ -495,6 +495,18 @@ let Predicates = [HasBWI, HasVLX] in { // If the bits are not zero we have to fall back to explicitly zeroing by // using shifts. +let Predicates = [HasAVX512] in { + def : Pat<(v16i1 (insert_subvector (v16i1 immAllZerosV), + (v2i1 VK2:$mask), (iPTR 0))), + (KSHIFTRWri (KSHIFTLWri (COPY_TO_REGCLASS VK2:$mask, VK16), + (i8 14)), (i8 14))>; + + def : Pat<(v16i1 (insert_subvector (v16i1 immAllZerosV), + (v4i1 VK4:$mask), (iPTR 0))), + (KSHIFTRWri (KSHIFTLWri (COPY_TO_REGCLASS VK4:$mask, VK16), + (i8 12)), (i8 12))>; +} + let Predicates = [HasAVX512, NoDQI] in { def : Pat<(v16i1 (insert_subvector (v16i1 immAllZerosV), (v8i1 VK8:$mask), (iPTR 0))), @@ -506,9 +518,7 @@ let Predicates = [HasDQI] in { def : Pat<(v16i1 (insert_subvector (v16i1 immAllZerosV), (v8i1 VK8:$mask), (iPTR 0))), (COPY_TO_REGCLASS (KMOVBkk VK8:$mask), VK16)>; -} -let Predicates = [HasVLX, HasDQI] in { def : Pat<(v8i1 (insert_subvector (v8i1 immAllZerosV), (v2i1 VK2:$mask), (iPTR 0))), (KSHIFTRBri (KSHIFTLBri (COPY_TO_REGCLASS VK2:$mask, VK8), @@ -519,17 +529,6 @@ let Predicates = [HasVLX, HasDQI] in { (i8 4)), (i8 4))>; } -let Predicates = [HasVLX] in { - def : Pat<(v16i1 (insert_subvector (v16i1 immAllZerosV), - (v2i1 VK2:$mask), (iPTR 0))), - (KSHIFTRWri (KSHIFTLWri (COPY_TO_REGCLASS VK2:$mask, VK16), - (i8 14)), (i8 14))>; - def : Pat<(v16i1 (insert_subvector (v16i1 immAllZerosV), - (v4i1 VK4:$mask), (iPTR 0))), - (KSHIFTRWri (KSHIFTLWri (COPY_TO_REGCLASS VK4:$mask, VK16), - (i8 12)), (i8 12))>; -} - let Predicates = [HasBWI] in { def : Pat<(v32i1 (insert_subvector (v32i1 immAllZerosV), (v16i1 VK16:$mask), (iPTR 0))), |

