diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index a5f558fb79c..81e4a813367 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -307,6 +307,8 @@ namespace { SDValue visitINSERT_SUBVECTOR(SDNode *N); SDValue visitMLOAD(SDNode *N); SDValue visitMSTORE(SDNode *N); + SDValue visitMGATHER(SDNode *N); + SDValue visitMSCATTER(SDNode *N); SDValue visitFP_TO_FP16(SDNode *N); SDValue visitFADDForFMACombine(SDNode *N); @@ -1382,7 +1384,9 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N); case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N); case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N); + case ISD::MGATHER: return visitMGATHER(N); case ISD::MLOAD: return visitMLOAD(N); + case ISD::MSCATTER: return visitMSCATTER(N); case ISD::MSTORE: return visitMSTORE(N); case ISD::FP_TO_FP16: return visitFP_TO_FP16(N); } @@ -5073,6 +5077,67 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1)); } +SDValue DAGCombiner::visitMSCATTER(SDNode *N) { + + if (Level >= AfterLegalizeTypes) + return SDValue(); + + MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N); + SDValue Mask = MSC->getMask(); + SDValue Data = MSC->getValue(); + SDLoc DL(N); + + // If the MSCATTER data type requires splitting and the mask is provided by a + // SETCC, then split both nodes and its operands before legalization. This + // prevents the type legalizer from unrolling SETCC into scalar comparisons + // and enables future optimizations (e.g. min/max pattern matching on X86). + if (Mask.getOpcode() != ISD::SETCC) + return SDValue(); + + // Check if any splitting is required. + if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) != + TargetLowering::TypeSplitVector) + return SDValue(); + SDValue MaskLo, MaskHi, Lo, Hi; + std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MSC->getValueType(0)); + + SDValue Chain = MSC->getChain(); + + EVT MemoryVT = MSC->getMemoryVT(); + unsigned Alignment = MSC->getOriginalAlignment(); + + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + SDValue DataLo, DataHi; + std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL); + + SDValue BasePtr = MSC->getBasePtr(); + SDValue IndexLo, IndexHi; + std::tie(IndexLo, IndexHi) = DAG.SplitVector(MSC->getIndex(), DL); + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MSC->getPointerInfo(), + MachineMemOperand::MOStore, LoMemVT.getStoreSize(), + Alignment, MSC->getAAInfo(), MSC->getRanges()); + + SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo }; + Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), + DL, OpsLo, MMO); + + SDValue OpsHi[] = {Chain, DataHi, MaskHi, BasePtr, IndexHi}; + Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), + DL, OpsHi, MMO); + + AddToWorklist(Lo.getNode()); + AddToWorklist(Hi.getNode()); + + return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi); +} + SDValue DAGCombiner::visitMSTORE(SDNode *N) { if (Level >= AfterLegalizeTypes) @@ -5147,6 +5212,83 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitMGATHER(SDNode *N) { + + if (Level >= AfterLegalizeTypes) + return SDValue(); + + MaskedGatherSDNode *MGT = dyn_cast<MaskedGatherSDNode>(N); + SDValue Mask = MGT->getMask(); + SDLoc DL(N); + + // If the MGATHER result requires splitting and the mask is provided by a + // SETCC, then split both nodes and its operands before legalization. This + // prevents the type legalizer from unrolling SETCC into scalar comparisons + // and enables future optimizations (e.g. min/max pattern matching on X86). + + if (Mask.getOpcode() != ISD::SETCC) + return SDValue(); + + EVT VT = N->getValueType(0); + + // Check if any splitting is required. + if (TLI.getTypeAction(*DAG.getContext(), VT) != + TargetLowering::TypeSplitVector) + return SDValue(); + + SDValue MaskLo, MaskHi, Lo, Hi; + std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); + + SDValue Src0 = MGT->getValue(); + SDValue Src0Lo, Src0Hi; + std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); + + SDValue Chain = MGT->getChain(); + EVT MemoryVT = MGT->getMemoryVT(); + unsigned Alignment = MGT->getOriginalAlignment(); + + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + SDValue BasePtr = MGT->getBasePtr(); + SDValue Index = MGT->getIndex(); + SDValue IndexLo, IndexHi; + std::tie(IndexLo, IndexHi) = DAG.SplitVector(Index, DL); + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MGT->getPointerInfo(), + MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), + Alignment, MGT->getAAInfo(), MGT->getRanges()); + + SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo }; + Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo, + MMO); + + SDValue OpsHi[] = {Chain, Src0Hi, MaskHi, BasePtr, IndexHi}; + Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi, + MMO); + + AddToWorklist(Lo.getNode()); + AddToWorklist(Hi.getNode()); + + // Build a factor node to remember that this load is independent of the + // other one. + Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo.getValue(1), + Hi.getValue(1)); + + // Legalized the chain result - switch anything that used the old chain to + // use the new one. + DAG.ReplaceAllUsesOfValueWith(SDValue(MGT, 1), Chain); + + SDValue GatherRes = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); + + SDValue RetOps[] = { GatherRes, Chain }; + return DAG.getMergeValues(RetOps, DL); +} + SDValue DAGCombiner::visitMLOAD(SDNode *N) { if (Level >= AfterLegalizeTypes) |