diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 35 |
1 files changed, 24 insertions, 11 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 784bf6d58c5..915046048ff 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -33429,8 +33429,19 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, if (Src.getOpcode() == ISD::SCALAR_TO_VECTOR) return DAG.getNode(X86ISD::VBROADCAST, DL, VT, Src.getOperand(0)); + // Share broadcast with the longest vector and extract low subvector (free). + for (SDNode *User : Src->uses()) + if (User != N.getNode() && + (User->getOpcode() == X86ISD::VBROADCAST || + User->getOpcode() == X86ISD::VBROADCAST_LOAD) && + User->getValueSizeInBits(0) > VT.getSizeInBits()) { + return extractSubVector(SDValue(User, 0), 0, DAG, DL, + VT.getSizeInBits()); + } + // vbroadcast(scalarload X) -> vbroadcast_load X - if (!SrcVT.isVector() && Src.hasOneUse() && + // For float loads, extract other uses of the scalar from the broadcast. + if (!SrcVT.isVector() && (Src.hasOneUse() || VT.isFloatingPoint()) && ISD::isNormalLoad(Src.getNode())) { LoadSDNode *LN = cast<LoadSDNode>(Src); SDVTList Tys = DAG.getVTList(VT, MVT::Other); @@ -33438,17 +33449,19 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, SDValue BcastLd = DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, DL, Tys, Ops, LN->getMemoryVT(), LN->getMemOperand()); - DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1)); - return BcastLd; - } - - // Share broadcast with the longest vector and extract low subvector (free). - for (SDNode *User : Src->uses()) - if (User != N.getNode() && User->getOpcode() == X86ISD::VBROADCAST && - User->getValueSizeInBits(0) > VT.getSizeInBits()) { - return extractSubVector(SDValue(User, 0), 0, DAG, DL, - VT.getSizeInBits()); + // If the load value is used only by N, replace it via CombineTo N. + bool NoReplaceExtract = Src.hasOneUse(); + DCI.CombineTo(N.getNode(), BcastLd); + if (NoReplaceExtract) { + DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1)); + DCI.recursivelyDeleteUnusedNodes(LN); + } else { + SDValue Scl = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcVT, BcastLd, + DAG.getIntPtrConstant(0, DL)); + DCI.CombineTo(LN, Scl, BcastLd.getValue(1)); } + return N; // Return N so it doesn't get rechecked! + } return SDValue(); } |