summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp57
-rw-r--r--llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp28
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnroll.cpp13
3 files changed, 65 insertions, 33 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 8c6ddffb87b..f03051ddb4a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5424,6 +5424,10 @@ const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
return getBackedgeTakenInfo(L).getMax(this);
}
+bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
+ return getBackedgeTakenInfo(L).isMaxOrZero(this);
+}
+
/// Push PHI nodes in the header of the given loop onto the given Worklist.
static void
PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
@@ -5656,6 +5660,13 @@ ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
return getMax();
}
+bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const {
+ auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
+ return !ENT.hasAlwaysTruePredicate();
+ };
+ return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
+}
+
bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
ScalarEvolution *SE) const {
if (getMax() && getMax() != SE->getCouldNotCompute() &&
@@ -5675,8 +5686,8 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo>
&&ExitCounts,
- bool Complete, const SCEV *MaxCount)
- : MaxAndComplete(MaxCount, Complete) {
+ bool Complete, const SCEV *MaxCount, bool MaxOrZero)
+ : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) {
typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
ExitNotTaken.reserve(ExitCounts.size());
std::transform(
@@ -5714,6 +5725,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
const SCEV *MustExitMaxBECount = nullptr;
const SCEV *MayExitMaxBECount = nullptr;
+ bool MustExitMaxOrZero = false;
// Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
// and compute maxBECount.
@@ -5746,9 +5758,10 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
// computable EL.MaxNotTaken.
if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
DT.dominates(ExitBB, Latch)) {
- if (!MustExitMaxBECount)
+ if (!MustExitMaxBECount) {
MustExitMaxBECount = EL.MaxNotTaken;
- else {
+ MustExitMaxOrZero = EL.MaxOrZero;
+ } else {
MustExitMaxBECount =
getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
}
@@ -5763,8 +5776,11 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
}
const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
(MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
+ // The loop backedge will be taken the maximum or zero times if there's
+ // a single exit that must be taken the maximum or zero times.
+ bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
- MaxBECount);
+ MaxBECount, MaxOrZero);
}
ScalarEvolution::ExitLimit
@@ -5901,7 +5917,8 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
!isa<SCEVCouldNotCompute>(BECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
+ return ExitLimit(BECount, MaxBECount, false,
+ {&EL0.Predicates, &EL1.Predicates});
}
if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or.
@@ -5940,7 +5957,8 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
BECount = EL0.ExactNotTaken;
}
- return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
+ return ExitLimit(BECount, MaxBECount, false,
+ {&EL0.Predicates, &EL1.Predicates});
}
}
@@ -6325,7 +6343,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *UpperBound =
getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
- return ExitLimit(getCouldNotCompute(), UpperBound);
+ return ExitLimit(getCouldNotCompute(), UpperBound, false);
}
return getCouldNotCompute();
@@ -7121,7 +7139,8 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// should not accept a root of 2.
const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
if (Val->isZero())
- return ExitLimit(R1, R1, Predicates); // We found a quadratic root!
+ // We found a quadratic root!
+ return ExitLimit(R1, R1, false, Predicates);
}
}
return getCouldNotCompute();
@@ -7178,7 +7197,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
else
MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
: -CR.getUnsignedMin());
- return ExitLimit(Distance, MaxBECount, Predicates);
+ return ExitLimit(Distance, MaxBECount, false, Predicates);
}
// As a special case, handle the instance where Step is a positive power of
@@ -7233,7 +7252,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
const SCEV *Limit =
getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
- return ExitLimit(Limit, Limit, Predicates);
+ return ExitLimit(Limit, Limit, false, Predicates);
}
}
@@ -7246,14 +7265,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
loopHasNoAbnormalExits(AddRec->getLoop())) {
const SCEV *Exact =
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
- return ExitLimit(Exact, Exact, Predicates);
+ return ExitLimit(Exact, Exact, false, Predicates);
}
// Then, try to solve the above equation provided that Start is constant.
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
const SCEV *E = SolveLinEquationWithOverflow(
StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
- return ExitLimit(E, E, Predicates);
+ return ExitLimit(E, E, false, Predicates);
}
return getCouldNotCompute();
}
@@ -8695,14 +8714,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
}
const SCEV *MaxBECount;
+ bool MaxOrZero = false;
if (isa<SCEVConstant>(BECount))
MaxBECount = BECount;
- else if (isa<SCEVConstant>(BECountIfBackedgeTaken))
+ else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
// If we know exactly how many times the backedge will be taken if it's
// taken at least once, then the backedge count will either be that or
// zero.
MaxBECount = BECountIfBackedgeTaken;
- else {
+ MaxOrZero = true;
+ } else {
// Calculate the maximum backedge count based on the range of values
// permitted by Start, End, and Stride.
APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin()
@@ -8739,7 +8760,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount, Predicates);
+ return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
}
ScalarEvolution::ExitLimit
@@ -8816,7 +8837,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount, Predicates);
+ return ExitLimit(BECount, MaxBECount, false, Predicates);
}
const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
@@ -9598,6 +9619,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
+ if (SE->isBackedgeTakenCountMaxOrZero(L))
+ OS << ", actual taken count either this or zero.";
} else {
OS << "Unpredictable max backedge-taken count. ";
}
diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
index e4b181b2c7c..b81cf842af8 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
@@ -1000,14 +1000,22 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
if (Convergent)
UP.AllowRemainder = false;
- // Try to find the trip count upper bound if it is allowed and we cannot find
- // exact trip count.
- if (UP.UpperBound) {
- if (!TripCount) {
- MaxTripCount = SE->getSmallConstantMaxTripCount(L);
- // Only unroll with small upper bound.
- if (MaxTripCount > UnrollMaxUpperBound)
- MaxTripCount = 0;
+ // Try to find the trip count upper bound if we cannot find the exact trip
+ // count.
+ bool MaxOrZero = false;
+ if (!TripCount) {
+ MaxTripCount = SE->getSmallConstantMaxTripCount(L);
+ MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L);
+ // We can unroll by the upper bound amount if it's generally allowed or if
+ // we know that the loop is executed either the upper bound or zero times.
+ // (MaxOrZero unrolling keeps only the first loop test, so the number of
+ // loop tests remains the same compared to the non-unrolled version, whereas
+ // the generic upper bound unrolling keeps all but the last loop test so the
+ // number of loop tests goes up which may end up being worse on targets with
+ // constriained branch predictor resources so is controlled by an option.)
+ // In addition we only unroll small upper bounds.
+ if (!(UP.UpperBound || MaxOrZero) || MaxTripCount > UnrollMaxUpperBound) {
+ MaxTripCount = 0;
}
}
@@ -1025,8 +1033,8 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
// Unroll the loop.
if (!UnrollLoop(L, UP.Count, TripCount, UP.Force, UP.Runtime,
- UP.AllowExpensiveTripCount, UseUpperBound, TripMultiple, LI,
- SE, &DT, &AC, &ORE, PreserveLCSSA))
+ UP.AllowExpensiveTripCount, UseUpperBound, MaxOrZero,
+ TripMultiple, LI, SE, &DT, &AC, &ORE, PreserveLCSSA))
return false;
// If loop has an unroll count pragma or unrolled by explicitly set count
diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
index 847224753f4..912f9e0abe7 100644
--- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
@@ -189,7 +189,8 @@ static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks,
///
/// PreserveCondBr indicates whether the conditional branch of the LatchBlock
/// needs to be preserved. It is needed when we use trip count upper bound to
-/// fully unroll the loop.
+/// fully unroll the loop. If PreserveOnlyFirst is also set then only the first
+/// conditional branch needs to be preserved.
///
/// Similarly, TripMultiple divides the number of times that the LatchBlock may
/// execute without exiting the loop.
@@ -207,10 +208,10 @@ static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks,
/// DominatorTree if they are non-null.
bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
bool AllowRuntime, bool AllowExpensiveTripCount,
- bool PreserveCondBr, unsigned TripMultiple, LoopInfo *LI,
- ScalarEvolution *SE, DominatorTree *DT,
- AssumptionCache *AC, OptimizationRemarkEmitter *ORE,
- bool PreserveLCSSA) {
+ bool PreserveCondBr, bool PreserveOnlyFirst,
+ unsigned TripMultiple, LoopInfo *LI, ScalarEvolution *SE,
+ DominatorTree *DT, AssumptionCache *AC,
+ OptimizationRemarkEmitter *ORE, bool PreserveLCSSA) {
BasicBlock *Preheader = L->getLoopPreheader();
if (!Preheader) {
DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n");
@@ -550,7 +551,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
assert(NeedConditional &&
"NeedCondition cannot be modified by both complete "
"unrolling and runtime unrolling");
- NeedConditional = (PreserveCondBr && j);
+ NeedConditional = (PreserveCondBr && j && !(PreserveOnlyFirst && i != 0));
} else if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) {
// If we know the trip count or a multiple of it, we can safely use an
// unconditional branch for some iterations.
OpenPOWER on IntegriCloud