diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/Scalar/GuardWidening.cpp | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp index 5ac4374038e..24be4f508cc 100644 --- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -130,6 +130,55 @@ class GuardWideningImpl { bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt, Value *&Result); + /// Represents a range check of the form \c Base + \c Offset u< \c Length, + /// with the constraint that \c Length is not negative. \c CheckInst is the + /// pre-existing instruction in the IR that computes the result of this range + /// check. + struct RangeCheck { + Value *Base; + ConstantInt *Offset; + Value *Length; + ICmpInst *CheckInst; + + RangeCheck() {} + + explicit RangeCheck(Value *Base, ConstantInt *Offset, Value *Length, + ICmpInst *CheckInst) + : Base(Base), Offset(Offset), Length(Length), CheckInst(CheckInst) {} + + void print(raw_ostream &OS, bool PrintTypes = false) { + OS << "Base: "; + Base->printAsOperand(OS, PrintTypes); + OS << " Offset: "; + Offset->printAsOperand(OS, PrintTypes); + OS << " Length: "; + Length->printAsOperand(OS, PrintTypes); + } + + LLVM_DUMP_METHOD void dump() { + print(dbgs()); + dbgs() << "\n"; + } + }; + + /// Parse \p CheckCond into a conjunction (logical-and) of range checks; and + /// append them to \p Checks. Returns true on success, may clobber \c Checks + /// on failure. + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks) { + SmallPtrSet<Value *, 8> Visited; + return parseRangeChecks(CheckCond, Checks, Visited); + } + + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited); + + /// Combine the checks in \p Checks into a smaller set of checks and append + /// them into \p CombinedChecks. Return true on success (i.e. all of checks + /// in \p Checks were combined into \p CombinedChecks). Clobbers \p Checks + /// and \p CombinedChecks on success and on failure. + bool combineRangeChecks(SmallVectorImpl<RangeCheck> &Checks, + SmallVectorImpl<RangeCheck> &CombinedChecks); + /// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of /// computing only one of the two expressions? bool isWideningCondProfitable(Value *Cond0, Value *Cond1) { @@ -386,6 +435,27 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, } } + { + SmallVector<GuardWideningImpl::RangeCheck, 4> Checks, CombinedChecks; + if (parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) && + combineRangeChecks(Checks, CombinedChecks)) { + if (InsertPt) { + Result = nullptr; + for (auto &RC : CombinedChecks) { + makeAvailableAt(RC.CheckInst, InsertPt); + if (Result) + Result = + BinaryOperator::CreateAnd(RC.CheckInst, Result, "", InsertPt); + else + Result = RC.CheckInst; + } + + Result->setName("wide.chk"); + } + return true; + } + } + // Base case -- just logical-and the two conditions together. if (InsertPt) { @@ -399,6 +469,180 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, return false; } +bool GuardWideningImpl::parseRangeChecks( + Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited) { + if (!Visited.insert(CheckCond).second) + return true; + + using namespace llvm::PatternMatch; + + { + Value *AndLHS, *AndRHS; + if (match(CheckCond, m_And(m_Value(AndLHS), m_Value(AndRHS)))) + return parseRangeChecks(AndLHS, Checks) && + parseRangeChecks(AndRHS, Checks); + } + + auto *IC = dyn_cast<ICmpInst>(CheckCond); + if (!IC || !IC->getOperand(0)->getType()->isIntegerTy() || + (IC->getPredicate() != ICmpInst::ICMP_ULT && + IC->getPredicate() != ICmpInst::ICMP_UGT)) + return false; + + Value *CmpLHS = IC->getOperand(0), *CmpRHS = IC->getOperand(1); + if (IC->getPredicate() == ICmpInst::ICMP_UGT) + std::swap(CmpLHS, CmpRHS); + + auto &DL = IC->getModule()->getDataLayout(); + + GuardWideningImpl::RangeCheck Check; + Check.Base = CmpLHS; + Check.Offset = + cast<ConstantInt>(ConstantInt::getNullValue(CmpRHS->getType())); + Check.Length = CmpRHS; + Check.CheckInst = IC; + + if (!isKnownNonNegative(Check.Length, DL)) + return false; + + // What we have in \c Check now is a correct interpretation of \p CheckCond. + // Try to see if we can move some constant offsets into the \c Offset field. + + bool Changed; + + do { + Value *OpLHS; + ConstantInt *OpRHS; + Changed = false; + +#ifndef NDEBUG + auto *BaseInst = dyn_cast<Instruction>(Check.Base); + assert((!BaseInst || DT.isReachableFromEntry(BaseInst->getParent())) && + "Unreachable instruction?"); +#endif + + if (match(Check.Base, m_Add(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { + Check.Base = OpLHS; + Check.Offset = + ConstantInt::get(Check.Offset->getContext(), + Check.Offset->getValue() + OpRHS->getValue()); + Changed = true; + } else if (match(Check.Base, m_Or(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { + unsigned BitWidth = OpLHS->getType()->getScalarSizeInBits(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(OpLHS, KnownZero, KnownOne, DL); + if ((OpRHS->getValue() & KnownZero) == OpRHS->getValue()) { + Check.Base = OpLHS; + Check.Offset = + ConstantInt::get(Check.Offset->getContext(), + Check.Offset->getValue() + OpRHS->getValue()); + Changed = true; + } + } + } while (Changed); + + Checks.push_back(Check); + return true; +} + +bool GuardWideningImpl::combineRangeChecks( + SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, + SmallVectorImpl<GuardWideningImpl::RangeCheck> &RangeChecksOut) { + unsigned OldCount = Checks.size(); + while (!Checks.empty()) { + Value *Base = Checks[0].Base; + Value *Length = Checks[0].Length; + auto ChecksStart = + remove_if(Checks, [&](GuardWideningImpl::RangeCheck &RC) { + return RC.Base == Base && RC.Length == Length; + }); + + unsigned CheckCount = std::distance(ChecksStart, Checks.end()); + assert(CheckCount != 0 && "We know we have at least one!"); + + if (CheckCount < 3) { + RangeChecksOut.insert(RangeChecksOut.end(), ChecksStart, Checks.end()); + Checks.erase(ChecksStart, Checks.end()); + continue; + } + + // CheckCount will typically be 3 here, but so far there has been no need to + // hard-code that fact. + + std::sort(ChecksStart, Checks.end(), + [&](GuardWideningImpl::RangeCheck &LHS, + GuardWideningImpl::RangeCheck &RHS) { + return LHS.Offset->getValue().slt(RHS.Offset->getValue()); + }); + + // Note: std::sort should not invalidate the ChecksStart iterator. + + ConstantInt *MinOffset = ChecksStart->Offset, + *MaxOffset = Checks.back().Offset; + + unsigned BitWidth = MaxOffset->getValue().getBitWidth(); + if ((MaxOffset->getValue() - MinOffset->getValue()) + .ugt(APInt::getSignedMinValue(BitWidth))) + return false; + + APInt MaxDiff = MaxOffset->getValue() - MinOffset->getValue(); + APInt HighOffset = MaxOffset->getValue(); + auto OffsetOK = [&](GuardWideningImpl::RangeCheck &RC) { + return (HighOffset - RC.Offset->getValue()).ult(MaxDiff); + }; + + if (MaxDiff.isMinValue() || + !std::all_of(std::next(ChecksStart), Checks.end(), OffsetOK)) + return false; + + // We have a series of f+1 checks as: + // + // I+k_0 u< L ... Chk_0 + // I_k_1 u< L ... Chk_1 + // ... + // I_k_f u< L ... Chk_(f+1) + // + // with forall i in [0,f): k_f-k_i u< k_f-k_0 ... Precond_0 + // k_f-k_0 u< INT_MIN+k_f ... Precond_1 + // k_f != k_0 ... Precond_2 + // + // Claim: + // Chk_0 AND Chk_(f+1) implies all the other checks + // + // Informal proof sketch: + // + // We will show that the integer range [I+k_0,I+k_f] does not unsigned-wrap + // (i.e. going from I+k_0 to I+k_f does not cross the -1,0 boundary) and + // thus I+k_f is the greatest unsigned value in that range. + // + // This combined with Ckh_(f+1) shows that everything in that range is u< L. + // Via Precond_0 we know that all of the indices in Chk_0 through Chk_(f+1) + // lie in [I+k_0,I+k_f], this proving our claim. + // + // To see that [I+k_0,I+k_f] is not a wrapping range, note that there are + // two possibilities: I+k_0 u< I+k_f or I+k_0 >u I+k_f (they can't be equal + // since k_0 != k_f). In the former case, [I+k_0,I+k_f] is not a wrapping + // range by definition, and the latter case is impossible: + // + // 0-----I+k_f---I+k_0----L---INT_MAX,INT_MIN------------------(-1) + // xxxxxx xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + // + // For Chk_0 to succeed, we'd have to have k_f-k_0 (the range highlighted + // with 'x' above) to be at least >u INT_MIN. + + RangeChecksOut.emplace_back(Base, MinOffset, Length, + ChecksStart->CheckInst); + RangeChecksOut.emplace_back(Base, MaxOffset, Length, + Checks.back().CheckInst); + + Checks.erase(ChecksStart, Checks.end()); + } + + assert(RangeChecksOut.size() <= OldCount && "We pessimized!"); + return RangeChecksOut.size() != OldCount; +} + PreservedAnalyses GuardWideningPass::run(Function &F, AnalysisManager<Function> &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); |