diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 50 |
1 files changed, 32 insertions, 18 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 49ca4c6773c..d4179993f02 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -4810,6 +4810,18 @@ bool X86TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.flags |= MachineMemOperand::MOStore; break; } + case GATHER: + case GATHER_AVX2: { + Info.ptrVal = nullptr; + MVT DataVT = MVT::getVT(I.getType()); + MVT IndexVT = MVT::getVT(I.getArgOperand(2)->getType()); + unsigned NumElts = std::min(DataVT.getVectorNumElements(), + IndexVT.getVectorNumElements()); + Info.memVT = MVT::getVectorVT(DataVT.getVectorElementType(), NumElts); + Info.align = 1; + Info.flags |= MachineMemOperand::MOLoad; + break; + } default: return false; } @@ -22376,25 +22388,26 @@ static SDValue getAVX2GatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, if (!C) return SDValue(); SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8); - EVT MaskVT = Mask.getValueType(); + EVT MaskVT = Mask.getValueType().changeVectorElementTypeToInteger(); SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other); - SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32); - SDValue Segment = DAG.getRegister(0, MVT::i32); // If source is undef or we know it won't be used, use a zero vector // to break register dependency. // TODO: use undef instead and let BreakFalseDeps deal with it? if (Src.isUndef() || ISD::isBuildVectorAllOnes(Mask.getNode())) Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl); - SDValue Ops[] = {Src, Base, Scale, Index, Disp, Segment, Mask, Chain}; - SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops); - SDValue RetOps[] = { SDValue(Res, 0), SDValue(Res, 2) }; - return DAG.getMergeValues(RetOps, dl); + + MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op); + + SDValue Ops[] = {Chain, Src, Mask, Base, Index, Scale }; + SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( + VTs, Ops, dl, MemIntr->getMemoryVT(), MemIntr->getMemOperand()); + return DAG.getMergeValues({ Res, Res.getValue(2) }, dl); } -static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, - SDValue Src, SDValue Mask, SDValue Base, - SDValue Index, SDValue ScaleOp, SDValue Chain, - const X86Subtarget &Subtarget) { +static SDValue getGatherNode(SDValue Op, SelectionDAG &DAG, + SDValue Src, SDValue Mask, SDValue Base, + SDValue Index, SDValue ScaleOp, SDValue Chain, + const X86Subtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); SDLoc dl(Op); auto *C = dyn_cast<ConstantSDNode>(ScaleOp); @@ -22412,17 +22425,18 @@ static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other); - SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32); - SDValue Segment = DAG.getRegister(0, MVT::i32); // If source is undef or we know it won't be used, use a zero vector // to break register dependency. // TODO: use undef instead and let BreakFalseDeps deal with it? if (Src.isUndef() || ISD::isBuildVectorAllOnes(Mask.getNode())) Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl); - SDValue Ops[] = {Src, Mask, Base, Scale, Index, Disp, Segment, Chain}; - SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops); - SDValue RetOps[] = { SDValue(Res, 0), SDValue(Res, 2) }; - return DAG.getMergeValues(RetOps, dl); + + MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op); + + SDValue Ops[] = {Chain, Src, Mask, Base, Index, Scale }; + SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( + VTs, Ops, dl, MemIntr->getMemoryVT(), MemIntr->getMemOperand()); + return DAG.getMergeValues({ Res, Res.getValue(2) }, dl); } static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, @@ -22787,7 +22801,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, SDValue Index = Op.getOperand(4); SDValue Mask = Op.getOperand(5); SDValue Scale = Op.getOperand(6); - return getGatherNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index, Scale, + return getGatherNode(Op, DAG, Src, Mask, Base, Index, Scale, Chain, Subtarget); } case SCATTER: { |