summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/TypeMetadataUtils.cpp44
-rw-r--r--llvm/lib/IR/Metadata.cpp1
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp154
3 files changed, 182 insertions, 17 deletions
diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp
index 750cce33856..8d173d77fb6 100644
--- a/llvm/lib/Analysis/TypeMetadataUtils.cpp
+++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/TypeMetadataUtils.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
@@ -21,15 +22,17 @@ using namespace llvm;
// Search for virtual calls that call FPtr and add them to DevirtCalls.
static void
findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
- Value *FPtr, uint64_t Offset) {
+ bool *HasNonCallUses, Value *FPtr, uint64_t Offset) {
for (const Use &U : FPtr->uses()) {
Value *User = U.getUser();
if (isa<BitCastInst>(User)) {
- findCallsAtConstantOffset(DevirtCalls, User, Offset);
+ findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset);
} else if (auto CI = dyn_cast<CallInst>(User)) {
DevirtCalls.push_back({Offset, CI});
} else if (auto II = dyn_cast<InvokeInst>(User)) {
DevirtCalls.push_back({Offset, II});
+ } else if (HasNonCallUses) {
+ *HasNonCallUses = true;
}
}
}
@@ -44,7 +47,7 @@ findLoadCallsAtConstantOffset(Module *M,
if (isa<BitCastInst>(User)) {
findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset);
} else if (isa<LoadInst>(User)) {
- findCallsAtConstantOffset(DevirtCalls, User, Offset);
+ findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset);
} else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) {
// Take into account the GEP offset.
if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) {
@@ -57,7 +60,7 @@ findLoadCallsAtConstantOffset(Module *M,
}
}
-void llvm::findDevirtualizableCalls(
+void llvm::findDevirtualizableCallsForTypeTest(
SmallVectorImpl<DevirtCallSite> &DevirtCalls,
SmallVectorImpl<CallInst *> &Assumes, CallInst *CI) {
assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test);
@@ -80,3 +83,36 @@ void llvm::findDevirtualizableCalls(
findLoadCallsAtConstantOffset(M, DevirtCalls,
CI->getArgOperand(0)->stripPointerCasts(), 0);
}
+
+void llvm::findDevirtualizableCallsForTypeCheckedLoad(
+ SmallVectorImpl<DevirtCallSite> &DevirtCalls,
+ SmallVectorImpl<Instruction *> &LoadedPtrs,
+ SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, CallInst *CI) {
+ assert(CI->getCalledFunction()->getIntrinsicID() ==
+ Intrinsic::type_checked_load);
+
+ auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+ if (!Offset) {
+ HasNonCallUses = true;
+ return;
+ }
+
+ for (Use &U : CI->uses()) {
+ auto CIU = U.getUser();
+ if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) {
+ if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) {
+ LoadedPtrs.push_back(EVI);
+ continue;
+ }
+ if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) {
+ Preds.push_back(EVI);
+ continue;
+ }
+ }
+ HasNonCallUses = true;
+ }
+
+ for (Value *LoadedPtr : LoadedPtrs)
+ findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr,
+ Offset->getZExtValue());
+}
diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
index 965f9737629..ed39fbafcb0 100644
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -1369,7 +1369,6 @@ void GlobalObject::clearMetadata() {
setHasMetadataHashEntry(false);
}
-
void GlobalObject::setMetadata(unsigned KindID, MDNode *N) {
eraseMetadata(KindID);
if (N)
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index c19c667a51e..e6f4fa29224 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -210,6 +210,11 @@ struct VirtualCallSite {
Value *VTable;
CallSite CS;
+ // If non-null, this field points to the associated unsafe use count stored in
+ // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
+ // of that field for details.
+ unsigned *NumUnsafeUses;
+
void replaceAndErase(Value *New) {
CS->replaceAllUsesWith(New);
if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
@@ -217,6 +222,9 @@ struct VirtualCallSite {
II->getUnwindDest()->removePredecessor(II->getParent());
}
CS->eraseFromParent();
+ // This use is no longer unsafe.
+ if (NumUnsafeUses)
+ --*NumUnsafeUses;
}
};
@@ -228,11 +236,24 @@ struct DevirtModule {
MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots;
+ // This map keeps track of the number of "unsafe" uses of a loaded function
+ // pointer. The key is the associated llvm.type.test intrinsic call generated
+ // by this pass. An unsafe use is one that calls the loaded function pointer
+ // directly. Every time we eliminate an unsafe use (for example, by
+ // devirtualizing it or by applying virtual constant propagation), we
+ // decrement the value stored in this map. If a value reaches zero, we can
+ // eliminate the type check by RAUWing the associated llvm.type.test call with
+ // true.
+ std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
+
DevirtModule(Module &M)
: M(M), Int8Ty(Type::getInt8Ty(M.getContext())),
Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
Int32Ty(Type::getInt32Ty(M.getContext())) {}
+ void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
+ void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
+
void buildTypeIdentifierMap(
std::vector<VTableBits> &Bits,
DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
@@ -375,6 +396,9 @@ bool DevirtModule::trySingleImplDevirt(
for (auto &&VCallSite : CallSites) {
VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
TheFn, VCallSite.CS.getCalledValue()->getType()));
+ // This use is no longer unsafe.
+ if (VCallSite.NumUnsafeUses)
+ --*VCallSite.NumUnsafeUses;
}
return true;
}
@@ -601,6 +625,10 @@ void DevirtModule::rebuildGlobal(VTableBits &B) {
NewGV->setSection(B.GV->getSection());
NewGV->setComdat(B.GV->getComdat());
+ // Copy the original vtable's metadata to the anonymous global, adjusting
+ // offsets as required.
+ NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
+
// Build an alias named after the original global, pointing at the second
// element (the original initializer).
auto Alias = GlobalAlias::create(
@@ -617,16 +645,8 @@ void DevirtModule::rebuildGlobal(VTableBits &B) {
B.GV->eraseFromParent();
}
-bool DevirtModule::run() {
- Function *TypeTestFunc =
- M.getFunction(Intrinsic::getName(Intrinsic::type_test));
- if (!TypeTestFunc || TypeTestFunc->use_empty())
- return false;
-
- Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
- if (!AssumeFunc || AssumeFunc->use_empty())
- return false;
-
+void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
+ Function *AssumeFunc) {
// Find all virtual calls via a virtual table pointer %p under an assumption
// of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
// points to a member of the type identifier %md. Group calls by (type ID,
@@ -643,7 +663,7 @@ bool DevirtModule::run() {
// Search for virtual calls based on %p and add them to DevirtCalls.
SmallVector<DevirtCallSite, 1> DevirtCalls;
SmallVector<CallInst *, 1> Assumes;
- findDevirtualizableCalls(DevirtCalls, Assumes, CI);
+ findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
// If we found any, add them to CallSlots. Only do this if we haven't seen
// the vtable pointer before, as it may have been CSE'd with pointers from
@@ -655,7 +675,7 @@ bool DevirtModule::run() {
if (SeenPtrs.insert(Ptr).second) {
for (DevirtCallSite Call : DevirtCalls) {
CallSlots[{TypeId, Call.Offset}].push_back(
- {CI->getArgOperand(0), Call.CS});
+ {CI->getArgOperand(0), Call.CS, nullptr});
}
}
}
@@ -668,6 +688,104 @@ bool DevirtModule::run() {
if (CI->use_empty())
CI->eraseFromParent();
}
+}
+
+void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
+ Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
+
+ for (auto I = TypeCheckedLoadFunc->use_begin(),
+ E = TypeCheckedLoadFunc->use_end();
+ I != E;) {
+ auto CI = dyn_cast<CallInst>(I->getUser());
+ ++I;
+ if (!CI)
+ continue;
+
+ Value *Ptr = CI->getArgOperand(0);
+ Value *Offset = CI->getArgOperand(1);
+ Value *TypeIdValue = CI->getArgOperand(2);
+ Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
+
+ SmallVector<DevirtCallSite, 1> DevirtCalls;
+ SmallVector<Instruction *, 1> LoadedPtrs;
+ SmallVector<Instruction *, 1> Preds;
+ bool HasNonCallUses = false;
+ findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
+ HasNonCallUses, CI);
+
+ // Start by generating "pessimistic" code that explicitly loads the function
+ // pointer from the vtable and performs the type check. If possible, we will
+ // eliminate the load and the type check later.
+
+ // If possible, only generate the load at the point where it is used.
+ // This helps avoid unnecessary spills.
+ IRBuilder<> LoadB(
+ (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
+ Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
+ Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
+ Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
+
+ for (Instruction *LoadedPtr : LoadedPtrs) {
+ LoadedPtr->replaceAllUsesWith(LoadedValue);
+ LoadedPtr->eraseFromParent();
+ }
+
+ // Likewise for the type test.
+ IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
+ CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
+
+ for (Instruction *Pred : Preds) {
+ Pred->replaceAllUsesWith(TypeTestCall);
+ Pred->eraseFromParent();
+ }
+
+ // We have already erased any extractvalue instructions that refer to the
+ // intrinsic call, but the intrinsic may have other non-extractvalue uses
+ // (although this is unlikely). In that case, explicitly build a pair and
+ // RAUW it.
+ if (!CI->use_empty()) {
+ Value *Pair = UndefValue::get(CI->getType());
+ IRBuilder<> B(CI);
+ Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
+ Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
+ CI->replaceAllUsesWith(Pair);
+ }
+
+ // The number of unsafe uses is initially the number of uses.
+ auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
+ NumUnsafeUses = DevirtCalls.size();
+
+ // If the function pointer has a non-call user, we cannot eliminate the type
+ // check, as one of those users may eventually call the pointer. Increment
+ // the unsafe use count to make sure it cannot reach zero.
+ if (HasNonCallUses)
+ ++NumUnsafeUses;
+ for (DevirtCallSite Call : DevirtCalls) {
+ CallSlots[{TypeId, Call.Offset}].push_back(
+ {Ptr, Call.CS, &NumUnsafeUses});
+ }
+
+ CI->eraseFromParent();
+ }
+}
+
+bool DevirtModule::run() {
+ Function *TypeTestFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_test));
+ Function *TypeCheckedLoadFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
+ Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
+
+ if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
+ AssumeFunc->use_empty()) &&
+ (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
+ return false;
+
+ if (TypeTestFunc && AssumeFunc)
+ scanTypeTestUsers(TypeTestFunc, AssumeFunc);
+
+ if (TypeCheckedLoadFunc)
+ scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
// Rebuild type metadata into a map for easy lookup.
std::vector<VTableBits> Bits;
@@ -693,6 +811,18 @@ bool DevirtModule::run() {
DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second);
}
+ // If we were able to eliminate all unsafe uses for a type checked load,
+ // eliminate the type test by replacing it with true.
+ if (TypeCheckedLoadFunc) {
+ auto True = ConstantInt::getTrue(M.getContext());
+ for (auto &&U : NumUnsafeUsesForTypeTest) {
+ if (U.second == 0) {
+ U.first->replaceAllUsesWith(True);
+ U.first->eraseFromParent();
+ }
+ }
+ }
+
// Rebuild each global we touched as part of virtual constant propagation to
// include the before and after bytes.
if (DidVirtualConstProp)
OpenPOWER on IntegriCloud