diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 187 |
1 files changed, 171 insertions, 16 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index c867eb6c44d..659a1e5e1e7 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -137,6 +137,11 @@ static cl::opt<unsigned> MaxSCEVCompareDepth( cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32)); +static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth( + "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SCEV operations implication analysis"), + cl::init(2)); + static cl::opt<unsigned> MaxValueCompareDepth( "scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), @@ -3418,6 +3423,10 @@ Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { return getDataLayout().getIntPtrType(Ty); } +Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { + return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; +} + const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } @@ -8559,19 +8568,161 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, llvm_unreachable("covered switch fell through?!"); } +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, + unsigned Depth) { + assert(getTypeSizeInBits(LHS->getType()) == + getTypeSizeInBits(RHS->getType()) && + "LHS and RHS have different sizes?"); + assert(getTypeSizeInBits(FoundLHS->getType()) == + getTypeSizeInBits(FoundRHS->getType()) && + "FoundLHS and FoundRHS have different sizes?"); + // We want to avoid hurting the compile time with analysis of too big trees. + if (Depth > MaxSCEVOperationsImplicationDepth) + return false; + // We only want to work with ICMP_SGT comparison so far. + // TODO: Extend to ICMP_UGT? + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SGT; + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + } + if (Pred != ICmpInst::ICMP_SGT) + return false; + + auto GetOpFromSExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S)) + return Ext->getOperand(); + // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off + // the constant in some cases. + return S; + }; + + // Acquire values from extensions. + auto *OrigFoundLHS = FoundLHS; + LHS = GetOpFromSExt(LHS); + FoundLHS = GetOpFromSExt(FoundLHS); + + // Is the SGT predicate can be proved trivially or using the found context. + auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { + return isKnownViaSimpleReasoning(ICmpInst::ICMP_SGT, S1, S2) || + isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, + FoundRHS, Depth + 1); + }; + + if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) { + // We want to avoid creation of any new non-constant SCEV. Since we are + // going to compare the operands to RHS, we should be certain that we don't + // need any size extensions for this. So let's decline all cases when the + // sizes of types of LHS and RHS do not match. + // TODO: Maybe try to get RHS from sext to catch more cases? + if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType())) + return false; + + // Should not overflow. + if (!LHSAddExpr->hasNoSignedWrap()) + return false; + + auto *LL = LHSAddExpr->getOperand(0); + auto *LR = LHSAddExpr->getOperand(1); + auto *MinusOne = getNegativeSCEV(getOne(RHS->getType())); + + // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. + auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { + return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); + }; + // Try to prove the following rule: + // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). + // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). + if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL)) + return true; + } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) { + Value *LL, *LR; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { + // Rules for division. + // We are going to perform some comparisons with Denominator and its + // derivative expressions. In general case, creating a SCEV for it may + // lead to a complex analysis of the entire graph, and in particular it + // can request trip count recalculation for the same loop. This would + // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid + // this, we only want to create SCEVs that are constants in this section. + // So we bail if Denominator is not a constant. + if (!isa<ConstantInt>(LR)) + return false; + + auto *Denominator = cast<SCEVConstant>(getSCEV(LR)); + + // We want to make sure that LHS = FoundLHS / Denominator. If it is so, + // then a SCEV for the numerator already exists and matches with FoundLHS. + auto *Numerator = getExistingSCEV(LL); + if (!Numerator || Numerator->getType() != FoundLHS->getType()) + return false; + + // Make sure that the numerator matches with FoundLHS and the denominator + // is positive. + if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator)) + return false; + + auto *DTy = Denominator->getType(); + auto *FRHSTy = FoundRHS->getType(); + if (DTy->isPointerTy() != FRHSTy->isPointerTy()) + // One of types is a pointer and another one is not. We cannot extend + // them properly to a wider type, so let us just reject this case. + // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help + // to avoid this check. + return false; + + // Given that: + // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. + auto *WTy = getWiderType(DTy, FRHSTy); + auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); + auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); + + // Try to prove the following rule: + // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). + // For example, given that FoundLHS > 2. It means that FoundLHS is at + // least 3. If we divide it by Denominator < 4, we will have at least 1. + auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); + if (isKnownNonPositive(RHS) && + IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) + return true; + + // Try to prove the following rule: + // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS). + // For example, given that FoundLHS > -3. Then FoundLHS is at least -2. + // If we divide it by Denominator > 2, then: + // 1. If FoundLHS is negative, then the result is 0. + // 2. If FoundLHS is non-negative, then the result is non-negative. + // Anyways, the result is non-negative. + auto *MinusOne = getNegativeSCEV(getOne(WTy)); + auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); + if (isKnownNegative(RHS) && + IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) + return true; + } + } + + return false; +} + +bool +ScalarEvolution::isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || + isKnownPredicateViaNoOverflow(Pred, LHS, RHS); +} + bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { - auto IsKnownPredicateFull = - [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || - IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || - IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || - isKnownPredicateViaNoOverflow(Pred, LHS, RHS); - }; - switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -8581,30 +8732,34 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } + // Maybe it can be proved via operations? + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return false; } |