diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 134 |
1 files changed, 97 insertions, 37 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 6d162bede1d..afef7ece03b 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -64,6 +64,14 @@ EVT AMDGPUTargetLowering::getEquivalentLoadRegType(LLVMContext &Ctx, EVT VT) { return EVT::getVectorVT(Ctx, MVT::i32, StoreSize / 32); } +EVT AMDGPUTargetLowering::getEquivalentBitType(LLVMContext &Ctx, EVT VT) { + unsigned StoreSize = VT.getStoreSizeInBits(); + if (StoreSize <= 32) + return EVT::getIntegerVT(Ctx, StoreSize); + + return EVT::getVectorVT(Ctx, MVT::i32, StoreSize / 32); +} + AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, const AMDGPUSubtarget &STI) : TargetLowering(TM), Subtarget(&STI) { @@ -535,15 +543,17 @@ bool AMDGPUTargetLowering::shouldReduceLoadWidth(SDNode *N, bool AMDGPUTargetLowering::isLoadBitCastBeneficial(EVT LoadTy, EVT CastTy) const { - if (LoadTy.getSizeInBits() != CastTy.getSizeInBits()) - return true; - unsigned LScalarSize = LoadTy.getScalarType().getSizeInBits(); - unsigned CastScalarSize = CastTy.getScalarType().getSizeInBits(); + assert(LoadTy.getSizeInBits() == CastTy.getSizeInBits()); + + if (LoadTy.getScalarType() == MVT::i32) + return false; + + unsigned LScalarSize = LoadTy.getScalarSizeInBits(); + unsigned CastScalarSize = CastTy.getScalarSizeInBits(); - return ((LScalarSize <= CastScalarSize) || - (CastScalarSize >= 32) || - (LScalarSize < 32)); + return (LScalarSize < CastScalarSize) || + (CastScalarSize >= 32); } // SI+ has instructions for cttz / ctlz for 32-bit values. This is probably also @@ -2161,56 +2171,105 @@ static SDValue constantFoldBFE(SelectionDAG &DAG, IntTy Src0, uint32_t Offset, return DAG.getConstant(Src0 >> Offset, DL, MVT::i32); } -static bool usesAllNormalStores(SDNode *LoadVal) { - for (SDNode::use_iterator I = LoadVal->use_begin(); !I.atEnd(); ++I) { - if (!ISD::isNormalStore(*I)) +static bool hasVolatileUser(SDNode *Val) { + for (SDNode *U : Val->uses()) { + if (MemSDNode *M = dyn_cast<MemSDNode>(U)) { + if (M->isVolatile()) + return true; + } + } + + return false; +} + +bool AMDGPUTargetLowering::shouldCombineMemoryType(const MemSDNode *M) const { + EVT VT = M->getMemoryVT(); + + // i32 vectors are the canonical memory type. + if (VT.getScalarType() == MVT::i32 || isTypeLegal(VT)) + return false; + + + if (!VT.isByteSized()) + return false; + + unsigned Size = VT.getStoreSize(); + + if ((Size == 1 || Size == 2 || Size == 4) && !VT.isVector()) + return false; + + if (Size == 3 || (Size > 4 && (Size % 4 != 0))) + return false; + + unsigned Align = M->getAlignment(); + if (Align < Size) { + bool IsFast; + if (!allowsMisalignedMemoryAccesses(VT, M->getAddressSpace(), Align, &IsFast) || + !IsFast) { return false; + } } return true; } -// If we have a copy of an illegal type, replace it with a load / store of an -// equivalently sized legal type. This avoids intermediate bit pack / unpack -// instructions emitted when handling extloads and truncstores. Ideally we could -// recognize the pack / unpack pattern to eliminate it. +// Replace load of an illegal type with a store of a bitcast to a friendlier +// type. +SDValue AMDGPUTargetLowering::performLoadCombine(SDNode *N, + DAGCombinerInfo &DCI) const { + if (!DCI.isBeforeLegalize()) + return SDValue(); + + LoadSDNode *LN = cast<LoadSDNode>(N); + if (LN->isVolatile() || !ISD::isNormalLoad(LN) || hasVolatileUser(LN)) + return SDValue(); + + if (!shouldCombineMemoryType(LN)) + return SDValue(); + + SDLoc SL(N); + SelectionDAG &DAG = DCI.DAG; + EVT VT = LN->getMemoryVT(); + EVT NewVT = getEquivalentMemType(*DAG.getContext(), VT); + + SDValue NewLoad + = DAG.getLoad(NewVT, SL, LN->getChain(), + LN->getBasePtr(), LN->getMemOperand()); + + SDValue BC = DAG.getNode(ISD::BITCAST, SL, VT, NewLoad); + DCI.CombineTo(N, BC, NewLoad.getValue(1)); + return SDValue(N, 0); +} + +// Replace store of an illegal type with a store of a bitcast to a friendlier +// type. SDValue AMDGPUTargetLowering::performStoreCombine(SDNode *N, DAGCombinerInfo &DCI) const { if (!DCI.isBeforeLegalize()) return SDValue(); StoreSDNode *SN = cast<StoreSDNode>(N); - SDValue Value = SN->getValue(); - EVT VT = Value.getValueType(); - - if (isTypeLegal(VT) || SN->isVolatile() || - !ISD::isNormalLoad(Value.getNode()) || VT.getSizeInBits() < 8) + if (SN->isVolatile() || !ISD::isNormalStore(SN)) return SDValue(); - LoadSDNode *LoadVal = cast<LoadSDNode>(Value); - if (LoadVal->isVolatile() || !usesAllNormalStores(LoadVal)) + if (!shouldCombineMemoryType(SN)) return SDValue(); - EVT MemVT = LoadVal->getMemoryVT(); - if (!MemVT.isRound()) - return SDValue(); + SDValue Val = SN->getValue(); + EVT VT = SN->getMemoryVT(); SDLoc SL(N); SelectionDAG &DAG = DCI.DAG; - EVT LoadVT = getEquivalentMemType(*DAG.getContext(), MemVT); + EVT NewVT = getEquivalentMemType(*DAG.getContext(), VT); - SDValue NewLoad = DAG.getLoad(ISD::UNINDEXED, ISD::NON_EXTLOAD, - LoadVT, SL, - LoadVal->getChain(), - LoadVal->getBasePtr(), - LoadVal->getOffset(), - LoadVT, - LoadVal->getMemOperand()); - - SDValue CastLoad = DAG.getNode(ISD::BITCAST, SL, VT, NewLoad.getValue(0)); - DCI.CombineTo(LoadVal, CastLoad, NewLoad.getValue(1), false); + bool OtherUses = !Val.hasOneUse(); + SDValue CastVal = DAG.getNode(ISD::BITCAST, SL, NewVT, Val); + if (OtherUses) { + SDValue CastBack = DAG.getNode(ISD::BITCAST, SL, VT, CastVal); + DAG.ReplaceAllUsesOfValueWith(Val, CastBack); + } - return DAG.getStore(SN->getChain(), SL, NewLoad, + return DAG.getStore(SN->getChain(), SL, CastVal, SN->getBasePtr(), SN->getMemOperand()); } @@ -2645,7 +2704,8 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N, break; } - + case ISD::LOAD: + return performLoadCombine(N, DCI); case ISD::STORE: return performStoreCombine(N, DCI); } |