diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 93 |
1 files changed, 58 insertions, 35 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 5741b80d1a7..521fc3cd37b 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43381,26 +43381,36 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG, static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { SDLoc DL(N); + auto *GorS = cast<MaskedGatherScatterSDNode>(N); + SDValue Chain = GorS->getChain(); + SDValue Index = GorS->getIndex(); + SDValue Mask = GorS->getMask(); + SDValue Base = GorS->getBasePtr(); + SDValue Scale = GorS->getScale(); if (DCI.isBeforeLegalizeOps()) { - SDValue Index = N->getOperand(4); // Remove any sign extends from 32 or smaller to larger than 32. // Only do this before LegalizeOps in case we need the sign extend for // legalization. - if (Index.getOpcode() == ISD::SIGN_EXTEND) { - if (Index.getScalarValueSizeInBits() > 32 && - Index.getOperand(0).getScalarValueSizeInBits() <= 32) { - SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); - NewOps[4] = Index.getOperand(0); - SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); - if (Res == N) { - // The original sign extend has less users, add back to worklist in - // case it needs to be removed - DCI.AddToWorklist(Index.getNode()); - DCI.AddToWorklist(N); - } - return SDValue(Res, 0); - } + if (Index.getOpcode() == ISD::SIGN_EXTEND && + Index.getScalarValueSizeInBits() > 32 && + Index.getOperand(0).getScalarValueSizeInBits() <= 32) { + Index = Index.getOperand(0); + if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) { + SDValue Ops[] = { Chain, Gather->getPassThru(), + Mask, Base, Index, Scale } ; + return DAG.getMaskedGather(Gather->getVTList(), + Gather->getMemoryVT(), DL, Ops, + Gather->getMemOperand(), + Gather->getIndexType()); + } + auto *Scatter = cast<MaskedScatterSDNode>(GorS); + SDValue Ops[] = { Chain, Scatter->getValue(), + Mask, Base, Index, Scale }; + return DAG.getMaskedScatter(Scatter->getVTList(), + Scatter->getMemoryVT(), DL, + Ops, Scatter->getMemOperand(), + Scatter->getIndexType()); } // Make sure the index is either i32 or i64 @@ -43410,36 +43420,49 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT, Index.getValueType().getVectorNumElements()); Index = DAG.getSExtOrTrunc(Index, DL, IndexVT); - SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); - NewOps[4] = Index; - SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); - if (Res == N) - DCI.AddToWorklist(N); - return SDValue(Res, 0); + if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) { + SDValue Ops[] = { Chain, Gather->getPassThru(), + Mask, Base, Index, Scale } ; + return DAG.getMaskedGather(Gather->getVTList(), + Gather->getMemoryVT(), DL, Ops, + Gather->getMemOperand(), + Gather->getIndexType()); + } + auto *Scatter = cast<MaskedScatterSDNode>(GorS); + SDValue Ops[] = { Chain, Scatter->getValue(), + Mask, Base, Index, Scale }; + return DAG.getMaskedScatter(Scatter->getVTList(), + Scatter->getMemoryVT(), DL, + Ops, Scatter->getMemOperand(), + Scatter->getIndexType()); } // Try to remove zero extends from 32->64 if we know the sign bit of // the input is zero. if (Index.getOpcode() == ISD::ZERO_EXTEND && Index.getScalarValueSizeInBits() == 64 && - Index.getOperand(0).getScalarValueSizeInBits() == 32) { - if (DAG.SignBitIsZero(Index.getOperand(0))) { - SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); - NewOps[4] = Index.getOperand(0); - SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); - if (Res == N) { - // The original sign extend has less users, add back to worklist in - // case it needs to be removed - DCI.AddToWorklist(Index.getNode()); - DCI.AddToWorklist(N); - } - return SDValue(Res, 0); - } + Index.getOperand(0).getScalarValueSizeInBits() == 32 && + DAG.SignBitIsZero(Index.getOperand(0))) { + Index = Index.getOperand(0); + if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) { + SDValue Ops[] = { Chain, Gather->getPassThru(), + Mask, Base, Index, Scale } ; + return DAG.getMaskedGather(Gather->getVTList(), + Gather->getMemoryVT(), DL, Ops, + Gather->getMemOperand(), + Gather->getIndexType()); + } + auto *Scatter = cast<MaskedScatterSDNode>(GorS); + SDValue Ops[] = { Chain, Scatter->getValue(), + Mask, Base, Index, Scale }; + return DAG.getMaskedScatter(Scatter->getVTList(), + Scatter->getMemoryVT(), DL, + Ops, Scatter->getMemOperand(), + Scatter->getIndexType()); } } // With vector masks we only demand the upper bit of the mask. - SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask(); if (Mask.getScalarValueSizeInBits() != 1) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits())); |