diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86InstrAVX512.td | 12 |
2 files changed, 14 insertions, 2 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 99893be4e60..9e7a41c752a 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -19697,7 +19697,7 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16) { // If v16i32 is to be avoided, we'll need to split and concatenate. if (NumElts == 16 && !Subtarget.canExtendTo512DQ()) - return SplitAndExtendv16i1(ISD::SIGN_EXTEND, VT, In, dl, DAG); + return SplitAndExtendv16i1(Op.getOpcode(), VT, In, dl, DAG); ExtVT = MVT::getVectorVT(MVT::i32, NumElts); } @@ -19716,7 +19716,7 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, MVT WideEltVT = WideVT.getVectorElementType(); if ((Subtarget.hasDQI() && WideEltVT.getSizeInBits() >= 32) || (Subtarget.hasBWI() && WideEltVT.getSizeInBits() <= 16)) { - V = DAG.getNode(ISD::SIGN_EXTEND, dl, WideVT, In); + V = DAG.getNode(Op.getOpcode(), dl, WideVT, In); } else { SDValue NegOne = getOnesVector(WideVT, DAG, dl); SDValue Zero = getZeroVector(WideVT, Subtarget, DAG, dl); diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index ec314f329fd..f8ade37f8df 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -9958,6 +9958,10 @@ def rr : AVX512XS8I<opc, MRMSrcReg, (outs Vec.RC:$dst), (ins Vec.KRC:$src), !strconcat(OpcodeStr##Vec.Suffix, "\t{$src, $dst|$dst, $src}"), [(set Vec.RC:$dst, (Vec.VT (sext Vec.KRC:$src)))]>, EVEX, Sched<[WriteMove]>; // TODO - WriteVecTrunc? + +// Also need a pattern for anyextend. +def : Pat<(Vec.VT (anyext Vec.KRC:$src)), + (!cast<Instruction>(NAME#"rr") Vec.KRC:$src)>; } multiclass cvt_mask_by_elt_width<bits<8> opc, AVX512VLVectorVTInfo VTInfo, @@ -10031,11 +10035,19 @@ let Predicates = [HasDQI, NoBWI] in { (VPMOVDBZrr (v16i32 (VPMOVM2DZrr VK16:$src)))>; def : Pat<(v16i16 (sext (v16i1 VK16:$src))), (VPMOVDWZrr (v16i32 (VPMOVM2DZrr VK16:$src)))>; + + def : Pat<(v16i8 (anyext (v16i1 VK16:$src))), + (VPMOVDBZrr (v16i32 (VPMOVM2DZrr VK16:$src)))>; + def : Pat<(v16i16 (anyext (v16i1 VK16:$src))), + (VPMOVDWZrr (v16i32 (VPMOVM2DZrr VK16:$src)))>; } let Predicates = [HasDQI, NoBWI, HasVLX] in { def : Pat<(v8i16 (sext (v8i1 VK8:$src))), (VPMOVDWZ256rr (v8i32 (VPMOVM2DZ256rr VK8:$src)))>; + + def : Pat<(v8i16 (anyext (v8i1 VK8:$src))), + (VPMOVDWZ256rr (v8i32 (VPMOVM2DZ256rr VK8:$src)))>; } //===----------------------------------------------------------------------===// |