diff options
Diffstat (limited to 'llvm/lib/CodeGen')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 60101d20116..81a54d90ed8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9345,6 +9345,35 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, return SDValue(N, 0); // Return N so it doesn't get rechecked! } +static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, + const TargetLowering &TLI, EVT VT, + SDNode *N, SDValue N0, + ISD::LoadExtType ExtLoadType, + ISD::NodeType ExtOpc) { + if (!N0.hasOneUse()) + return SDValue(); + + MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0); + if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD) + return SDValue(); + + if (!TLI.isLoadExtLegal(ExtLoadType, VT, Ld->getValueType(0))) + return SDValue(); + + if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0))) + return SDValue(); + + SDLoc dl(Ld); + SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru()); + SDValue NewLoad = DAG.getMaskedLoad(VT, dl, Ld->getChain(), + Ld->getBasePtr(), Ld->getMask(), + PassThru, Ld->getMemoryVT(), + Ld->getMemOperand(), ExtLoadType, + Ld->isExpandingLoad()); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1)); + return NewLoad; +} + static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG, bool LegalOperations) { assert((N->getOpcode() == ISD::SIGN_EXTEND || @@ -9445,6 +9474,11 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { ISD::SEXTLOAD, ISD::SIGN_EXTEND)) return foldedExt; + if (SDValue foldedExt = + tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD, + ISD::SIGN_EXTEND)) + return foldedExt; + // fold (sext (load x)) to multiple smaller sextloads. // Only on illegal but splittable vectors. if (SDValue ExtLoad = CombineExtLoad(N)) @@ -9733,6 +9767,11 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { ISD::ZEXTLOAD, ISD::ZERO_EXTEND)) return foldedExt; + if (SDValue foldedExt = + tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD, + ISD::ZERO_EXTEND)) + return foldedExt; + // fold (zext (load x)) to multiple smaller zextloads. // Only on illegal but splittable vectors. if (SDValue ExtLoad = CombineExtLoad(N)) |

