diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r-- | llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 154 |
1 files changed, 142 insertions, 12 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index c19c667a51e..e6f4fa29224 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -210,6 +210,11 @@ struct VirtualCallSite { Value *VTable; CallSite CS; + // If non-null, this field points to the associated unsafe use count stored in + // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description + // of that field for details. + unsigned *NumUnsafeUses; + void replaceAndErase(Value *New) { CS->replaceAllUsesWith(New); if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { @@ -217,6 +222,9 @@ struct VirtualCallSite { II->getUnwindDest()->removePredecessor(II->getParent()); } CS->eraseFromParent(); + // This use is no longer unsafe. + if (NumUnsafeUses) + --*NumUnsafeUses; } }; @@ -228,11 +236,24 @@ struct DevirtModule { MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; + // This map keeps track of the number of "unsafe" uses of a loaded function + // pointer. The key is the associated llvm.type.test intrinsic call generated + // by this pass. An unsafe use is one that calls the loaded function pointer + // directly. Every time we eliminate an unsafe use (for example, by + // devirtualizing it or by applying virtual constant propagation), we + // decrement the value stored in this map. If a value reaches zero, we can + // eliminate the type check by RAUWing the associated llvm.type.test call with + // true. + std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; + DevirtModule(Module &M) : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())) {} + void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc); + void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); + void buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); @@ -375,6 +396,9 @@ bool DevirtModule::trySingleImplDevirt( for (auto &&VCallSite : CallSites) { VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( TheFn, VCallSite.CS.getCalledValue()->getType())); + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; } return true; } @@ -601,6 +625,10 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { NewGV->setSection(B.GV->getSection()); NewGV->setComdat(B.GV->getComdat()); + // Copy the original vtable's metadata to the anonymous global, adjusting + // offsets as required. + NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); + // Build an alias named after the original global, pointing at the second // element (the original initializer). auto Alias = GlobalAlias::create( @@ -617,16 +645,8 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { B.GV->eraseFromParent(); } -bool DevirtModule::run() { - Function *TypeTestFunc = - M.getFunction(Intrinsic::getName(Intrinsic::type_test)); - if (!TypeTestFunc || TypeTestFunc->use_empty()) - return false; - - Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); - if (!AssumeFunc || AssumeFunc->use_empty()) - return false; - +void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, + Function *AssumeFunc) { // Find all virtual calls via a virtual table pointer %p under an assumption // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p // points to a member of the type identifier %md. Group calls by (type ID, @@ -643,7 +663,7 @@ bool DevirtModule::run() { // Search for virtual calls based on %p and add them to DevirtCalls. SmallVector<DevirtCallSite, 1> DevirtCalls; SmallVector<CallInst *, 1> Assumes; - findDevirtualizableCalls(DevirtCalls, Assumes, CI); + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); // If we found any, add them to CallSlots. Only do this if we haven't seen // the vtable pointer before, as it may have been CSE'd with pointers from @@ -655,7 +675,7 @@ bool DevirtModule::run() { if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { CallSlots[{TypeId, Call.Offset}].push_back( - {CI->getArgOperand(0), Call.CS}); + {CI->getArgOperand(0), Call.CS, nullptr}); } } } @@ -668,6 +688,104 @@ bool DevirtModule::run() { if (CI->use_empty()) CI->eraseFromParent(); } +} + +void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { + Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); + + for (auto I = TypeCheckedLoadFunc->use_begin(), + E = TypeCheckedLoadFunc->use_end(); + I != E;) { + auto CI = dyn_cast<CallInst>(I->getUser()); + ++I; + if (!CI) + continue; + + Value *Ptr = CI->getArgOperand(0); + Value *Offset = CI->getArgOperand(1); + Value *TypeIdValue = CI->getArgOperand(2); + Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); + + SmallVector<DevirtCallSite, 1> DevirtCalls; + SmallVector<Instruction *, 1> LoadedPtrs; + SmallVector<Instruction *, 1> Preds; + bool HasNonCallUses = false; + findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, + HasNonCallUses, CI); + + // Start by generating "pessimistic" code that explicitly loads the function + // pointer from the vtable and performs the type check. If possible, we will + // eliminate the load and the type check later. + + // If possible, only generate the load at the point where it is used. + // This helps avoid unnecessary spills. + IRBuilder<> LoadB( + (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); + Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); + Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); + Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + + for (Instruction *LoadedPtr : LoadedPtrs) { + LoadedPtr->replaceAllUsesWith(LoadedValue); + LoadedPtr->eraseFromParent(); + } + + // Likewise for the type test. + IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); + CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); + + for (Instruction *Pred : Preds) { + Pred->replaceAllUsesWith(TypeTestCall); + Pred->eraseFromParent(); + } + + // We have already erased any extractvalue instructions that refer to the + // intrinsic call, but the intrinsic may have other non-extractvalue uses + // (although this is unlikely). In that case, explicitly build a pair and + // RAUW it. + if (!CI->use_empty()) { + Value *Pair = UndefValue::get(CI->getType()); + IRBuilder<> B(CI); + Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); + Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); + CI->replaceAllUsesWith(Pair); + } + + // The number of unsafe uses is initially the number of uses. + auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; + NumUnsafeUses = DevirtCalls.size(); + + // If the function pointer has a non-call user, we cannot eliminate the type + // check, as one of those users may eventually call the pointer. Increment + // the unsafe use count to make sure it cannot reach zero. + if (HasNonCallUses) + ++NumUnsafeUses; + for (DevirtCallSite Call : DevirtCalls) { + CallSlots[{TypeId, Call.Offset}].push_back( + {Ptr, Call.CS, &NumUnsafeUses}); + } + + CI->eraseFromParent(); + } +} + +bool DevirtModule::run() { + Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); + Function *TypeCheckedLoadFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); + + if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || + AssumeFunc->use_empty()) && + (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) + return false; + + if (TypeTestFunc && AssumeFunc) + scanTypeTestUsers(TypeTestFunc, AssumeFunc); + + if (TypeCheckedLoadFunc) + scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); // Rebuild type metadata into a map for easy lookup. std::vector<VTableBits> Bits; @@ -693,6 +811,18 @@ bool DevirtModule::run() { DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second); } + // If we were able to eliminate all unsafe uses for a type checked load, + // eliminate the type test by replacing it with true. + if (TypeCheckedLoadFunc) { + auto True = ConstantInt::getTrue(M.getContext()); + for (auto &&U : NumUnsafeUsesForTypeTest) { + if (U.second == 0) { + U.first->replaceAllUsesWith(True); + U.first->eraseFromParent(); + } + } + } + // Rebuild each global we touched as part of virtual constant propagation to // include the before and after bytes. if (DidVirtualConstProp) |