From 534c0175b60ae1adf17cdd913c11afea973046ae Mon Sep 17 00:00:00 2001 From: Peter Collingbourne Date: Tue, 14 Feb 2017 22:12:23 +0000 Subject: WholeProgramDevirt: Change internal vcall data structures to match summary. Group calls into constant and non-constant arguments up front, and use uint64_t instead of ConstantInt to represent constant arguments. The goal is to allow the information from the summary to fit naturally into this data structure in a future change (specifically, it will be added to CallSiteInfo). This has two side effects: - We disallow VCP for constant integer arguments of width >64 bits. - We remove the restriction that the bitwidth of a vcall's argument and return types must match those of the vfunc definitions. I don't expect either of these to matter in practice. The first case is uncommon, and the second one will lead to UB (so we can do anything we like). Differential Revision: https://reviews.llvm.org/D29744 llvm-svn: 295110 --- llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 168 ++++++++++++++----------- 1 file changed, 94 insertions(+), 74 deletions(-) (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp') diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 73ea5c19267..470a9f62101 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -282,6 +282,48 @@ struct VirtualCallSite { } }; +// Call site information collected for a specific VTableSlot and possibly a list +// of constant integer arguments. The grouping by arguments is handled by the +// VTableSlotInfo class. +struct CallSiteInfo { + std::vector CallSites; +}; + +// Call site information collected for a specific VTableSlot. +struct VTableSlotInfo { + // The set of call sites which do not have all constant integer arguments + // (excluding "this"). + CallSiteInfo CSInfo; + + // The set of call sites with all constant integer arguments (excluding + // "this"), grouped by argument list. + std::map, CallSiteInfo> ConstCSInfo; + + void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); + +private: + CallSiteInfo &findCallSiteInfo(CallSite CS); +}; + +CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { + std::vector Args; + auto *CI = dyn_cast(CS.getType()); + if (!CI || CI->getBitWidth() > 64) + return CSInfo; + for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { + auto *CI = dyn_cast(Arg); + if (!CI || CI->getBitWidth() > 64) + return CSInfo; + Args.push_back(CI->getZExtValue()); + } + return ConstCSInfo[Args]; +} + +void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, + unsigned *NumUnsafeUses) { + findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses}); +} + struct DevirtModule { Module &M; @@ -294,7 +336,7 @@ struct DevirtModule { bool RemarksEnabled; - MapVector> CallSlots; + MapVector CallSlots; // This map keeps track of the number of "unsafe" uses of a loaded function // pointer. The key is the associated llvm.type.test intrinsic call generated @@ -327,18 +369,17 @@ struct DevirtModule { const std::set &TypeMemberInfos, uint64_t ByteOffset); bool trySingleImplDevirt(MutableArrayRef TargetsForSlot, - MutableArrayRef CallSites); + VTableSlotInfo &SlotInfo); bool tryEvaluateFunctionsWithArgs( MutableArrayRef TargetsForSlot, - ArrayRef Args); - bool tryUniformRetValOpt(IntegerType *RetType, - MutableArrayRef TargetsForSlot, - MutableArrayRef CallSites); + ArrayRef Args); + bool tryUniformRetValOpt(MutableArrayRef TargetsForSlot, + CallSiteInfo &CSInfo); bool tryUniqueRetValOpt(unsigned BitWidth, MutableArrayRef TargetsForSlot, - MutableArrayRef CallSites); + CallSiteInfo &CSInfo); bool tryVirtualConstProp(MutableArrayRef TargetsForSlot, - ArrayRef CallSites); + VTableSlotInfo &SlotInfo); void rebuildGlobal(VTableBits &B); @@ -521,7 +562,7 @@ bool DevirtModule::tryFindVirtualCallTargets( bool DevirtModule::trySingleImplDevirt( MutableArrayRef TargetsForSlot, - MutableArrayRef CallSites) { + VTableSlotInfo &SlotInfo) { // See if the program contains a single implementation of this virtual // function. Function *TheFn = TargetsForSlot[0].Fn; @@ -532,36 +573,44 @@ bool DevirtModule::trySingleImplDevirt( if (RemarksEnabled) TargetsForSlot[0].WasDevirt = true; // If so, update each call site to call that implementation directly. - for (auto &&VCallSite : CallSites) { - if (RemarksEnabled) - VCallSite.emitRemark("single-impl", TheFn->getName()); - VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( - TheFn, VCallSite.CS.getCalledValue()->getType())); - // This use is no longer unsafe. - if (VCallSite.NumUnsafeUses) - --*VCallSite.NumUnsafeUses; - } + auto Apply = [&](CallSiteInfo &CSInfo) { + for (auto &&VCallSite : CSInfo.CallSites) { + if (RemarksEnabled) + VCallSite.emitRemark("single-impl", TheFn->getName()); + VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( + TheFn, VCallSite.CS.getCalledValue()->getType())); + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; + } + }; + Apply(SlotInfo.CSInfo); + for (auto &P : SlotInfo.ConstCSInfo) + Apply(P.second); return true; } bool DevirtModule::tryEvaluateFunctionsWithArgs( MutableArrayRef TargetsForSlot, - ArrayRef Args) { + ArrayRef Args) { // Evaluate each function and store the result in each target's RetVal // field. for (VirtualCallTarget &Target : TargetsForSlot) { if (Target.Fn->arg_size() != Args.size() + 1) return false; - for (unsigned I = 0; I != Args.size(); ++I) - if (Target.Fn->getFunctionType()->getParamType(I + 1) != - Args[I]->getType()) - return false; Evaluator Eval(M.getDataLayout(), nullptr); SmallVector EvalArgs; EvalArgs.push_back( Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); - EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end()); + for (unsigned I = 0; I != Args.size(); ++I) { + auto *ArgTy = dyn_cast( + Target.Fn->getFunctionType()->getParamType(I + 1)); + if (!ArgTy) + return false; + EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); + } + Constant *RetVal; if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || !isa(RetVal)) @@ -572,8 +621,7 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs( } bool DevirtModule::tryUniformRetValOpt( - IntegerType *RetType, MutableArrayRef TargetsForSlot, - MutableArrayRef CallSites) { + MutableArrayRef TargetsForSlot, CallSiteInfo &CSInfo) { // Uniform return value optimization. If all functions return the same // constant, replace all calls with that constant. uint64_t TheRetVal = TargetsForSlot[0].RetVal; @@ -581,10 +629,10 @@ bool DevirtModule::tryUniformRetValOpt( if (Target.RetVal != TheRetVal) return false; - auto TheRetValConst = ConstantInt::get(RetType, TheRetVal); - for (auto Call : CallSites) + for (auto Call : CSInfo.CallSites) Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(), - RemarksEnabled, TheRetValConst); + RemarksEnabled, + ConstantInt::get(Call.CS->getType(), TheRetVal)); if (RemarksEnabled) for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; @@ -593,7 +641,7 @@ bool DevirtModule::tryUniformRetValOpt( bool DevirtModule::tryUniqueRetValOpt( unsigned BitWidth, MutableArrayRef TargetsForSlot, - MutableArrayRef CallSites) { + CallSiteInfo &CSInfo) { // IsOne controls whether we look for a 0 or a 1. auto tryUniqueRetValOptFor = [&](bool IsOne) { const TypeMemberInfo *UniqueMember = nullptr; @@ -610,12 +658,13 @@ bool DevirtModule::tryUniqueRetValOpt( assert(UniqueMember); // Replace each call with the comparison. - for (auto &&Call : CallSites) { + for (auto &&Call : CSInfo.CallSites) { IRBuilder<> B(Call.CS.getInstruction()); Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy); OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset); Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, OneAddr); + Cmp = B.CreateZExt(Cmp, Call.CS->getType()); Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(), RemarksEnabled, Cmp); } @@ -638,7 +687,7 @@ bool DevirtModule::tryUniqueRetValOpt( bool DevirtModule::tryVirtualConstProp( MutableArrayRef TargetsForSlot, - ArrayRef CallSites) { + VTableSlotInfo &SlotInfo) { // This only works if the function returns an integer. auto RetType = dyn_cast(TargetsForSlot[0].Fn->getReturnType()); if (!RetType) @@ -657,42 +706,11 @@ bool DevirtModule::tryVirtualConstProp( return false; } - // Group call sites by the list of constant arguments they pass. - // The comparator ensures deterministic ordering. - struct ByAPIntValue { - bool operator()(const std::vector &A, - const std::vector &B) const { - return std::lexicographical_compare( - A.begin(), A.end(), B.begin(), B.end(), - [](ConstantInt *AI, ConstantInt *BI) { - return AI->getValue().ult(BI->getValue()); - }); - } - }; - std::map, std::vector, - ByAPIntValue> - VCallSitesByConstantArg; - for (auto &&VCallSite : CallSites) { - std::vector Args; - if (VCallSite.CS.getType() != RetType) - continue; - for (auto &&Arg : - make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) { - if (!isa(Arg)) - break; - Args.push_back(cast(&Arg)); - } - if (Args.size() + 1 != VCallSite.CS.arg_size()) - continue; - - VCallSitesByConstantArg[Args].push_back(VCallSite); - } - - for (auto &&CSByConstantArg : VCallSitesByConstantArg) { + for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) continue; - if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second)) + if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second)) continue; if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) @@ -736,20 +754,22 @@ bool DevirtModule::tryVirtualConstProp( Target.WasDevirt = true; // Rewrite each call to a load from OffsetByte/OffsetBit. - for (auto Call : CSByConstantArg.second) { + for (auto Call : CSByConstantArg.second.CallSites) { + auto *CSRetType = cast(Call.CS.getType()); IRBuilder<> B(Call.CS.getInstruction()); Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte); - if (BitWidth == 1) { + if (CSRetType->getBitWidth() == 1) { Value *Bits = B.CreateLoad(Addr); Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); Value *BitsAndBit = B.CreateAnd(Bits, Bit); - auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); + Value *IsBitSet = + B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); Call.replaceAndErase("virtual-const-prop-1-bit", TargetsForSlot[0].Fn->getName(), RemarksEnabled, IsBitSet); } else { - Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); - Value *Val = B.CreateLoad(RetType, ValAddr); + Value *ValAddr = B.CreateBitCast(Addr, CSRetType->getPointerTo()); + Value *Val = B.CreateLoad(CSRetType, ValAddr); Call.replaceAndErase("virtual-const-prop", TargetsForSlot[0].Fn->getName(), RemarksEnabled, Val); @@ -842,8 +862,8 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].push_back( - {CI->getArgOperand(0), Call.CS, nullptr}); + CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0), + Call.CS, nullptr); } } } @@ -929,8 +949,8 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { if (HasNonCallUses) ++NumUnsafeUses; for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].push_back( - {Ptr, Call.CS, &NumUnsafeUses}); + CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, + &NumUnsafeUses); } CI->eraseFromParent(); -- cgit v1.2.3