diff options
author | Peter Collingbourne <peter@pcc.me.uk> | 2016-06-25 00:23:04 +0000 |
---|---|---|
committer | Peter Collingbourne <peter@pcc.me.uk> | 2016-06-25 00:23:04 +0000 |
commit | 0312f614b1bfdad55c1832ee37d6d4b738ea70cf (patch) | |
tree | fe32499f559708bbe83f63a35aa420c91b882594 /llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | |
parent | 6ad3d05b681b36f6ecc98523257d154053e4116d (diff) | |
download | bcm5719-llvm-0312f614b1bfdad55c1832ee37d6d4b738ea70cf.tar.gz bcm5719-llvm-0312f614b1bfdad55c1832ee37d6d4b738ea70cf.zip |
IR: Introduce llvm.type.checked.load intrinsic.
This intrinsic safely loads a function pointer from a virtual table pointer
using type metadata. This intrinsic is used to implement control flow integrity
in conjunction with virtual call optimization. The virtual call optimization
pass will optimize away llvm.type.checked.load intrinsics associated with
devirtualized calls, thereby removing the type check in cases where it is
not needed to enforce the control flow integrity constraint.
This patch also introduces the capability to copy type metadata between
global variables, and teaches the virtual call optimization pass to do so.
Differential Revision: http://reviews.llvm.org/D21121
llvm-svn: 273756
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r-- | llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 154 |
1 files changed, 142 insertions, 12 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index c19c667a51e..e6f4fa29224 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -210,6 +210,11 @@ struct VirtualCallSite { Value *VTable; CallSite CS; + // If non-null, this field points to the associated unsafe use count stored in + // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description + // of that field for details. + unsigned *NumUnsafeUses; + void replaceAndErase(Value *New) { CS->replaceAllUsesWith(New); if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { @@ -217,6 +222,9 @@ struct VirtualCallSite { II->getUnwindDest()->removePredecessor(II->getParent()); } CS->eraseFromParent(); + // This use is no longer unsafe. + if (NumUnsafeUses) + --*NumUnsafeUses; } }; @@ -228,11 +236,24 @@ struct DevirtModule { MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; + // This map keeps track of the number of "unsafe" uses of a loaded function + // pointer. The key is the associated llvm.type.test intrinsic call generated + // by this pass. An unsafe use is one that calls the loaded function pointer + // directly. Every time we eliminate an unsafe use (for example, by + // devirtualizing it or by applying virtual constant propagation), we + // decrement the value stored in this map. If a value reaches zero, we can + // eliminate the type check by RAUWing the associated llvm.type.test call with + // true. + std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; + DevirtModule(Module &M) : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())) {} + void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc); + void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); + void buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); @@ -375,6 +396,9 @@ bool DevirtModule::trySingleImplDevirt( for (auto &&VCallSite : CallSites) { VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( TheFn, VCallSite.CS.getCalledValue()->getType())); + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; } return true; } @@ -601,6 +625,10 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { NewGV->setSection(B.GV->getSection()); NewGV->setComdat(B.GV->getComdat()); + // Copy the original vtable's metadata to the anonymous global, adjusting + // offsets as required. + NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); + // Build an alias named after the original global, pointing at the second // element (the original initializer). auto Alias = GlobalAlias::create( @@ -617,16 +645,8 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { B.GV->eraseFromParent(); } -bool DevirtModule::run() { - Function *TypeTestFunc = - M.getFunction(Intrinsic::getName(Intrinsic::type_test)); - if (!TypeTestFunc || TypeTestFunc->use_empty()) - return false; - - Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); - if (!AssumeFunc || AssumeFunc->use_empty()) - return false; - +void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, + Function *AssumeFunc) { // Find all virtual calls via a virtual table pointer %p under an assumption // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p // points to a member of the type identifier %md. Group calls by (type ID, @@ -643,7 +663,7 @@ bool DevirtModule::run() { // Search for virtual calls based on %p and add them to DevirtCalls. SmallVector<DevirtCallSite, 1> DevirtCalls; SmallVector<CallInst *, 1> Assumes; - findDevirtualizableCalls(DevirtCalls, Assumes, CI); + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); // If we found any, add them to CallSlots. Only do this if we haven't seen // the vtable pointer before, as it may have been CSE'd with pointers from @@ -655,7 +675,7 @@ bool DevirtModule::run() { if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { CallSlots[{TypeId, Call.Offset}].push_back( - {CI->getArgOperand(0), Call.CS}); + {CI->getArgOperand(0), Call.CS, nullptr}); } } } @@ -668,6 +688,104 @@ bool DevirtModule::run() { if (CI->use_empty()) CI->eraseFromParent(); } +} + +void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { + Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); + + for (auto I = TypeCheckedLoadFunc->use_begin(), + E = TypeCheckedLoadFunc->use_end(); + I != E;) { + auto CI = dyn_cast<CallInst>(I->getUser()); + ++I; + if (!CI) + continue; + + Value *Ptr = CI->getArgOperand(0); + Value *Offset = CI->getArgOperand(1); + Value *TypeIdValue = CI->getArgOperand(2); + Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); + + SmallVector<DevirtCallSite, 1> DevirtCalls; + SmallVector<Instruction *, 1> LoadedPtrs; + SmallVector<Instruction *, 1> Preds; + bool HasNonCallUses = false; + findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, + HasNonCallUses, CI); + + // Start by generating "pessimistic" code that explicitly loads the function + // pointer from the vtable and performs the type check. If possible, we will + // eliminate the load and the type check later. + + // If possible, only generate the load at the point where it is used. + // This helps avoid unnecessary spills. + IRBuilder<> LoadB( + (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); + Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); + Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); + Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + + for (Instruction *LoadedPtr : LoadedPtrs) { + LoadedPtr->replaceAllUsesWith(LoadedValue); + LoadedPtr->eraseFromParent(); + } + + // Likewise for the type test. + IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); + CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); + + for (Instruction *Pred : Preds) { + Pred->replaceAllUsesWith(TypeTestCall); + Pred->eraseFromParent(); + } + + // We have already erased any extractvalue instructions that refer to the + // intrinsic call, but the intrinsic may have other non-extractvalue uses + // (although this is unlikely). In that case, explicitly build a pair and + // RAUW it. + if (!CI->use_empty()) { + Value *Pair = UndefValue::get(CI->getType()); + IRBuilder<> B(CI); + Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); + Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); + CI->replaceAllUsesWith(Pair); + } + + // The number of unsafe uses is initially the number of uses. + auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; + NumUnsafeUses = DevirtCalls.size(); + + // If the function pointer has a non-call user, we cannot eliminate the type + // check, as one of those users may eventually call the pointer. Increment + // the unsafe use count to make sure it cannot reach zero. + if (HasNonCallUses) + ++NumUnsafeUses; + for (DevirtCallSite Call : DevirtCalls) { + CallSlots[{TypeId, Call.Offset}].push_back( + {Ptr, Call.CS, &NumUnsafeUses}); + } + + CI->eraseFromParent(); + } +} + +bool DevirtModule::run() { + Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); + Function *TypeCheckedLoadFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); + + if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || + AssumeFunc->use_empty()) && + (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) + return false; + + if (TypeTestFunc && AssumeFunc) + scanTypeTestUsers(TypeTestFunc, AssumeFunc); + + if (TypeCheckedLoadFunc) + scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); // Rebuild type metadata into a map for easy lookup. std::vector<VTableBits> Bits; @@ -693,6 +811,18 @@ bool DevirtModule::run() { DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second); } + // If we were able to eliminate all unsafe uses for a type checked load, + // eliminate the type test by replacing it with true. + if (TypeCheckedLoadFunc) { + auto True = ConstantInt::getTrue(M.getContext()); + for (auto &&U : NumUnsafeUsesForTypeTest) { + if (U.second == 0) { + U.first->replaceAllUsesWith(True); + U.first->eraseFromParent(); + } + } + } + // Rebuild each global we touched as part of virtual constant propagation to // include the before and after bytes. if (DidVirtualConstProp) |