diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelDAGToDAG.cpp | 11 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 12 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.h | 44 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86InstrFragmentsSIMD.td | 31 |
4 files changed, 59 insertions, 39 deletions
diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 71ae97d7e92..504482a5e2a 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -1522,14 +1522,9 @@ bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base, SDValue &Scale, SDValue &Index, SDValue &Disp, SDValue &Segment) { X86ISelAddressMode AM; - if (auto Mgs = dyn_cast<MaskedGatherScatterSDNode>(Parent)) { - AM.IndexReg = Mgs->getIndex(); - AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8; - } else { - auto X86Gather = cast<X86MaskedGatherSDNode>(Parent); - AM.IndexReg = X86Gather->getIndex(); - AM.Scale = X86Gather->getValue().getScalarValueSizeInBits() / 8; - } + auto *Mgs = cast<X86MaskedGatherScatterSDNode>(Parent); + AM.IndexReg = Mgs->getIndex(); + AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8; unsigned AddrSpace = cast<MemSDNode>(Parent)->getPointerInfo().getAddrSpace(); // AddrSpace 256 -> GS, 257 -> FS, 258 -> SS. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b3afe081834..d6436eeac68 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -24112,19 +24112,12 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, assert(Subtarget.hasAVX512() && "MGATHER/MSCATTER are supported on AVX-512 arch only"); - // X86 scatter kills mask register, so its type should be added to - // the list of return values. - // If the "scatter" has 2 return values, it is already handled. - if (Op.getNode()->getNumValues() == 2) - return Op; - MaskedScatterSDNode *N = cast<MaskedScatterSDNode>(Op.getNode()); SDValue Src = N->getValue(); MVT VT = Src.getSimpleValueType(); assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op"); SDLoc dl(Op); - SDValue NewScatter; SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); SDValue Chain = N->getChain(); @@ -24195,8 +24188,8 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, // The mask is killed by scatter, add it to the values SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other); SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index}; - NewScatter = DAG.getMaskedScatter(VTs, N->getMemoryVT(), dl, Ops, - N->getMemOperand()); + SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( + VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); return SDValue(NewScatter.getNode(), 1); } @@ -25261,6 +25254,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::CVTS2UI_RND: return "X86ISD::CVTS2UI_RND"; case X86ISD::LWPINS: return "X86ISD::LWPINS"; case X86ISD::MGATHER: return "X86ISD::MGATHER"; + case X86ISD::MSCATTER: return "X86ISD::MSCATTER"; case X86ISD::VPDPBUSD: return "X86ISD::VPDPBUSD"; case X86ISD::VPDPBUSDS: return "X86ISD::VPDPBUSDS"; case X86ISD::VPDPWSSD: return "X86ISD::VPDPWSSD"; diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h index b79addfe198..fc8519bb973 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -637,8 +637,8 @@ namespace llvm { // Vector truncating masked store with unsigned/signed saturation VMTRUNCSTOREUS, VMTRUNCSTORES, - // X86 specific gather - MGATHER + // X86 specific gather and scatter + MGATHER, MSCATTER, // WARNING: Do not add anything in the end unless you want the node to // have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all @@ -1423,16 +1423,15 @@ namespace llvm { } }; - // X86 specific Gather node. - // The class has the same order of operands as MaskedGatherSDNode for + // X86 specific Gather/Scatter nodes. + // The class has the same order of operands as MaskedGatherScatterSDNode for // convenience. - class X86MaskedGatherSDNode : public MemSDNode { + class X86MaskedGatherScatterSDNode : public MemSDNode { public: - X86MaskedGatherSDNode(unsigned Order, - const DebugLoc &dl, SDVTList VTs, EVT MemVT, - MachineMemOperand *MMO) - : MemSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT, MMO) - {} + X86MaskedGatherScatterSDNode(unsigned Opc, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(Opc, Order, dl, VTs, MemVT, MMO) {} const SDValue &getBasePtr() const { return getOperand(3); } const SDValue &getIndex() const { return getOperand(4); } @@ -1440,10 +1439,35 @@ namespace llvm { const SDValue &getValue() const { return getOperand(1); } static bool classof(const SDNode *N) { + return N->getOpcode() == X86ISD::MGATHER || + N->getOpcode() == X86ISD::MSCATTER; + } + }; + + class X86MaskedGatherSDNode : public X86MaskedGatherScatterSDNode { + public: + X86MaskedGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO) + : X86MaskedGatherScatterSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT, + MMO) {} + + static bool classof(const SDNode *N) { return N->getOpcode() == X86ISD::MGATHER; } }; + class X86MaskedScatterSDNode : public X86MaskedGatherScatterSDNode { + public: + X86MaskedScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO) + : X86MaskedGatherScatterSDNode(X86ISD::MSCATTER, Order, dl, VTs, MemVT, + MMO) {} + + static bool classof(const SDNode *N) { + return N->getOpcode() == X86ISD::MSCATTER; + } + }; + /// Generate unpacklo/unpackhi shuffle mask. template <typename T = int> void createUnpackShuffleMask(MVT VT, SmallVectorImpl<T> &Mask, bool Lo, diff --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td index c38b2c730c7..2eb735abd69 100644 --- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td +++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td @@ -781,6 +781,13 @@ def X86masked_gather : SDNode<"X86ISD::MGATHER", SDTCisPtrTy<4>]>, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def X86masked_scatter : SDNode<"X86ISD::MSCATTER", + SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisVec<1>, + SDTCisSameAs<0, 2>, + SDTCVecEltisVT<0, i1>, + SDTCisPtrTy<3>]>, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; + def mgatherv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), (X86masked_gather node:$src1, node:$src2, node:$src3) , [{ X86MaskedGatherSDNode *Mgt = cast<X86MaskedGatherSDNode>(N); @@ -815,37 +822,37 @@ def mgatherv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), }]>; def mscatterv2i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_scatter node:$src1, node:$src2, node:$src3) , [{ - MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N); + (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{ + X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N); return Sc->getIndex().getValueType() == MVT::v2i64; }]>; def mscatterv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_scatter node:$src1, node:$src2, node:$src3) , [{ - MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N); + (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{ + X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N); return Sc->getIndex().getValueType() == MVT::v4i32; }]>; def mscatterv4i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_scatter node:$src1, node:$src2, node:$src3) , [{ - MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N); + (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{ + X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N); return Sc->getIndex().getValueType() == MVT::v4i64; }]>; def mscatterv8i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_scatter node:$src1, node:$src2, node:$src3) , [{ - MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N); + (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{ + X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N); return Sc->getIndex().getValueType() == MVT::v8i32; }]>; def mscatterv8i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_scatter node:$src1, node:$src2, node:$src3) , [{ - MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N); + (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{ + X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N); return Sc->getIndex().getValueType() == MVT::v8i64; }]>; def mscatterv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_scatter node:$src1, node:$src2, node:$src3) , [{ - MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N); + (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{ + X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N); return Sc->getIndex().getValueType() == MVT::v16i32; }]>; |