diff options
Diffstat (limited to 'llvm/lib/CodeGen')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 145 |
1 files changed, 134 insertions, 11 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 29adcad22e1..eca5d8369eb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -250,6 +250,11 @@ namespace { SDValue SplitIndexingFromLoad(LoadSDNode *LD); bool SliceUpLoad(SDNode *N); + // Scalars have size 0 to distinguish from singleton vectors. + SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD); + bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val); + bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val); + /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed /// load. /// @@ -12762,6 +12767,133 @@ SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) { return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc); } +static inline int numVectorEltsOrZero(EVT T) { + return T.isVector() ? T.getVectorNumElements() : 0; +} + +bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) { + Val = ST->getValue(); + EVT STType = Val.getValueType(); + EVT STMemType = ST->getMemoryVT(); + if (STType == STMemType) + return true; + if (isTypeLegal(STMemType)) + return false; // fail. + if (STType.isFloatingPoint() && STMemType.isFloatingPoint() && + TLI.isOperationLegal(ISD::FTRUNC, STMemType)) { + Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val); + return true; + } + if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) && + STType.isInteger() && STMemType.isInteger()) { + Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val); + return true; + } + if (STType.getSizeInBits() == STMemType.getSizeInBits()) { + Val = DAG.getBitcast(STMemType, Val); + return true; + } + return false; // fail. +} + +bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) { + EVT LDMemType = LD->getMemoryVT(); + EVT LDType = LD->getValueType(0); + assert(Val.getValueType() == LDMemType && + "Attempting to extend value of non-matching type"); + if (LDType == LDMemType) + return true; + if (LDMemType.isInteger() && LDType.isInteger()) { + switch (LD->getExtensionType()) { + case ISD::NON_EXTLOAD: + Val = DAG.getBitcast(LDType, Val); + return true; + case ISD::EXTLOAD: + Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val); + return true; + case ISD::SEXTLOAD: + Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val); + return true; + case ISD::ZEXTLOAD: + Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val); + return true; + } + } + return false; +} + +SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) { + if (OptLevel == CodeGenOpt::None || LD->isVolatile()) + return SDValue(); + SDValue Chain = LD->getOperand(0); + StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode()); + if (!ST || ST->isVolatile()) + return SDValue(); + + EVT LDType = LD->getValueType(0); + EVT LDMemType = LD->getMemoryVT(); + EVT STMemType = ST->getMemoryVT(); + EVT STType = ST->getValue().getValueType(); + + BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG); + BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG); + int64_t Offset; + + bool STCoversLD = + BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset) && (Offset >= 0) && + (Offset * 8 <= LDMemType.getSizeInBits()) && + (Offset * 8 + LDMemType.getSizeInBits() <= STMemType.getSizeInBits()); + + if (!STCoversLD) + return SDValue(); + + // Memory as copy space (potentially masked). + if (Offset == 0 && LDType == STType && STMemType == LDMemType) { + // Simple case: Direct non-truncating forwarding + if (LDType.getSizeInBits() == LDMemType.getSizeInBits()) + return CombineTo(LD, ST->getValue(), Chain); + // Can we model the truncate and extension with an and mask? + if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() && + !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) { + // Mask to size of LDMemType + auto Mask = + DAG.getConstant(APInt::getLowBitsSet(STType.getSizeInBits(), + STMemType.getSizeInBits()), + SDLoc(ST), STType); + auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask); + return CombineTo(LD, Val, Chain); + } + } + + // TODO: Deal with nonzero offset. + if (LD->getBasePtr().isUndef() || Offset != 0) + return SDValue(); + // Model necessary truncations / extenstions. + SDValue Val; + // Truncate Value To Stored Memory Size. + do { + if (!getTruncatedStoreValue(ST, Val)) + continue; + if (!isTypeLegal(LDMemType)) + continue; + if (STMemType != LDMemType) { + if (numVectorEltsOrZero(STMemType) == numVectorEltsOrZero(LDMemType) && + STMemType.isInteger() && LDMemType.isInteger()) + Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val); + else + continue; + } + if (!extendLoadedValueToExtension(LD, Val)) + continue; + return CombineTo(LD, Val, Chain); + } while (false); + + // On failure, cleanup dead nodes we may have created. + if (Val->use_empty()) + deleteAndRecombine(Val.getNode()); + return SDValue(); +} + SDValue DAGCombiner::visitLOAD(SDNode *N) { LoadSDNode *LD = cast<LoadSDNode>(N); SDValue Chain = LD->getChain(); @@ -12828,17 +12960,8 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // If this load is directly stored, replace the load value with the stored // value. - // TODO: Handle store large -> read small portion. - // TODO: Handle TRUNCSTORE/LOADEXT - if (OptLevel != CodeGenOpt::None && - ISD::isNormalLoad(N) && !LD->isVolatile()) { - if (ISD::isNON_TRUNCStore(Chain.getNode())) { - StoreSDNode *PrevST = cast<StoreSDNode>(Chain); - if (PrevST->getBasePtr() == Ptr && - PrevST->getValue().getValueType() == N->getValueType(0)) - return CombineTo(N, PrevST->getOperand(1), Chain); - } - } + if (auto V = ForwardStoreValueToDirectLoad(LD)) + return V; // Try to infer better alignment information than the load already has. if (OptLevel != CodeGenOpt::None && LD->isUnindexed()) { |

