diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 201429fe754..7347111728e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -303,6 +303,8 @@ namespace { SDValue visitEXTRACT_SUBVECTOR(SDNode *N); SDValue visitVECTOR_SHUFFLE(SDNode *N); SDValue visitINSERT_SUBVECTOR(SDNode *N); + SDValue visitMLOAD(SDNode *N); + SDValue visitMSTORE(SDNode *N); SDValue XformToShuffleWithZero(SDNode *N); SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); @@ -1351,6 +1353,8 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N); case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N); case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N); + case ISD::MLOAD: return visitMLOAD(N); + case ISD::MSTORE: return visitMSTORE(N); } return SDValue(); } @@ -4771,6 +4775,162 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1)); } +SDValue DAGCombiner::visitMSTORE(SDNode *N) { + + if (Level >= AfterLegalizeTypes) + return SDValue(); + + MaskedStoreSDNode *MST = dyn_cast<MaskedStoreSDNode>(N); + SDValue Mask = MST->getMask(); + SDValue Data = MST->getData(); + SDLoc DL(N); + + // If the MSTORE data type requires splitting and the mask is provided by a + // SETCC, then split both nodes and its operands before legalization. This + // prevents the type legalizer from unrolling SETCC into scalar comparisons + // and enables future optimizations (e.g. min/max pattern matching on X86). + if (Mask.getOpcode() == ISD::SETCC) { + + // Check if any splitting is required. + if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) != + TargetLowering::TypeSplitVector) + return SDValue(); + + SDValue MaskLo, MaskHi, Lo, Hi; + std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MST->getValueType(0)); + + SDValue Chain = MST->getChain(); + SDValue Ptr = MST->getBasePtr(); + + EVT MemoryVT = MST->getMemoryVT(); + unsigned Alignment = MST->getOriginalAlignment(); + + // if Alignment is equal to the vector size, + // take the half of it for the second part + unsigned SecondHalfAlignment = + (Alignment == Data->getValueType(0).getSizeInBits()/8) ? + Alignment/2 : Alignment; + + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + SDValue DataLo, DataHi; + std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL); + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MST->getPointerInfo(), + MachineMemOperand::MOStore, LoMemVT.getStoreSize(), + Alignment, MST->getAAInfo(), MST->getRanges()); + + Lo = DAG.getMaskedStore(Chain, DL, DataLo, Ptr, MaskLo, MMO); + + unsigned IncrementSize = LoMemVT.getSizeInBits()/8; + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, + DAG.getConstant(IncrementSize, Ptr.getValueType())); + + MMO = DAG.getMachineFunction(). + getMachineMemOperand(MST->getPointerInfo(), + MachineMemOperand::MOStore, HiMemVT.getStoreSize(), + SecondHalfAlignment, MST->getAAInfo(), + MST->getRanges()); + + Hi = DAG.getMaskedStore(Chain, DL, DataHi, Ptr, MaskHi, MMO); + + AddToWorklist(Lo.getNode()); + AddToWorklist(Hi.getNode()); + + return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi); + } + return SDValue(); +} + +SDValue DAGCombiner::visitMLOAD(SDNode *N) { + + if (Level >= AfterLegalizeTypes) + return SDValue(); + + MaskedLoadSDNode *MLD = dyn_cast<MaskedLoadSDNode>(N); + SDValue Mask = MLD->getMask(); + SDLoc DL(N); + + // If the MLOAD result requires splitting and the mask is provided by a + // SETCC, then split both nodes and its operands before legalization. This + // prevents the type legalizer from unrolling SETCC into scalar comparisons + // and enables future optimizations (e.g. min/max pattern matching on X86). + + if (Mask.getOpcode() == ISD::SETCC) { + EVT VT = N->getValueType(0); + + // Check if any splitting is required. + if (TLI.getTypeAction(*DAG.getContext(), VT) != + TargetLowering::TypeSplitVector) + return SDValue(); + + SDValue MaskLo, MaskHi, Lo, Hi; + std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); + + SDValue Src0 = MLD->getSrc0(); + SDValue Src0Lo, Src0Hi; + std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MLD->getValueType(0)); + + SDValue Chain = MLD->getChain(); + SDValue Ptr = MLD->getBasePtr(); + EVT MemoryVT = MLD->getMemoryVT(); + unsigned Alignment = MLD->getOriginalAlignment(); + + // if Alignment is equal to the vector size, + // take the half of it for the second part + unsigned SecondHalfAlignment = + (Alignment == MLD->getValueType(0).getSizeInBits()/8) ? + Alignment/2 : Alignment; + + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MLD->getPointerInfo(), + MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), + Alignment, MLD->getAAInfo(), MLD->getRanges()); + + Lo = DAG.getMaskedLoad(LoVT, DL, Chain, Ptr, MaskLo, Src0Lo, MMO); + + unsigned IncrementSize = LoMemVT.getSizeInBits()/8; + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, + DAG.getConstant(IncrementSize, Ptr.getValueType())); + + MMO = DAG.getMachineFunction(). + getMachineMemOperand(MLD->getPointerInfo(), + MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), + SecondHalfAlignment, MLD->getAAInfo(), MLD->getRanges()); + + Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, Src0Hi, MMO); + + AddToWorklist(Lo.getNode()); + AddToWorklist(Hi.getNode()); + + // Build a factor node to remember that this load is independent of the + // other one. + Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo.getValue(1), + Hi.getValue(1)); + + // Legalized the chain result - switch anything that used the old chain to + // use the new one. + DAG.ReplaceAllUsesOfValueWith(SDValue(MLD, 1), Chain); + + SDValue LoadRes = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); + + SDValue RetOps[] = { LoadRes, Chain }; + return DAG.getMergeValues(RetOps, DL); + } + return SDValue(); +} + SDValue DAGCombiner::visitVSELECT(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); |