diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 74 |
1 files changed, 50 insertions, 24 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 16394f0edc6..ff7ea5d38f5 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1390,6 +1390,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::ROTR, VT, Custom); } + // Custom legalize 2x32 to get a little better code. + setOperationAction(ISD::MSCATTER, MVT::v2f32, Custom); + setOperationAction(ISD::MSCATTER, MVT::v2i32, Custom); + for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) setOperationAction(ISD::MSCATTER, VT, Custom); @@ -24322,33 +24326,55 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, SDValue Mask = N->getMask(); SDValue Chain = N->getChain(); SDValue BasePtr = N->getBasePtr(); - MVT MemVT = N->getMemoryVT().getSimpleVT(); + + if (VT == MVT::v2f32) { + assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); + // If the index is v2i64 and we have VLX we can use xmm for data and index. + if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) { + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, + DAG.getUNDEF(MVT::v2f32)); + SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( + VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); + DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); + return SDValue(NewScatter.getNode(), 1); + } + return SDValue(); + } + + if (VT == MVT::v2i32) { + assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src, + DAG.getUNDEF(MVT::v2i32)); + // If the index is v2i64 and we have VLX we can use xmm for data and index. + if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) { + SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( + VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); + DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); + return SDValue(NewScatter.getNode(), 1); + } + // Custom widen all the operands to avoid promotion. + EVT NewIndexVT = EVT::getVectorVT( + *DAG.getContext(), Index.getValueType().getVectorElementType(), 4); + Index = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewIndexVT, Index, + DAG.getUNDEF(Index.getValueType())); + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, + DAG.getConstant(0, dl, MVT::v2i1)); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(), dl, + Ops, N->getMemOperand()); + } + MVT IndexVT = Index.getSimpleValueType(); MVT MaskVT = Mask.getSimpleValueType(); - if (MemVT.getScalarSizeInBits() < VT.getScalarSizeInBits()) { - // The v2i32 value was promoted to v2i64. - // Now we "redo" the type legalizer's work and widen the original - // v2i32 value to v4i32. The original v2i32 is retrieved from v2i64 - // with a shuffle. - assert((MemVT == MVT::v2i32 && VT == MVT::v2i64) && - "Unexpected memory type"); - int ShuffleMask[] = {0, 2, -1, -1}; - Src = DAG.getVectorShuffle(MVT::v4i32, dl, DAG.getBitcast(MVT::v4i32, Src), - DAG.getUNDEF(MVT::v4i32), ShuffleMask); - // Now we have 4 elements instead of 2. - // Expand the index. - MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), 4); - Index = ExtendToType(Index, NewIndexVT, DAG); - - // Expand the mask with zeroes - // Mask may be <2 x i64> or <2 x i1> at this moment - assert((MaskVT == MVT::v2i1 || MaskVT == MVT::v2i64) && - "Unexpected mask type"); - MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), 4); - Mask = ExtendToType(Mask, ExtMaskVT, DAG, true); - VT = MVT::v4i32; - } + // If the index is v2i32, we're being called by type legalization and we + // should just let the default handling take care of it. + if (IndexVT == MVT::v2i32) + return SDValue(); unsigned NumElts = VT.getVectorNumElements(); if (!Subtarget.hasVLX() && !VT.is512BitVector() && |