diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r-- | llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 150 |
1 files changed, 75 insertions, 75 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 93ee07cf64e..c19c667a51e 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// // // This pass implements whole program optimization of virtual calls in cases -// where we know (via bitset information) that the list of callee is fixed. This +// where we know (via !type metadata) that the list of callees is fixed. This // includes the following: // - Single implementation devirtualization: if a virtual call has a single // possible callee, replace all calls with a direct call to that callee. @@ -31,7 +31,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" -#include "llvm/Analysis/BitSetUtils.h" +#include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -89,8 +89,8 @@ wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, // at MinByte. std::vector<ArrayRef<uint8_t>> Used; for (const VirtualCallTarget &Target : Targets) { - ArrayRef<uint8_t> VTUsed = IsAfter ? Target.BS->Bits->After.BytesUsed - : Target.BS->Bits->Before.BytesUsed; + ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed + : Target.TM->Bits->Before.BytesUsed; uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() : MinByte - Target.minBeforeBytes(); @@ -163,17 +163,17 @@ void wholeprogramdevirt::setAfterReturnValues( } } -VirtualCallTarget::VirtualCallTarget(Function *Fn, const BitSetInfo *BS) - : Fn(Fn), BS(BS), +VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) + : Fn(Fn), TM(TM), IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()) {} namespace { -// A slot in a set of virtual tables. The BitSetID identifies the set of virtual +// A slot in a set of virtual tables. The TypeID identifies the set of virtual // tables, and the ByteOffset is the offset in bytes from the address point to // the virtual function pointer. struct VTableSlot { - Metadata *BitSetID; + Metadata *TypeID; uint64_t ByteOffset; }; @@ -191,12 +191,12 @@ template <> struct DenseMapInfo<VTableSlot> { DenseMapInfo<uint64_t>::getTombstoneKey()}; } static unsigned getHashValue(const VTableSlot &I) { - return DenseMapInfo<Metadata *>::getHashValue(I.BitSetID) ^ + return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); } static bool isEqual(const VTableSlot &LHS, const VTableSlot &RHS) { - return LHS.BitSetID == RHS.BitSetID && LHS.ByteOffset == RHS.ByteOffset; + return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; } }; @@ -233,11 +233,13 @@ struct DevirtModule { Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())) {} - void buildBitSets(std::vector<VTableBits> &Bits, - DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets); - bool tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, - const std::set<BitSetInfo> &BitSetInfos, - uint64_t ByteOffset); + void buildTypeIdentifierMap( + std::vector<VTableBits> &Bits, + DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); + bool + tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, + const std::set<TypeMemberInfo> &TypeMemberInfos, + uint64_t ByteOffset); bool trySingleImplDevirt(ArrayRef<VirtualCallTarget> TargetsForSlot, MutableArrayRef<VirtualCallSite> CallSites); bool tryEvaluateFunctionsWithArgs( @@ -287,60 +289,55 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return PreservedAnalyses::none(); } -void DevirtModule::buildBitSets( +void DevirtModule::buildTypeIdentifierMap( std::vector<VTableBits> &Bits, - DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets) { - NamedMDNode *BitSetNM = M.getNamedMetadata("llvm.bitsets"); - if (!BitSetNM) - return; - + DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { DenseMap<GlobalVariable *, VTableBits *> GVToBits; - Bits.reserve(BitSetNM->getNumOperands()); - for (auto Op : BitSetNM->operands()) { - auto OpConstMD = dyn_cast_or_null<ConstantAsMetadata>(Op->getOperand(1)); - if (!OpConstMD) + Bits.reserve(M.getGlobalList().size()); + SmallVector<MDNode *, 2> Types; + for (GlobalVariable &GV : M.globals()) { + Types.clear(); + GV.getMetadata(LLVMContext::MD_type, Types); + if (Types.empty()) continue; - auto BitSetID = Op->getOperand(0).get(); - - Constant *OpConst = OpConstMD->getValue(); - if (auto GA = dyn_cast<GlobalAlias>(OpConst)) - OpConst = GA->getAliasee(); - auto OpGlobal = dyn_cast<GlobalVariable>(OpConst); - if (!OpGlobal) - continue; - - uint64_t Offset = - cast<ConstantInt>( - cast<ConstantAsMetadata>(Op->getOperand(2))->getValue()) - ->getZExtValue(); - VTableBits *&BitsPtr = GVToBits[OpGlobal]; + VTableBits *&BitsPtr = GVToBits[&GV]; if (!BitsPtr) { Bits.emplace_back(); - Bits.back().GV = OpGlobal; - Bits.back().ObjectSize = M.getDataLayout().getTypeAllocSize( - OpGlobal->getInitializer()->getType()); + Bits.back().GV = &GV; + Bits.back().ObjectSize = + M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); BitsPtr = &Bits.back(); } - BitSets[BitSetID].insert({BitsPtr, Offset}); + + for (MDNode *Type : Types) { + auto TypeID = Type->getOperand(1).get(); + + uint64_t Offset = + cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); + + TypeIdMap[TypeID].insert({BitsPtr, Offset}); + } } } bool DevirtModule::tryFindVirtualCallTargets( std::vector<VirtualCallTarget> &TargetsForSlot, - const std::set<BitSetInfo> &BitSetInfos, uint64_t ByteOffset) { - for (const BitSetInfo &BS : BitSetInfos) { - if (!BS.Bits->GV->isConstant()) + const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) { + for (const TypeMemberInfo &TM : TypeMemberInfos) { + if (!TM.Bits->GV->isConstant()) return false; - auto Init = dyn_cast<ConstantArray>(BS.Bits->GV->getInitializer()); + auto Init = dyn_cast<ConstantArray>(TM.Bits->GV->getInitializer()); if (!Init) return false; ArrayType *VTableTy = Init->getType(); uint64_t ElemSize = M.getDataLayout().getTypeAllocSize(VTableTy->getElementType()); - uint64_t GlobalSlotOffset = BS.Offset + ByteOffset; + uint64_t GlobalSlotOffset = TM.Offset + ByteOffset; if (GlobalSlotOffset % ElemSize != 0) return false; @@ -357,7 +354,7 @@ bool DevirtModule::tryFindVirtualCallTargets( if (Fn->getName() == "__cxa_pure_virtual") continue; - TargetsForSlot.push_back({Fn, &BS}); + TargetsForSlot.push_back({Fn, &TM}); } // Give up if we couldn't find any targets. @@ -430,24 +427,24 @@ bool DevirtModule::tryUniqueRetValOpt( MutableArrayRef<VirtualCallSite> CallSites) { // IsOne controls whether we look for a 0 or a 1. auto tryUniqueRetValOptFor = [&](bool IsOne) { - const BitSetInfo *UniqueBitSet = 0; + const TypeMemberInfo *UniqueMember = 0; for (const VirtualCallTarget &Target : TargetsForSlot) { if (Target.RetVal == (IsOne ? 1 : 0)) { - if (UniqueBitSet) + if (UniqueMember) return false; - UniqueBitSet = Target.BS; + UniqueMember = Target.TM; } } - // We should have found a unique bit set or bailed out by now. We already + // We should have found a unique member or bailed out by now. We already // checked for a uniform return value in tryUniformRetValOpt. - assert(UniqueBitSet); + assert(UniqueMember); // Replace each call with the comparison. for (auto &&Call : CallSites) { IRBuilder<> B(Call.CS.getInstruction()); - Value *OneAddr = B.CreateBitCast(UniqueBitSet->Bits->GV, Int8PtrTy); - OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueBitSet->Offset); + Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy); + OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset); Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, OneAddr); Call.replaceAndErase(Cmp); @@ -526,7 +523,8 @@ bool DevirtModule::tryVirtualConstProp( if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) continue; - // Find an allocation offset in bits in all vtables in the bitset. + // Find an allocation offset in bits in all vtables associated with the + // type. uint64_t AllocBefore = findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); uint64_t AllocAfter = @@ -620,9 +618,9 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { } bool DevirtModule::run() { - Function *BitSetTestFunc = - M.getFunction(Intrinsic::getName(Intrinsic::bitset_test)); - if (!BitSetTestFunc || BitSetTestFunc->use_empty()) + Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); + if (!TypeTestFunc || TypeTestFunc->use_empty()) return false; Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); @@ -630,11 +628,12 @@ bool DevirtModule::run() { return false; // Find all virtual calls via a virtual table pointer %p under an assumption - // of the form llvm.assume(llvm.bitset.test(%p, %md)). This indicates that %p - // points to a vtable in the bitset %md. Group calls by (bitset, offset) pair - // (effectively the identity of the virtual function) and store to CallSlots. + // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p + // points to a member of the type identifier %md. Group calls by (type ID, + // offset) pair (effectively the identity of the virtual function) and store + // to CallSlots. DenseSet<Value *> SeenPtrs; - for (auto I = BitSetTestFunc->use_begin(), E = BitSetTestFunc->use_end(); + for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); I != E;) { auto CI = dyn_cast<CallInst>(I->getUser()); ++I; @@ -650,18 +649,18 @@ bool DevirtModule::run() { // the vtable pointer before, as it may have been CSE'd with pointers from // other call sites, and we don't want to process call sites multiple times. if (!Assumes.empty()) { - Metadata *BitSet = + Metadata *TypeId = cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{BitSet, Call.Offset}].push_back( + CallSlots[{TypeId, Call.Offset}].push_back( {CI->getArgOperand(0), Call.CS}); } } } - // We no longer need the assumes or the bitset test. + // We no longer need the assumes or the type test. for (auto Assume : Assumes) Assume->eraseFromParent(); // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we @@ -670,20 +669,21 @@ bool DevirtModule::run() { CI->eraseFromParent(); } - // Rebuild llvm.bitsets metadata into a map for easy lookup. + // Rebuild type metadata into a map for easy lookup. std::vector<VTableBits> Bits; - DenseMap<Metadata *, std::set<BitSetInfo>> BitSets; - buildBitSets(Bits, BitSets); - if (BitSets.empty()) + DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; + buildTypeIdentifierMap(Bits, TypeIdMap); + if (TypeIdMap.empty()) return true; - // For each (bitset, offset) pair: + // For each (type, offset) pair: bool DidVirtualConstProp = false; for (auto &S : CallSlots) { - // Search each of the vtables in the bitset for the virtual function - // implementation at offset S.first.ByteOffset, and add to TargetsForSlot. + // Search each of the members of the type identifier for the virtual + // function implementation at offset S.first.ByteOffset, and add to + // TargetsForSlot. std::vector<VirtualCallTarget> TargetsForSlot; - if (!tryFindVirtualCallTargets(TargetsForSlot, BitSets[S.first.BitSetID], + if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], S.first.ByteOffset)) continue; |