diff options
Diffstat (limited to 'llvm/lib')
38 files changed, 1120 insertions, 500 deletions
diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp index bcb005ca689..9cfd02c0218 100644 --- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp +++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -194,7 +195,9 @@ namespace { /// represented in the result. static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, ExtensionKind &Extension, - const DataLayout &DL, unsigned Depth) { + const DataLayout &DL, unsigned Depth, + AssumptionTracker *AT, + DominatorTree *DT) { assert(V->getType()->isIntegerTy() && "Not an integer value"); // Limit our recursion depth. @@ -211,23 +214,24 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, case Instruction::Or: // X|C == X+C if all the bits in C are unset in X. Otherwise we can't // analyze it. - if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), &DL)) + if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), &DL, 0, + AT, BOp, DT)) break; // FALL THROUGH. case Instruction::Add: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth+1, AT, DT); Offset += RHSC->getValue(); return V; case Instruction::Mul: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth+1, AT, DT); Offset *= RHSC->getValue(); Scale *= RHSC->getValue(); return V; case Instruction::Shl: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth+1, AT, DT); Offset <<= RHSC->getValue().getLimitedValue(); Scale <<= RHSC->getValue().getLimitedValue(); return V; @@ -248,7 +252,7 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, Extension = isa<SExtInst>(V) ? EK_SignExt : EK_ZeroExt; Value *Result = GetLinearExpression(CastOp, Scale, Offset, Extension, - DL, Depth+1); + DL, Depth+1, AT, DT); Scale = Scale.zext(OldWidth); Offset = Offset.zext(OldWidth); @@ -278,7 +282,8 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, static const Value * DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, SmallVectorImpl<VariableGEPIndex> &VarIndices, - bool &MaxLookupReached, const DataLayout *DL) { + bool &MaxLookupReached, const DataLayout *DL, + AssumptionTracker *AT, DominatorTree *DT) { // Limit recursion depth to limit compile time in crazy cases. unsigned MaxLookup = MaxLookupSearchDepth; MaxLookupReached = false; @@ -309,7 +314,10 @@ DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, // If it's not a GEP, hand it off to SimplifyInstruction to see if it // can come up with something. This matches what GetUnderlyingObject does. if (const Instruction *I = dyn_cast<Instruction>(V)) - // TODO: Get a DominatorTree and use it here. + // TODO: Get a DominatorTree and AssumptionTracker and use them here + // (these are both now available in this function, but this should be + // updated when GetUnderlyingObject is updated). TLI should be + // provided also. if (const Value *Simplified = SimplifyInstruction(const_cast<Instruction *>(I), DL)) { V = Simplified; @@ -368,7 +376,7 @@ DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, // Use GetLinearExpression to decompose the index into a C1*V+C2 form. APInt IndexScale(Width, 0), IndexOffset(Width, 0); Index = GetLinearExpression(Index, IndexScale, IndexOffset, Extension, - *DL, 0); + *DL, 0, AT, DT); // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale. // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale. @@ -449,6 +457,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AliasAnalysis>(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetLibraryInfo>(); } @@ -571,6 +580,7 @@ char BasicAliasAnalysis::ID = 0; INITIALIZE_AG_PASS_BEGIN(BasicAliasAnalysis, AliasAnalysis, "basicaa", "Basic Alias Analysis (stateless AA impl)", false, true, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_AG_PASS_END(BasicAliasAnalysis, AliasAnalysis, "basicaa", "Basic Alias Analysis (stateless AA impl)", @@ -884,6 +894,11 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, bool GEP1MaxLookupReached; SmallVector<VariableGEPIndex, 4> GEP1VariableIndices; + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; + // If we have two gep instructions with must-alias or not-alias'ing base // pointers, figure out if the indexes to the GEP tell us anything about the // derived pointer. @@ -907,10 +922,10 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, SmallVector<VariableGEPIndex, 4> GEP2VariableIndices; const Value *GEP2BasePtr = DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, - GEP2MaxLookupReached, DL); + GEP2MaxLookupReached, DL, AT, DT); const Value *GEP1BasePtr = DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + GEP1MaxLookupReached, DL, AT, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. if (GEP1BasePtr != UnderlyingV1 || GEP2BasePtr != UnderlyingV2) { @@ -939,14 +954,14 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // about the relation of the resulting pointer. const Value *GEP1BasePtr = DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + GEP1MaxLookupReached, DL, AT, DT); int64_t GEP2BaseOffset; bool GEP2MaxLookupReached; SmallVector<VariableGEPIndex, 4> GEP2VariableIndices; const Value *GEP2BasePtr = DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, - GEP2MaxLookupReached, DL); + GEP2MaxLookupReached, DL, AT, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. @@ -985,7 +1000,7 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, const Value *GEP1BasePtr = DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + GEP1MaxLookupReached, DL, AT, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 2642ffffd77..cd795fdb781 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -45,9 +45,13 @@ struct Query { const DataLayout *DL; const TargetLibraryInfo *TLI; const DominatorTree *DT; + AssumptionTracker *AT; + const Instruction *CxtI; Query(const DataLayout *DL, const TargetLibraryInfo *tli, - const DominatorTree *dt) : DL(DL), TLI(tli), DT(dt) {} + const DominatorTree *dt, AssumptionTracker *at = nullptr, + const Instruction *cxti = nullptr) + : DL(DL), TLI(tli), DT(dt), AT(at), CxtI(cxti) {} }; static Value *SimplifyAndInst(Value *, Value *, const Query &, unsigned); @@ -575,9 +579,10 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, + Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } /// \brief Compute the base pointer and cumulative constant offsets for V. @@ -781,9 +786,10 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, + Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } /// Given operands for an FAdd, see if we can fold the result. If not, this @@ -959,28 +965,37 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFSubInst(Op0, Op1, FMF, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFMulInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFMulInst(Op0, Op1, FMF, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyMulInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyMulInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyDiv - Given operands for an SDiv or UDiv, see if we can @@ -1067,8 +1082,11 @@ static Value *SimplifySDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifySDivInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyUDivInst - Given operands for a UDiv, see if we can @@ -1083,8 +1101,11 @@ static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyUDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyUDivInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } static Value *SimplifyFDivInst(Value *Op0, Value *Op1, const Query &Q, @@ -1102,8 +1123,11 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFDivInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyRem - Given operands for an SRem or URem, see if we can @@ -1172,8 +1196,11 @@ static Value *SimplifySRemInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySRemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifySRemInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyURemInst - Given operands for a URem, see if we can @@ -1188,8 +1215,11 @@ static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyURemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyURemInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } static Value *SimplifyFRemInst(Value *Op0, Value *Op1, const Query &, @@ -1207,8 +1237,11 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, const Query &, Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFRemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFRemInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// isUndefShift - Returns true if a shift by \c Amount always yields undef. @@ -1296,8 +1329,9 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -1328,8 +1362,10 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyLShrInst(Op0, Op1, isExact, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyLShrInst(Op0, Op1, isExact, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -1359,7 +1395,7 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, return X; // Arithmetic shifting an all-sign-bit value is a no-op. - unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL); + unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AT, Q.CxtI, Q.DT); if (NumSignBits == Op0->getType()->getScalarSizeInBits()) return Op0; @@ -1369,8 +1405,10 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAShrInst(Op0, Op1, isExact, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyAShrInst(Op0, Op1, isExact, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -1424,9 +1462,9 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, // A & (-A) = A if A is a power of two or zero. if (match(Op0, m_Neg(m_Specific(Op1))) || match(Op1, m_Neg(m_Specific(Op0)))) { - if (isKnownToBeAPowerOfTwo(Op0, /*OrZero*/true)) + if (isKnownToBeAPowerOfTwo(Op0, /*OrZero*/true, 0, Q.AT, Q.CxtI, Q.DT)) return Op0; - if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true)) + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true, 0, Q.AT, Q.CxtI, Q.DT)) return Op1; } @@ -1464,8 +1502,10 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAndInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyAndInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyOrInst - Given operands for an Or, see if we can @@ -1557,18 +1597,22 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, if ((C2->getValue() & (C2->getValue() + 1)) == 0 && // C2 == 0+1+ match(A, m_Add(m_Value(V1), m_Value(V2)))) { // Add commutes, try both ways. - if (V1 == B && MaskedValueIsZero(V2, C2->getValue())) + if (V1 == B && MaskedValueIsZero(V2, C2->getValue(), Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return A; - if (V2 == B && MaskedValueIsZero(V1, C2->getValue())) + if (V2 == B && MaskedValueIsZero(V1, C2->getValue(), Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return A; } // Or commutes, try both ways. if ((C1->getValue() & (C1->getValue() + 1)) == 0 && match(B, m_Add(m_Value(V1), m_Value(V2)))) { // Add commutes, try both ways. - if (V1 == A && MaskedValueIsZero(V2, C1->getValue())) + if (V1 == A && MaskedValueIsZero(V2, C1->getValue(), Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return B; - if (V2 == A && MaskedValueIsZero(V1, C1->getValue())) + if (V2 == A && MaskedValueIsZero(V1, C1->getValue(), Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return B; } } @@ -1585,8 +1629,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyOrInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyOrInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyXorInst - Given operands for a Xor, see if we can @@ -1640,8 +1686,10 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyXorInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyXorInst(Op0, Op1, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } static Type *GetCompareTy(Value *Op) { @@ -1895,40 +1943,46 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULE: - if (isKnownNonZero(LHS, Q.DL)) + if (isKnownNonZero(LHS, Q.DL, 0, Q.AT, Q.CxtI, Q.DT)) return getFalse(ITy); break; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: - if (isKnownNonZero(LHS, Q.DL)) + if (isKnownNonZero(LHS, Q.DL, 0, Q.AT, Q.CxtI, Q.DT)) return getTrue(ITy); break; case ICmpInst::ICMP_SLT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (LHSKnownNegative) return getTrue(ITy); if (LHSKnownNonNegative) return getFalse(ITy); break; case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (LHSKnownNegative) return getTrue(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL)) + if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return getFalse(ITy); break; case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (LHSKnownNegative) return getFalse(ITy); if (LHSKnownNonNegative) return getTrue(ITy); break; case ICmpInst::ICMP_SGT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (LHSKnownNegative) return getFalse(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL)) + if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT)) return getTrue(ITy); break; } @@ -2224,10 +2278,12 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, uint32_t BitWidth = CI->getBitWidth(); APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); APInt RHSKnownZero(BitWidth, 0); APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (((LHSKnownOne & RHSKnownZero) != 0) || ((LHSKnownZero & RHSKnownOne) != 0)) return (Pred == ICmpInst::ICMP_EQ) @@ -2329,7 +2385,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2339,7 +2396,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getFalse(ITy); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2358,7 +2416,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2368,7 +2427,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, + 0, Q.AT, Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2688,8 +2748,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyICmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + Instruction *CxtI) { + return ::SimplifyICmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -2785,8 +2847,10 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyFCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -2824,9 +2888,11 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Query (DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifySelectInst(Cond, TrueVal, FalseVal, + Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } /// SimplifyGEPInst - Given operands for an GetElementPtrInst, see if we can @@ -2913,8 +2979,9 @@ static Value *SimplifyGEPInst(ArrayRef<Value *> Ops, const Query &Q, unsigned) { Value *llvm::SimplifyGEPInst(ArrayRef<Value *> Ops, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyGEPInst(Ops, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyGEPInst(Ops, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } /// SimplifyInsertValueInst - Given operands for an InsertValueInst, see if we @@ -2950,8 +3017,11 @@ Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, ArrayRef<unsigned> Idxs, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyInsertValueInst(Agg, Val, Idxs, Query (DL, TLI, DT), + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyInsertValueInst(Agg, Val, Idxs, + Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -2998,8 +3068,11 @@ static Value *SimplifyTruncInst(Value *Op, Type *Ty, const Query &Q, unsigned) { Value *llvm::SimplifyTruncInst(Value *Op, Type *Ty, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyTruncInst(Op, Ty, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, + AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyTruncInst(Op, Ty, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } //=== Helper functions for higher up the class hierarchy. @@ -3071,8 +3144,10 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyBinOp(Opcode, LHS, RHS, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyBinOp(Opcode, LHS, RHS, Query (DL, TLI, DT, AT, CxtI), + RecursionLimit); } /// SimplifyCmpInst - Given operands for a CmpInst, see if we can @@ -3086,8 +3161,9 @@ static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT, AT, CxtI), RecursionLimit); } @@ -3162,23 +3238,26 @@ static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, Value *llvm::SimplifyCall(Value *V, User::op_iterator ArgBegin, User::op_iterator ArgEnd, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCall(V, ArgBegin, ArgEnd, Query(DL, TLI, DT), + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyCall(V, ArgBegin, ArgEnd, Query(DL, TLI, DT, AT, CxtI), RecursionLimit); } Value *llvm::SimplifyCall(Value *V, ArrayRef<Value *> Args, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCall(V, Args.begin(), Args.end(), Query(DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, AssumptionTracker *AT, + const Instruction *CxtI) { + return ::SimplifyCall(V, Args.begin(), Args.end(), + Query(DL, TLI, DT, AT, CxtI), RecursionLimit); } /// SimplifyInstruction - See if we can compute a simplified version of this /// instruction. If not, this returns null. Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, + AssumptionTracker *AT) { Value *Result; switch (I->getOpcode()) { @@ -3187,109 +3266,122 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, break; case Instruction::FAdd: Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AT, I); break; case Instruction::Add: Result = SimplifyAddInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::FSub: Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AT, I); break; case Instruction::Sub: Result = SimplifySubInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::FMul: Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AT, I); break; case Instruction::Mul: - Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::SDiv: - Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::UDiv: - Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::FDiv: - Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::SRem: - Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::URem: - Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::FRem: - Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::Shl: Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::LShr: Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->isExact(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::AShr: Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->isExact(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; case Instruction::And: - Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::Or: - Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, + AT, I); break; case Instruction::Xor: - Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::ICmp: Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), DL, TLI, DT); + I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::FCmp: Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), DL, TLI, DT); + I->getOperand(0), I->getOperand(1), + DL, TLI, DT, AT, I); break; case Instruction::Select: Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), - I->getOperand(2), DL, TLI, DT); + I->getOperand(2), DL, TLI, DT, AT, I); break; case Instruction::GetElementPtr: { SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); - Result = SimplifyGEPInst(Ops, DL, TLI, DT); + Result = SimplifyGEPInst(Ops, DL, TLI, DT, AT, I); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast<InsertValueInst>(I); Result = SimplifyInsertValueInst(IV->getAggregateOperand(), IV->getInsertedValueOperand(), - IV->getIndices(), DL, TLI, DT); + IV->getIndices(), DL, TLI, DT, AT, I); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast<PHINode>(I), Query (DL, TLI, DT)); + Result = SimplifyPHINode(cast<PHINode>(I), Query (DL, TLI, DT, AT, I)); break; case Instruction::Call: { CallSite CS(cast<CallInst>(I)); Result = SimplifyCall(CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), - DL, TLI, DT); + DL, TLI, DT, AT, I); break; } case Instruction::Trunc: - Result = SimplifyTruncInst(I->getOperand(0), I->getType(), DL, TLI, DT); + Result = SimplifyTruncInst(I->getOperand(0), I->getType(), DL, TLI, DT, + AT, I); break; } @@ -3313,7 +3405,8 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, + AssumptionTracker *AT) { bool Simplified = false; SmallSetVector<Instruction *, 8> Worklist; @@ -3340,7 +3433,7 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, I = Worklist[Idx]; // See if this instruction simplifies. - SimpleV = SimplifyInstruction(I, DL, TLI, DT); + SimpleV = SimplifyInstruction(I, DL, TLI, DT, AT); if (!SimpleV) continue; @@ -3366,15 +3459,17 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, bool llvm::recursivelySimplifyInstruction(Instruction *I, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return replaceAndRecursivelySimplifyImpl(I, nullptr, DL, TLI, DT); + const DominatorTree *DT, + AssumptionTracker *AT) { + return replaceAndRecursivelySimplifyImpl(I, nullptr, DL, TLI, DT, AT); } bool llvm::replaceAndRecursivelySimplify(Instruction *I, Value *SimpleV, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, + AssumptionTracker *AT) { assert(I != SimpleV && "replaceAndRecursivelySimplify(X,X) is not valid!"); assert(SimpleV && "Must provide a simplified value."); - return replaceAndRecursivelySimplifyImpl(I, SimpleV, DL, TLI, DT); + return replaceAndRecursivelySimplifyImpl(I, SimpleV, DL, TLI, DT, AT); } diff --git a/llvm/lib/Analysis/Lint.cpp b/llvm/lib/Analysis/Lint.cpp index eebe6b3e56b..48ea8885e31 100644 --- a/llvm/lib/Analysis/Lint.cpp +++ b/llvm/lib/Analysis/Lint.cpp @@ -37,6 +37,7 @@ #include "llvm/Analysis/Lint.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" @@ -101,6 +102,7 @@ namespace { public: Module *Mod; AliasAnalysis *AA; + AssumptionTracker *AT; DominatorTree *DT; const DataLayout *DL; TargetLibraryInfo *TLI; @@ -118,6 +120,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); AU.addRequired<AliasAnalysis>(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetLibraryInfo>(); AU.addRequired<DominatorTreeWrapperPass>(); } @@ -151,6 +154,7 @@ namespace { char Lint::ID = 0; INITIALIZE_PASS_BEGIN(Lint, "lint", "Statically lint-checks LLVM IR", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) @@ -175,6 +179,7 @@ INITIALIZE_PASS_END(Lint, "lint", "Statically lint-checks LLVM IR", bool Lint::runOnFunction(Function &F) { Mod = F.getParent(); AA = &getAnalysis<AliasAnalysis>(); + AT = &getAnalysis<AssumptionTracker>(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; @@ -504,7 +509,8 @@ void Lint::visitShl(BinaryOperator &I) { "Undefined result: Shift count out of range", &I); } -static bool isZero(Value *V, const DataLayout *DL) { +static bool isZero(Value *V, const DataLayout *DL, DominatorTree *DT, + AssumptionTracker *AT) { // Assume undef could be zero. if (isa<UndefValue>(V)) return true; @@ -513,7 +519,8 @@ static bool isZero(Value *V, const DataLayout *DL) { if (!VecTy) { unsigned BitWidth = V->getType()->getIntegerBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL); + computeKnownBits(V, KnownZero, KnownOne, DL, + 0, AT, dyn_cast<Instruction>(V), DT); return KnownZero.isAllOnesValue(); } @@ -543,22 +550,22 @@ static bool isZero(Value *V, const DataLayout *DL) { } void Lint::visitSDiv(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AT), "Undefined behavior: Division by zero", &I); } void Lint::visitUDiv(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AT), "Undefined behavior: Division by zero", &I); } void Lint::visitSRem(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AT), "Undefined behavior: Division by zero", &I); } void Lint::visitURem(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AT), "Undefined behavior: Division by zero", &I); } @@ -678,7 +685,7 @@ Value *Lint::findValueImpl(Value *V, bool OffsetOk, // As a last resort, try SimplifyInstruction or constant folding. if (Instruction *Inst = dyn_cast<Instruction>(V)) { - if (Value *W = SimplifyInstruction(Inst, DL, TLI, DT)) + if (Value *W = SimplifyInstruction(Inst, DL, TLI, DT, AT)) return findValueImpl(W, OffsetOk, Visited); } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { if (Value *W = ConstantFoldConstantExpression(CE, DL, TLI)) diff --git a/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp b/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp index f7180aae69e..7582dd8a067 100644 --- a/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp +++ b/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/PHITransAddr.h" @@ -55,6 +56,7 @@ char MemoryDependenceAnalysis::ID = 0; // Register this pass... INITIALIZE_PASS_BEGIN(MemoryDependenceAnalysis, "memdep", "Memory Dependence Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) INITIALIZE_PASS_END(MemoryDependenceAnalysis, "memdep", "Memory Dependence Analysis", false, true) @@ -83,11 +85,13 @@ void MemoryDependenceAnalysis::releaseMemory() { /// void MemoryDependenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionTracker>(); AU.addRequiredTransitive<AliasAnalysis>(); } bool MemoryDependenceAnalysis::runOnFunction(Function &) { AA = &getAnalysis<AliasAnalysis>(); + AT = &getAnalysis<AssumptionTracker>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; DominatorTreeWrapperPass *DTWP = @@ -859,7 +863,7 @@ getNonLocalPointerDependency(const AliasAnalysis::Location &Loc, bool isLoad, "Can't get pointer deps of a non-pointer!"); Result.clear(); - PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL); + PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL, AT); // This is the set of blocks we've inspected, and the pointer we consider in // each block. Because of critical edges, we currently bail out if querying diff --git a/llvm/lib/Analysis/PHITransAddr.cpp b/llvm/lib/Analysis/PHITransAddr.cpp index bfe86425119..b3d060a1acd 100644 --- a/llvm/lib/Analysis/PHITransAddr.cpp +++ b/llvm/lib/Analysis/PHITransAddr.cpp @@ -228,7 +228,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, return GEP; // Simplify the GEP to handle 'gep x, 0' -> x etc. - if (Value *V = SimplifyGEPInst(GEPOps, DL, TLI, DT)) { + if (Value *V = SimplifyGEPInst(GEPOps, DL, TLI, DT, AT)) { for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) RemoveInstInputs(GEPOps[i], InstInputs); @@ -283,7 +283,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, } // See if the add simplifies away. - if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, DL, TLI, DT)) { + if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, DL, TLI, DT, AT)) { // If we simplified the operands, the LHS is no longer an input, but Res // is. RemoveInstInputs(LHS, InstInputs); @@ -369,7 +369,7 @@ InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, SmallVectorImpl<Instruction*> &NewInsts) { // See if we have a version of this value already available and dominating // PredBB. If so, there is no need to insert a new instance of it. - PHITransAddr Tmp(InVal, DL); + PHITransAddr Tmp(InVal, DL, AT); if (!Tmp.PHITranslateValue(CurBB, PredBB, &DT)) return Tmp.getAddr(); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index ff8fbfdd76e..00b485dedee 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -62,6 +62,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -113,6 +114,7 @@ VerifySCEV("verify-scev", INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) @@ -3258,7 +3260,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, DL, TLI, DT)) + if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AT)) if (LI->replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -3390,7 +3392,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT); return Zeros.countTrailingOnes(); } @@ -3529,7 +3531,7 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT); if (Ones == ~Zeros + 1) return setUnsignedRange(U, ConservativeResult); return setUnsignedRange(U, @@ -3681,7 +3683,7 @@ ScalarEvolution::getSignedRange(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. if (!U->getValue()->getType()->isIntegerTy() && !DL) return setSignedRange(U, ConservativeResult); - unsigned NS = ComputeNumSignBits(U->getValue(), DL); + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AT, nullptr, DT); if (NS <= 1) return setSignedRange(U, ConservativeResult); return setSignedRange(U, ConservativeResult.intersectWith( @@ -3788,7 +3790,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, + 0, AT, nullptr, DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -7633,6 +7636,7 @@ ScalarEvolution::ScalarEvolution() bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; + AT = &getAnalysis<AssumptionTracker>(); LI = &getAnalysis<LoopInfo>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; @@ -7673,6 +7677,7 @@ void ScalarEvolution::releaseMemory() { void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionTracker>(); AU.addRequiredTransitive<LoopInfo>(); AU.addRequiredTransitive<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfo>(); diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp index 968c619a48d..c6fa8f8e839 100644 --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1711,7 +1711,7 @@ unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, // Fold constant phis. They may be congruent to other constant phis and // would confuse the logic below that expects proper IVs. - if (Value *V = SimplifyInstruction(Phi, SE.DL, SE.TLI, SE.DT)) { + if (Value *V = SimplifyInstruction(Phi, SE.DL, SE.TLI, SE.DT, SE.AT)) { Phi->replaceAllUsesWith(V); DeadInsts.push_back(Phi); ++NumElim; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 6409b85b1a0..92db3772008 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -20,6 +21,7 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" @@ -46,17 +48,154 @@ static unsigned getBitWidth(Type *Ty, const DataLayout *TD) { return TD ? TD->getPointerTypeSizeInBits(Ty) : 0; } +// Many of these functions have internal versions that take an assumption +// exclusion set. This is because of the potential for mutual recursion to +// cause computeKnownBits to repeatedly visit the same assume intrinsic. The +// classic case of this is assume(x = y), which will attempt to determine +// bits in x from bits in y, which will attempt to determine bits in y from +// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call +// isKnownNonZero, which calls computeKnownBits and ComputeSignBit and +// isKnownToBeAPowerOfTwo (all of which can call computeKnownBits), and so on. +typedef SmallPtrSet<const Value *, 8> ExclInvsSet; + +// Simplifying using an assume can only be done in a particular control-flow +// context (the context instruction provides that context). If an assume and +// the context instruction are not in the same block then the DT helps in +// figuring out if we can use it. +struct Query { + ExclInvsSet ExclInvs; + AssumptionTracker *AT; + const Instruction *CxtI; + const DominatorTree *DT; + + Query(AssumptionTracker *AT = nullptr, const Instruction *CxtI = nullptr, + const DominatorTree *DT = nullptr) + : AT(AT), CxtI(CxtI), DT(DT) {} + + Query(const Query &Q, const Value *NewExcl) + : ExclInvs(Q.ExclInvs), AT(Q.AT), CxtI(Q.CxtI), DT(Q.DT) { + ExclInvs.insert(NewExcl); + } +}; + +// Given the provided Value and, potentially, a context instruction, returned +// the preferred context instruction (if any). +static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { + // If we've been provided with a context instruction, then use that (provided + // it has been inserted). + if (CxtI && CxtI->getParent()) + return CxtI; + + // If the value is really an already-inserted instruction, then use that. + CxtI = dyn_cast<Instruction>(V); + if (CxtI && CxtI->getParent()) + return CxtI; + + return nullptr; +} + +static void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + AssumptionTracker *AT, const Instruction *CxtI, + const DominatorTree *DT) { + ::computeKnownBits(V, KnownZero, KnownOne, TD, Depth, + Query(AT, safeCxtI(V, CxtI), DT)); +} + +static void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + AssumptionTracker *AT, const Instruction *CxtI, + const DominatorTree *DT) { + ::ComputeSignBit(V, KnownZero, KnownOne, TD, Depth, + Query(AT, safeCxtI(V, CxtI), DT)); +} + +static bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + const Query &Q); + +bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + AssumptionTracker *AT, + const Instruction *CxtI, + const DominatorTree *DT) { + return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth, + Query(AT, safeCxtI(V, CxtI), DT)); +} + +static bool isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + const Query &Q); + +bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + AssumptionTracker *AT, const Instruction *CxtI, + const DominatorTree *DT) { + return ::isKnownNonZero(V, TD, Depth, Query(AT, safeCxtI(V, CxtI), DT)); +} + +static bool MaskedValueIsZero(Value *V, const APInt &Mask, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, + const DataLayout *TD, unsigned Depth, + AssumptionTracker *AT, const Instruction *CxtI, + const DominatorTree *DT) { + return ::MaskedValueIsZero(V, Mask, TD, Depth, + Query(AT, safeCxtI(V, CxtI), DT)); +} + +static unsigned ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, const Query &Q); + +unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, AssumptionTracker *AT, + const Instruction *CxtI, + const DominatorTree *DT) { + return ::ComputeNumSignBits(V, TD, Depth, Query(AT, safeCxtI(V, CxtI), DT)); +} + static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout *TD, unsigned Depth) { + const DataLayout *TD, unsigned Depth, + const Query &Q) { + if (!Add) { + if (ConstantInt *CLHS = dyn_cast<ConstantInt>(Op0)) { + // We know that the top bits of C-X are clear if X contains less bits + // than C (i.e. no wrap-around can happen). For example, 20-X is + // positive if we can prove that X is >= 0 and < 16. + if (!CLHS->getValue().isNegative()) { + unsigned BitWidth = KnownZero.getBitWidth(); + unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros(); + // NLZ can't be BitWidth with no sign bit + APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1); + computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1, Q); + + // If all of the MaskV bits are known to be zero, then we know the + // output top bits are zero, because we now know that the output is + // from [0-C]. + if ((KnownZero2 & MaskV) == MaskV) { + unsigned NLZ2 = CLHS->getValue().countLeadingZeros(); + // Top bits known zero. + KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2); + } + } + } + } + unsigned BitWidth = KnownZero.getBitWidth(); // If an initial sequence of bits in the result is not needed, the // corresponding bits in the operands are not needed. APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - llvm::computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, TD, Depth+1); - llvm::computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, TD, Depth+1, Q); + computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1, Q); // Carry in a 1 for a subtract, rather than a 0. APInt CarryIn(BitWidth, 0); @@ -104,10 +243,11 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout *TD, unsigned Depth) { + const DataLayout *TD, unsigned Depth, + const Query &Q) { unsigned BitWidth = KnownZero.getBitWidth(); - computeKnownBits(Op1, KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(Op0, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(Op1, KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(Op0, KnownZero2, KnownOne2, TD, Depth+1, Q); bool isKnownNegative = false; bool isKnownNonNegative = false; @@ -128,9 +268,9 @@ static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, // negative or zero. if (!isKnownNonNegative) isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && - isKnownNonZero(Op0, TD, Depth)) || + isKnownNonZero(Op0, TD, Depth, Q)) || (isKnownNegativeOp0 && isKnownNonNegativeOp1 && - isKnownNonZero(Op1, TD, Depth)); + isKnownNonZero(Op1, TD, Depth, Q)); } } @@ -182,6 +322,198 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, KnownZero = APInt::getHighBitsSet(BitWidth, MinLeadingZeros); } +static bool isEphemeralValueOf(Instruction *I, const Value *E) { + SmallVector<const Value *, 16> WorkSet(1, I); + SmallPtrSet<const Value *, 32> Visited; + SmallPtrSet<const Value *, 16> EphValues; + + while (!WorkSet.empty()) { + const Value *V = WorkSet.pop_back_val(); + if (!Visited.insert(V)) + continue; + + // If all uses of this value are ephemeral, then so is this value. + bool FoundNEUse = false; + for (const User *I : V->users()) + if (!EphValues.count(I)) { + FoundNEUse = true; + break; + } + + if (!FoundNEUse) { + if (V == E) + return true; + + EphValues.insert(V); + if (const User *U = dyn_cast<User>(V)) + for (User::const_op_iterator J = U->op_begin(), JE = U->op_end(); + J != JE; ++J) { + if (isSafeToSpeculativelyExecute(*J)) + WorkSet.push_back(*J); + } + } + } + + return false; +} + +// Is this an intrinsic that cannot be speculated but also cannot trap? +static bool isAssumeLikeIntrinsic(const Instruction *I) { + if (const CallInst *CI = dyn_cast<CallInst>(I)) + if (Function *F = CI->getCalledFunction()) + switch (F->getIntrinsicID()) { + default: break; + // FIXME: This list is repeated from NoTTI::getIntrinsicCost. + case Intrinsic::assume: + case Intrinsic::dbg_declare: + case Intrinsic::dbg_value: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::objectsize: + case Intrinsic::ptr_annotation: + case Intrinsic::var_annotation: + return true; + } + + return false; +} + +static bool isValidAssumeForContext(Value *V, const Query &Q, + const DataLayout *DL) { + Instruction *Inv = cast<Instruction>(V); + + // There are two restrictions on the use of an assume: + // 1. The assume must dominate the context (or the control flow must + // reach the assume whenever it reaches the context). + // 2. The context must not be in the assume's set of ephemeral values + // (otherwise we will use the assume to prove that the condition + // feeding the assume is trivially true, thus causing the removal of + // the assume). + + if (Q.DT) { + if (Q.DT->dominates(Inv, Q.CxtI)) { + return true; + } else if (Inv->getParent() == Q.CxtI->getParent()) { + // The context comes first, but they're both in the same block. Make sure + // there is nothing in between that might interrupt the control flow. + for (BasicBlock::const_iterator I = + std::next(BasicBlock::const_iterator(Q.CxtI)), + IE(Inv); I != IE; ++I) + if (!isSafeToSpeculativelyExecute(I, DL) && + !isAssumeLikeIntrinsic(I)) + return false; + + return !isEphemeralValueOf(Inv, Q.CxtI); + } + + return false; + } + + // When we don't have a DT, we do a limited search... + if (Inv->getParent() == Q.CxtI->getParent()->getSinglePredecessor()) { + return true; + } else if (Inv->getParent() == Q.CxtI->getParent()) { + // Search forward from the assume until we reach the context (or the end + // of the block); the common case is that the assume will come first. + for (BasicBlock::iterator I = std::next(BasicBlock::iterator(Inv)), + IE = Inv->getParent()->end(); I != IE; ++I) + if (I == Q.CxtI) + return true; + + // The context must come first... + for (BasicBlock::const_iterator I = + std::next(BasicBlock::const_iterator(Q.CxtI)), + IE(Inv); I != IE; ++I) + if (!isSafeToSpeculativelyExecute(I, DL) && + !isAssumeLikeIntrinsic(I)) + return false; + + return !isEphemeralValueOf(Inv, Q.CxtI); + } + + return false; +} + +bool llvm::isValidAssumeForContext(const Instruction *I, + const Instruction *CxtI, + const DataLayout *DL, + const DominatorTree *DT) { + return ::isValidAssumeForContext(const_cast<Instruction*>(I), + Query(nullptr, CxtI, DT), DL); +} + +template<typename LHS, typename RHS> +inline match_combine_or<CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>, + CmpClass_match<RHS, LHS, ICmpInst, ICmpInst::Predicate>> +m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { + return m_CombineOr(m_ICmp(Pred, L, R), m_ICmp(Pred, R, L)); +} + +template<typename LHS, typename RHS> +inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::And>, + BinaryOp_match<RHS, LHS, Instruction::And>> +m_c_And(const LHS &L, const RHS &R) { + return m_CombineOr(m_And(L, R), m_And(R, L)); +} + +static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, + APInt &KnownOne, + const DataLayout *DL, + unsigned Depth, const Query &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AT || !Q.CxtI) + return; + + unsigned BitWidth = KnownZero.getBitWidth(); + + Function *F = const_cast<Function*>(Q.CxtI->getParent()->getParent()); + for (auto &CI : Q.AT->assumptions(F)) { + CallInst *I = CI; + if (Q.ExclInvs.count(I)) + continue; + + if (match(I, m_Intrinsic<Intrinsic::assume>(m_Specific(V))) && + isValidAssumeForContext(I, Q, DL)) { + assert(BitWidth == 1 && "assume operand is not i1?"); + KnownZero.clearAllBits(); + KnownOne.setAllBits(); + return; + } + + Value *A, *B; + auto m_V = m_CombineOr(m_Specific(V), + m_CombineOr(m_PtrToInt(m_Specific(V)), + m_BitCast(m_Specific(V)))); + + CmpInst::Predicate Pred; + // assume(v = a) + if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_V, m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + KnownZero |= RHSKnownZero; + KnownOne |= RHSKnownOne; + // assume(v & b = a) + } else if (match(I, m_Intrinsic<Intrinsic::assume>( + m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); + computeKnownBits(B, MaskKnownZero, MaskKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in the mask that are known to be one, we can propagate + // known bits from the RHS to V. + KnownZero |= RHSKnownZero & MaskKnownOne; + KnownOne |= RHSKnownOne & MaskKnownOne; + } + } +} + /// Determine which bits of V are known to be either zero or one and return /// them in the KnownZero/KnownOne bit sets. /// @@ -197,8 +529,9 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, /// where V is a vector, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. -void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - const DataLayout *TD, unsigned Depth) { +void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q) { assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); unsigned BitWidth = KnownZero.getBitWidth(); @@ -274,7 +607,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (GA->mayBeOverridden()) { KnownZero.clearAllBits(); KnownOne.clearAllBits(); } else { - computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, TD, Depth+1, Q); } return; } @@ -291,6 +624,10 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (Align) KnownZero = APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); + + // Don't give up yet... there might be an assumption that provides more + // information... + computeKnownBitsFromAssume(V, KnownZero, KnownOne, TD, Depth, Q); return; } @@ -300,6 +637,9 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (Depth == MaxDepth) return; // Limit search depth. + // Check whether a nearby assume intrinsic can determine some known bits. + computeKnownBitsFromAssume(V, KnownZero, KnownOne, TD, Depth, Q); + Operator *I = dyn_cast<Operator>(V); if (!I) return; @@ -312,8 +652,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-1 bits are only known if set in both the LHS & RHS. KnownOne &= KnownOne2; @@ -322,8 +662,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Or: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-0 bits are only known if clear in both the LHS & RHS. KnownZero &= KnownZero2; @@ -332,8 +672,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Xor: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-0 bits are known if clear or set in both the LHS & RHS. APInt KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2); @@ -345,19 +685,20 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, case Instruction::Mul: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, - KnownZero, KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownZero, KnownOne, KnownZero2, KnownOne2, TD, + Depth, Q); break; } case Instruction::UDiv: { // For the purposes of computing leading zeros we can conservatively // treat a udiv as a logical right shift by the power of 2 known to // be less than the denominator. - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned LeadZ = KnownZero2.countLeadingOnes(); KnownOne2.clearAllBits(); KnownZero2.clearAllBits(); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned RHSUnknownLeadingOnes = KnownOne2.countLeadingZeros(); if (RHSUnknownLeadingOnes != BitWidth) LeadZ = std::min(BitWidth, @@ -367,9 +708,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Select: - computeKnownBits(I->getOperand(2), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, - Depth+1); + computeKnownBits(I->getOperand(2), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); // Only known if known in both the LHS and RHS. KnownOne &= KnownOne2; @@ -405,7 +745,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, assert(SrcBitWidth && "SrcBitWidth can't be zero"); KnownZero = KnownZero.zextOrTrunc(SrcBitWidth); KnownOne = KnownOne.zextOrTrunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = KnownZero.zextOrTrunc(BitWidth); KnownOne = KnownOne.zextOrTrunc(BitWidth); // Any top bits are known to be zero. @@ -419,7 +759,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // TODO: For now, not handling conversions like: // (bitcast i64 %x to <2 x i32>) !I->getType()->isVectorTy()) { - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); break; } break; @@ -430,7 +770,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, KnownZero = KnownZero.trunc(SrcBitWidth); KnownOne = KnownOne.trunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = KnownZero.zext(BitWidth); KnownOne = KnownOne.zext(BitWidth); @@ -446,7 +786,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero <<= ShiftAmt; KnownOne <<= ShiftAmt; KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); // low bits known 0 @@ -460,7 +800,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); // Unsigned shift right. - computeKnownBits(I->getOperand(0), KnownZero,KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero,KnownOne, TD, Depth+1, Q); KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); // high bits known zero. @@ -475,7 +815,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); // Signed shift right. - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); @@ -491,14 +831,14 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, KnownZero, KnownOne, KnownZero2, KnownOne2, TD, - Depth); + Depth, Q); break; } case Instruction::Add: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, KnownZero, KnownOne, KnownZero2, KnownOne2, TD, - Depth); + Depth, Q); break; } case Instruction::SRem: @@ -506,7 +846,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, APInt RA = Rem->getValue().abs(); if (RA.isPowerOf2()) { APInt LowBits = RA - 1; - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, + Depth+1, Q); // The low bits of the first operand are unchanged by the srem. KnownZero = KnownZero2 & LowBits; @@ -531,7 +872,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (KnownZero.isNonNegative()) { APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, TD, - Depth+1); + Depth+1, Q); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) KnownZero.setBit(BitWidth - 1); @@ -544,7 +885,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (RA.isPowerOf2()) { APInt LowBits = (RA - 1); computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, - Depth+1); + Depth+1, Q); KnownZero |= ~LowBits; KnownOne &= LowBits; break; @@ -553,8 +894,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // Since the result is less than or equal to either operand, any leading // zero bits in either operand must also exist in the result. - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned Leaders = std::max(KnownZero.countLeadingOnes(), KnownZero2.countLeadingOnes()); @@ -578,7 +919,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // to determine if we can prove known low zero bits. APInt LocalKnownZero(BitWidth, 0), LocalKnownOne(BitWidth, 0); computeKnownBits(I->getOperand(0), LocalKnownZero, LocalKnownOne, TD, - Depth+1); + Depth+1, Q); unsigned TrailZ = LocalKnownZero.countTrailingOnes(); gep_type_iterator GTI = gep_type_begin(I); @@ -614,7 +955,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); uint64_t TypeSize = TD ? TD->getTypeAllocSize(IndexedTy) : 1; LocalKnownZero = LocalKnownOne = APInt(GEPOpiBits, 0); - computeKnownBits(Index, LocalKnownZero, LocalKnownOne, TD, Depth+1); + computeKnownBits(Index, LocalKnownZero, LocalKnownOne, TD, Depth+1, Q); TrailZ = std::min(TrailZ, unsigned(countTrailingZeros(TypeSize) + LocalKnownZero.countTrailingOnes())); @@ -656,11 +997,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; // Ok, we have a PHI of the form L op= R. Check for low // zero bits. - computeKnownBits(R, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(R, KnownZero2, KnownOne2, TD, Depth+1, Q); // We need to take the minimum number of known bits APInt KnownZero3(KnownZero), KnownOne3(KnownOne); - computeKnownBits(L, KnownZero3, KnownOne3, TD, Depth+1); + computeKnownBits(L, KnownZero3, KnownOne3, TD, Depth+1, Q); KnownZero = APInt::getLowBitsSet(BitWidth, std::min(KnownZero2.countTrailingOnes(), @@ -692,7 +1033,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // Recurse, but cap the recursion to one level, because we don't // want to waste time spinning around in loops. computeKnownBits(P->getIncomingValue(i), KnownZero2, KnownOne2, TD, - MaxDepth-1); + MaxDepth-1, Q); KnownZero &= KnownZero2; KnownOne &= KnownOne2; // If all bits have been ruled out, there's no need to check @@ -744,19 +1085,19 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, case Intrinsic::sadd_with_overflow: computeKnownBitsAddSub(true, II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownOne, KnownZero2, KnownOne2, TD, Depth, Q); break; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: computeKnownBitsAddSub(false, II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownOne, KnownZero2, KnownOne2, TD, Depth, Q); break; case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, KnownOne, - KnownZero2, KnownOne2, TD, Depth); + KnownZero2, KnownOne2, TD, Depth, Q); break; } } @@ -768,8 +1109,9 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, /// ComputeSignBit - Determine whether the sign bit is known to be zero or /// one. Convenience wrapper around computeKnownBits. -void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, - const DataLayout *TD, unsigned Depth) { +void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q) { unsigned BitWidth = getBitWidth(V->getType(), TD); if (!BitWidth) { KnownZero = false; @@ -778,7 +1120,7 @@ void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, } APInt ZeroBits(BitWidth, 0); APInt OneBits(BitWidth, 0); - computeKnownBits(V, ZeroBits, OneBits, TD, Depth); + computeKnownBits(V, ZeroBits, OneBits, TD, Depth, Q); KnownOne = OneBits[BitWidth - 1]; KnownZero = ZeroBits[BitWidth - 1]; } @@ -787,7 +1129,8 @@ void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, /// bit set when defined. For vectors return true if every element is known to /// be a power of two when defined. Supports values with integer or pointer /// types and vectors of integers. -bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { +bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + const Query &Q) { if (Constant *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return OrZero; @@ -814,19 +1157,20 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { // A shift of a power of two is a power of two or zero. if (OrZero && (match(V, m_Shl(m_Value(X), m_Value())) || match(V, m_Shr(m_Value(X), m_Value())))) - return isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth); + return isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth, Q); if (ZExtInst *ZI = dyn_cast<ZExtInst>(V)) - return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth); + return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth, Q); if (SelectInst *SI = dyn_cast<SelectInst>(V)) - return isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth) && - isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth); + return + isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth, Q) && + isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth, Q); if (OrZero && match(V, m_And(m_Value(X), m_Value(Y)))) { // A power of two and'd with anything is a power of two or zero. - if (isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth) || - isKnownToBeAPowerOfTwo(Y, /*OrZero*/true, Depth)) + if (isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth, Q) || + isKnownToBeAPowerOfTwo(Y, /*OrZero*/true, Depth, Q)) return true; // X & (-X) is always a power of two or zero. if (match(X, m_Neg(m_Specific(Y))) || match(Y, m_Neg(m_Specific(X)))) @@ -841,19 +1185,19 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { if (OrZero || VOBO->hasNoUnsignedWrap() || VOBO->hasNoSignedWrap()) { if (match(X, m_And(m_Specific(Y), m_Value())) || match(X, m_And(m_Value(), m_Specific(Y)))) - if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth)) + if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q)) return true; if (match(Y, m_And(m_Specific(X), m_Value())) || match(Y, m_And(m_Value(), m_Specific(X)))) - if (isKnownToBeAPowerOfTwo(X, OrZero, Depth)) + if (isKnownToBeAPowerOfTwo(X, OrZero, Depth, Q)) return true; unsigned BitWidth = V->getType()->getScalarSizeInBits(); APInt LHSZeroBits(BitWidth, 0), LHSOneBits(BitWidth, 0); - computeKnownBits(X, LHSZeroBits, LHSOneBits, nullptr, Depth); + computeKnownBits(X, LHSZeroBits, LHSOneBits, nullptr, Depth, Q); APInt RHSZeroBits(BitWidth, 0), RHSOneBits(BitWidth, 0); - computeKnownBits(Y, RHSZeroBits, RHSOneBits, nullptr, Depth); + computeKnownBits(Y, RHSZeroBits, RHSOneBits, nullptr, Depth, Q); // If i8 V is a power of two or zero: // ZeroBits: 1 1 1 0 1 1 1 1 // ~ZeroBits: 0 0 0 1 0 0 0 0 @@ -870,7 +1214,8 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { // copying a sign bit (sdiv int_min, 2). if (match(V, m_Exact(m_LShr(m_Value(), m_Value()))) || match(V, m_Exact(m_UDiv(m_Value(), m_Value())))) { - return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, Depth); + return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, + Depth, Q); } return false; @@ -883,7 +1228,7 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { /// /// Currently this routine does not support vector GEPs. static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, - unsigned Depth) { + unsigned Depth, const Query &Q) { if (!GEP->isInBounds() || GEP->getPointerAddressSpace() != 0) return false; @@ -892,7 +1237,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, // If the base pointer is non-null, we cannot walk to a null address with an // inbounds GEP in address space zero. - if (isKnownNonZero(GEP->getPointerOperand(), DL, Depth)) + if (isKnownNonZero(GEP->getPointerOperand(), DL, Depth, Q)) return true; // Past this, if we don't have DataLayout, we can't do much. @@ -935,7 +1280,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, if (Depth++ >= MaxDepth) continue; - if (isKnownNonZero(GTI.getOperand(), DL, Depth)) + if (isKnownNonZero(GTI.getOperand(), DL, Depth, Q)) return true; } @@ -946,7 +1291,8 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, /// when defined. For vectors return true if every element is known to be /// non-zero when defined. Supports values with integer or pointer type and /// vectors of integers. -bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { +bool isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + const Query &Q) { if (Constant *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return false; @@ -966,7 +1312,7 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { if (isKnownNonNull(V)) return true; if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) - if (isGEPKnownNonNull(GEP, TD, Depth)) + if (isGEPKnownNonNull(GEP, TD, Depth, Q)) return true; } @@ -975,11 +1321,12 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // X | Y != 0 if X != 0 or Y != 0. Value *X = nullptr, *Y = nullptr; if (match(V, m_Or(m_Value(X), m_Value(Y)))) - return isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q) || + isKnownNonZero(Y, TD, Depth, Q); // ext X != 0 if X != 0. if (isa<SExtInst>(V) || isa<ZExtInst>(V)) - return isKnownNonZero(cast<Instruction>(V)->getOperand(0), TD, Depth); + return isKnownNonZero(cast<Instruction>(V)->getOperand(0), TD, Depth, Q); // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined // if the lowest bit is shifted off the end. @@ -987,11 +1334,11 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // shl nuw can't remove any non-zero bits. OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); if (BO->hasNoUnsignedWrap()) - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(X, KnownZero, KnownOne, TD, Depth); + computeKnownBits(X, KnownZero, KnownOne, TD, Depth, Q); if (KnownOne[0]) return true; } @@ -1001,28 +1348,29 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // shr exact can only shift out zero bits. PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V); if (BO->isExact()) - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); bool XKnownNonNegative, XKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth); + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth, Q); if (XKnownNegative) return true; } // div exact can only produce a zero if the dividend is zero. else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) { - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); } // X + Y. else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { bool XKnownNonNegative, XKnownNegative; bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, TD, Depth); + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth, Q); + ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, TD, Depth, Q); // If X and Y are both non-negative (as signed values) then their sum is not // zero unless both X and Y are zero. if (XKnownNonNegative && YKnownNonNegative) - if (isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth)) + if (isKnownNonZero(X, TD, Depth, Q) || + isKnownNonZero(Y, TD, Depth, Q)) return true; // If X and Y are both negative (as signed values) then their sum is not @@ -1033,20 +1381,22 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { APInt Mask = APInt::getSignedMaxValue(BitWidth); // The sign bit of X is set. If some other bit is set then X is not equal // to INT_MIN. - computeKnownBits(X, KnownZero, KnownOne, TD, Depth); + computeKnownBits(X, KnownZero, KnownOne, TD, Depth, Q); if ((KnownOne & Mask) != 0) return true; // The sign bit of Y is set. If some other bit is set then Y is not equal // to INT_MIN. - computeKnownBits(Y, KnownZero, KnownOne, TD, Depth); + computeKnownBits(Y, KnownZero, KnownOne, TD, Depth, Q); if ((KnownOne & Mask) != 0) return true; } // The sum of a non-negative number and a power of two is not zero. - if (XKnownNonNegative && isKnownToBeAPowerOfTwo(Y, /*OrZero*/false, Depth)) + if (XKnownNonNegative && + isKnownToBeAPowerOfTwo(Y, /*OrZero*/false, Depth, Q)) return true; - if (YKnownNonNegative && isKnownToBeAPowerOfTwo(X, /*OrZero*/false, Depth)) + if (YKnownNonNegative && + isKnownToBeAPowerOfTwo(X, /*OrZero*/false, Depth, Q)) return true; } // X * Y. @@ -1055,20 +1405,21 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // If X and Y are non-zero then so is X * Y as long as the multiplication // does not overflow. if ((BO->hasNoSignedWrap() || BO->hasNoUnsignedWrap()) && - isKnownNonZero(X, TD, Depth) && isKnownNonZero(Y, TD, Depth)) + isKnownNonZero(X, TD, Depth, Q) && + isKnownNonZero(Y, TD, Depth, Q)) return true; } // (C ? X : Y) != 0 if X != 0 and Y != 0. else if (SelectInst *SI = dyn_cast<SelectInst>(V)) { - if (isKnownNonZero(SI->getTrueValue(), TD, Depth) && - isKnownNonZero(SI->getFalseValue(), TD, Depth)) + if (isKnownNonZero(SI->getTrueValue(), TD, Depth, Q) && + isKnownNonZero(SI->getFalseValue(), TD, Depth, Q)) return true; } if (!BitWidth) return false; APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); return KnownOne != 0; } @@ -1081,10 +1432,11 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { /// where V is a vector, the mask, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. -bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, - const DataLayout *TD, unsigned Depth) { +bool MaskedValueIsZero(Value *V, const APInt &Mask, + const DataLayout *TD, unsigned Depth, + const Query &Q) { APInt KnownZero(Mask.getBitWidth(), 0), KnownOne(Mask.getBitWidth(), 0); - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); return (KnownZero & Mask) == Mask; } @@ -1098,8 +1450,8 @@ bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, /// /// 'Op' must have a scalar integer type. /// -unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, - unsigned Depth) { +unsigned ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, const Query &Q) { assert((TD || V->getType()->isIntOrIntVectorTy()) && "ComputeNumSignBits requires a DataLayout object to operate " "on non-integer values!"); @@ -1120,10 +1472,10 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, default: break; case Instruction::SExt: Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits(); - return ComputeNumSignBits(U->getOperand(0), TD, Depth+1) + Tmp; + return ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q) + Tmp; case Instruction::AShr: { - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); // ashr X, C -> adds C sign bits. Vectors too. const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { @@ -1136,7 +1488,7 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { // shl destroys sign bits. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); Tmp2 = ShAmt->getZExtValue(); if (Tmp2 >= TyBits || // Bad shift. Tmp2 >= Tmp) break; // Shifted all sign bits out. @@ -1148,9 +1500,9 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, case Instruction::Or: case Instruction::Xor: // NOT is handled here. // Logical binary ops preserve the number of sign bits at the worst. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp != 1) { - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); FirstAnswer = std::min(Tmp, Tmp2); // We computed what we know about the sign bits as our first // answer. Now proceed to the generic code that uses @@ -1159,22 +1511,22 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, break; case Instruction::Select: - Tmp = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. - Tmp2 = ComputeNumSignBits(U->getOperand(2), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(2), TD, Depth+1, Q); return std::min(Tmp, Tmp2); case Instruction::Add: // Add can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. // Special case decrementing a value (ADD X, -1): if (ConstantInt *CRHS = dyn_cast<ConstantInt>(U->getOperand(1))) if (CRHS->isAllOnesValue()) { APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. @@ -1187,19 +1539,19 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, return Tmp; } - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp2 == 1) return 1; return std::min(Tmp, Tmp2)-1; case Instruction::Sub: - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp2 == 1) return 1; // Handle NEG. if (ConstantInt *CLHS = dyn_cast<ConstantInt>(U->getOperand(0))) if (CLHS->isNullValue()) { APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(1), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(U->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. if ((KnownZero | APInt(TyBits, 1)).isAllOnesValue()) @@ -1215,7 +1567,7 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, // Sub can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. return std::min(Tmp, Tmp2)-1; @@ -1226,11 +1578,12 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, // Take the minimum of all incoming values. This can't infinitely loop // because of our depth threshold. - Tmp = ComputeNumSignBits(PN->getIncomingValue(0), TD, Depth+1); + Tmp = ComputeNumSignBits(PN->getIncomingValue(0), TD, Depth+1, Q); for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) { if (Tmp == 1) return Tmp; Tmp = std::min(Tmp, - ComputeNumSignBits(PN->getIncomingValue(i), TD, Depth+1)); + ComputeNumSignBits(PN->getIncomingValue(i), TD, + Depth+1, Q)); } return Tmp; } @@ -1245,7 +1598,7 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, // use this information. APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); APInt Mask; - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); if (KnownZero.isNegative()) { // sign bit is 0 Mask = KnownZero; @@ -1881,7 +2234,7 @@ llvm::GetUnderlyingObject(Value *V, const DataLayout *TD, unsigned MaxLookup) { } else { // See if InstructionSimplify knows any relevant tricks. if (Instruction *I = dyn_cast<Instruction>(V)) - // TODO: Acquire a DominatorTree and use it. + // TODO: Acquire a DominatorTree and AssumptionTracker and use them. if (Value *Simplified = SimplifyInstruction(I, TD, nullptr)) { V = Simplified; continue; diff --git a/llvm/lib/Transforms/InstCombine/InstCombine.h b/llvm/lib/Transforms/InstCombine/InstCombine.h index c56dc3c8684..0c3954f4c40 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombine.h +++ b/llvm/lib/Transforms/InstCombine/InstCombine.h @@ -27,6 +27,7 @@ namespace llvm { class CallSite; class DataLayout; +class DominatorTree; class TargetLibraryInfo; class DbgDeclareInst; class MemIntrinsic; @@ -97,6 +98,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner AssumptionTracker *AT; const DataLayout *DL; TargetLibraryInfo *TLI; + DominatorTree *DT; // not required bool MadeIRChange; LibCallSimplifier *Simplifier; bool MinimizeSize; @@ -126,6 +128,8 @@ public: AssumptionTracker *getAssumptionTracker() const { return AT; } const DataLayout *getDataLayout() const { return DL; } + + DominatorTree *getDominatorTree() const { return DT; } TargetLibraryInfo *getTargetLibraryInfo() const { return TLI; } @@ -159,7 +163,7 @@ public: Value *FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); Value *FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *visitAnd(BinaryOperator &I); - Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI); Value *FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *FoldOrWithConstants(BinaryOperator &I, Value *Op, Value *A, Value *B, Value *C); @@ -261,10 +265,10 @@ private: Instruction *transformZExtICmp(ICmpInst *ICI, Instruction &CI, bool DoXform = true); Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); - bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS); - bool WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS); - bool WillNotOverflowSignedSub(Value *LHS, Value *RHS); - bool WillNotOverflowUnsignedSub(Value *LHS, Value *RHS); + bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowSignedSub(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, Instruction *CxtI); Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); @@ -333,16 +337,19 @@ public: } void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - unsigned Depth = 0) const { - return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth); + unsigned Depth = 0, Instruction *CxtI = nullptr) const { + return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, + AT, CxtI, DT); } bool MaskedValueIsZero(Value *V, const APInt &Mask, - unsigned Depth = 0) const { - return llvm::MaskedValueIsZero(V, Mask, DL, Depth); + unsigned Depth = 0, + Instruction *CxtI = nullptr) const { + return llvm::MaskedValueIsZero(V, Mask, DL, Depth, AT, CxtI, DT); } - unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0) const { - return llvm::ComputeNumSignBits(Op, DL, Depth); + unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0, + Instruction *CxtI = nullptr) const { + return llvm::ComputeNumSignBits(Op, DL, Depth, AT, CxtI, DT); } private: @@ -360,7 +367,8 @@ private: /// SimplifyDemandedUseBits - Attempts to replace V with a simpler value /// based on the demanded bits. Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, - APInt &KnownOne, unsigned Depth); + APInt &KnownOne, unsigned Depth, + Instruction *CxtI = nullptr); bool SimplifyDemandedBits(Use &U, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth = 0); /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 600e09433cc..6287536b3e9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -895,7 +895,8 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero, /// This basically requires proving that the add in the original type would not /// overflow to change the sign bit or have a carry out. /// TODO: Handle this for Vectors. -bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { +bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS, + Instruction *CxtI) { // There are different heuristics we can use for this. Here are some simple // ones. @@ -913,18 +914,19 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { // // Since the carry into the most significant position is always equal to // the carry out of the addition, there is no signed overflow. - if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1) + if (ComputeNumSignBits(LHS, 0, CxtI) > 1 && + ComputeNumSignBits(RHS, 0, CxtI) > 1) return true; if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) { int BitWidth = IT->getBitWidth(); APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); APInt RHSKnownZero(BitWidth, 0); APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); // Addition of two 2's compliment numbers having opposite signs will never // overflow. @@ -943,13 +945,14 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { /// WillNotOverflowUnsignedAdd - Return true if we can prove that: /// (zext (add LHS, RHS)) === (add (zext LHS), (zext RHS)) -bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) { +bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS, + Instruction *CxtI) { // There are different heuristics we can use for this. Here is a simple one. // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap. bool LHSKnownNonNegative, LHSKnownNegative; bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0); - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT); + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT); if (LHSKnownNonNegative && RHSKnownNonNegative) return true; @@ -961,21 +964,23 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) { /// This basically requires proving that the add in the original type would not /// overflow to change the sign bit or have a carry out. /// TODO: Handle this for Vectors. -bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS) { +bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS, + Instruction *CxtI) { // If LHS and RHS each have at least two sign bits, the subtraction // cannot overflow. - if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1) + if (ComputeNumSignBits(LHS, 0, CxtI) > 1 && + ComputeNumSignBits(RHS, 0, CxtI) > 1) return true; if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) { unsigned BitWidth = IT->getBitWidth(); APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); APInt RHSKnownZero(BitWidth, 0); APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); // Subtraction of two 2's compliment numbers having identical signs will // never overflow. @@ -990,12 +995,13 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS) { /// \brief Return true if we can prove that: /// (sub LHS, RHS) === (sub nuw LHS, RHS) -bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS) { +bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, + Instruction *CxtI) { // If the LHS is negative and the RHS is non-negative, no unsigned wrap. bool LHSKnownNonNegative, LHSKnownNegative; bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0); - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT); + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT); if (LHSKnownNegative && RHSKnownNonNegative) return true; @@ -1071,7 +1077,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL)) + I.hasNoUnsignedWrap(), DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A*B)+(A*C) -> A*(B+C) etc @@ -1110,7 +1116,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (ExtendAmt) { APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); - if (!MaskedValueIsZero(XorLHS, Mask)) + if (!MaskedValueIsZero(XorLHS, Mask, 0, &I)) ExtendAmt = 0; } @@ -1126,7 +1132,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { IntegerType *IT = cast<IntegerType>(I.getType()); APInt LHSKnownOne(IT->getBitWidth(), 0); APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne, 0, &I); if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue()) return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI), XorLHS); @@ -1179,11 +1185,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) { APInt LHSKnownOne(IT->getBitWidth(), 0); APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &I); if (LHSKnownZero != 0) { APInt RHSKnownOne(IT->getBitWidth(), 0); APInt RHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &I); // No bits in common -> bitwise or. if ((LHSKnownZero|RHSKnownZero).isAllOnesValue()) @@ -1261,7 +1267,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSExt(CI, I.getType()) == RHSC && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) { // Insert the new, smaller add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); @@ -1277,7 +1283,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0))) { + RHSConv->getOperand(0), &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); @@ -1325,11 +1331,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // TODO(jingyue): Consider WillNotOverflowSignedAdd and // WillNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. - if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS)) { + if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, &I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS)) { + if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS, &I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -1344,7 +1350,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL)) + if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, + TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(RHS)) { @@ -1386,7 +1393,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSIToFP(CI, I.getType()) == CFP && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); @@ -1402,7 +1409,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0))) { + RHSConv->getOperand(0), &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0),"addconv"); @@ -1523,7 +1530,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL)) + I.hasNoUnsignedWrap(), DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc @@ -1673,11 +1680,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { } bool Changed = false; - if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1)) { + if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, &I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1)) { + if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, &I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -1691,7 +1698,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL)) + if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, + TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(Op0)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index ed6253532e0..0a9bb2d1087 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -355,7 +355,7 @@ Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth(); APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); - if (MaskedValueIsZero(RHS, Mask)) + if (MaskedValueIsZero(RHS, Mask, 0, &I)) break; } } @@ -1108,7 +1108,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, DL)) + if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A|B)&(A|C) -> A|(B&C) etc @@ -1135,14 +1135,14 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (!Op0I->hasOneUse()) break; APInt NotAndRHS(~AndRHSMask); - if (MaskedValueIsZero(Op0LHS, NotAndRHS)) { + if (MaskedValueIsZero(Op0LHS, NotAndRHS, 0, &I)) { // Not masking anything out for the LHS, move to RHS. Value *NewRHS = Builder->CreateAnd(Op0RHS, AndRHS, Op0RHS->getName()+".masked"); return BinaryOperator::Create(Op0I->getOpcode(), Op0LHS, NewRHS); } if (!isa<Constant>(Op0RHS) && - MaskedValueIsZero(Op0RHS, NotAndRHS)) { + MaskedValueIsZero(Op0RHS, NotAndRHS, 0, &I)) { // Not masking anything out for the RHS, move to LHS. Value *NewLHS = Builder->CreateAnd(Op0LHS, AndRHS, Op0LHS->getName()+".masked"); @@ -1175,7 +1175,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { uint32_t Zeros = AndRHSMask.countLeadingZeros(); APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros); - if (MaskedValueIsZero(Op0LHS, Mask)) { + if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) { Value *NewNeg = Builder->CreateNeg(Op0RHS); return BinaryOperator::CreateAnd(NewNeg, AndRHS); } @@ -1584,7 +1584,8 @@ static Instruction *MatchSelectFromAndOr(Value *A, Value *B, } /// FoldOrOfICmps - Fold (icmp)|(icmp) if possible. -Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { +Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction *CxtI) { ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) @@ -1604,13 +1605,15 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *Mask = nullptr; Value *Masked = nullptr; if (LAnd->getOperand(0) == RAnd->getOperand(0) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(1)) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(1))) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(1), false, 0, AT, CxtI, DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(1), false, 0, AT, CxtI, DT)) { Mask = Builder->CreateOr(LAnd->getOperand(1), RAnd->getOperand(1)); Masked = Builder->CreateAnd(LAnd->getOperand(0), Mask); } else if (LAnd->getOperand(1) == RAnd->getOperand(1) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(0)) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(0))) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(0), + false, 0, AT, CxtI, DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(0), + false, 0, AT, CxtI, DT)) { Mask = Builder->CreateOr(LAnd->getOperand(0), RAnd->getOperand(0)); Masked = Builder->CreateAnd(LAnd->getOperand(1), Mask); } @@ -2030,7 +2033,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, DL)) + if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A&B)|(A&C) -> A&(B|C) etc @@ -2090,7 +2093,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // (X^C)|Y -> (X|Y)^C iff Y&C == 0 if (Op0->hasOneUse() && match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op1, C1->getValue())) { + MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) { Value *NOr = Builder->CreateOr(A, Op1); NOr->takeName(Op0); return BinaryOperator::CreateXor(NOr, C1); @@ -2099,7 +2102,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // Y|(X^C) -> (X|Y)^C iff Y&C == 0 if (Op1->hasOneUse() && match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op0, C1->getValue())) { + MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) { Value *NOr = Builder->CreateOr(A, Op0); NOr->takeName(Op0); return BinaryOperator::CreateXor(NOr, C1); @@ -2137,14 +2140,18 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) // iff (C1&C2) == 0 and (N&~C1) == 0 if (match(A, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == B && MaskedValueIsZero(V2, ~C1->getValue())) || // (V|N) - (V2 == B && MaskedValueIsZero(V1, ~C1->getValue())))) // (N|V) + ((V1 == B && + MaskedValueIsZero(V2, ~C1->getValue(), 0, &I)) || // (V|N) + (V2 == B && + MaskedValueIsZero(V1, ~C1->getValue(), 0, &I)))) // (N|V) return BinaryOperator::CreateAnd(A, Builder->getInt(C1->getValue()|C2->getValue())); // Or commutes, try both ways. if (match(B, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == A && MaskedValueIsZero(V2, ~C2->getValue())) || // (V|N) - (V2 == A && MaskedValueIsZero(V1, ~C2->getValue())))) // (N|V) + ((V1 == A && + MaskedValueIsZero(V2, ~C2->getValue(), 0, &I)) || // (V|N) + (V2 == A && + MaskedValueIsZero(V1, ~C2->getValue(), 0, &I)))) // (N|V) return BinaryOperator::CreateAnd(B, Builder->getInt(C1->getValue()|C2->getValue())); @@ -2300,7 +2307,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) - if (Value *Res = FoldOrOfICmps(LHS, RHS)) + if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) return ReplaceInstUsesWith(I, Res); // (fcmp uno x, c) | (fcmp uno y, c) -> (fcmp uno x, y) @@ -2331,7 +2338,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // cast is otherwise not optimizable. This happens for vector sexts. if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1COp)) if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0COp)) - if (Value *Res = FoldOrOfICmps(LHS, RHS)) + if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); // If this is or(cast(fcmp), cast(fcmp)), try to fold this even if the @@ -2387,7 +2394,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, DL)) + if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // (A&B)^(A&C) -> A&(B^C) etc @@ -2489,7 +2496,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } else if (Op0I->getOpcode() == Instruction::Or) { // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 - if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue())) { + if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(), + 0, &I)) { Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); // Anything in both C1 and C2 is known to be zero, remove it from // NewRHS. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 92f38fc19c8..30fc3c93383 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -58,8 +58,8 @@ static Type *reduceToSingleValueType(Type *T) { } Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { - unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL); - unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL); + unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, AT, MI, DT); + unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, AT, MI, DT); unsigned MinAlign = std::min(DstAlign, SrcAlign); unsigned CopyAlign = MI->getAlignment(); @@ -154,7 +154,7 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { } Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { - unsigned Alignment = getKnownAlignment(MI->getDest(), DL); + unsigned Alignment = getKnownAlignment(MI->getDest(), DL, AT, MI, DT); if (MI->getAlignment() < Alignment) { MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), Alignment, false)); @@ -322,7 +322,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { uint32_t BitWidth = IT->getBitWidth(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne); + computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); unsigned TrailingZeros = KnownOne.countTrailingZeros(); APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros)); if ((Mask & KnownZero) == Mask) @@ -340,7 +340,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { uint32_t BitWidth = IT->getBitWidth(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne); + computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); unsigned LeadingZeros = KnownOne.countLeadingZeros(); APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros)); if ((Mask & KnownZero) == Mask) @@ -355,14 +355,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { uint32_t BitWidth = IT->getBitWidth(); APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, II); bool LHSKnownNegative = LHSKnownOne[BitWidth - 1]; bool LHSKnownPositive = LHSKnownZero[BitWidth - 1]; if (LHSKnownNegative || LHSKnownPositive) { APInt RHSKnownZero(BitWidth, 0); APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, II); bool RHSKnownNegative = RHSKnownOne[BitWidth - 1]; bool RHSKnownPositive = RHSKnownZero[BitWidth - 1]; if (LHSKnownNegative && RHSKnownNegative) { @@ -426,7 +426,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // can prove that it will never overflow. if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow) { Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - if (WillNotOverflowSignedAdd(LHS, RHS)) { + if (WillNotOverflowSignedAdd(LHS, RHS, II)) { Value *Add = Builder->CreateNSWAdd(LHS, RHS); Add->takeName(&CI); Constant *V[] = {UndefValue::get(Add->getType()), Builder->getFalse()}; @@ -464,10 +464,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, II); APInt RHSKnownZero(BitWidth, 0); APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, II); // Get the largest possible values for each operand. APInt LHSMax = ~LHSKnownZero; @@ -521,7 +521,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ppc_altivec_lvx: case Intrinsic::ppc_altivec_lvxl: // Turn PPC lvx -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, + DL, AT, II, DT) >= 16) { Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); @@ -530,7 +531,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: // Turn stvx -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, + DL, AT, II, DT) >= 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); @@ -541,7 +543,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_storeu_pd: case Intrinsic::x86_sse2_storeu_dq: // Turn X86 storeu -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, + DL, AT, II, DT) >= 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(1)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), OpPtrTy); @@ -886,7 +889,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::arm_neon_vst2lane: case Intrinsic::arm_neon_vst3lane: case Intrinsic::arm_neon_vst4lane: { - unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL); + unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL, AT, II, DT); unsigned AlignArg = II->getNumArgOperands() - 1; ConstantInt *IntrAlign = dyn_cast<ConstantInt>(II->getArgOperand(AlignArg)); if (IntrAlign && IntrAlign->getZExtValue() < MemAlign) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index b9c3d0f6471..c16992ff626 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -335,7 +335,8 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { /// /// This function works on both vectors and scalars. /// -static bool CanEvaluateTruncated(Value *V, Type *Ty) { +static bool CanEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, + Instruction *CxtI) { // We can always evaluate constants in another type. if (isa<Constant>(V)) return true; @@ -364,8 +365,8 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { case Instruction::Or: case Instruction::Xor: // These operators can all arbitrarily be extended or truncated. - return CanEvaluateTruncated(I->getOperand(0), Ty) && - CanEvaluateTruncated(I->getOperand(1), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + CanEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); case Instruction::UDiv: case Instruction::URem: { @@ -374,10 +375,10 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { uint32_t BitWidth = Ty->getScalarSizeInBits(); if (BitWidth < OrigBitWidth) { APInt Mask = APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth); - if (MaskedValueIsZero(I->getOperand(0), Mask) && - MaskedValueIsZero(I->getOperand(1), Mask)) { - return CanEvaluateTruncated(I->getOperand(0), Ty) && - CanEvaluateTruncated(I->getOperand(1), Ty); + if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && + IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + CanEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); } } break; @@ -388,7 +389,7 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { uint32_t BitWidth = Ty->getScalarSizeInBits(); if (CI->getLimitedValue(BitWidth) < BitWidth) - return CanEvaluateTruncated(I->getOperand(0), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } break; case Instruction::LShr: @@ -398,10 +399,10 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (MaskedValueIsZero(I->getOperand(0), - APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) && CI->getLimitedValue(BitWidth) < BitWidth) { - return CanEvaluateTruncated(I->getOperand(0), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } } break; @@ -415,8 +416,8 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { return true; case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateTruncated(SI->getTrueValue(), Ty) && - CanEvaluateTruncated(SI->getFalseValue(), Ty); + return CanEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) && + CanEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -424,7 +425,7 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateTruncated(PN->getIncomingValue(i), Ty)) + if (!CanEvaluateTruncated(PN->getIncomingValue(i), Ty, IC, CxtI)) return false; return true; } @@ -453,7 +454,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // expression tree to something weird like i93 unless the source is also // strange. if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && - CanEvaluateTruncated(Src, DestTy)) { + CanEvaluateTruncated(Src, DestTy, *this, &CI)) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. @@ -553,7 +554,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, // If Op1C some other power of two, convert: uint32_t BitWidth = Op1C->getType()->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne, 0, &CI); APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? @@ -601,8 +602,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0); APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0); - computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS); - computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS); + computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS, 0, &CI); + computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS, 0, &CI); if (KnownZeroLHS == KnownZeroRHS && KnownOneLHS == KnownOneRHS) { APInt KnownBits = KnownZeroLHS | KnownOneLHS; @@ -651,7 +652,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, /// clear the top bits anyway, doing this has no extra cost. /// /// This function works on both vectors and scalars. -static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { +static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, + InstCombiner &IC, Instruction *CxtI) { BitsToClear = 0; if (isa<Constant>(V)) return true; @@ -680,8 +682,8 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { case Instruction::Add: case Instruction::Sub: case Instruction::Mul: - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear) || - !CanEvaluateZExtd(I->getOperand(1), Ty, Tmp)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI) || + !CanEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI)) return false; // These can all be promoted if neither operand has 'bits to clear'. if (BitsToClear == 0 && Tmp == 0) @@ -695,8 +697,9 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We use MaskedValueIsZero here for generality, but the case we care // about the most is constant RHS. unsigned VSize = V->getType()->getScalarSizeInBits(); - if (MaskedValueIsZero(I->getOperand(1), - APInt::getHighBitsSet(VSize, BitsToClear))) + if (IC.MaskedValueIsZero(I->getOperand(1), + APInt::getHighBitsSet(VSize, BitsToClear), + 0, CxtI)) return true; } @@ -707,7 +710,7 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We can promote shl(x, cst) if we can promote x. Since shl overwrites the // upper bits we can reduce BitsToClear by the shift amount. if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; uint64_t ShiftAmt = Amt->getZExtValue(); BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0; @@ -718,7 +721,7 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We can promote lshr(x, cst) if we can promote x. This requires the // ultimate 'and' to clear out the high zero bits we're clearing out though. if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; BitsToClear += Amt->getZExtValue(); if (BitsToClear > V->getType()->getScalarSizeInBits()) @@ -728,8 +731,8 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // Cannot promote variable LSHR. return false; case Instruction::Select: - if (!CanEvaluateZExtd(I->getOperand(1), Ty, Tmp) || - !CanEvaluateZExtd(I->getOperand(2), Ty, BitsToClear) || + if (!CanEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) || + !CanEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || // TODO: If important, we could handle the case when the BitsToClear are // known zero in the disagreeing side. Tmp != BitsToClear) @@ -741,10 +744,10 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // get into trouble with cyclic PHIs here because we only consider // instructions with a single use. PHINode *PN = cast<PHINode>(I); - if (!CanEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear, IC, CxtI)) return false; for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp) || + if (!CanEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) || // TODO: If important, we could handle the case when the BitsToClear // are known zero in the disagreeing input. Tmp != BitsToClear) @@ -781,7 +784,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // strange. unsigned BitsToClear; if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && - CanEvaluateZExtd(Src, DestTy, BitsToClear)) { + CanEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { assert(BitsToClear < SrcTy->getScalarSizeInBits() && "Unreasonable BitsToClear"); @@ -796,8 +799,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // If the high bits are already filled with zeros, just replace this // cast with the result. - if (MaskedValueIsZero(Res, APInt::getHighBitsSet(DestBitSize, - DestBitSize-SrcBitsKept))) + if (MaskedValueIsZero(Res, + APInt::getHighBitsSet(DestBitSize, + DestBitSize-SrcBitsKept), + 0, &CI)) return ReplaceInstUsesWith(CI, Res); // We need to emit an AND to clear the high bits. @@ -921,7 +926,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ unsigned BitWidth = Op1C->getType()->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(Op0, KnownZero, KnownOne); + computeKnownBits(Op0, KnownZero, KnownOne, 0, &CI); APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { @@ -1072,7 +1077,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // If the high bits are already filled with sign bit, just replace this // cast with the result. - if (ComputeNumSignBits(Res) > DestBitSize - SrcBitSize) + if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize) return ReplaceInstUsesWith(CI, Res); // We need to emit a shl + ashr to do the sign extend. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 7a2d175b58e..e1c72430d7e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1129,7 +1129,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { @@ -2033,8 +2033,8 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A) < NeededSignBits || - IC.ComputeNumSignBits(B) < NeededSignBits) + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) return nullptr; // In order to replace the original add with a narrower @@ -2442,7 +2442,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL)) + if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val @@ -3222,7 +3222,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // and (A & ~B) != 0 --> (A & B) == 0 // if A is a power of 2. if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A) && I.isEquality()) + match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A, false, + 0, AT, &I, DT) && + I.isEquality()) return new ICmpInst(I.getInversePredicate(), Builder->CreateAnd(A, B), Op1); @@ -3612,7 +3614,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL)) + if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 9e46041886d..4aafc2e2cad 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -268,7 +268,8 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { SmallVector<Instruction *, 4> ToDelete; if (MemTransferInst *Copy = isOnlyCopiedFromConstantGlobal(&AI, ToDelete)) { unsigned SourceAlign = getOrEnforceKnownAlignment(Copy->getSource(), - AI.getAlignment(), DL); + AI.getAlignment(), + DL, AT, &AI, DT); if (AI.getAlignment() <= SourceAlign) { DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); @@ -363,7 +364,8 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // Attempt to improve the alignment. if (DL) { unsigned KnownAlign = - getOrEnforceKnownAlignment(Op, DL->getPrefTypeAlignment(LI.getType()),DL); + getOrEnforceKnownAlignment(Op, DL->getPrefTypeAlignment(LI.getType()), + DL, AT, &LI, DT); unsigned LoadAlign = LI.getAlignment(); unsigned EffectiveLoadAlign = LoadAlign != 0 ? LoadAlign : DL->getABITypeAlignment(LI.getType()); @@ -601,7 +603,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (DL) { unsigned KnownAlign = getOrEnforceKnownAlignment(Ptr, DL->getPrefTypeAlignment(Val->getType()), - DL); + DL, AT, &SI, DT); unsigned StoreAlign = SI.getAlignment(); unsigned EffectiveStoreAlign = StoreAlign != 0 ? StoreAlign : DL->getABITypeAlignment(Val->getType()); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 3f86ddfd104..d2d94e82197 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -25,7 +25,8 @@ using namespace PatternMatch; /// simplifyValueKnownNonZero - The specific integer value is used in a context /// where it is known to be non-zero. If this allows us to simplify the /// computation, do so and return the new operand, otherwise return null. -static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { +static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, + Instruction *CxtI) { // If V has multiple uses, then we would have to do more analysis to determine // if this is safe. For example, the use could be in dynamically unreached // code. @@ -39,7 +40,8 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))), m_Value(B))) && // The "1" can be any value known to be a power of 2. - isKnownToBeAPowerOfTwo(PowerOf2)) { + isKnownToBeAPowerOfTwo(PowerOf2, false, 0, IC.getAssumptionTracker(), + CxtI, IC.getDominatorTree())) { A = IC.Builder->CreateSub(A, B); return IC.Builder->CreateShl(PowerOf2, A); } @@ -47,10 +49,13 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. if (BinaryOperator *I = dyn_cast<BinaryOperator>(V)) - if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0))) { + if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0), false, + 0, IC.getAssumptionTracker(), + CxtI, + IC.getDominatorTree())) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. - if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) { + if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { I->setOperand(0, V2); MadeChange = true; } @@ -138,7 +143,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL)) + if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) @@ -292,9 +297,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2)) + if (MaskedValueIsZero(Op0, Negative2, 0, &I)) BoolCast = Op0, OtherOp = Op1; - else if (MaskedValueIsZero(Op1, Negative2)) + else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) BoolCast = Op1, OtherOp = Op0; if (BoolCast) { @@ -455,7 +460,8 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (isa<Constant>(Op0)) std::swap(Op0, Op1); - if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL)) + if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, + DT, AT)) return ReplaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -699,7 +705,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) { + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, &I)) { I.setOperand(1, V); return &I; } @@ -952,7 +958,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL)) + if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1014,7 +1020,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL)) + if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1051,8 +1057,8 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { // unsigned inputs), turn this into a udiv. if (I.getType()->isIntegerTy()) { APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op0, Mask)) { - if (MaskedValueIsZero(Op1, Mask)) { + if (MaskedValueIsZero(Op0, Mask, 0, &I)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I)) { // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); } @@ -1107,7 +1113,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFDivInst(Op0, Op1, DL)) + if (Value *V = SimplifyFDivInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1238,7 +1244,7 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) { + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, &I)) { I.setOperand(1, V); return &I; } @@ -1272,7 +1278,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL)) + if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1285,7 +1291,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true)) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true, 0, AT, &I, DT)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); Value *Add = Builder->CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); @@ -1307,7 +1313,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL)) + if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Handle the integer rem common cases @@ -1328,7 +1334,8 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { // unsigned inputs), turn this into a urem. if (I.getType()->isIntegerTy()) { APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I) && + MaskedValueIsZero(Op0, Mask, 0, &I)) { // X srem Y -> X urem Y, iff X and Y don't have sign bit set return BinaryOperator::CreateURem(Op0, Op1, I.getName()); } @@ -1381,7 +1388,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFRemInst(Op0, Op1, DL)) + if (Value *V = SimplifyFRemInst(Op0, Op1, DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 52261caaccd..6983a90c2c1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -788,7 +788,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // PHINode simplification // Instruction *InstCombiner::visitPHINode(PHINode &PN) { - if (Value *V = SimplifyInstruction(&PN, DL, TLI)) + if (Value *V = SimplifyInstruction(&PN, DL, TLI, DT, AT)) return ReplaceInstUsesWith(PN, V); // If all PHI operands are the same operation, pull them through the PHI, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 06c9e290c6e..ae0b5769c64 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -313,7 +313,9 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// replaced with RepOp. static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const DataLayout *TD, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + DominatorTree *DT, + AssumptionTracker *AT) { // Trivial replacement. if (V == Op) return RepOp; @@ -334,10 +336,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (CmpInst *C = dyn_cast<CmpInst>(I)) { if (C->getOperand(0) == Op) return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD, - TLI); + TLI, DT, AT); if (C->getOperand(1) == Op) return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD, - TLI); + TLI, DT, AT); } // TODO: We could hand off more cases to instsimplify here. @@ -605,18 +607,26 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI) == TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, + DT, AT) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, + DT, AT) == TrueVal) return ReplaceInstUsesWith(SI, FalseVal); - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI) == FalseVal) + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, + DT, AT) == FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, + DT, AT) == FalseVal) return ReplaceInstUsesWith(SI, FalseVal); } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI) == FalseVal) + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, + DT, AT) == FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, + DT, AT) == FalseVal) return ReplaceInstUsesWith(SI, TrueVal); - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI) == TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, + DT, AT) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, + DT, AT) == TrueVal) return ReplaceInstUsesWith(SI, TrueVal); } @@ -825,7 +835,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); - if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL)) + if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, + DT, AT)) return ReplaceInstUsesWith(SI, V); if (SI.getType()->isIntegerTy(1)) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 3d0cc05f30c..afa907a7bc2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -68,7 +68,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { /// this succeeds, the GetShiftedValue function will be called to produce the /// value. static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, - InstCombiner &IC) { + InstCombiner &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) return true; @@ -111,8 +111,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Or: case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. - return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC) && - CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC); + return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC, I) && + CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC, I); case Instruction::Shl: { // We can often fold the shift into shifts-by-a-constant. @@ -131,8 +131,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // profitable unless we know the and'd out bits are already zero. if (CI->getZExtValue() > NumBits) { unsigned LowBits = TypeWidth - CI->getZExtValue(); - if (MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits)) + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, + 0, CxtI)) return true; } @@ -155,8 +156,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // profitable unless we know the and'd out bits are already zero. if (CI->getValue().ult(TypeWidth) && CI->getZExtValue() > NumBits) { unsigned LowBits = CI->getZExtValue() - NumBits; - if (MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits)) + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, + 0, CxtI)) return true; } @@ -164,8 +166,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, IC) && - CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC); + return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, + IC, SI) && + CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC, SI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -173,7 +176,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateShifted(PN->getIncomingValue(i), NumBits, isLeftShift,IC)) + if (!CanEvaluateShifted(PN->getIncomingValue(i), NumBits, isLeftShift, + IC, PN)) return false; return true; } @@ -329,7 +333,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && - CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this)) { + CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) { DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); @@ -691,7 +695,7 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - DL)) + DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) @@ -703,14 +707,15 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && MaskedValueIsZero(I.getOperand(0), - APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt))) { + APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)) { I.setHasNoUnsignedWrap(); return &I; } // If the shifted out value is all signbits, this is a NSW shift. if (!I.hasNoSignedWrap() && - ComputeNumSignBits(I.getOperand(0)) > ShAmt) { + ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) { I.setHasNoSignedWrap(); return &I; } @@ -731,7 +736,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), - I.isExact(), DL)) + I.isExact(), DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) @@ -760,7 +765,8 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt))){ + MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)){ I.setIsExact(); return &I; } @@ -774,7 +780,7 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), - I.isExact(), DL)) + I.isExact(), DL, TLI, DT, AT)) return ReplaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) @@ -804,7 +810,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt))){ + MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt), + 0, &I)){ I.setIsExact(); return &I; } @@ -812,7 +819,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // See if we can turn a signed shr into an unsigned shr. if (MaskedValueIsZero(Op0, - APInt::getSignBit(I.getType()->getScalarSizeInBits()))) + APInt::getSignBit(I.getType()->getScalarSizeInBits()), + 0, &I)) return BinaryOperator::CreateLShr(Op0, Op1); return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index f0c96bdf768..249544a061f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -71,7 +71,7 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, - KnownZero, KnownOne, 0); + KnownZero, KnownOne, 0, &Inst); if (!V) return false; if (V == &Inst) return true; ReplaceInstUsesWith(Inst, V); @@ -85,7 +85,8 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth) { Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, - KnownZero, KnownOne, Depth); + KnownZero, KnownOne, Depth, + dyn_cast<Instruction>(U.getUser())); if (!NewVal) return false; U = NewVal; return true; @@ -115,7 +116,8 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, /// in the context where the specified bits are demanded, but not for all users. Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, - unsigned Depth) { + unsigned Depth, + Instruction *CxtI) { assert(V != nullptr && "Null pointer of Value???"); assert(Depth <= 6 && "Limit Search Depth"); uint32_t BitWidth = DemandedMask.getBitWidth(); @@ -158,7 +160,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Instruction *I = dyn_cast<Instruction>(V); if (!I) { - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); return nullptr; // Only analyze instructions. } @@ -172,8 +174,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // this instruction has a simpler value in that context. if (I->getOpcode() == Instruction::And) { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and' in this @@ -194,8 +198,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // only bits from X or Y are demanded. // If either the LHS or the RHS are One, the result is One. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known zero on one side, return the // other. These bits cannot contribute to the result of the 'or' in this @@ -219,8 +225,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // We can simplify (X^Y) -> X or Y in the user's context if we know that // only bits from X or Y are demanded. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known zero on one side, return the // other. @@ -231,7 +239,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } // Compute the KnownZero/KnownOne bits to simplify things downstream. - computeKnownBits(I, KnownZero, KnownOne, Depth); + computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); return nullptr; } @@ -244,7 +252,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, switch (I->getOpcode()) { default: - computeKnownBits(I, KnownZero, KnownOne, Depth); + computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); break; case Instruction::And: // If either the LHS or the RHS are Zero, the result is zero. @@ -595,7 +603,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Otherwise just hand the sub off to computeKnownBits to fill in // the known zeros and ones. - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. @@ -766,7 +774,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // remainder is zero. if (DemandedMask.isNegative() && KnownZero.isNonNegative()) { APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) KnownZero.setBit(KnownZero.getBitWidth() - 1); @@ -828,7 +837,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return nullptr; } } - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); break; } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 3ae9f0ddce8..e137f32db8b 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -46,6 +46,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" @@ -1317,7 +1318,7 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - if (Value *V = SimplifyGEPInst(Ops, DL)) + if (Value *V = SimplifyGEPInst(Ops, DL, TLI, DT, AT)) return ReplaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -2914,6 +2915,11 @@ bool InstCombiner::runOnFunction(Function &F) { DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; TLI = &getAnalysis<TargetLibraryInfo>(); + + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DT = DTWP ? &DTWP->getDomTree() : nullptr; + // Minimizing size? MinimizeSize = F.getAttributes().hasAttribute(AttributeSet::FunctionIndex, Attribute::MinSize); diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 082946229b3..075c0351534 100644 --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -126,6 +126,7 @@ bool CorrelatedValuePropagation::processPHI(PHINode *P) { Changed = true; } + // FIXME: Provide DL, TLI, DT, AT to SimplifyInstruction. if (Value *V = SimplifyInstruction(P)) { P->replaceAllUsesWith(V); P->eraseFromParent(); diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 735f5c194cb..21ef34772d7 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" @@ -266,6 +267,7 @@ public: const DataLayout *DL; const TargetLibraryInfo *TLI; DominatorTree *DT; + AssumptionTracker *AT; typedef RecyclingAllocator<BumpPtrAllocator, ScopedHashTableVal<SimpleValue, Value*> > AllocatorTy; typedef ScopedHashTable<SimpleValue, Value*, DenseMapInfo<SimpleValue>, @@ -378,6 +380,7 @@ private: // This transformation requires dominator postdominator info void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfo>(); AU.setPreservesCFG(); @@ -393,6 +396,7 @@ FunctionPass *llvm::createEarlyCSEPass() { } INITIALIZE_PASS_BEGIN(EarlyCSE, "early-cse", "Early CSE", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(EarlyCSE, "early-cse", "Early CSE", false, false) @@ -433,7 +437,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If the instruction can be simplified (e.g. X+0 = X) then replace it with // its simpler value. - if (Value *V = SimplifyInstruction(Inst, DL, TLI, DT)) { + if (Value *V = SimplifyInstruction(Inst, DL, TLI, DT, AT)) { DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V << '\n'); Inst->replaceAllUsesWith(V); Inst->eraseFromParent(); @@ -562,6 +566,7 @@ bool EarlyCSE::runOnFunction(Function &F) { DL = DLP ? &DLP->getDataLayout() : nullptr; TLI = &getAnalysis<TargetLibraryInfo>(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AT = &getAnalysis<AssumptionTracker>(); // Tables that the pass uses when walking the domtree. ScopedHTType AVTable; diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index dcbee7635de..7dba4e2d3ab 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -591,6 +592,7 @@ namespace { DominatorTree *DT; const DataLayout *DL; const TargetLibraryInfo *TLI; + AssumptionTracker *AT; SetVector<BasicBlock *> DeadBlocks; ValueTable VN; @@ -680,6 +682,7 @@ namespace { // This transformation requires dominator postdominator info void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfo>(); if (!NoLoads) @@ -728,6 +731,7 @@ FunctionPass *llvm::createGVNPass(bool NoLoads) { } INITIALIZE_PASS_BEGIN(GVN, "gvn", "Global Value Numbering", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) @@ -1617,7 +1621,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // If all preds have a single successor, then we know it is safe to insert // the load on the pred (?!?), so we can insert code to materialize the // pointer if it is not available. - PHITransAddr Address(LI->getPointerOperand(), DL); + PHITransAddr Address(LI->getPointerOperand(), DL, AT); Value *LoadPtr = nullptr; LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred, *DT, NewInsts); @@ -2210,7 +2214,7 @@ bool GVN::processInstruction(Instruction *I) { // to value numbering it. Value numbering often exposes redundancies, for // example if it determines that %y is equal to %x then the instruction // "%z = and i32 %x, %y" becomes "%z = and i32 %x, %x" which we now simplify. - if (Value *V = SimplifyInstruction(I, DL, TLI, DT)) { + if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AT)) { I->replaceAllUsesWith(V); if (MD && V->getType()->getScalarType()->isPointerTy()) MD->invalidateCachedPointerInfo(V); @@ -2330,6 +2334,7 @@ bool GVN::runOnFunction(Function& F) { DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; + AT = &getAnalysis<AssumptionTracker>(); TLI = &getAnalysis<TargetLibraryInfo>(); VN.setAliasAnalysis(&getAnalysis<AliasAnalysis>()); VN.setMemDep(MD); diff --git a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index ab1a9393c52..7c29b8cc07e 100644 --- a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -41,6 +42,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<LoopInfo>(); AU.addRequiredID(LoopSimplifyID); AU.addPreservedID(LoopSimplifyID); @@ -54,6 +56,7 @@ namespace { char LoopInstSimplify::ID = 0; INITIALIZE_PASS_BEGIN(LoopInstSimplify, "loop-instsimplify", "Simplify instructions in loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfo) @@ -76,6 +79,7 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfo>(); + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); SmallVector<BasicBlock*, 8> ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); @@ -116,7 +120,7 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { // Don't bother simplifying unused instructions. if (!I->use_empty()) { - Value *V = SimplifyInstruction(I, DL, TLI, DT); + Value *V = SimplifyInstruction(I, DL, TLI, DT, AT); if (V && LI->replacementPreservesLCSSAForm(I, V)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) diff --git a/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/llvm/lib/Transforms/Scalar/LoopRotation.cpp index 166720435e6..ddb53926ff7 100644 --- a/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -414,6 +414,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // With the operands remapped, see if the instruction constant folds or is // otherwise simplifyable. This commonly occurs because the entry from PHI // nodes allows icmps and other instructions to fold. + // FIXME: Provide DL, TLI, DT, AT to SimplifyInstruction. Value *V = SimplifyInstruction(C); if (V && LI->replacementPreservesLCSSAForm(C, V)) { // If so, then delete the temporary instruction and stick the folded value diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index c750ece41b4..9709dfcc1f6 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" @@ -329,6 +330,7 @@ namespace { // This transformation requires dominator postdominator info void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<MemoryDependenceAnalysis>(); AU.addRequired<AliasAnalysis>(); @@ -361,6 +363,7 @@ FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOpt(); } INITIALIZE_PASS_BEGIN(MemCpyOpt, "memcpyopt", "MemCpy Optimization", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) @@ -977,8 +980,11 @@ bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) { // If it is greater than the memcpy, then we check to see if we can force the // source of the memcpy to the alignment we need. If we fail, we bail out. + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); if (MDep->getAlignment() < ByValAlign && - getOrEnforceKnownAlignment(MDep->getSource(),ByValAlign, DL) < ByValAlign) + getOrEnforceKnownAlignment(MDep->getSource(),ByValAlign, + DL, AT, CS.getInstruction(), &DT) < ByValAlign) return false; // Verify that the copied-from memory doesn't change in between the memcpy and diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 32a61a0b9f2..ada44cd5d28 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -28,6 +28,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/PtrUseVisitor.h" #include "llvm/Analysis/ValueTracking.h" @@ -924,6 +925,7 @@ class SROA : public FunctionPass { LLVMContext *C; const DataLayout *DL; DominatorTree *DT; + AssumptionTracker *AT; /// \brief Worklist of alloca instructions to simplify. /// @@ -1003,6 +1005,7 @@ FunctionPass *llvm::createSROAPass(bool RequiresDomTree) { INITIALIZE_PASS_BEGIN(SROA, "sroa", "Scalar Replacement Of Aggregates", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(SROA, "sroa", "Scalar Replacement Of Aggregates", false, false) @@ -3551,7 +3554,7 @@ bool SROA::promoteAllocas(Function &F) { if (DT && !ForceSSAUpdater) { DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); - PromoteMemToReg(PromotableAllocas, *DT); + PromoteMemToReg(PromotableAllocas, *DT, nullptr, AT); PromotableAllocas.clear(); return true; } @@ -3633,6 +3636,7 @@ bool SROA::runOnFunction(Function &F) { DominatorTreeWrapperPass *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); DT = DTWP ? &DTWP->getDomTree() : nullptr; + AT = &getAnalysis<AssumptionTracker>(); BasicBlock &EntryBB = F.getEntryBlock(); for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end()); @@ -3676,6 +3680,7 @@ bool SROA::runOnFunction(Function &F) { } void SROA::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<AssumptionTracker>(); if (RequiresDomTree) AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); diff --git a/llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp b/llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp index 579ce13e0b9..eb8d2a6f73f 100644 --- a/llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp +++ b/llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CallSite.h" @@ -197,6 +198,7 @@ namespace { // getAnalysisUsage - This pass does not require any passes, but we know it // will not alter the CFG, so say so. void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); } @@ -214,6 +216,7 @@ namespace { // getAnalysisUsage - This pass does not require any passes, but we know it // will not alter the CFG, so say so. void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); AU.setPreservesCFG(); } }; @@ -225,12 +228,14 @@ char SROA_SSAUp::ID = 0; INITIALIZE_PASS_BEGIN(SROA_DT, "scalarrepl", "Scalar Replacement of Aggregates (DT)", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(SROA_DT, "scalarrepl", "Scalar Replacement of Aggregates (DT)", false, false) INITIALIZE_PASS_BEGIN(SROA_SSAUp, "scalarrepl-ssa", "Scalar Replacement of Aggregates (SSAUp)", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_END(SROA_SSAUp, "scalarrepl-ssa", "Scalar Replacement of Aggregates (SSAUp)", false, false) @@ -1412,6 +1417,7 @@ bool SROA::performPromotion(Function &F) { DominatorTree *DT = nullptr; if (HasDomTree) DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function DIBuilder DIB(*F.getParent()); @@ -1430,7 +1436,7 @@ bool SROA::performPromotion(Function &F) { if (Allocas.empty()) break; if (HasDomTree) - PromoteMemToReg(Allocas, *DT); + PromoteMemToReg(Allocas, *DT, nullptr, AT); else { SSAUpdater SSA; for (unsigned i = 0, e = Allocas.size(); i != e; ++i) { diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 5d5606ba47b..66adde0ea41 100644 --- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CFG.h" @@ -50,6 +51,7 @@ struct CFGSimplifyPass : public FunctionPass { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetTransformInfo>(); } }; @@ -59,6 +61,7 @@ char CFGSimplifyPass::ID = 0; INITIALIZE_PASS_BEGIN(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, false) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, false) @@ -146,7 +149,8 @@ static bool mergeEmptyReturnBlocks(Function &F) { /// iterativelySimplifyCFG - Call SimplifyCFG on all the blocks in the function, /// iterating until no more changes are made. static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, - const DataLayout *DL) { + const DataLayout *DL, + AssumptionTracker *AT) { bool Changed = false; bool LocalChange = true; while (LocalChange) { @@ -155,7 +159,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, // Loop over all of the basic blocks and remove them if they are unneeded... // for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) { - if (SimplifyCFG(BBIt++, TTI, DL)) { + if (SimplifyCFG(BBIt++, TTI, DL, AT)) { LocalChange = true; ++NumSimpl; } @@ -172,12 +176,13 @@ bool CFGSimplifyPass::runOnFunction(Function &F) { if (skipOptnoneFunction(F)) return false; + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfo>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; bool EverChanged = removeUnreachableBlocks(F); EverChanged |= mergeEmptyReturnBlocks(F); - EverChanged |= iterativelySimplifyCFG(F, TTI, DL); + EverChanged |= iterativelySimplifyCFG(F, TTI, DL, AT); // If neither pass changed anything, we're done. if (!EverChanged) return false; @@ -191,7 +196,7 @@ bool CFGSimplifyPass::runOnFunction(Function &F) { return true; do { - EverChanged = iterativelySimplifyCFG(F, TTI, DL); + EverChanged = iterativelySimplifyCFG(F, TTI, DL, AT); EverChanged |= removeUnreachableBlocks(F); } while (EverChanged); diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 8aaeacf10c2..e6f0d990ad1 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -732,7 +732,7 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, // If the pointer is already known to be sufficiently aligned, or if we can // round it up to a larger alignment, then we don't need a temporary. if (getOrEnforceKnownAlignment(Arg, ByValAlignment, - IFI.DL) >= ByValAlignment) + IFI.DL, IFI.AT, TheCall) >= ByValAlignment) return Arg; // Otherwise, we have to make a memcpy to get a safe alignment. This is bad @@ -1358,7 +1358,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // the entries are the same or undef). If so, remove the PHI so it doesn't // block other optimizations. if (PHI) { - if (Value *V = SimplifyInstruction(PHI, IFI.DL)) { + if (Value *V = SimplifyInstruction(PHI, IFI.DL, nullptr, nullptr, IFI.AT)) { PHI->replaceAllUsesWith(V); PHI->eraseFromParent(); } diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index aa4d6a28363..e4ce699ecb4 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -939,13 +939,16 @@ static unsigned enforceKnownAlignment(Value *V, unsigned Align, /// and it is more than the alignment of the ultimate object, see if we can /// increase the alignment of the ultimate object, making this check succeed. unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign, - const DataLayout *DL) { + const DataLayout *DL, + AssumptionTracker *AT, + const Instruction *CxtI, + const DominatorTree *DT) { assert(V->getType()->isPointerTy() && "getOrEnforceKnownAlignment expects a pointer!"); unsigned BitWidth = DL ? DL->getPointerTypeSizeInBits(V->getType()) : 64; APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL); + computeKnownBits(V, KnownZero, KnownOne, DL, 0, AT, CxtI, DT); unsigned TrailZ = KnownZero.countTrailingOnes(); // Avoid trouble with ridiculously large TrailZ values, such as diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 2decaa18799..c5a4adf1168 100644 --- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -44,6 +44,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -208,11 +209,12 @@ static void addBlockAndPredsToSet(BasicBlock *InputBB, BasicBlock *StopBlock, /// \brief The first part of loop-nestification is to find a PHI node that tells /// us how to partition the loops. static PHINode *findPHIToPartitionLoops(Loop *L, AliasAnalysis *AA, - DominatorTree *DT) { + DominatorTree *DT, + AssumptionTracker *AT) { for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ) { PHINode *PN = cast<PHINode>(I); ++I; - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT, AT)) { // This is a degenerate PHI already, don't modify it! PN->replaceAllUsesWith(V); if (AA) AA->deleteValue(PN); @@ -251,7 +253,8 @@ static PHINode *findPHIToPartitionLoops(Loop *L, AliasAnalysis *AA, /// static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, AliasAnalysis *AA, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, Pass *PP) { + LoopInfo *LI, ScalarEvolution *SE, Pass *PP, + AssumptionTracker *AT) { // Don't try to separate loops without a preheader. if (!Preheader) return nullptr; @@ -260,7 +263,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, assert(!L->getHeader()->isLandingPad() && "Can't insert backedge to landing pad"); - PHINode *PN = findPHIToPartitionLoops(L, AA, DT); + PHINode *PN = findPHIToPartitionLoops(L, AA, DT, AT); if (!PN) return nullptr; // No known way to partition. // Pull out all predecessors that have varying values in the loop. This @@ -474,7 +477,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, static bool simplifyOneLoop(Loop *L, SmallVectorImpl<Loop *> &Worklist, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, Pass *PP, - const DataLayout *DL) { + const DataLayout *DL, AssumptionTracker *AT) { bool Changed = false; ReprocessLoop: @@ -580,7 +583,8 @@ ReprocessLoop: // this for loops with a giant number of backedges, just factor them into a // common backedge instead. if (L->getNumBackEdges() < 8) { - if (Loop *OuterL = separateNestedLoop(L, Preheader, AA, DT, LI, SE, PP)) { + if (Loop *OuterL = separateNestedLoop(L, Preheader, AA, DT, LI, SE, + PP, AT)) { ++NumNested; // Enqueue the outer loop as it should be processed next in our // depth-first nest walk. @@ -610,7 +614,7 @@ ReprocessLoop: PHINode *PN; for (BasicBlock::iterator I = L->getHeader()->begin(); (PN = dyn_cast<PHINode>(I++)); ) - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT, AT)) { if (AA) AA->deleteValue(PN); if (SE) SE->forgetValue(PN); PN->replaceAllUsesWith(V); @@ -710,7 +714,7 @@ ReprocessLoop: bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Pass *PP, AliasAnalysis *AA, ScalarEvolution *SE, - const DataLayout *DL) { + const DataLayout *DL, AssumptionTracker *AT) { bool Changed = false; // Worklist maintains our depth-first queue of loops in this nest to process. @@ -728,7 +732,7 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Pass *PP, while (!Worklist.empty()) Changed |= simplifyOneLoop(Worklist.pop_back_val(), Worklist, AA, DT, LI, - SE, PP, DL); + SE, PP, DL, AT); return Changed; } @@ -747,10 +751,13 @@ namespace { LoopInfo *LI; ScalarEvolution *SE; const DataLayout *DL; + AssumptionTracker *AT; bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); + // We need loop information to identify the loops... AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); @@ -772,6 +779,7 @@ namespace { char LoopSimplify::ID = 0; INITIALIZE_PASS_BEGIN(LoopSimplify, "loop-simplify", "Canonicalize natural loops", true, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", @@ -792,10 +800,11 @@ bool LoopSimplify::runOnFunction(Function &F) { SE = getAnalysisIfAvailable<ScalarEvolution>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; + AT = &getAnalysis<AssumptionTracker>(); // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= simplifyLoop(*I, DT, LI, this, AA, SE, DL); + Changed |= simplifyLoop(*I, DT, LI, this, AA, SE, DL, AT); return Changed; } diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp index 576298892f4..4326dc11616 100644 --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -509,7 +509,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, DataLayoutPass *DLP = PP->getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); - simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE, DL); + simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE, DL, AT); // LCSSA must be performed on the outermost affected loop. The unrolled // loop's last loop latch is guaranteed to be in the outermost loop after diff --git a/llvm/lib/Transforms/Utils/Mem2Reg.cpp b/llvm/lib/Transforms/Utils/Mem2Reg.cpp index 189caa7d145..477ee7af78f 100644 --- a/llvm/lib/Transforms/Utils/Mem2Reg.cpp +++ b/llvm/lib/Transforms/Utils/Mem2Reg.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -38,6 +39,7 @@ namespace { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); // This is a cluster of orthogonal Transforms @@ -51,6 +53,7 @@ namespace { char PromotePass::ID = 0; INITIALIZE_PASS_BEGIN(PromotePass, "mem2reg", "Promote Memory to Register", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(PromotePass, "mem2reg", "Promote Memory to Register", false, false) @@ -63,6 +66,7 @@ bool PromotePass::runOnFunction(Function &F) { bool Changed = false; DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); while (1) { Allocas.clear(); @@ -76,7 +80,7 @@ bool PromotePass::runOnFunction(Function &F) { if (Allocas.empty()) break; - PromoteMemToReg(Allocas, DT); + PromoteMemToReg(Allocas, DT, nullptr, AT); NumPromoted += Allocas.size(); Changed = true; } diff --git a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index ec48ab1144f..00a4c314a0e 100644 --- a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -238,6 +238,9 @@ struct PromoteMem2Reg { /// An AliasSetTracker object to update. If null, don't update it. AliasSetTracker *AST; + /// A cache of @llvm.assume intrinsics used by SimplifyInstruction. + AssumptionTracker *AT; + /// Reverse mapping of Allocas. DenseMap<AllocaInst *, unsigned> AllocaLookup; @@ -279,9 +282,9 @@ struct PromoteMem2Reg { public: PromoteMem2Reg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST) + AliasSetTracker *AST, AssumptionTracker *AT) : Allocas(Allocas.begin(), Allocas.end()), DT(DT), - DIB(*DT.getRoot()->getParent()->getParent()), AST(AST) {} + DIB(*DT.getRoot()->getParent()->getParent()), AST(AST), AT(AT) {} void run(); @@ -685,7 +688,7 @@ void PromoteMem2Reg::run() { PHINode *PN = I->second; // If this PHI node merges one value and/or undefs, get the value. - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, &DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, &DT, AT)) { if (AST && PN->getType()->isPointerTy()) AST->deleteValue(PN); PN->replaceAllUsesWith(V); @@ -1065,10 +1068,10 @@ NextIteration: } void llvm::PromoteMemToReg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST) { + AliasSetTracker *AST, AssumptionTracker *AT) { // If there is nothing to do, bail out... if (Allocas.empty()) return; - PromoteMem2Reg(Allocas, DT, AST).run(); + PromoteMem2Reg(Allocas, DT, AST, AT).run(); } diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index cb56abe294e..dd4dff57c3b 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -94,6 +94,7 @@ namespace { class SimplifyCFGOpt { const TargetTransformInfo &TTI; const DataLayout *const DL; + AssumptionTracker *AT; Value *isValueEqualityComparison(TerminatorInst *TI); BasicBlock *GetValueEqualityComparisonCases(TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); @@ -112,8 +113,9 @@ class SimplifyCFGOpt { bool SimplifyCondBranch(BranchInst *BI, IRBuilder <>&Builder); public: - SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout *DL) - : TTI(TTI), DL(DL) {} + SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout *DL, + AssumptionTracker *AT) + : TTI(TTI), DL(DL), AT(AT) {} bool run(BasicBlock *BB); }; } @@ -2657,7 +2659,7 @@ static bool SimplifyIndirectBrOnSelect(IndirectBrInst *IBI, SelectInst *SI) { /// the PHI, merging the third icmp into the switch. static bool TryToSimplifyUncondBranchWithICmpInIt( ICmpInst *ICI, IRBuilder<> &Builder, const TargetTransformInfo &TTI, - const DataLayout *DL) { + const DataLayout *DL, AssumptionTracker *AT) { BasicBlock *BB = ICI->getParent(); // If the block has any PHIs in it or the icmp has multiple uses, it is too @@ -2690,7 +2692,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->eraseFromParent(); } // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } // Ok, the block is reachable from the default dest. If the constant we're @@ -2706,7 +2708,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->replaceAllUsesWith(V); ICI->eraseFromParent(); // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } // The use of the icmp has to be in the 'end' block, by the only PHI node in @@ -3216,11 +3218,12 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { /// EliminateDeadSwitchCases - Compute masked bits for the condition of a switch /// and use it to remove dead cases. -static bool EliminateDeadSwitchCases(SwitchInst *SI) { +static bool EliminateDeadSwitchCases(SwitchInst *SI, const DataLayout *DL, + AssumptionTracker *AT) { Value *Cond = SI->getCondition(); unsigned Bits = Cond->getType()->getIntegerBitWidth(); APInt KnownZero(Bits, 0), KnownOne(Bits, 0); - computeKnownBits(Cond, KnownZero, KnownOne); + computeKnownBits(Cond, KnownZero, KnownOne, DL, 0, AT, SI); // Gather dead cases. SmallVector<ConstantInt*, 8> DeadCases; @@ -3940,12 +3943,12 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // see if that predecessor totally determines the outcome of this switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(SI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; Value *Cond = SI->getCondition(); if (SelectInst *Select = dyn_cast<SelectInst>(Cond)) if (SimplifySwitchOnSelect(SI, Select)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; // If the block only contains the switch, see if we can fold the block // away into any preds. @@ -3955,22 +3958,22 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { ++BBI; if (SI == &*BBI) if (FoldValueComparisonIntoPredecessors(SI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } // Try to transform the switch into an icmp and a branch. if (TurnSwitchRangeIntoICmp(SI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; // Remove unreachable cases. - if (EliminateDeadSwitchCases(SI)) - return SimplifyCFG(BB, TTI, DL) | true; + if (EliminateDeadSwitchCases(SI, DL, AT)) + return SimplifyCFG(BB, TTI, DL, AT) | true; if (ForwardSwitchConditionToPHI(SI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; if (SwitchToLookupTable(SI, Builder, TTI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; return false; } @@ -4007,7 +4010,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (SimplifyIndirectBrOnSelect(IBI, SI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } return Changed; } @@ -4031,7 +4034,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ for (++I; isa<DbgInfoIntrinsic>(I); ++I) ; if (I->isTerminator() && - TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, TTI, DL)) + TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, TTI, DL, AT)) return true; } @@ -4040,7 +4043,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ // predecessor and use logical operations to update the incoming value // for PHI nodes in common successor. if (FoldBranchToCommonDest(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; return false; } @@ -4055,7 +4058,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; // This block must be empty, except for the setcond inst, if it exists. // Ignore dbg intrinsics. @@ -4065,14 +4068,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { ++I; if (&*I == BI) { if (FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } else if (&*I == cast<Instruction>(BI->getCondition())){ ++I; // Ignore dbg intrinsics. while (isa<DbgInfoIntrinsic>(I)) ++I; if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } } @@ -4084,7 +4087,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. if (FoldBranchToCommonDest(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; // We have a conditional branch to two blocks that are only reachable // from BI. We know that the condbr dominates the two blocks, so see if @@ -4093,7 +4096,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { if (HoistThenElseCodeToIf(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } else { // If Successor #1 has multiple preds, we may be able to conditionally // execute Successor #0 if it branches to Successor #1. @@ -4101,7 +4104,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ0TI->getNumSuccessors() == 1 && Succ0TI->getSuccessor(0) == BI->getSuccessor(1)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0), DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } } else if (BI->getSuccessor(1)->getSinglePredecessor()) { // If Successor #0 has multiple preds, we may be able to conditionally @@ -4110,7 +4113,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ1TI->getNumSuccessors() == 1 && Succ1TI->getSuccessor(0) == BI->getSuccessor(0)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1), DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; } // If this is a branch on a phi node in the current block, thread control @@ -4118,14 +4121,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition())) if (PN->getParent() == BI->getParent()) if (FoldCondBranchOnPHI(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; // Scan predecessor blocks for conditional branches. for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator())) if (PBI != BI && PBI->isConditional()) if (SimplifyCondBranchToCondBranch(PBI, BI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, DL, AT) | true; return false; } @@ -4269,6 +4272,6 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { /// of the CFG. It returns true if a modification was made. /// bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, - const DataLayout *DL) { - return SimplifyCFGOpt(TTI, DL).run(BB); + const DataLayout *DL, AssumptionTracker *AT) { + return SimplifyCFGOpt(TTI, DL, AT).run(BB); } diff --git a/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp b/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp index 33b36378027..5632095b124 100644 --- a/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" @@ -41,6 +42,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<AssumptionTracker>(); AU.addRequired<TargetLibraryInfo>(); } @@ -52,6 +54,7 @@ namespace { DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfo>(); + AssumptionTracker *AT = &getAnalysis<AssumptionTracker>(); SmallPtrSet<const Instruction*, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; bool Changed = false; @@ -68,7 +71,7 @@ namespace { continue; // Don't waste time simplifying unused instructions. if (!I->use_empty()) - if (Value *V = SimplifyInstruction(I, DL, TLI, DT)) { + if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AT)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) Next->insert(cast<Instruction>(U)); @@ -101,6 +104,7 @@ namespace { char InstSimplifier::ID = 0; INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) |