diff options
Diffstat (limited to 'llvm/lib/Analysis/LoopAccessAnalysis.cpp')
-rw-r--r-- | llvm/lib/Analysis/LoopAccessAnalysis.cpp | 253 |
1 files changed, 215 insertions, 38 deletions
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index b11cd7e84a6..65a258698e4 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -48,6 +48,13 @@ static cl::opt<unsigned, true> RuntimeMemoryCheckThreshold( cl::location(VectorizerParams::RuntimeMemoryCheckThreshold), cl::init(8)); unsigned VectorizerParams::RuntimeMemoryCheckThreshold; +/// \brief The maximum iterations used to merge memory checks +static cl::opt<unsigned> MemoryCheckMergeThreshold( + "memory-check-merge-threshold", cl::Hidden, + cl::desc("Maximum number of comparisons done when trying to merge " + "runtime memory checks. (default = 100)"), + cl::init(100)); + /// Maximum SIMD width. const unsigned VectorizerParams::MaxVectorWidth = 64; @@ -113,8 +120,8 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE, } void LoopAccessInfo::RuntimePointerCheck::insert( - ScalarEvolution *SE, Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides) { + Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId, + const ValueToValueMap &Strides) { // Get the stride replaced scev. const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr); const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc); @@ -127,6 +134,136 @@ void LoopAccessInfo::RuntimePointerCheck::insert( IsWritePtr.push_back(WritePtr); DependencySetId.push_back(DepSetId); AliasSetId.push_back(ASId); + Exprs.push_back(Sc); +} + +bool LoopAccessInfo::RuntimePointerCheck::needsChecking( + const CheckingPtrGroup &M, const CheckingPtrGroup &N, + const SmallVectorImpl<int> *PtrPartition) const { + for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I) + for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J) + if (needsChecking(M.Members[I], N.Members[J], PtrPartition)) + return true; + return false; +} + +/// Compare \p I and \p J and return the minimum. +/// Return nullptr in case we couldn't find an answer. +static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J, + ScalarEvolution *SE) { + const SCEV *Diff = SE->getMinusSCEV(J, I); + const SCEVConstant *C = dyn_cast<const SCEVConstant>(Diff); + + if (!C) + return nullptr; + if (C->getValue()->isNegative()) + return J; + return I; +} + +bool LoopAccessInfo::RuntimePointerCheck::CheckingPtrGroup::addPointer( + unsigned Index) { + // Compare the starts and ends with the known minimum and maximum + // of this set. We need to know how we compare against the min/max + // of the set in order to be able to emit memchecks. + const SCEV *Min0 = getMinFromExprs(RtCheck.Starts[Index], Low, RtCheck.SE); + if (!Min0) + return false; + + const SCEV *Min1 = getMinFromExprs(RtCheck.Ends[Index], High, RtCheck.SE); + if (!Min1) + return false; + + // Update the low bound expression if we've found a new min value. + if (Min0 == RtCheck.Starts[Index]) + Low = RtCheck.Starts[Index]; + + // Update the high bound expression if we've found a new max value. + if (Min1 != RtCheck.Ends[Index]) + High = RtCheck.Ends[Index]; + + Members.push_back(Index); + return true; +} + +void LoopAccessInfo::RuntimePointerCheck::groupChecks( + MemoryDepChecker::DepCandidates &DepCands, + bool UseDependencies) { + // We build the groups from dependency candidates equivalence classes + // because: + // - We know that pointers in the same equivalence class share + // the same underlying object and therefore there is a chance + // that we can compare pointers + // - We wouldn't be able to merge two pointers for which we need + // to emit a memcheck. The classes in DepCands are already + // conveniently built such that no two pointers in the same + // class need checking against each other. + + // We use the following (greedy) algorithm to construct the groups + // For every pointer in the equivalence class: + // For each existing group: + // - if the difference between this pointer and the min/max bounds + // of the group is a constant, then make the pointer part of the + // group and update the min/max bounds of that group as required. + + CheckingGroups.clear(); + + // If we don't have the dependency partitions, construct a new + // checking pointer group for each pointer. + if (!UseDependencies) { + for (unsigned I = 0; I < Pointers.size(); ++I) + CheckingGroups.push_back(CheckingPtrGroup(I, *this)); + return; + } + + unsigned TotalComparisons = 0; + + DenseMap<Value *, unsigned> PositionMap; + for (unsigned Pointer = 0; Pointer < Pointers.size(); ++Pointer) + PositionMap[Pointers[Pointer]] = Pointer; + + // Go through all equivalence classes, get the the "pointer check groups" + // and add them to the overall solution. + for (auto DI = DepCands.begin(), DE = DepCands.end(); DI != DE; ++DI) { + if (!DI->isLeader()) + continue; + + SmallVector<CheckingPtrGroup, 2> Groups; + + for (auto MI = DepCands.member_begin(DI), ME = DepCands.member_end(); + MI != ME; ++MI) { + unsigned Pointer = PositionMap[MI->getPointer()]; + bool Merged = false; + + // Go through all the existing sets and see if we can find one + // which can include this pointer. + for (CheckingPtrGroup &Group : Groups) { + // Don't perform more than a certain amount of comparisons. + // This should limit the cost of grouping the pointers to something + // reasonable. If we do end up hitting this threshold, the algorithm + // will create separate groups for all remaining pointers. + if (TotalComparisons > MemoryCheckMergeThreshold) + break; + + TotalComparisons++; + + if (Group.addPointer(Pointer)) { + Merged = true; + break; + } + } + + if (!Merged) + // We couldn't add this pointer to any existing set or the threshold + // for the number of comparisons has been reached. Create a new group + // to hold the current pointer. + Groups.push_back(CheckingPtrGroup(Pointer, *this)); + } + + // We've computed the grouped checks for this partition. + // Save the results and continue with the next one. + std::copy(Groups.begin(), Groups.end(), std::back_inserter(CheckingGroups)); + } } bool LoopAccessInfo::RuntimePointerCheck::needsChecking( @@ -156,42 +293,71 @@ bool LoopAccessInfo::RuntimePointerCheck::needsChecking( void LoopAccessInfo::RuntimePointerCheck::print( raw_ostream &OS, unsigned Depth, const SmallVectorImpl<int> *PtrPartition) const { - unsigned NumPointers = Pointers.size(); - if (NumPointers == 0) - return; OS.indent(Depth) << "Run-time memory checks:\n"; + unsigned N = 0; - for (unsigned I = 0; I < NumPointers; ++I) - for (unsigned J = I + 1; J < NumPointers; ++J) - if (needsChecking(I, J, PtrPartition)) { - OS.indent(Depth) << N++ << ":\n"; - OS.indent(Depth + 2) << *Pointers[I]; - if (PtrPartition) - OS << " (Partition: " << (*PtrPartition)[I] << ")"; - OS << "\n"; - OS.indent(Depth + 2) << *Pointers[J]; - if (PtrPartition) - OS << " (Partition: " << (*PtrPartition)[J] << ")"; - OS << "\n"; + for (unsigned I = 0; I < CheckingGroups.size(); ++I) + for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) + if (needsChecking(CheckingGroups[I], CheckingGroups[J], PtrPartition)) { + OS.indent(Depth) << "Check " << N++ << ":\n"; + OS.indent(Depth + 2) << "Comparing group " << I << ":\n"; + + for (unsigned K = 0; K < CheckingGroups[I].Members.size(); ++K) { + OS.indent(Depth + 2) << *Pointers[CheckingGroups[I].Members[K]] + << "\n"; + if (PtrPartition) + OS << " (Partition: " + << (*PtrPartition)[CheckingGroups[I].Members[K]] << ")" + << "\n"; + } + + OS.indent(Depth + 2) << "Against group " << J << ":\n"; + + for (unsigned K = 0; K < CheckingGroups[J].Members.size(); ++K) { + OS.indent(Depth + 2) << *Pointers[CheckingGroups[J].Members[K]] + << "\n"; + if (PtrPartition) + OS << " (Partition: " + << (*PtrPartition)[CheckingGroups[J].Members[K]] << ")" + << "\n"; + } } + + OS.indent(Depth) << "Grouped accesses:\n"; + for (unsigned I = 0; I < CheckingGroups.size(); ++I) { + OS.indent(Depth + 2) << "Group " << I << ":\n"; + OS.indent(Depth + 4) << "(Low: " << *CheckingGroups[I].Low + << " High: " << *CheckingGroups[I].High << ")\n"; + for (unsigned J = 0; J < CheckingGroups[I].Members.size(); ++J) { + OS.indent(Depth + 6) << "Member: " << *Exprs[CheckingGroups[I].Members[J]] + << "\n"; + } + } } unsigned LoopAccessInfo::RuntimePointerCheck::getNumberOfChecks( const SmallVectorImpl<int> *PtrPartition) const { - unsigned NumPointers = Pointers.size(); + + unsigned NumPartitions = CheckingGroups.size(); unsigned CheckCount = 0; - for (unsigned I = 0; I < NumPointers; ++I) - for (unsigned J = I + 1; J < NumPointers; ++J) - if (needsChecking(I, J, PtrPartition)) + for (unsigned I = 0; I < NumPartitions; ++I) + for (unsigned J = I + 1; J < NumPartitions; ++J) + if (needsChecking(CheckingGroups[I], CheckingGroups[J], PtrPartition)) CheckCount++; return CheckCount; } bool LoopAccessInfo::RuntimePointerCheck::needsAnyChecking( const SmallVectorImpl<int> *PtrPartition) const { - return getNumberOfChecks(PtrPartition) != 0; + unsigned NumPointers = Pointers.size(); + + for (unsigned I = 0; I < NumPointers; ++I) + for (unsigned J = I + 1; J < NumPointers; ++J) + if (needsChecking(I, J, PtrPartition)) + return true; + return false; } namespace { @@ -341,7 +507,7 @@ bool AccessAnalysis::canCheckPtrAtRT( // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(SE, TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -387,6 +553,9 @@ bool AccessAnalysis::canCheckPtrAtRT( } } + if (NeedRTCheck && CanDoRT) + RtCheck.groupChecks(DepCands, IsDepCheckNeeded); + return CanDoRT; } @@ -1360,32 +1529,35 @@ std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeCheck( if (!PtrRtCheck.Need) return std::make_pair(nullptr, nullptr); - unsigned NumPointers = PtrRtCheck.Pointers.size(); - SmallVector<TrackingVH<Value> , 2> Starts; - SmallVector<TrackingVH<Value> , 2> Ends; + SmallVector<TrackingVH<Value>, 2> Starts; + SmallVector<TrackingVH<Value>, 2> Ends; LLVMContext &Ctx = Loc->getContext(); SCEVExpander Exp(*SE, DL, "induction"); Instruction *FirstInst = nullptr; - for (unsigned i = 0; i < NumPointers; ++i) { - Value *Ptr = PtrRtCheck.Pointers[i]; + for (unsigned i = 0; i < PtrRtCheck.CheckingGroups.size(); ++i) { + const RuntimePointerCheck::CheckingPtrGroup &CG = + PtrRtCheck.CheckingGroups[i]; + Value *Ptr = PtrRtCheck.Pointers[CG.Members[0]]; const SCEV *Sc = SE->getSCEV(Ptr); if (SE->isLoopInvariant(Sc, TheLoop)) { - DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << - *Ptr <<"\n"); + DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << *Ptr + << "\n"); Starts.push_back(Ptr); Ends.push_back(Ptr); } else { - DEBUG(dbgs() << "LAA: Adding RT check for range:" << *Ptr << '\n'); unsigned AS = Ptr->getType()->getPointerAddressSpace(); // Use this type for pointer arithmetic. Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + Value *Start = nullptr, *End = nullptr; - Value *Start = Exp.expandCodeFor(PtrRtCheck.Starts[i], PtrArithTy, Loc); - Value *End = Exp.expandCodeFor(PtrRtCheck.Ends[i], PtrArithTy, Loc); + DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); + Start = Exp.expandCodeFor(CG.Low, PtrArithTy, Loc); + End = Exp.expandCodeFor(CG.High, PtrArithTy, Loc); + DEBUG(dbgs() << "Start: " << *CG.Low << " End: " << *CG.High << "\n"); Starts.push_back(Start); Ends.push_back(End); } @@ -1394,9 +1566,14 @@ std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeCheck( IRBuilder<> ChkBuilder(Loc); // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; - for (unsigned i = 0; i < NumPointers; ++i) { - for (unsigned j = i+1; j < NumPointers; ++j) { - if (!PtrRtCheck.needsChecking(i, j, PtrPartition)) + for (unsigned i = 0; i < PtrRtCheck.CheckingGroups.size(); ++i) { + for (unsigned j = i + 1; j < PtrRtCheck.CheckingGroups.size(); ++j) { + const RuntimePointerCheck::CheckingPtrGroup &CGI = + PtrRtCheck.CheckingGroups[i]; + const RuntimePointerCheck::CheckingPtrGroup &CGJ = + PtrRtCheck.CheckingGroups[j]; + + if (!PtrRtCheck.needsChecking(CGI, CGJ, PtrPartition)) continue; unsigned AS0 = Starts[i]->getType()->getPointerAddressSpace(); @@ -1447,8 +1624,8 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), - TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), + : PtrRtCheck(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), TLI(TLI), + AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { if (canAnalyzeLoop()) |