diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/TypeMetadataUtils.cpp | 44 | ||||
-rw-r--r-- | llvm/lib/IR/Metadata.cpp | 1 | ||||
-rw-r--r-- | llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 154 |
3 files changed, 182 insertions, 17 deletions
diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp index 750cce33856..8d173d77fb6 100644 --- a/llvm/lib/Analysis/TypeMetadataUtils.cpp +++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" @@ -21,15 +22,17 @@ using namespace llvm; // Search for virtual calls that call FPtr and add them to DevirtCalls. static void findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls, - Value *FPtr, uint64_t Offset) { + bool *HasNonCallUses, Value *FPtr, uint64_t Offset) { for (const Use &U : FPtr->uses()) { Value *User = U.getUser(); if (isa<BitCastInst>(User)) { - findCallsAtConstantOffset(DevirtCalls, User, Offset); + findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset); } else if (auto CI = dyn_cast<CallInst>(User)) { DevirtCalls.push_back({Offset, CI}); } else if (auto II = dyn_cast<InvokeInst>(User)) { DevirtCalls.push_back({Offset, II}); + } else if (HasNonCallUses) { + *HasNonCallUses = true; } } } @@ -44,7 +47,7 @@ findLoadCallsAtConstantOffset(Module *M, if (isa<BitCastInst>(User)) { findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset); } else if (isa<LoadInst>(User)) { - findCallsAtConstantOffset(DevirtCalls, User, Offset); + findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset); } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) { // Take into account the GEP offset. if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { @@ -57,7 +60,7 @@ findLoadCallsAtConstantOffset(Module *M, } } -void llvm::findDevirtualizableCalls( +void llvm::findDevirtualizableCallsForTypeTest( SmallVectorImpl<DevirtCallSite> &DevirtCalls, SmallVectorImpl<CallInst *> &Assumes, CallInst *CI) { assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test); @@ -80,3 +83,36 @@ void llvm::findDevirtualizableCalls( findLoadCallsAtConstantOffset(M, DevirtCalls, CI->getArgOperand(0)->stripPointerCasts(), 0); } + +void llvm::findDevirtualizableCallsForTypeCheckedLoad( + SmallVectorImpl<DevirtCallSite> &DevirtCalls, + SmallVectorImpl<Instruction *> &LoadedPtrs, + SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, CallInst *CI) { + assert(CI->getCalledFunction()->getIntrinsicID() == + Intrinsic::type_checked_load); + + auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!Offset) { + HasNonCallUses = true; + return; + } + + for (Use &U : CI->uses()) { + auto CIU = U.getUser(); + if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) { + if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) { + LoadedPtrs.push_back(EVI); + continue; + } + if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) { + Preds.push_back(EVI); + continue; + } + } + HasNonCallUses = true; + } + + for (Value *LoadedPtr : LoadedPtrs) + findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr, + Offset->getZExtValue()); +} diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp index 965f9737629..ed39fbafcb0 100644 --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -1369,7 +1369,6 @@ void GlobalObject::clearMetadata() { setHasMetadataHashEntry(false); } - void GlobalObject::setMetadata(unsigned KindID, MDNode *N) { eraseMetadata(KindID); if (N) diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index c19c667a51e..e6f4fa29224 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -210,6 +210,11 @@ struct VirtualCallSite { Value *VTable; CallSite CS; + // If non-null, this field points to the associated unsafe use count stored in + // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description + // of that field for details. + unsigned *NumUnsafeUses; + void replaceAndErase(Value *New) { CS->replaceAllUsesWith(New); if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { @@ -217,6 +222,9 @@ struct VirtualCallSite { II->getUnwindDest()->removePredecessor(II->getParent()); } CS->eraseFromParent(); + // This use is no longer unsafe. + if (NumUnsafeUses) + --*NumUnsafeUses; } }; @@ -228,11 +236,24 @@ struct DevirtModule { MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; + // This map keeps track of the number of "unsafe" uses of a loaded function + // pointer. The key is the associated llvm.type.test intrinsic call generated + // by this pass. An unsafe use is one that calls the loaded function pointer + // directly. Every time we eliminate an unsafe use (for example, by + // devirtualizing it or by applying virtual constant propagation), we + // decrement the value stored in this map. If a value reaches zero, we can + // eliminate the type check by RAUWing the associated llvm.type.test call with + // true. + std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; + DevirtModule(Module &M) : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())) {} + void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc); + void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); + void buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); @@ -375,6 +396,9 @@ bool DevirtModule::trySingleImplDevirt( for (auto &&VCallSite : CallSites) { VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( TheFn, VCallSite.CS.getCalledValue()->getType())); + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; } return true; } @@ -601,6 +625,10 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { NewGV->setSection(B.GV->getSection()); NewGV->setComdat(B.GV->getComdat()); + // Copy the original vtable's metadata to the anonymous global, adjusting + // offsets as required. + NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); + // Build an alias named after the original global, pointing at the second // element (the original initializer). auto Alias = GlobalAlias::create( @@ -617,16 +645,8 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { B.GV->eraseFromParent(); } -bool DevirtModule::run() { - Function *TypeTestFunc = - M.getFunction(Intrinsic::getName(Intrinsic::type_test)); - if (!TypeTestFunc || TypeTestFunc->use_empty()) - return false; - - Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); - if (!AssumeFunc || AssumeFunc->use_empty()) - return false; - +void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, + Function *AssumeFunc) { // Find all virtual calls via a virtual table pointer %p under an assumption // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p // points to a member of the type identifier %md. Group calls by (type ID, @@ -643,7 +663,7 @@ bool DevirtModule::run() { // Search for virtual calls based on %p and add them to DevirtCalls. SmallVector<DevirtCallSite, 1> DevirtCalls; SmallVector<CallInst *, 1> Assumes; - findDevirtualizableCalls(DevirtCalls, Assumes, CI); + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); // If we found any, add them to CallSlots. Only do this if we haven't seen // the vtable pointer before, as it may have been CSE'd with pointers from @@ -655,7 +675,7 @@ bool DevirtModule::run() { if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { CallSlots[{TypeId, Call.Offset}].push_back( - {CI->getArgOperand(0), Call.CS}); + {CI->getArgOperand(0), Call.CS, nullptr}); } } } @@ -668,6 +688,104 @@ bool DevirtModule::run() { if (CI->use_empty()) CI->eraseFromParent(); } +} + +void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { + Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); + + for (auto I = TypeCheckedLoadFunc->use_begin(), + E = TypeCheckedLoadFunc->use_end(); + I != E;) { + auto CI = dyn_cast<CallInst>(I->getUser()); + ++I; + if (!CI) + continue; + + Value *Ptr = CI->getArgOperand(0); + Value *Offset = CI->getArgOperand(1); + Value *TypeIdValue = CI->getArgOperand(2); + Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); + + SmallVector<DevirtCallSite, 1> DevirtCalls; + SmallVector<Instruction *, 1> LoadedPtrs; + SmallVector<Instruction *, 1> Preds; + bool HasNonCallUses = false; + findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, + HasNonCallUses, CI); + + // Start by generating "pessimistic" code that explicitly loads the function + // pointer from the vtable and performs the type check. If possible, we will + // eliminate the load and the type check later. + + // If possible, only generate the load at the point where it is used. + // This helps avoid unnecessary spills. + IRBuilder<> LoadB( + (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); + Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); + Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); + Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + + for (Instruction *LoadedPtr : LoadedPtrs) { + LoadedPtr->replaceAllUsesWith(LoadedValue); + LoadedPtr->eraseFromParent(); + } + + // Likewise for the type test. + IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); + CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); + + for (Instruction *Pred : Preds) { + Pred->replaceAllUsesWith(TypeTestCall); + Pred->eraseFromParent(); + } + + // We have already erased any extractvalue instructions that refer to the + // intrinsic call, but the intrinsic may have other non-extractvalue uses + // (although this is unlikely). In that case, explicitly build a pair and + // RAUW it. + if (!CI->use_empty()) { + Value *Pair = UndefValue::get(CI->getType()); + IRBuilder<> B(CI); + Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); + Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); + CI->replaceAllUsesWith(Pair); + } + + // The number of unsafe uses is initially the number of uses. + auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; + NumUnsafeUses = DevirtCalls.size(); + + // If the function pointer has a non-call user, we cannot eliminate the type + // check, as one of those users may eventually call the pointer. Increment + // the unsafe use count to make sure it cannot reach zero. + if (HasNonCallUses) + ++NumUnsafeUses; + for (DevirtCallSite Call : DevirtCalls) { + CallSlots[{TypeId, Call.Offset}].push_back( + {Ptr, Call.CS, &NumUnsafeUses}); + } + + CI->eraseFromParent(); + } +} + +bool DevirtModule::run() { + Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); + Function *TypeCheckedLoadFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); + + if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || + AssumeFunc->use_empty()) && + (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) + return false; + + if (TypeTestFunc && AssumeFunc) + scanTypeTestUsers(TypeTestFunc, AssumeFunc); + + if (TypeCheckedLoadFunc) + scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); // Rebuild type metadata into a map for easy lookup. std::vector<VTableBits> Bits; @@ -693,6 +811,18 @@ bool DevirtModule::run() { DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second); } + // If we were able to eliminate all unsafe uses for a type checked load, + // eliminate the type test by replacing it with true. + if (TypeCheckedLoadFunc) { + auto True = ConstantInt::getTrue(M.getContext()); + for (auto &&U : NumUnsafeUsesForTypeTest) { + if (U.second == 0) { + U.first->replaceAllUsesWith(True); + U.first->eraseFromParent(); + } + } + } + // Rebuild each global we touched as part of virtual constant propagation to // include the before and after bytes. if (DidVirtualConstProp) |