diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 113 |
1 files changed, 38 insertions, 75 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index d93ec99ed89..44815757515 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -24385,47 +24385,32 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, } MVT IndexVT = Index.getSimpleValueType(); + MVT MaskVT = Mask.getSimpleValueType(); // If the index is v2i32, we're being called by type legalization and we // should just let the default handling take care of it. if (IndexVT == MVT::v2i32) return SDValue(); - unsigned NumElts = VT.getVectorNumElements(); + // If we don't have VLX and neither the passthru or index is 512-bits, we + // need to widen until one is. if (!Subtarget.hasVLX() && !VT.is512BitVector() && !Index.getSimpleValueType().is512BitVector()) { - // AVX512F supports only 512-bit vectors. Or data or index should - // be 512 bit wide. If now the both index and data are 256-bit, but - // the vector contains 8 elements, we just sign-extend the index - if (IndexVT == MVT::v8i32) - // Just extend index - Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - else { - // The minimal number of elts in scatter is 8 - NumElts = 8; - // Index - MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), NumElts); - // Use original index here, do not modify the index twice - Index = ExtendToType(N->getIndex(), NewIndexVT, DAG); - if (IndexVT.getScalarType() == MVT::i32) - Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - - // Mask - // At this point we have promoted mask operand - assert(Mask.getValueType().getScalarType() == MVT::i1 && - "unexpected mask type"); - MVT ExtMaskVT = MVT::getVectorVT(MVT::i1, NumElts); - // Use the original mask here, do not modify the mask twice - Mask = ExtendToType(N->getMask(), ExtMaskVT, DAG, true); - - // The value that should be stored - MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts); - Src = ExtendToType(Src, NewVT, DAG); - } - } - - // The mask is killed by scatter, add it to the values - SDVTList VTs = DAG.getVTList(Mask.getValueType(), MVT::Other); + // Determine how much we need to widen by to get a 512-bit type. + unsigned Factor = std::min(512/VT.getSizeInBits(), + 512/IndexVT.getSizeInBits()); + unsigned NumElts = VT.getVectorNumElements() * Factor; + + VT = MVT::getVectorVT(VT.getVectorElementType(), NumElts); + IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts); + MaskVT = MVT::getVectorVT(MVT::i1, NumElts); + + Src = ExtendToType(Src, VT, DAG); + Index = ExtendToType(Index, IndexVT, DAG); + Mask = ExtendToType(Mask, MaskVT, DAG, true); + } + + SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other); SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); @@ -24532,68 +24517,46 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, MaskedGatherSDNode *N = cast<MaskedGatherSDNode>(Op.getNode()); SDLoc dl(Op); MVT VT = Op.getSimpleValueType(); - SDValue Scale = N->getScale(); SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); SDValue Src0 = N->getValue(); MVT IndexVT = Index.getSimpleValueType(); MVT MaskVT = Mask.getSimpleValueType(); - unsigned NumElts = VT.getVectorNumElements(); assert(VT.getScalarSizeInBits() >= 32 && "Unsupported gather op"); // If the index is v2i32, we're being called by type legalization. if (IndexVT == MVT::v2i32) return SDValue(); + // If we don't have VLX and neither the passthru or index is 512-bits, we + // need to widen until one is. + MVT OrigVT = VT; if (Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && - !Index.getSimpleValueType().is512BitVector()) { - // AVX512F supports only 512-bit vectors. Or data or index should - // be 512 bit wide. If now the both index and data are 256-bit, but - // the vector contains 8 elements, we just sign-extend the index - if (NumElts == 8) { - Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, - Scale }; - SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( - DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), - N->getMemOperand()); - return DAG.getMergeValues({NewGather, NewGather.getValue(2)}, dl); - } - - // Minimal number of elements in Gather - NumElts = 8; - // Index - MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), NumElts); - Index = ExtendToType(Index, NewIndexVT, DAG); - if (IndexVT.getScalarType() == MVT::i32) - Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - - // Mask - assert(MaskVT.getScalarType() == MVT::i1 && "unexpected mask type"); - MaskVT = MVT::getVectorVT(MVT::i1, NumElts); - Mask = ExtendToType(Mask, MaskVT, DAG, true); + !IndexVT.is512BitVector()) { + // Determine how much we need to widen by to get a 512-bit type. + unsigned Factor = std::min(512/VT.getSizeInBits(), + 512/IndexVT.getSizeInBits()); + + unsigned NumElts = VT.getVectorNumElements() * Factor; - // The pass-through value - MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts); - Src0 = ExtendToType(Src0, NewVT, DAG); + VT = MVT::getVectorVT(VT.getVectorElementType(), NumElts); + IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts); + MaskVT = MVT::getVectorVT(MVT::i1, NumElts); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; - SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( - DAG.getVTList(NewVT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), - N->getMemOperand()); - SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, - NewGather.getValue(0), - DAG.getIntPtrConstant(0, dl)); - SDValue RetOps[] = {Extract, NewGather.getValue(2)}; - return DAG.getMergeValues(RetOps, dl); + Src0 = ExtendToType(Src0, VT, DAG); + Index = ExtendToType(Index, IndexVT, DAG); + Mask = ExtendToType(Mask, MaskVT, DAG, true); } - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, + N->getScale() }; SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), N->getMemOperand()); - return DAG.getMergeValues({NewGather, NewGather.getValue(2)}, dl); + SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OrigVT, + NewGather, DAG.getIntPtrConstant(0, dl)); + return DAG.getMergeValues({Extract, NewGather.getValue(2)}, dl); } SDValue X86TargetLowering::LowerGC_TRANSITION_START(SDValue Op, |

