summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp154
1 files changed, 142 insertions, 12 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index c19c667a51e..e6f4fa29224 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -210,6 +210,11 @@ struct VirtualCallSite {
Value *VTable;
CallSite CS;
+ // If non-null, this field points to the associated unsafe use count stored in
+ // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
+ // of that field for details.
+ unsigned *NumUnsafeUses;
+
void replaceAndErase(Value *New) {
CS->replaceAllUsesWith(New);
if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
@@ -217,6 +222,9 @@ struct VirtualCallSite {
II->getUnwindDest()->removePredecessor(II->getParent());
}
CS->eraseFromParent();
+ // This use is no longer unsafe.
+ if (NumUnsafeUses)
+ --*NumUnsafeUses;
}
};
@@ -228,11 +236,24 @@ struct DevirtModule {
MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots;
+ // This map keeps track of the number of "unsafe" uses of a loaded function
+ // pointer. The key is the associated llvm.type.test intrinsic call generated
+ // by this pass. An unsafe use is one that calls the loaded function pointer
+ // directly. Every time we eliminate an unsafe use (for example, by
+ // devirtualizing it or by applying virtual constant propagation), we
+ // decrement the value stored in this map. If a value reaches zero, we can
+ // eliminate the type check by RAUWing the associated llvm.type.test call with
+ // true.
+ std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
+
DevirtModule(Module &M)
: M(M), Int8Ty(Type::getInt8Ty(M.getContext())),
Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
Int32Ty(Type::getInt32Ty(M.getContext())) {}
+ void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
+ void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
+
void buildTypeIdentifierMap(
std::vector<VTableBits> &Bits,
DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
@@ -375,6 +396,9 @@ bool DevirtModule::trySingleImplDevirt(
for (auto &&VCallSite : CallSites) {
VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
TheFn, VCallSite.CS.getCalledValue()->getType()));
+ // This use is no longer unsafe.
+ if (VCallSite.NumUnsafeUses)
+ --*VCallSite.NumUnsafeUses;
}
return true;
}
@@ -601,6 +625,10 @@ void DevirtModule::rebuildGlobal(VTableBits &B) {
NewGV->setSection(B.GV->getSection());
NewGV->setComdat(B.GV->getComdat());
+ // Copy the original vtable's metadata to the anonymous global, adjusting
+ // offsets as required.
+ NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
+
// Build an alias named after the original global, pointing at the second
// element (the original initializer).
auto Alias = GlobalAlias::create(
@@ -617,16 +645,8 @@ void DevirtModule::rebuildGlobal(VTableBits &B) {
B.GV->eraseFromParent();
}
-bool DevirtModule::run() {
- Function *TypeTestFunc =
- M.getFunction(Intrinsic::getName(Intrinsic::type_test));
- if (!TypeTestFunc || TypeTestFunc->use_empty())
- return false;
-
- Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
- if (!AssumeFunc || AssumeFunc->use_empty())
- return false;
-
+void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
+ Function *AssumeFunc) {
// Find all virtual calls via a virtual table pointer %p under an assumption
// of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
// points to a member of the type identifier %md. Group calls by (type ID,
@@ -643,7 +663,7 @@ bool DevirtModule::run() {
// Search for virtual calls based on %p and add them to DevirtCalls.
SmallVector<DevirtCallSite, 1> DevirtCalls;
SmallVector<CallInst *, 1> Assumes;
- findDevirtualizableCalls(DevirtCalls, Assumes, CI);
+ findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
// If we found any, add them to CallSlots. Only do this if we haven't seen
// the vtable pointer before, as it may have been CSE'd with pointers from
@@ -655,7 +675,7 @@ bool DevirtModule::run() {
if (SeenPtrs.insert(Ptr).second) {
for (DevirtCallSite Call : DevirtCalls) {
CallSlots[{TypeId, Call.Offset}].push_back(
- {CI->getArgOperand(0), Call.CS});
+ {CI->getArgOperand(0), Call.CS, nullptr});
}
}
}
@@ -668,6 +688,104 @@ bool DevirtModule::run() {
if (CI->use_empty())
CI->eraseFromParent();
}
+}
+
+void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
+ Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
+
+ for (auto I = TypeCheckedLoadFunc->use_begin(),
+ E = TypeCheckedLoadFunc->use_end();
+ I != E;) {
+ auto CI = dyn_cast<CallInst>(I->getUser());
+ ++I;
+ if (!CI)
+ continue;
+
+ Value *Ptr = CI->getArgOperand(0);
+ Value *Offset = CI->getArgOperand(1);
+ Value *TypeIdValue = CI->getArgOperand(2);
+ Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
+
+ SmallVector<DevirtCallSite, 1> DevirtCalls;
+ SmallVector<Instruction *, 1> LoadedPtrs;
+ SmallVector<Instruction *, 1> Preds;
+ bool HasNonCallUses = false;
+ findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
+ HasNonCallUses, CI);
+
+ // Start by generating "pessimistic" code that explicitly loads the function
+ // pointer from the vtable and performs the type check. If possible, we will
+ // eliminate the load and the type check later.
+
+ // If possible, only generate the load at the point where it is used.
+ // This helps avoid unnecessary spills.
+ IRBuilder<> LoadB(
+ (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
+ Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
+ Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
+ Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
+
+ for (Instruction *LoadedPtr : LoadedPtrs) {
+ LoadedPtr->replaceAllUsesWith(LoadedValue);
+ LoadedPtr->eraseFromParent();
+ }
+
+ // Likewise for the type test.
+ IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
+ CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
+
+ for (Instruction *Pred : Preds) {
+ Pred->replaceAllUsesWith(TypeTestCall);
+ Pred->eraseFromParent();
+ }
+
+ // We have already erased any extractvalue instructions that refer to the
+ // intrinsic call, but the intrinsic may have other non-extractvalue uses
+ // (although this is unlikely). In that case, explicitly build a pair and
+ // RAUW it.
+ if (!CI->use_empty()) {
+ Value *Pair = UndefValue::get(CI->getType());
+ IRBuilder<> B(CI);
+ Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
+ Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
+ CI->replaceAllUsesWith(Pair);
+ }
+
+ // The number of unsafe uses is initially the number of uses.
+ auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
+ NumUnsafeUses = DevirtCalls.size();
+
+ // If the function pointer has a non-call user, we cannot eliminate the type
+ // check, as one of those users may eventually call the pointer. Increment
+ // the unsafe use count to make sure it cannot reach zero.
+ if (HasNonCallUses)
+ ++NumUnsafeUses;
+ for (DevirtCallSite Call : DevirtCalls) {
+ CallSlots[{TypeId, Call.Offset}].push_back(
+ {Ptr, Call.CS, &NumUnsafeUses});
+ }
+
+ CI->eraseFromParent();
+ }
+}
+
+bool DevirtModule::run() {
+ Function *TypeTestFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_test));
+ Function *TypeCheckedLoadFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
+ Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
+
+ if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
+ AssumeFunc->use_empty()) &&
+ (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
+ return false;
+
+ if (TypeTestFunc && AssumeFunc)
+ scanTypeTestUsers(TypeTestFunc, AssumeFunc);
+
+ if (TypeCheckedLoadFunc)
+ scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
// Rebuild type metadata into a map for easy lookup.
std::vector<VTableBits> Bits;
@@ -693,6 +811,18 @@ bool DevirtModule::run() {
DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second);
}
+ // If we were able to eliminate all unsafe uses for a type checked load,
+ // eliminate the type test by replacing it with true.
+ if (TypeCheckedLoadFunc) {
+ auto True = ConstantInt::getTrue(M.getContext());
+ for (auto &&U : NumUnsafeUsesForTypeTest) {
+ if (U.second == 0) {
+ U.first->replaceAllUsesWith(True);
+ U.first->eraseFromParent();
+ }
+ }
+ }
+
// Rebuild each global we touched as part of virtual constant propagation to
// include the before and after bytes.
if (DidVirtualConstProp)
OpenPOWER on IntegriCloud