diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/CMakeLists.txt | 2 | ||||
-rw-r--r-- | llvm/lib/Analysis/TypeMetadataUtils.cpp (renamed from llvm/lib/Analysis/BitSetUtils.cpp) | 12 | ||||
-rw-r--r-- | llvm/lib/IR/LLVMContext.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/IR/Metadata.cpp | 25 | ||||
-rw-r--r-- | llvm/lib/Linker/IRMover.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Transforms/IPO/CMakeLists.txt | 2 | ||||
-rw-r--r-- | llvm/lib/Transforms/IPO/CrossDSOCFI.cpp | 51 | ||||
-rw-r--r-- | llvm/lib/Transforms/IPO/IPO.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Transforms/IPO/LowerTypeTests.cpp (renamed from llvm/lib/Transforms/IPO/LowerBitSets.cpp) | 374 | ||||
-rw-r--r-- | llvm/lib/Transforms/IPO/PassManagerBuilder.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 150 |
11 files changed, 306 insertions, 328 deletions
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt index a4233df4fb4..56ab9a0b384 100644 --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -5,7 +5,6 @@ add_llvm_library(LLVMAnalysis Analysis.cpp AssumptionCache.cpp BasicAliasAnalysis.cpp - BitSetUtils.cpp BlockFrequencyInfo.cpp BlockFrequencyInfoImpl.cpp BranchProbabilityInfo.cpp @@ -71,6 +70,7 @@ add_llvm_library(LLVMAnalysis TargetTransformInfo.cpp Trace.cpp TypeBasedAliasAnalysis.cpp + TypeMetadataUtils.cpp ScopedNoAliasAA.cpp ValueTracking.cpp VectorUtils.cpp diff --git a/llvm/lib/Analysis/BitSetUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp index d83dca60cce..750cce33856 100644 --- a/llvm/lib/Analysis/BitSetUtils.cpp +++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -1,4 +1,4 @@ -//===- BitSetUtils.cpp - Utilities related to pointer bitsets -------------===// +//===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===// // // The LLVM Compiler Infrastructure // @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// // -// This file contains functions that make it easier to manipulate bitsets for -// devirtualization. +// This file contains functions that make it easier to manipulate type metadata +// for devirtualization. // //===----------------------------------------------------------------------===// -#include "llvm/Analysis/BitSetUtils.h" +#include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" @@ -60,11 +60,11 @@ findLoadCallsAtConstantOffset(Module *M, void llvm::findDevirtualizableCalls( SmallVectorImpl<DevirtCallSite> &DevirtCalls, SmallVectorImpl<CallInst *> &Assumes, CallInst *CI) { - assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::bitset_test); + assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test); Module *M = CI->getParent()->getParent()->getParent(); - // Find llvm.assume intrinsics for this llvm.bitset.test call. + // Find llvm.assume intrinsics for this llvm.type.test call. for (const Use &CIU : CI->uses()) { auto AssumeCI = dyn_cast<CallInst>(CIU.getUser()); if (AssumeCI) { diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp index 9f8680cf6c0..c7ba400ced1 100644 --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -134,6 +134,10 @@ LLVMContext::LLVMContext() : pImpl(new LLVMContextImpl(*this)) { assert(LoopID == MD_loop && "llvm.loop kind id drifted"); (void)LoopID; + unsigned TypeID = getMDKindID("type"); + assert(TypeID == MD_type && "type kind id drifted"); + (void)TypeID; + auto *DeoptEntry = pImpl->getOrInsertBundleTag("deopt"); assert(DeoptEntry->second == LLVMContext::OB_deopt && "deopt operand bundle id drifted!"); diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp index 3b5f87a2b40..965f9737629 100644 --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -1393,11 +1393,32 @@ MDNode *GlobalObject::getMetadata(StringRef Kind) const { return getMetadata(getContext().getMDKindID(Kind)); } -void GlobalObject::copyMetadata(const GlobalObject *Other) { +void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) { SmallVector<std::pair<unsigned, MDNode *>, 8> MDs; Other->getAllMetadata(MDs); - for (auto &MD : MDs) + for (auto &MD : MDs) { + // We need to adjust the type metadata offset. + if (Offset != 0 && MD.first == LLVMContext::MD_type) { + auto *OffsetConst = cast<ConstantInt>( + cast<ConstantAsMetadata>(MD.second->getOperand(0))->getValue()); + Metadata *TypeId = MD.second->getOperand(1); + auto *NewOffsetMD = ConstantAsMetadata::get(ConstantInt::get( + OffsetConst->getType(), OffsetConst->getValue() + Offset)); + addMetadata(LLVMContext::MD_type, + *MDNode::get(getContext(), {NewOffsetMD, TypeId})); + continue; + } addMetadata(MD.first, *MD.second); + } +} + +void GlobalObject::addTypeMetadata(unsigned Offset, Metadata *TypeID) { + addMetadata( + LLVMContext::MD_type, + *MDTuple::get(getContext(), + {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( + Type::getInt64Ty(getContext()), Offset)), + TypeID})); } void Function::setSubprogram(DISubprogram *SP) { diff --git a/llvm/lib/Linker/IRMover.cpp b/llvm/lib/Linker/IRMover.cpp index 2f3813f6184..4935868c00f 100644 --- a/llvm/lib/Linker/IRMover.cpp +++ b/llvm/lib/Linker/IRMover.cpp @@ -641,7 +641,7 @@ GlobalValue *IRLinker::copyGlobalValueProto(const GlobalValue *SGV, if (auto *NewGO = dyn_cast<GlobalObject>(NewGV)) { // Metadata for global variables and function declarations is copied eagerly. if (isa<GlobalVariable>(SGV) || SGV->isDeclaration()) - NewGO->copyMetadata(cast<GlobalObject>(SGV)); + NewGO->copyMetadata(cast<GlobalObject>(SGV), 0); } // Remove these copied constants in case this stays a declaration, since @@ -967,7 +967,7 @@ Error IRLinker::linkFunctionBody(Function &Dst, Function &Src) { Dst.setPersonalityFn(Src.getPersonalityFn()); // Copy over the metadata attachments without remapping. - Dst.copyMetadata(&Src); + Dst.copyMetadata(&Src, 0); // Steal arguments and splice the body of Src into Dst. Dst.stealArgumentListFrom(Src); diff --git a/llvm/lib/Transforms/IPO/CMakeLists.txt b/llvm/lib/Transforms/IPO/CMakeLists.txt index 367b395d7e3..d6782c738cb 100644 --- a/llvm/lib/Transforms/IPO/CMakeLists.txt +++ b/llvm/lib/Transforms/IPO/CMakeLists.txt @@ -19,7 +19,7 @@ add_llvm_library(LLVMipo Inliner.cpp Internalize.cpp LoopExtractor.cpp - LowerBitSets.cpp + LowerTypeTests.cpp MergeFunctions.cpp PartialInlining.cpp PassManagerBuilder.cpp diff --git a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index 4c62cbe2da5..99d58333954 100644 --- a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -36,7 +36,7 @@ using namespace llvm; #define DEBUG_TYPE "cross-dso-cfi" -STATISTIC(TypeIds, "Number of unique type identifiers"); +STATISTIC(NumTypeIds, "Number of unique type identifiers"); namespace { @@ -49,7 +49,7 @@ struct CrossDSOCFI : public ModulePass { Module *M; MDNode *VeryLikelyWeights; - ConstantInt *extractBitSetTypeId(MDNode *MD); + ConstantInt *extractNumericTypeId(MDNode *MD); void buildCFICheck(); bool doInitialization(Module &M) override; @@ -73,10 +73,10 @@ bool CrossDSOCFI::doInitialization(Module &Mod) { return false; } -/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode. -ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) { +/// Extracts a numeric type identifier from an MDNode containing type metadata. +ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) { // This check excludes vtables for classes inside anonymous namespaces. - auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0)); + auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(1)); if (!TM) return nullptr; auto C = dyn_cast_or_null<ConstantInt>(TM->getValue()); @@ -84,29 +84,27 @@ ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) { // We are looking for i64 constants. if (C->getBitWidth() != 64) return nullptr; - // Sanity check. - auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1)); - // Can be null if a function was removed by an optimization. - if (FM) { - auto F = dyn_cast<Function>(FM->getValue()); - // But can never be a function declaration. - assert(!F || !F->isDeclaration()); - (void)F; // Suppress unused variable warning in the no-asserts build. - } return C; } /// buildCFICheck - emits __cfi_check for the current module. void CrossDSOCFI::buildCFICheck() { // FIXME: verify that __cfi_check ends up near the end of the code section, - // but before the jump slots created in LowerBitSets. - llvm::DenseSet<uint64_t> BitSetIds; - NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets"); - - if (BitSetNM) - for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I) - if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I))) - BitSetIds.insert(TypeId->getZExtValue()); + // but before the jump slots created in LowerTypeTests. + llvm::DenseSet<uint64_t> TypeIds; + SmallVector<MDNode *, 2> Types; + for (GlobalObject &GO : M->global_objects()) { + Types.clear(); + GO.getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + // Sanity check. GO must not be a function declaration. + auto F = dyn_cast<Function>(&GO); + assert(!F || !F->isDeclaration()); + + if (ConstantInt *TypeId = extractNumericTypeId(Type)) + TypeIds.insert(TypeId->getZExtValue()); + } + } LLVMContext &Ctx = M->getContext(); Constant *C = M->getOrInsertFunction( @@ -138,13 +136,12 @@ void CrossDSOCFI::buildCFICheck() { IRBExit.CreateRetVoid(); IRBuilder<> IRB(BB); - SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size()); - for (uint64_t TypeId : BitSetIds) { + SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, TypeIds.size()); + for (uint64_t TypeId : TypeIds) { ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId); BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F); IRBuilder<> IRBTest(TestBB); - Function *BitsetTestFn = - Intrinsic::getDeclaration(M, Intrinsic::bitset_test); + Function *BitsetTestFn = Intrinsic::getDeclaration(M, Intrinsic::type_test); Value *Test = IRBTest.CreateCall( BitsetTestFn, {&Addr, MetadataAsValue::get( @@ -153,7 +150,7 @@ void CrossDSOCFI::buildCFICheck() { BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights); SI->addCase(CaseTypeId, TestBB); - ++TypeIds; + ++NumTypeIds; } } diff --git a/llvm/lib/Transforms/IPO/IPO.cpp b/llvm/lib/Transforms/IPO/IPO.cpp index c80f08787bf..35520e4a27e 100644 --- a/llvm/lib/Transforms/IPO/IPO.cpp +++ b/llvm/lib/Transforms/IPO/IPO.cpp @@ -39,7 +39,7 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeLoopExtractorPass(Registry); initializeBlockExtractorPassPass(Registry); initializeSingleLoopExtractorPass(Registry); - initializeLowerBitSetsPass(Registry); + initializeLowerTypeTestsPass(Registry); initializeMergeFunctionsPass(Registry); initializePartialInlinerPass(Registry); initializePostOrderFunctionAttrsLegacyPassPass(Registry); diff --git a/llvm/lib/Transforms/IPO/LowerBitSets.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index 78981fdedad..06e66468c83 100644 --- a/llvm/lib/Transforms/IPO/LowerBitSets.cpp +++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -1,4 +1,4 @@ -//===-- LowerBitSets.cpp - Bitset lowering pass ---------------------------===// +//===-- LowerTypeTests.cpp - type metadata lowering pass ------------------===// // // The LLVM Compiler Infrastructure // @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// // -// This pass lowers bitset metadata and calls to the llvm.bitset.test intrinsic. -// See http://llvm.org/docs/LangRef.html#bitsets for more information. +// This pass lowers type metadata and calls to the llvm.type.test intrinsic. +// See http://llvm.org/docs/TypeMetadata.html for more information. // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO/LowerBitSets.h" +#include "llvm/Transforms/IPO/LowerTypeTests.h" #include "llvm/Transforms/IPO.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Statistic.h" @@ -33,18 +33,18 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; -using namespace lowerbitsets; +using namespace lowertypetests; -#define DEBUG_TYPE "lowerbitsets" +#define DEBUG_TYPE "lowertypetests" STATISTIC(ByteArraySizeBits, "Byte array size in bits"); STATISTIC(ByteArraySizeBytes, "Byte array size in bytes"); STATISTIC(NumByteArraysCreated, "Number of byte arrays created"); -STATISTIC(NumBitSetCallsLowered, "Number of bitset calls lowered"); -STATISTIC(NumBitSetDisjointSets, "Number of disjoint sets of bitsets"); +STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered"); +STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type identifiers"); static cl::opt<bool> AvoidReuse( - "lowerbitsets-avoid-reuse", + "lowertypetests-avoid-reuse", cl::desc("Try to avoid reuse of byte array addresses using aliases"), cl::Hidden, cl::init(true)); @@ -204,10 +204,10 @@ struct ByteArrayInfo { Constant *Mask; }; -struct LowerBitSets : public ModulePass { +struct LowerTypeTests : public ModulePass { static char ID; - LowerBitSets() : ModulePass(ID) { - initializeLowerBitSetsPass(*PassRegistry::getPassRegistry()); + LowerTypeTests() : ModulePass(ID) { + initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry()); } Module *M; @@ -222,41 +222,37 @@ struct LowerBitSets : public ModulePass { IntegerType *Int64Ty; IntegerType *IntPtrTy; - // The llvm.bitsets named metadata. - NamedMDNode *BitSetNM; - - // Mapping from bitset identifiers to the call sites that test them. - DenseMap<Metadata *, std::vector<CallInst *>> BitSetTestCallSites; + // Mapping from type identifiers to the call sites that test them. + DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites; std::vector<ByteArrayInfo> ByteArrayInfos; BitSetInfo - buildBitSet(Metadata *BitSet, + buildBitSet(Metadata *TypeId, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); ByteArrayInfo *createByteArray(BitSetInfo &BSI); void allocateByteArrays(); Value *createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, ByteArrayInfo *&BAI, Value *BitOffset); - void lowerBitSetCalls(ArrayRef<Metadata *> BitSets, - Constant *CombinedGlobalAddr, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); + void + lowerTypeTestCalls(ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, + const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); Value * lowerBitSetCall(CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI, Constant *CombinedGlobal, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); - void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> BitSets, + void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalVariable *> Globals); unsigned getJumpTableEntrySize(); Type *getJumpTableEntryType(); Constant *createJumpTableEntry(GlobalObject *Src, Function *Dest, unsigned Distance); - void verifyBitSetMDNode(MDNode *Op); - void buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, + void verifyTypeMDNode(GlobalObject *GO, MDNode *Type); + void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, ArrayRef<Function *> Functions); - void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> BitSets, + void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalObject *> Globals); - bool buildBitSets(); - bool eraseBitSetMetadata(); + bool lower(); bool doInitialization(Module &M) override; bool runOnModule(Module &M) override; @@ -264,15 +260,13 @@ struct LowerBitSets : public ModulePass { } // anonymous namespace -INITIALIZE_PASS_BEGIN(LowerBitSets, "lowerbitsets", - "Lower bitset metadata", false, false) -INITIALIZE_PASS_END(LowerBitSets, "lowerbitsets", - "Lower bitset metadata", false, false) -char LowerBitSets::ID = 0; +INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false, + false) +char LowerTypeTests::ID = 0; -ModulePass *llvm::createLowerBitSetsPass() { return new LowerBitSets; } +ModulePass *llvm::createLowerTypeTestsPass() { return new LowerTypeTests; } -bool LowerBitSets::doInitialization(Module &Mod) { +bool LowerTypeTests::doInitialization(Module &Mod) { M = &Mod; const DataLayout &DL = Mod.getDataLayout(); @@ -288,39 +282,31 @@ bool LowerBitSets::doInitialization(Module &Mod) { Int64Ty = Type::getInt64Ty(M->getContext()); IntPtrTy = DL.getIntPtrType(M->getContext(), 0); - BitSetNM = M->getNamedMetadata("llvm.bitsets"); - - BitSetTestCallSites.clear(); + TypeTestCallSites.clear(); return false; } -/// Build a bit set for BitSet using the object layouts in +/// Build a bit set for TypeId using the object layouts in /// GlobalLayout. -BitSetInfo LowerBitSets::buildBitSet( - Metadata *BitSet, +BitSetInfo LowerTypeTests::buildBitSet( + Metadata *TypeId, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { BitSetBuilder BSB; - // Compute the byte offset of each element of this bitset. - if (BitSetNM) { - for (MDNode *Op : BitSetNM->operands()) { - if (Op->getOperand(0) != BitSet || !Op->getOperand(1)) - continue; - Constant *OpConst = - cast<ConstantAsMetadata>(Op->getOperand(1))->getValue(); - if (auto GA = dyn_cast<GlobalAlias>(OpConst)) - OpConst = GA->getAliasee(); - auto OpGlobal = dyn_cast<GlobalObject>(OpConst); - if (!OpGlobal) + // Compute the byte offset of each address associated with this type + // identifier. + SmallVector<MDNode *, 2> Types; + for (auto &GlobalAndOffset : GlobalLayout) { + Types.clear(); + GlobalAndOffset.first->getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + if (Type->getOperand(1) != TypeId) continue; uint64_t Offset = - cast<ConstantInt>(cast<ConstantAsMetadata>(Op->getOperand(2)) + cast<ConstantInt>(cast<ConstantAsMetadata>(Type->getOperand(0)) ->getValue())->getZExtValue(); - - Offset += GlobalLayout.find(OpGlobal)->second; - - BSB.addOffset(Offset); + BSB.addOffset(GlobalAndOffset.second + Offset); } } @@ -342,7 +328,7 @@ static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits, return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0)); } -ByteArrayInfo *LowerBitSets::createByteArray(BitSetInfo &BSI) { +ByteArrayInfo *LowerTypeTests::createByteArray(BitSetInfo &BSI) { // Create globals to stand in for byte arrays and masks. These never actually // get initialized, we RAUW and erase them later in allocateByteArrays() once // we know the offset and mask to use. @@ -361,7 +347,7 @@ ByteArrayInfo *LowerBitSets::createByteArray(BitSetInfo &BSI) { return BAI; } -void LowerBitSets::allocateByteArrays() { +void LowerTypeTests::allocateByteArrays() { std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(), [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) { return BAI1.BitSize > BAI2.BitSize; @@ -414,8 +400,8 @@ void LowerBitSets::allocateByteArrays() { /// Build a test that bit BitOffset is set in BSI, where /// BitSetGlobal is a global containing the bits in BSI. -Value *LowerBitSets::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, - ByteArrayInfo *&BAI, Value *BitOffset) { +Value *LowerTypeTests::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, + ByteArrayInfo *&BAI, Value *BitOffset) { if (BSI.BitSize <= 64) { // If the bit set is sufficiently small, we can avoid a load by bit testing // a constant. @@ -455,9 +441,9 @@ Value *LowerBitSets::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, } } -/// Lower a llvm.bitset.test call to its implementation. Returns the value to +/// Lower a llvm.type.test call to its implementation. Returns the value to /// replace the call with. -Value *LowerBitSets::lowerBitSetCall( +Value *LowerTypeTests::lowerBitSetCall( CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI, Constant *CombinedGlobalIntAddr, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { @@ -525,10 +511,10 @@ Value *LowerBitSets::lowerBitSetCall( return P; } -/// Given a disjoint set of bitsets and globals, layout the globals, build the -/// bit sets and lower the llvm.bitset.test calls. -void LowerBitSets::buildBitSetsFromGlobalVariables( - ArrayRef<Metadata *> BitSets, ArrayRef<GlobalVariable *> Globals) { +/// Given a disjoint set of type identifiers and globals, lay out the globals, +/// build the bit sets and lower the llvm.type.test calls. +void LowerTypeTests::buildBitSetsFromGlobalVariables( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalVariable *> Globals) { // Build a new global with the combined contents of the referenced globals. // This global is a struct whose even-indexed elements contain the original // contents of the referenced globals and whose odd-indexed elements contain @@ -566,7 +552,7 @@ void LowerBitSets::buildBitSetsFromGlobalVariables( // Multiply by 2 to account for padding elements. GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2); - lowerBitSetCalls(BitSets, CombinedGlobal, GlobalLayout); + lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout); // Build aliases pointing to offsets into the combined global for each // global from which we built the combined global, and replace references @@ -592,19 +578,19 @@ void LowerBitSets::buildBitSetsFromGlobalVariables( } } -void LowerBitSets::lowerBitSetCalls( - ArrayRef<Metadata *> BitSets, Constant *CombinedGlobalAddr, +void LowerTypeTests::lowerTypeTestCalls( + ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { Constant *CombinedGlobalIntAddr = ConstantExpr::getPtrToInt(CombinedGlobalAddr, IntPtrTy); - // For each bitset in this disjoint set... - for (Metadata *BS : BitSets) { + // For each type identifier in this disjoint set... + for (Metadata *TypeId : TypeIds) { // Build the bitset. - BitSetInfo BSI = buildBitSet(BS, GlobalLayout); + BitSetInfo BSI = buildBitSet(TypeId, GlobalLayout); DEBUG({ - if (auto BSS = dyn_cast<MDString>(BS)) - dbgs() << BSS->getString() << ": "; + if (auto MDS = dyn_cast<MDString>(TypeId)) + dbgs() << MDS->getString() << ": "; else dbgs() << "<unnamed>: "; BSI.print(dbgs()); @@ -612,9 +598,9 @@ void LowerBitSets::lowerBitSetCalls( ByteArrayInfo *BAI = nullptr; - // Lower each call to llvm.bitset.test for this bitset. - for (CallInst *CI : BitSetTestCallSites[BS]) { - ++NumBitSetCallsLowered; + // Lower each call to llvm.type.test for this type identifier. + for (CallInst *CI : TypeTestCallSites[TypeId]) { + ++NumTypeTestCallsLowered; Value *Lowered = lowerBitSetCall(CI, BSI, BAI, CombinedGlobalIntAddr, GlobalLayout); CI->replaceAllUsesWith(Lowered); @@ -623,40 +609,32 @@ void LowerBitSets::lowerBitSetCalls( } } -void LowerBitSets::verifyBitSetMDNode(MDNode *Op) { - if (Op->getNumOperands() != 3) +void LowerTypeTests::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { + if (Type->getNumOperands() != 2) report_fatal_error( - "All operands of llvm.bitsets metadata must have 3 elements"); - if (!Op->getOperand(1)) - return; + "All operands of type metadata must have 2 elements"); - auto OpConstMD = dyn_cast<ConstantAsMetadata>(Op->getOperand(1)); - if (!OpConstMD) - report_fatal_error("Bit set element must be a constant"); - auto OpGlobal = dyn_cast<GlobalObject>(OpConstMD->getValue()); - if (!OpGlobal) - return; - - if (OpGlobal->isThreadLocal()) + if (GO->isThreadLocal()) report_fatal_error("Bit set element may not be thread-local"); - if (isa<GlobalVariable>(OpGlobal) && OpGlobal->hasSection()) + if (isa<GlobalVariable>(GO) && GO->hasSection()) report_fatal_error( - "Bit set global var element may not have an explicit section"); + "A member of a type identifier may not have an explicit section"); - if (isa<GlobalVariable>(OpGlobal) && OpGlobal->isDeclarationForLinker()) - report_fatal_error("Bit set global var element must be a definition"); + if (isa<GlobalVariable>(GO) && GO->isDeclarationForLinker()) + report_fatal_error( + "A global var member of a type identifier must be a definition"); - auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Op->getOperand(2)); + auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0)); if (!OffsetConstMD) - report_fatal_error("Bit set element offset must be a constant"); + report_fatal_error("Type offset must be a constant"); auto OffsetInt = dyn_cast<ConstantInt>(OffsetConstMD->getValue()); if (!OffsetInt) - report_fatal_error("Bit set element offset must be an integer constant"); + report_fatal_error("Type offset must be an integer constant"); } static const unsigned kX86JumpTableEntrySize = 8; -unsigned LowerBitSets::getJumpTableEntrySize() { +unsigned LowerTypeTests::getJumpTableEntrySize() { if (Arch != Triple::x86 && Arch != Triple::x86_64) report_fatal_error("Unsupported architecture for jump tables"); @@ -667,8 +645,9 @@ unsigned LowerBitSets::getJumpTableEntrySize() { // consists of an instruction sequence containing a relative branch to Dest. The // constant will be laid out at address Src+(Len*Distance) where Len is the // target-specific jump table entry size. -Constant *LowerBitSets::createJumpTableEntry(GlobalObject *Src, Function *Dest, - unsigned Distance) { +Constant *LowerTypeTests::createJumpTableEntry(GlobalObject *Src, + Function *Dest, + unsigned Distance) { if (Arch != Triple::x86 && Arch != Triple::x86_64) report_fatal_error("Unsupported architecture for jump tables"); @@ -695,7 +674,7 @@ Constant *LowerBitSets::createJumpTableEntry(GlobalObject *Src, Function *Dest, return ConstantStruct::getAnon(Fields, /*Packed=*/true); } -Type *LowerBitSets::getJumpTableEntryType() { +Type *LowerTypeTests::getJumpTableEntryType() { if (Arch != Triple::x86 && Arch != Triple::x86_64) report_fatal_error("Unsupported architecture for jump tables"); @@ -704,10 +683,10 @@ Type *LowerBitSets::getJumpTableEntryType() { /*Packed=*/true); } -/// Given a disjoint set of bitsets and functions, build a jump table for the -/// functions, build the bit sets and lower the llvm.bitset.test calls. -void LowerBitSets::buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, - ArrayRef<Function *> Functions) { +/// Given a disjoint set of type identifiers and functions, build a jump table +/// for the functions, build the bit sets and lower the llvm.type.test calls. +void LowerTypeTests::buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, + ArrayRef<Function *> Functions) { // Unlike the global bitset builder, the function bitset builder cannot // re-arrange functions in a particular order and base its calculations on the // layout of the functions' entry points, as we have no idea how large a @@ -721,8 +700,7 @@ void LowerBitSets::buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, // verification done inside the module. // // In more concrete terms, suppose we have three functions f, g, h which are - // members of a single bitset, and a function foo that returns their - // addresses: + // of the same type, and a function foo that returns their addresses: // // f: // mov 0, %eax @@ -805,7 +783,7 @@ void LowerBitSets::buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, JumpTable->setSection(ObjectFormat == Triple::MachO ? "__TEXT,__text,regular,pure_instructions" : ".text"); - lowerBitSetCalls(BitSets, JumpTable, GlobalLayout); + lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout); // Build aliases pointing to offsets into the jump table, and replace // references to the original functions with references to the aliases. @@ -840,39 +818,32 @@ void LowerBitSets::buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, ConstantArray::get(JumpTableType, JumpTableEntries)); } -void LowerBitSets::buildBitSetsFromDisjointSet( - ArrayRef<Metadata *> BitSets, ArrayRef<GlobalObject *> Globals) { - llvm::DenseMap<Metadata *, uint64_t> BitSetIndices; - llvm::DenseMap<GlobalObject *, uint64_t> GlobalIndices; - for (unsigned I = 0; I != BitSets.size(); ++I) - BitSetIndices[BitSets[I]] = I; - for (unsigned I = 0; I != Globals.size(); ++I) - GlobalIndices[Globals[I]] = I; - - // For each bitset, build a set of indices that refer to globals referenced by - // the bitset. - std::vector<std::set<uint64_t>> BitSetMembers(BitSets.size()); - if (BitSetNM) { - for (MDNode *Op : BitSetNM->operands()) { - // Op = { bitset name, global, offset } - if (!Op->getOperand(1)) - continue; - auto I = BitSetIndices.find(Op->getOperand(0)); - if (I == BitSetIndices.end()) - continue; - - auto OpGlobal = dyn_cast<GlobalObject>( - cast<ConstantAsMetadata>(Op->getOperand(1))->getValue()); - if (!OpGlobal) - continue; - BitSetMembers[I->second].insert(GlobalIndices[OpGlobal]); +void LowerTypeTests::buildBitSetsFromDisjointSet( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalObject *> Globals) { + llvm::DenseMap<Metadata *, uint64_t> TypeIdIndices; + for (unsigned I = 0; I != TypeIds.size(); ++I) + TypeIdIndices[TypeIds[I]] = I; + + // For each type identifier, build a set of indices that refer to members of + // the type identifier. + std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size()); + SmallVector<MDNode *, 2> Types; + unsigned GlobalIndex = 0; + for (GlobalObject *GO : Globals) { + Types.clear(); + GO->getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + // Type = { offset, type identifier } + unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)]; + TypeMembers[TypeIdIndex].insert(GlobalIndex); } + GlobalIndex++; } // Order the sets of indices by size. The GlobalLayoutBuilder works best // when given small index sets first. std::stable_sort( - BitSetMembers.begin(), BitSetMembers.end(), + TypeMembers.begin(), TypeMembers.end(), [](const std::set<uint64_t> &O1, const std::set<uint64_t> &O2) { return O1.size() < O2.size(); }); @@ -881,7 +852,7 @@ void LowerBitSets::buildBitSetsFromDisjointSet( // fragments. The GlobalLayoutBuilder tries to lay out members of fragments as // close together as possible. GlobalLayoutBuilder GLB(Globals.size()); - for (auto &&MemSet : BitSetMembers) + for (auto &&MemSet : TypeMembers) GLB.addFragment(MemSet); // Build the bitsets from this disjoint set. @@ -893,13 +864,13 @@ void LowerBitSets::buildBitSetsFromDisjointSet( for (auto &&Offset : F) { auto GV = dyn_cast<GlobalVariable>(Globals[Offset]); if (!GV) - report_fatal_error( - "Bit set may not contain both global variables and functions"); + report_fatal_error("Type identifier may not contain both global " + "variables and functions"); *OGI++ = GV; } } - buildBitSetsFromGlobalVariables(BitSets, OrderedGVs); + buildBitSetsFromGlobalVariables(TypeIds, OrderedGVs); } else { // Build a vector of functions with the computed layout. std::vector<Function *> OrderedFns(Globals.size()); @@ -908,102 +879,97 @@ void LowerBitSets::buildBitSetsFromDisjointSet( for (auto &&Offset : F) { auto Fn = dyn_cast<Function>(Globals[Offset]); if (!Fn) - report_fatal_error( - "Bit set may not contain both global variables and functions"); + report_fatal_error("Type identifier may not contain both global " + "variables and functions"); *OFI++ = Fn; } } - buildBitSetsFromFunctions(BitSets, OrderedFns); + buildBitSetsFromFunctions(TypeIds, OrderedFns); } } -/// Lower all bit sets in this module. -bool LowerBitSets::buildBitSets() { - Function *BitSetTestFunc = - M->getFunction(Intrinsic::getName(Intrinsic::bitset_test)); - if (!BitSetTestFunc || BitSetTestFunc->use_empty()) +/// Lower all type tests in this module. +bool LowerTypeTests::lower() { + Function *TypeTestFunc = + M->getFunction(Intrinsic::getName(Intrinsic::type_test)); + if (!TypeTestFunc || TypeTestFunc->use_empty()) return false; - // Equivalence class set containing bitsets and the globals they reference. - // This is used to partition the set of bitsets in the module into disjoint - // sets. + // Equivalence class set containing type identifiers and the globals that + // reference them. This is used to partition the set of type identifiers in + // the module into disjoint sets. typedef EquivalenceClasses<PointerUnion<GlobalObject *, Metadata *>> GlobalClassesTy; GlobalClassesTy GlobalClasses; - // Verify the bitset metadata and build a mapping from bitset identifiers to - // their last observed index in BitSetNM. This will used later to - // deterministically order the list of bitset identifiers. - llvm::DenseMap<Metadata *, unsigned> BitSetIdIndices; - if (BitSetNM) { - for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I) { - MDNode *Op = BitSetNM->getOperand(I); - verifyBitSetMDNode(Op); - BitSetIdIndices[Op->getOperand(0)] = I; + // Verify the type metadata and build a mapping from type identifiers to their + // last observed index in the list of globals. This will be used later to + // deterministically order the list of type identifiers. + llvm::DenseMap<Metadata *, unsigned> TypeIdIndices; + unsigned I = 0; + SmallVector<MDNode *, 2> Types; + for (GlobalObject &GO : M->global_objects()) { + Types.clear(); + GO.getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + verifyTypeMDNode(&GO, Type); + TypeIdIndices[cast<MDNode>(Type)->getOperand(1)] = ++I; } } - for (const Use &U : BitSetTestFunc->uses()) { + for (const Use &U : TypeTestFunc->uses()) { auto CI = cast<CallInst>(U.getUser()); auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); if (!BitSetMDVal) report_fatal_error( - "Second argument of llvm.bitset.test must be metadata"); + "Second argument of llvm.type.test must be metadata"); auto BitSet = BitSetMDVal->getMetadata(); - // Add the call site to the list of call sites for this bit set. We also use - // BitSetTestCallSites to keep track of whether we have seen this bit set - // before. If we have, we don't need to re-add the referenced globals to the - // equivalence class. - std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, - bool> Ins = - BitSetTestCallSites.insert( + // Add the call site to the list of call sites for this type identifier. We + // also use TypeTestCallSites to keep track of whether we have seen this + // type identifier before. If we have, we don't need to re-add the + // referenced globals to the equivalence class. + std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, bool> + Ins = TypeTestCallSites.insert( std::make_pair(BitSet, std::vector<CallInst *>())); Ins.first->second.push_back(CI); if (!Ins.second) continue; - // Add the bitset to the equivalence class. + // Add the type identifier to the equivalence class. GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet); GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI); - if (!BitSetNM) - continue; - - // Add the referenced globals to the bitset's equivalence class. - for (MDNode *Op : BitSetNM->operands()) { - if (Op->getOperand(0) != BitSet || !Op->getOperand(1)) - continue; - - auto OpGlobal = dyn_cast<GlobalObject>( - cast<ConstantAsMetadata>(Op->getOperand(1))->getValue()); - if (!OpGlobal) - continue; - - CurSet = GlobalClasses.unionSets( - CurSet, GlobalClasses.findLeader(GlobalClasses.insert(OpGlobal))); + // Add the referenced globals to the type identifier's equivalence class. + for (GlobalObject &GO : M->global_objects()) { + Types.clear(); + GO.getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) + if (Type->getOperand(1) == BitSet) + CurSet = GlobalClasses.unionSets( + CurSet, GlobalClasses.findLeader(GlobalClasses.insert(&GO))); } } if (GlobalClasses.empty()) return false; - // Build a list of disjoint sets ordered by their maximum BitSetNM index - // for determinism. + // Build a list of disjoint sets ordered by their maximum global index for + // determinism. std::vector<std::pair<GlobalClassesTy::iterator, unsigned>> Sets; for (GlobalClassesTy::iterator I = GlobalClasses.begin(), E = GlobalClasses.end(); I != E; ++I) { if (!I->isLeader()) continue; - ++NumBitSetDisjointSets; + ++NumTypeIdDisjointSets; unsigned MaxIndex = 0; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I); MI != GlobalClasses.member_end(); ++MI) { if ((*MI).is<Metadata *>()) - MaxIndex = std::max(MaxIndex, BitSetIdIndices[MI->get<Metadata *>()]); + MaxIndex = std::max(MaxIndex, TypeIdIndices[MI->get<Metadata *>()]); } Sets.emplace_back(I, MaxIndex); } @@ -1015,26 +981,26 @@ bool LowerBitSets::buildBitSets() { // For each disjoint set we found... for (const auto &S : Sets) { - // Build the list of bitsets in this disjoint set. - std::vector<Metadata *> BitSets; + // Build the list of type identifiers in this disjoint set. + std::vector<Metadata *> TypeIds; std::vector<GlobalObject *> Globals; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(S.first); MI != GlobalClasses.member_end(); ++MI) { if ((*MI).is<Metadata *>()) - BitSets.push_back(MI->get<Metadata *>()); + TypeIds.push_back(MI->get<Metadata *>()); else Globals.push_back(MI->get<GlobalObject *>()); } - // Order bitsets by BitSetNM index for determinism. This ordering is stable - // as there is a one-to-one mapping between metadata and indices. - std::sort(BitSets.begin(), BitSets.end(), [&](Metadata *M1, Metadata *M2) { - return BitSetIdIndices[M1] < BitSetIdIndices[M2]; + // Order type identifiers by global index for determinism. This ordering is + // stable as there is a one-to-one mapping between metadata and indices. + std::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) { + return TypeIdIndices[M1] < TypeIdIndices[M2]; }); - // Lower the bitsets in this disjoint set. - buildBitSetsFromDisjointSet(BitSets, Globals); + // Build bitsets for this disjoint set. + buildBitSetsFromDisjointSet(TypeIds, Globals); } allocateByteArrays(); @@ -1042,19 +1008,9 @@ bool LowerBitSets::buildBitSets() { return true; } -bool LowerBitSets::eraseBitSetMetadata() { - if (!BitSetNM) - return false; - - M->eraseNamedMetadata(BitSetNM); - return true; -} - -bool LowerBitSets::runOnModule(Module &M) { +bool LowerTypeTests::runOnModule(Module &M) { if (skipModule(M)) return false; - bool Changed = buildBitSets(); - Changed |= eraseBitSetMetadata(); - return Changed; + return lower(); } diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp index 094dda8cc1f..2a209b34ccf 100644 --- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -748,10 +748,10 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { // in the current module. PM.add(createCrossDSOCFIPass()); - // Lower bit sets to globals. This pass supports Clang's control flow - // integrity mechanisms (-fsanitize=cfi*) and needs to run at link time if CFI - // is enabled. The pass does nothing if CFI is disabled. - PM.add(createLowerBitSetsPass()); + // Lower type metadata and the type.test intrinsic. This pass supports Clang's + // control flow integrity mechanisms (-fsanitize=cfi*) and needs to run at + // link time if CFI is enabled. The pass does nothing if CFI is disabled. + PM.add(createLowerTypeTestsPass()); if (OptLevel != 0) addLateLTOOptimizationPasses(PM); diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 93ee07cf64e..c19c667a51e 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// // // This pass implements whole program optimization of virtual calls in cases -// where we know (via bitset information) that the list of callee is fixed. This +// where we know (via !type metadata) that the list of callees is fixed. This // includes the following: // - Single implementation devirtualization: if a virtual call has a single // possible callee, replace all calls with a direct call to that callee. @@ -31,7 +31,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" -#include "llvm/Analysis/BitSetUtils.h" +#include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -89,8 +89,8 @@ wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, // at MinByte. std::vector<ArrayRef<uint8_t>> Used; for (const VirtualCallTarget &Target : Targets) { - ArrayRef<uint8_t> VTUsed = IsAfter ? Target.BS->Bits->After.BytesUsed - : Target.BS->Bits->Before.BytesUsed; + ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed + : Target.TM->Bits->Before.BytesUsed; uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() : MinByte - Target.minBeforeBytes(); @@ -163,17 +163,17 @@ void wholeprogramdevirt::setAfterReturnValues( } } -VirtualCallTarget::VirtualCallTarget(Function *Fn, const BitSetInfo *BS) - : Fn(Fn), BS(BS), +VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) + : Fn(Fn), TM(TM), IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()) {} namespace { -// A slot in a set of virtual tables. The BitSetID identifies the set of virtual +// A slot in a set of virtual tables. The TypeID identifies the set of virtual // tables, and the ByteOffset is the offset in bytes from the address point to // the virtual function pointer. struct VTableSlot { - Metadata *BitSetID; + Metadata *TypeID; uint64_t ByteOffset; }; @@ -191,12 +191,12 @@ template <> struct DenseMapInfo<VTableSlot> { DenseMapInfo<uint64_t>::getTombstoneKey()}; } static unsigned getHashValue(const VTableSlot &I) { - return DenseMapInfo<Metadata *>::getHashValue(I.BitSetID) ^ + return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); } static bool isEqual(const VTableSlot &LHS, const VTableSlot &RHS) { - return LHS.BitSetID == RHS.BitSetID && LHS.ByteOffset == RHS.ByteOffset; + return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; } }; @@ -233,11 +233,13 @@ struct DevirtModule { Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())) {} - void buildBitSets(std::vector<VTableBits> &Bits, - DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets); - bool tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, - const std::set<BitSetInfo> &BitSetInfos, - uint64_t ByteOffset); + void buildTypeIdentifierMap( + std::vector<VTableBits> &Bits, + DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); + bool + tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, + const std::set<TypeMemberInfo> &TypeMemberInfos, + uint64_t ByteOffset); bool trySingleImplDevirt(ArrayRef<VirtualCallTarget> TargetsForSlot, MutableArrayRef<VirtualCallSite> CallSites); bool tryEvaluateFunctionsWithArgs( @@ -287,60 +289,55 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return PreservedAnalyses::none(); } -void DevirtModule::buildBitSets( +void DevirtModule::buildTypeIdentifierMap( std::vector<VTableBits> &Bits, - DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets) { - NamedMDNode *BitSetNM = M.getNamedMetadata("llvm.bitsets"); - if (!BitSetNM) - return; - + DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { DenseMap<GlobalVariable *, VTableBits *> GVToBits; - Bits.reserve(BitSetNM->getNumOperands()); - for (auto Op : BitSetNM->operands()) { - auto OpConstMD = dyn_cast_or_null<ConstantAsMetadata>(Op->getOperand(1)); - if (!OpConstMD) + Bits.reserve(M.getGlobalList().size()); + SmallVector<MDNode *, 2> Types; + for (GlobalVariable &GV : M.globals()) { + Types.clear(); + GV.getMetadata(LLVMContext::MD_type, Types); + if (Types.empty()) continue; - auto BitSetID = Op->getOperand(0).get(); - - Constant *OpConst = OpConstMD->getValue(); - if (auto GA = dyn_cast<GlobalAlias>(OpConst)) - OpConst = GA->getAliasee(); - auto OpGlobal = dyn_cast<GlobalVariable>(OpConst); - if (!OpGlobal) - continue; - - uint64_t Offset = - cast<ConstantInt>( - cast<ConstantAsMetadata>(Op->getOperand(2))->getValue()) - ->getZExtValue(); - VTableBits *&BitsPtr = GVToBits[OpGlobal]; + VTableBits *&BitsPtr = GVToBits[&GV]; if (!BitsPtr) { Bits.emplace_back(); - Bits.back().GV = OpGlobal; - Bits.back().ObjectSize = M.getDataLayout().getTypeAllocSize( - OpGlobal->getInitializer()->getType()); + Bits.back().GV = &GV; + Bits.back().ObjectSize = + M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); BitsPtr = &Bits.back(); } - BitSets[BitSetID].insert({BitsPtr, Offset}); + + for (MDNode *Type : Types) { + auto TypeID = Type->getOperand(1).get(); + + uint64_t Offset = + cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); + + TypeIdMap[TypeID].insert({BitsPtr, Offset}); + } } } bool DevirtModule::tryFindVirtualCallTargets( std::vector<VirtualCallTarget> &TargetsForSlot, - const std::set<BitSetInfo> &BitSetInfos, uint64_t ByteOffset) { - for (const BitSetInfo &BS : BitSetInfos) { - if (!BS.Bits->GV->isConstant()) + const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) { + for (const TypeMemberInfo &TM : TypeMemberInfos) { + if (!TM.Bits->GV->isConstant()) return false; - auto Init = dyn_cast<ConstantArray>(BS.Bits->GV->getInitializer()); + auto Init = dyn_cast<ConstantArray>(TM.Bits->GV->getInitializer()); if (!Init) return false; ArrayType *VTableTy = Init->getType(); uint64_t ElemSize = M.getDataLayout().getTypeAllocSize(VTableTy->getElementType()); - uint64_t GlobalSlotOffset = BS.Offset + ByteOffset; + uint64_t GlobalSlotOffset = TM.Offset + ByteOffset; if (GlobalSlotOffset % ElemSize != 0) return false; @@ -357,7 +354,7 @@ bool DevirtModule::tryFindVirtualCallTargets( if (Fn->getName() == "__cxa_pure_virtual") continue; - TargetsForSlot.push_back({Fn, &BS}); + TargetsForSlot.push_back({Fn, &TM}); } // Give up if we couldn't find any targets. @@ -430,24 +427,24 @@ bool DevirtModule::tryUniqueRetValOpt( MutableArrayRef<VirtualCallSite> CallSites) { // IsOne controls whether we look for a 0 or a 1. auto tryUniqueRetValOptFor = [&](bool IsOne) { - const BitSetInfo *UniqueBitSet = 0; + const TypeMemberInfo *UniqueMember = 0; for (const VirtualCallTarget &Target : TargetsForSlot) { if (Target.RetVal == (IsOne ? 1 : 0)) { - if (UniqueBitSet) + if (UniqueMember) return false; - UniqueBitSet = Target.BS; + UniqueMember = Target.TM; } } - // We should have found a unique bit set or bailed out by now. We already + // We should have found a unique member or bailed out by now. We already // checked for a uniform return value in tryUniformRetValOpt. - assert(UniqueBitSet); + assert(UniqueMember); // Replace each call with the comparison. for (auto &&Call : CallSites) { IRBuilder<> B(Call.CS.getInstruction()); - Value *OneAddr = B.CreateBitCast(UniqueBitSet->Bits->GV, Int8PtrTy); - OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueBitSet->Offset); + Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy); + OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset); Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, OneAddr); Call.replaceAndErase(Cmp); @@ -526,7 +523,8 @@ bool DevirtModule::tryVirtualConstProp( if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) continue; - // Find an allocation offset in bits in all vtables in the bitset. + // Find an allocation offset in bits in all vtables associated with the + // type. uint64_t AllocBefore = findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); uint64_t AllocAfter = @@ -620,9 +618,9 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { } bool DevirtModule::run() { - Function *BitSetTestFunc = - M.getFunction(Intrinsic::getName(Intrinsic::bitset_test)); - if (!BitSetTestFunc || BitSetTestFunc->use_empty()) + Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); + if (!TypeTestFunc || TypeTestFunc->use_empty()) return false; Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); @@ -630,11 +628,12 @@ bool DevirtModule::run() { return false; // Find all virtual calls via a virtual table pointer %p under an assumption - // of the form llvm.assume(llvm.bitset.test(%p, %md)). This indicates that %p - // points to a vtable in the bitset %md. Group calls by (bitset, offset) pair - // (effectively the identity of the virtual function) and store to CallSlots. + // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p + // points to a member of the type identifier %md. Group calls by (type ID, + // offset) pair (effectively the identity of the virtual function) and store + // to CallSlots. DenseSet<Value *> SeenPtrs; - for (auto I = BitSetTestFunc->use_begin(), E = BitSetTestFunc->use_end(); + for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); I != E;) { auto CI = dyn_cast<CallInst>(I->getUser()); ++I; @@ -650,18 +649,18 @@ bool DevirtModule::run() { // the vtable pointer before, as it may have been CSE'd with pointers from // other call sites, and we don't want to process call sites multiple times. if (!Assumes.empty()) { - Metadata *BitSet = + Metadata *TypeId = cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{BitSet, Call.Offset}].push_back( + CallSlots[{TypeId, Call.Offset}].push_back( {CI->getArgOperand(0), Call.CS}); } } } - // We no longer need the assumes or the bitset test. + // We no longer need the assumes or the type test. for (auto Assume : Assumes) Assume->eraseFromParent(); // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we @@ -670,20 +669,21 @@ bool DevirtModule::run() { CI->eraseFromParent(); } - // Rebuild llvm.bitsets metadata into a map for easy lookup. + // Rebuild type metadata into a map for easy lookup. std::vector<VTableBits> Bits; - DenseMap<Metadata *, std::set<BitSetInfo>> BitSets; - buildBitSets(Bits, BitSets); - if (BitSets.empty()) + DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; + buildTypeIdentifierMap(Bits, TypeIdMap); + if (TypeIdMap.empty()) return true; - // For each (bitset, offset) pair: + // For each (type, offset) pair: bool DidVirtualConstProp = false; for (auto &S : CallSlots) { - // Search each of the vtables in the bitset for the virtual function - // implementation at offset S.first.ByteOffset, and add to TargetsForSlot. + // Search each of the members of the type identifier for the virtual + // function implementation at offset S.first.ByteOffset, and add to + // TargetsForSlot. std::vector<VirtualCallTarget> TargetsForSlot; - if (!tryFindVirtualCallTargets(TargetsForSlot, BitSets[S.first.BitSetID], + if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], S.first.ByteOffset)) continue; |