diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp | 235 |
1 files changed, 189 insertions, 46 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 4521640e394..edcc0156e88 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -26,22 +26,20 @@ // i64 and larger types when i64 is legal and the value has few bits set. It // would be good to enhance isel to emit a loop for ctpop in this case. // -// We should enhance the memset/memcpy recognition to handle multiple stores in -// the loop. This would handle things like: -// void foo(_Complex float *P) -// for (i) { __real__(*P) = 0; __imag__(*P) = 0; } -// // This could recognize common matrix multiplies and dot product idioms and // replace them with calls to BLAS (if linked in??). // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -108,7 +106,9 @@ public: private: typedef SmallVector<StoreInst *, 8> StoreList; - StoreList StoreRefsForMemset; + typedef MapVector<Value *, StoreList> StoreListMap; + StoreListMap StoreRefsForMemset; + StoreListMap StoreRefsForMemsetPattern; StoreList StoreRefsForMemcpy; bool HasMemset; bool HasMemsetPattern; @@ -122,14 +122,18 @@ private: SmallVectorImpl<BasicBlock *> &ExitBlocks); void collectStores(BasicBlock *BB); - bool isLegalStore(StoreInst *SI, bool &ForMemset, bool &ForMemcpy); - bool processLoopStore(StoreInst *SI, const SCEV *BECount); + bool isLegalStore(StoreInst *SI, bool &ForMemset, bool &ForMemsetPattern, + bool &ForMemcpy); + bool processLoopStores(SmallVectorImpl<StoreInst *> &SL, const SCEV *BECount, + bool ForMemset); bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize, unsigned StoreAlignment, Value *StoredVal, - Instruction *TheStore, const SCEVAddRecExpr *Ev, - const SCEV *BECount, bool NegStride); + Instruction *TheStore, + SmallPtrSetImpl<Instruction *> &Stores, + const SCEVAddRecExpr *Ev, const SCEV *BECount, + bool NegStride); bool processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount); /// @} @@ -305,7 +309,7 @@ static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) { } bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, - bool &ForMemcpy) { + bool &ForMemsetPattern, bool &ForMemcpy) { // Don't touch volatile stores. if (!SI->isSimple()) return false; @@ -353,7 +357,7 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, StorePtr->getType()->getPointerAddressSpace() == 0 && (PatternValue = getMemSetPatternValue(StoredVal, DL))) { // It looks like we can use PatternValue! - ForMemset = true; + ForMemsetPattern = true; return true; } @@ -393,6 +397,7 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, void LoopIdiomRecognize::collectStores(BasicBlock *BB) { StoreRefsForMemset.clear(); + StoreRefsForMemsetPattern.clear(); StoreRefsForMemcpy.clear(); for (Instruction &I : *BB) { StoreInst *SI = dyn_cast<StoreInst>(&I); @@ -400,15 +405,22 @@ void LoopIdiomRecognize::collectStores(BasicBlock *BB) { continue; bool ForMemset = false; + bool ForMemsetPattern = false; bool ForMemcpy = false; // Make sure this is a strided store with a constant stride. - if (!isLegalStore(SI, ForMemset, ForMemcpy)) + if (!isLegalStore(SI, ForMemset, ForMemsetPattern, ForMemcpy)) continue; // Save the store locations. - if (ForMemset) - StoreRefsForMemset.push_back(SI); - else if (ForMemcpy) + if (ForMemset) { + // Find the base pointer. + Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), *DL); + StoreRefsForMemset[Ptr].push_back(SI); + } else if (ForMemsetPattern) { + // Find the base pointer. + Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), *DL); + StoreRefsForMemsetPattern[Ptr].push_back(SI); + } else if (ForMemcpy) StoreRefsForMemcpy.push_back(SI); } } @@ -430,9 +442,14 @@ bool LoopIdiomRecognize::runOnLoopBlock( // Look for store instructions, which may be optimized to memset/memcpy. collectStores(BB); - // Look for a single store which can be optimized into a memset. - for (auto &SI : StoreRefsForMemset) - MadeChange |= processLoopStore(SI, BECount); + // Look for a single store or sets of stores with a common base, which can be + // optimized into a memset (memset_pattern). The latter most commonly happens + // with structs and handunrolled loops. + for (auto &SL : StoreRefsForMemset) + MadeChange |= processLoopStores(SL.second, BECount, true); + + for (auto &SL : StoreRefsForMemsetPattern) + MadeChange |= processLoopStores(SL.second, BECount, false); // Optimize the store into a memcpy, if it feeds an similarly strided load. for (auto &SI : StoreRefsForMemcpy) @@ -458,26 +475,144 @@ bool LoopIdiomRecognize::runOnLoopBlock( return MadeChange; } -/// processLoopStore - See if this store can be promoted to a memset. -bool LoopIdiomRecognize::processLoopStore(StoreInst *SI, const SCEV *BECount) { - assert(SI->isSimple() && "Expected only non-volatile stores."); +/// processLoopStores - See if this store(s) can be promoted to a memset. +bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, + const SCEV *BECount, + bool ForMemset) { + // Try to find consecutive stores that can be transformed into memsets. + SetVector<StoreInst *> Heads, Tails; + SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain; + + // Do a quadratic search on all of the given stores and find + // all of the pairs of stores that follow each other. + SmallVector<unsigned, 16> IndexQueue; + for (unsigned i = 0, e = SL.size(); i < e; ++i) { + assert(SL[i]->isSimple() && "Expected only non-volatile stores."); + + Value *FirstStoredVal = SL[i]->getValueOperand(); + Value *FirstStorePtr = SL[i]->getPointerOperand(); + const SCEVAddRecExpr *FirstStoreEv = + cast<SCEVAddRecExpr>(SE->getSCEV(FirstStorePtr)); + unsigned FirstStride = getStoreStride(FirstStoreEv); + unsigned FirstStoreSize = getStoreSizeInBytes(SL[i], DL); + + // See if we can optimize just this store in isolation. + if (FirstStride == FirstStoreSize || FirstStride == -FirstStoreSize) { + Heads.insert(SL[i]); + continue; + } - Value *StoredVal = SI->getValueOperand(); - Value *StorePtr = SI->getPointerOperand(); + Value *FirstSplatValue = nullptr; + Constant *FirstPatternValue = nullptr; - // Check to see if the stride matches the size of the store. If so, then we - // know that every byte is touched in the loop. - const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); - unsigned Stride = getStoreStride(StoreEv); - unsigned StoreSize = getStoreSizeInBytes(SI, DL); - if (StoreSize != Stride && StoreSize != -Stride) - return false; + if (ForMemset) + FirstSplatValue = isBytewiseValue(FirstStoredVal); + else + FirstPatternValue = getMemSetPatternValue(FirstStoredVal, DL); + + assert((FirstSplatValue || FirstPatternValue) && + "Expected either splat value or pattern value."); + + IndexQueue.clear(); + // If a store has multiple consecutive store candidates, search Stores + // array according to the sequence: from i+1 to e, then from i-1 to 0. + // This is because usually pairing with immediate succeeding or preceding + // candidate create the best chance to find memset opportunity. + unsigned j = 0; + for (j = i + 1; j < e; ++j) + IndexQueue.push_back(j); + for (j = i; j > 0; --j) + IndexQueue.push_back(j - 1); + + for (auto &k : IndexQueue) { + assert(SL[k]->isSimple() && "Expected only non-volatile stores."); + Value *SecondStorePtr = SL[k]->getPointerOperand(); + const SCEVAddRecExpr *SecondStoreEv = + cast<SCEVAddRecExpr>(SE->getSCEV(SecondStorePtr)); + unsigned SecondStride = getStoreStride(SecondStoreEv); + + if (FirstStride != SecondStride) + continue; - bool NegStride = StoreSize == -Stride; + Value *SecondStoredVal = SL[k]->getValueOperand(); + Value *SecondSplatValue = nullptr; + Constant *SecondPatternValue = nullptr; + + if (ForMemset) + SecondSplatValue = isBytewiseValue(SecondStoredVal); + else + SecondPatternValue = getMemSetPatternValue(SecondStoredVal, DL); + + assert((SecondSplatValue || SecondPatternValue) && + "Expected either splat value or pattern value."); + + if (isConsecutiveAccess(SL[i], SL[k], *DL, *SE, false)) { + if (ForMemset) { + if (FirstSplatValue != SecondSplatValue) + continue; + } else { + if (FirstPatternValue != SecondPatternValue) + continue; + } + Tails.insert(SL[k]); + Heads.insert(SL[i]); + ConsecutiveChain[SL[i]] = SL[k]; + break; + } + } + } + + // We may run into multiple chains that merge into a single chain. We mark the + // stores that we transformed so that we don't visit the same store twice. + SmallPtrSet<Value *, 16> TransformedStores; + bool Changed = false; + + // For stores that start but don't end a link in the chain: + for (SetVector<StoreInst *>::iterator it = Heads.begin(), e = Heads.end(); + it != e; ++it) { + if (Tails.count(*it)) + continue; + + // We found a store instr that starts a chain. Now follow the chain and try + // to transform it. + SmallPtrSet<Instruction *, 8> AdjacentStores; + StoreInst *I = *it; + + StoreInst *HeadStore = I; + unsigned StoreSize = 0; + + // Collect the chain into a list. + while (Tails.count(I) || Heads.count(I)) { + if (TransformedStores.count(I)) + break; + AdjacentStores.insert(I); + + StoreSize += getStoreSizeInBytes(I, DL); + // Move to the next value in the chain. + I = ConsecutiveChain[I]; + } + + Value *StoredVal = HeadStore->getValueOperand(); + Value *StorePtr = HeadStore->getPointerOperand(); + const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); + unsigned Stride = getStoreStride(StoreEv); + + // Check to see if the stride matches the size of the stores. If so, then + // we know that every byte is touched in the loop. + if (StoreSize != Stride && StoreSize != -Stride) + continue; + + bool NegStride = StoreSize == -Stride; + + if (processLoopStridedStore(StorePtr, StoreSize, HeadStore->getAlignment(), + StoredVal, HeadStore, AdjacentStores, StoreEv, + BECount, NegStride)) { + TransformedStores.insert(AdjacentStores.begin(), AdjacentStores.end()); + Changed = true; + } + } - // See if we can optimize just this store in isolation. - return processLoopStridedStore(StorePtr, StoreSize, SI->getAlignment(), - StoredVal, SI, StoreEv, BECount, NegStride); + return Changed; } /// processLoopMemSet - See if this memset can be promoted to a large memset. @@ -520,18 +655,21 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, if (!SplatValue || !CurLoop->isLoopInvariant(SplatValue)) return false; + SmallPtrSet<Instruction *, 1> MSIs; + MSIs.insert(MSI); return processLoopStridedStore(Pointer, (unsigned)SizeInBytes, - MSI->getAlignment(), SplatValue, MSI, Ev, + MSI->getAlignment(), SplatValue, MSI, MSIs, Ev, BECount, /*NegStride=*/false); } /// mayLoopAccessLocation - Return true if the specified loop might access the /// specified pointer location, which is a loop-strided access. The 'Access' /// argument specifies what the verboten forms of access are (read or write). -static bool mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, - const SCEV *BECount, unsigned StoreSize, - AliasAnalysis &AA, - Instruction *IgnoredStore) { +static bool +mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, + const SCEV *BECount, unsigned StoreSize, + AliasAnalysis &AA, + SmallPtrSetImpl<Instruction *> &IgnoredStores) { // Get the location that may be stored across the loop. Since the access is // strided positively through memory, we say that the modified location starts // at the pointer and has infinite size. @@ -551,7 +689,8 @@ static bool mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, for (Loop::block_iterator BI = L->block_begin(), E = L->block_end(); BI != E; ++BI) for (BasicBlock::iterator I = (*BI)->begin(), E = (*BI)->end(); I != E; ++I) - if (&*I != IgnoredStore && (AA.getModRefInfo(&*I, StoreLoc) & Access)) + if (IgnoredStores.count(&*I) == 0 && + (AA.getModRefInfo(&*I, StoreLoc) & Access)) return true; return false; @@ -574,7 +713,8 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( Value *DestPtr, unsigned StoreSize, unsigned StoreAlignment, - Value *StoredVal, Instruction *TheStore, const SCEVAddRecExpr *Ev, + Value *StoredVal, Instruction *TheStore, + SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, bool NegStride) { Value *SplatValue = isBytewiseValue(StoredVal); Constant *PatternValue = nullptr; @@ -609,7 +749,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( Value *BasePtr = Expander.expandCodeFor(Start, DestInt8PtrTy, Preheader->getTerminator()); if (mayLoopAccessLocation(BasePtr, MRI_ModRef, CurLoop, BECount, StoreSize, - *AA, TheStore)) { + *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. RecursivelyDeleteTriviallyDeadInstructions(BasePtr, TLI); @@ -662,7 +802,8 @@ bool LoopIdiomRecognize::processLoopStridedStore( // Okay, the memset has been formed. Zap the original store and anything that // feeds into it. - deleteDeadInstruction(TheStore, TLI); + for (auto *I : Stores) + deleteDeadInstruction(I, TLI); ++NumMemSet; return true; } @@ -714,8 +855,10 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, Value *StoreBasePtr = Expander.expandCodeFor( StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator()); + SmallPtrSet<Instruction *, 1> Stores; + Stores.insert(SI); if (mayLoopAccessLocation(StoreBasePtr, MRI_ModRef, CurLoop, BECount, - StoreSize, *AA, SI)) { + StoreSize, *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); @@ -735,7 +878,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator()); if (mayLoopAccessLocation(LoadBasePtr, MRI_Mod, CurLoop, BECount, StoreSize, - *AA, SI)) { + *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); |