summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/Scalar/GuardWidening.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/GuardWidening.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/GuardWidening.cpp244
1 files changed, 244 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
index 5ac4374038e..24be4f508cc 100644
--- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp
+++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
@@ -130,6 +130,55 @@ class GuardWideningImpl {
bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt,
Value *&Result);
+ /// Represents a range check of the form \c Base + \c Offset u< \c Length,
+ /// with the constraint that \c Length is not negative. \c CheckInst is the
+ /// pre-existing instruction in the IR that computes the result of this range
+ /// check.
+ struct RangeCheck {
+ Value *Base;
+ ConstantInt *Offset;
+ Value *Length;
+ ICmpInst *CheckInst;
+
+ RangeCheck() {}
+
+ explicit RangeCheck(Value *Base, ConstantInt *Offset, Value *Length,
+ ICmpInst *CheckInst)
+ : Base(Base), Offset(Offset), Length(Length), CheckInst(CheckInst) {}
+
+ void print(raw_ostream &OS, bool PrintTypes = false) {
+ OS << "Base: ";
+ Base->printAsOperand(OS, PrintTypes);
+ OS << " Offset: ";
+ Offset->printAsOperand(OS, PrintTypes);
+ OS << " Length: ";
+ Length->printAsOperand(OS, PrintTypes);
+ }
+
+ LLVM_DUMP_METHOD void dump() {
+ print(dbgs());
+ dbgs() << "\n";
+ }
+ };
+
+ /// Parse \p CheckCond into a conjunction (logical-and) of range checks; and
+ /// append them to \p Checks. Returns true on success, may clobber \c Checks
+ /// on failure.
+ bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks) {
+ SmallPtrSet<Value *, 8> Visited;
+ return parseRangeChecks(CheckCond, Checks, Visited);
+ }
+
+ bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks,
+ SmallPtrSetImpl<Value *> &Visited);
+
+ /// Combine the checks in \p Checks into a smaller set of checks and append
+ /// them into \p CombinedChecks. Return true on success (i.e. all of checks
+ /// in \p Checks were combined into \p CombinedChecks). Clobbers \p Checks
+ /// and \p CombinedChecks on success and on failure.
+ bool combineRangeChecks(SmallVectorImpl<RangeCheck> &Checks,
+ SmallVectorImpl<RangeCheck> &CombinedChecks);
+
/// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of
/// computing only one of the two expressions?
bool isWideningCondProfitable(Value *Cond0, Value *Cond1) {
@@ -386,6 +435,27 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
}
}
+ {
+ SmallVector<GuardWideningImpl::RangeCheck, 4> Checks, CombinedChecks;
+ if (parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) &&
+ combineRangeChecks(Checks, CombinedChecks)) {
+ if (InsertPt) {
+ Result = nullptr;
+ for (auto &RC : CombinedChecks) {
+ makeAvailableAt(RC.CheckInst, InsertPt);
+ if (Result)
+ Result =
+ BinaryOperator::CreateAnd(RC.CheckInst, Result, "", InsertPt);
+ else
+ Result = RC.CheckInst;
+ }
+
+ Result->setName("wide.chk");
+ }
+ return true;
+ }
+ }
+
// Base case -- just logical-and the two conditions together.
if (InsertPt) {
@@ -399,6 +469,180 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
return false;
}
+bool GuardWideningImpl::parseRangeChecks(
+ Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks,
+ SmallPtrSetImpl<Value *> &Visited) {
+ if (!Visited.insert(CheckCond).second)
+ return true;
+
+ using namespace llvm::PatternMatch;
+
+ {
+ Value *AndLHS, *AndRHS;
+ if (match(CheckCond, m_And(m_Value(AndLHS), m_Value(AndRHS))))
+ return parseRangeChecks(AndLHS, Checks) &&
+ parseRangeChecks(AndRHS, Checks);
+ }
+
+ auto *IC = dyn_cast<ICmpInst>(CheckCond);
+ if (!IC || !IC->getOperand(0)->getType()->isIntegerTy() ||
+ (IC->getPredicate() != ICmpInst::ICMP_ULT &&
+ IC->getPredicate() != ICmpInst::ICMP_UGT))
+ return false;
+
+ Value *CmpLHS = IC->getOperand(0), *CmpRHS = IC->getOperand(1);
+ if (IC->getPredicate() == ICmpInst::ICMP_UGT)
+ std::swap(CmpLHS, CmpRHS);
+
+ auto &DL = IC->getModule()->getDataLayout();
+
+ GuardWideningImpl::RangeCheck Check;
+ Check.Base = CmpLHS;
+ Check.Offset =
+ cast<ConstantInt>(ConstantInt::getNullValue(CmpRHS->getType()));
+ Check.Length = CmpRHS;
+ Check.CheckInst = IC;
+
+ if (!isKnownNonNegative(Check.Length, DL))
+ return false;
+
+ // What we have in \c Check now is a correct interpretation of \p CheckCond.
+ // Try to see if we can move some constant offsets into the \c Offset field.
+
+ bool Changed;
+
+ do {
+ Value *OpLHS;
+ ConstantInt *OpRHS;
+ Changed = false;
+
+#ifndef NDEBUG
+ auto *BaseInst = dyn_cast<Instruction>(Check.Base);
+ assert((!BaseInst || DT.isReachableFromEntry(BaseInst->getParent())) &&
+ "Unreachable instruction?");
+#endif
+
+ if (match(Check.Base, m_Add(m_Value(OpLHS), m_ConstantInt(OpRHS)))) {
+ Check.Base = OpLHS;
+ Check.Offset =
+ ConstantInt::get(Check.Offset->getContext(),
+ Check.Offset->getValue() + OpRHS->getValue());
+ Changed = true;
+ } else if (match(Check.Base, m_Or(m_Value(OpLHS), m_ConstantInt(OpRHS)))) {
+ unsigned BitWidth = OpLHS->getType()->getScalarSizeInBits();
+ APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+ computeKnownBits(OpLHS, KnownZero, KnownOne, DL);
+ if ((OpRHS->getValue() & KnownZero) == OpRHS->getValue()) {
+ Check.Base = OpLHS;
+ Check.Offset =
+ ConstantInt::get(Check.Offset->getContext(),
+ Check.Offset->getValue() + OpRHS->getValue());
+ Changed = true;
+ }
+ }
+ } while (Changed);
+
+ Checks.push_back(Check);
+ return true;
+}
+
+bool GuardWideningImpl::combineRangeChecks(
+ SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks,
+ SmallVectorImpl<GuardWideningImpl::RangeCheck> &RangeChecksOut) {
+ unsigned OldCount = Checks.size();
+ while (!Checks.empty()) {
+ Value *Base = Checks[0].Base;
+ Value *Length = Checks[0].Length;
+ auto ChecksStart =
+ remove_if(Checks, [&](GuardWideningImpl::RangeCheck &RC) {
+ return RC.Base == Base && RC.Length == Length;
+ });
+
+ unsigned CheckCount = std::distance(ChecksStart, Checks.end());
+ assert(CheckCount != 0 && "We know we have at least one!");
+
+ if (CheckCount < 3) {
+ RangeChecksOut.insert(RangeChecksOut.end(), ChecksStart, Checks.end());
+ Checks.erase(ChecksStart, Checks.end());
+ continue;
+ }
+
+ // CheckCount will typically be 3 here, but so far there has been no need to
+ // hard-code that fact.
+
+ std::sort(ChecksStart, Checks.end(),
+ [&](GuardWideningImpl::RangeCheck &LHS,
+ GuardWideningImpl::RangeCheck &RHS) {
+ return LHS.Offset->getValue().slt(RHS.Offset->getValue());
+ });
+
+ // Note: std::sort should not invalidate the ChecksStart iterator.
+
+ ConstantInt *MinOffset = ChecksStart->Offset,
+ *MaxOffset = Checks.back().Offset;
+
+ unsigned BitWidth = MaxOffset->getValue().getBitWidth();
+ if ((MaxOffset->getValue() - MinOffset->getValue())
+ .ugt(APInt::getSignedMinValue(BitWidth)))
+ return false;
+
+ APInt MaxDiff = MaxOffset->getValue() - MinOffset->getValue();
+ APInt HighOffset = MaxOffset->getValue();
+ auto OffsetOK = [&](GuardWideningImpl::RangeCheck &RC) {
+ return (HighOffset - RC.Offset->getValue()).ult(MaxDiff);
+ };
+
+ if (MaxDiff.isMinValue() ||
+ !std::all_of(std::next(ChecksStart), Checks.end(), OffsetOK))
+ return false;
+
+ // We have a series of f+1 checks as:
+ //
+ // I+k_0 u< L ... Chk_0
+ // I_k_1 u< L ... Chk_1
+ // ...
+ // I_k_f u< L ... Chk_(f+1)
+ //
+ // with forall i in [0,f): k_f-k_i u< k_f-k_0 ... Precond_0
+ // k_f-k_0 u< INT_MIN+k_f ... Precond_1
+ // k_f != k_0 ... Precond_2
+ //
+ // Claim:
+ // Chk_0 AND Chk_(f+1) implies all the other checks
+ //
+ // Informal proof sketch:
+ //
+ // We will show that the integer range [I+k_0,I+k_f] does not unsigned-wrap
+ // (i.e. going from I+k_0 to I+k_f does not cross the -1,0 boundary) and
+ // thus I+k_f is the greatest unsigned value in that range.
+ //
+ // This combined with Ckh_(f+1) shows that everything in that range is u< L.
+ // Via Precond_0 we know that all of the indices in Chk_0 through Chk_(f+1)
+ // lie in [I+k_0,I+k_f], this proving our claim.
+ //
+ // To see that [I+k_0,I+k_f] is not a wrapping range, note that there are
+ // two possibilities: I+k_0 u< I+k_f or I+k_0 >u I+k_f (they can't be equal
+ // since k_0 != k_f). In the former case, [I+k_0,I+k_f] is not a wrapping
+ // range by definition, and the latter case is impossible:
+ //
+ // 0-----I+k_f---I+k_0----L---INT_MAX,INT_MIN------------------(-1)
+ // xxxxxx xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
+ //
+ // For Chk_0 to succeed, we'd have to have k_f-k_0 (the range highlighted
+ // with 'x' above) to be at least >u INT_MIN.
+
+ RangeChecksOut.emplace_back(Base, MinOffset, Length,
+ ChecksStart->CheckInst);
+ RangeChecksOut.emplace_back(Base, MaxOffset, Length,
+ Checks.back().CheckInst);
+
+ Checks.erase(ChecksStart, Checks.end());
+ }
+
+ assert(RangeChecksOut.size() <= OldCount && "We pessimized!");
+ return RangeChecksOut.size() != OldCount;
+}
+
PreservedAnalyses GuardWideningPass::run(Function &F,
AnalysisManager<Function> &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
OpenPOWER on IntegriCloud