summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms
diff options
context:
space:
mode:
authorPeter Collingbourne <peter@pcc.me.uk>2018-03-09 19:11:44 +0000
committerPeter Collingbourne <peter@pcc.me.uk>2018-03-09 19:11:44 +0000
commit2974856ad4326989052f04299affaa516985e77a (patch)
tree9be1d673e3f12d12980c7a093f7c12554c44ace4 /llvm/lib/Transforms
parentdee18b82c23c1637bdfad001bfe8c62bdf8c5955 (diff)
downloadbcm5719-llvm-2974856ad4326989052f04299affaa516985e77a.tar.gz
bcm5719-llvm-2974856ad4326989052f04299affaa516985e77a.zip
Use branch funnels for virtual calls when retpoline mitigation is enabled.
The retpoline mitigation for variant 2 of CVE-2017-5715 inhibits the branch predictor, and as a result it can lead to a measurable loss of performance. We can reduce the performance impact of retpolined virtual calls by replacing them with a special construct known as a branch funnel, which is an instruction sequence that implements virtual calls to a set of known targets using a binary tree of direct branches. This allows the processor to speculately execute valid implementations of the virtual function without allowing for speculative execution of of calls to arbitrary addresses. This patch extends the whole-program devirtualization pass to replace certain virtual calls with calls to branch funnels, which are represented using a new llvm.icall.jumptable intrinsic. It also extends the LowerTypeTests pass to recognize the new intrinsic, generate code for the branch funnels (x86_64 only for now) and lay out virtual tables as required for each branch funnel. The implementation supports full LTO as well as ThinLTO, and extends the ThinLTO summary format used for whole-program devirtualization to support branch funnels. For more details see RFC: http://lists.llvm.org/pipermail/llvm-dev/2018-January/120672.html Differential Revision: https://reviews.llvm.org/D42453 llvm-svn: 327163
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r--llvm/lib/Transforms/IPO/LowerTypeTests.cpp117
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp211
2 files changed, 291 insertions, 37 deletions
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index 3ec6e4045aa..72ce332f581 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -8,6 +8,8 @@
//===----------------------------------------------------------------------===//
//
// This pass lowers type metadata and calls to the llvm.type.test intrinsic.
+// It also ensures that globals are properly laid out for the
+// llvm.icall.branch.funnel intrinsic.
// See http://llvm.org/docs/TypeMetadata.html for more information.
//
//===----------------------------------------------------------------------===//
@@ -25,6 +27,7 @@
#include "llvm/ADT/TinyPtrVector.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
@@ -291,6 +294,29 @@ public:
}
};
+struct ICallBranchFunnel final
+ : TrailingObjects<ICallBranchFunnel, GlobalTypeMember *> {
+ static ICallBranchFunnel *create(BumpPtrAllocator &Alloc, CallInst *CI,
+ ArrayRef<GlobalTypeMember *> Targets) {
+ auto *Call = static_cast<ICallBranchFunnel *>(
+ Alloc.Allocate(totalSizeToAlloc<GlobalTypeMember *>(Targets.size()),
+ alignof(ICallBranchFunnel)));
+ Call->CI = CI;
+ Call->NTargets = Targets.size();
+ std::uninitialized_copy(Targets.begin(), Targets.end(),
+ Call->getTrailingObjects<GlobalTypeMember *>());
+ return Call;
+ }
+
+ CallInst *CI;
+ ArrayRef<GlobalTypeMember *> targets() const {
+ return makeArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets);
+ }
+
+private:
+ size_t NTargets;
+};
+
class LowerTypeTestsModule {
Module &M;
@@ -372,6 +398,7 @@ class LowerTypeTestsModule {
const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
const TypeIdLowering &TIL);
+
void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds,
ArrayRef<GlobalTypeMember *> Globals);
unsigned getJumpTableEntrySize();
@@ -383,11 +410,13 @@ class LowerTypeTestsModule {
void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
ArrayRef<GlobalTypeMember *> Functions);
void buildBitSetsFromFunctionsNative(ArrayRef<Metadata *> TypeIds,
- ArrayRef<GlobalTypeMember *> Functions);
+ ArrayRef<GlobalTypeMember *> Functions);
void buildBitSetsFromFunctionsWASM(ArrayRef<Metadata *> TypeIds,
ArrayRef<GlobalTypeMember *> Functions);
- void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds,
- ArrayRef<GlobalTypeMember *> Globals);
+ void
+ buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds,
+ ArrayRef<GlobalTypeMember *> Globals,
+ ArrayRef<ICallBranchFunnel *> ICallBranchFunnels);
void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT);
void moveInitializerToModuleConstructor(GlobalVariable *GV);
@@ -1462,7 +1491,8 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsWASM(
}
void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
- ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) {
+ ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals,
+ ArrayRef<ICallBranchFunnel *> ICallBranchFunnels) {
DenseMap<Metadata *, uint64_t> TypeIdIndices;
for (unsigned I = 0; I != TypeIds.size(); ++I)
TypeIdIndices[TypeIds[I]] = I;
@@ -1471,15 +1501,25 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
// the type identifier.
std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size());
unsigned GlobalIndex = 0;
+ DenseMap<GlobalTypeMember *, uint64_t> GlobalIndices;
for (GlobalTypeMember *GTM : Globals) {
for (MDNode *Type : GTM->types()) {
// Type = { offset, type identifier }
- unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)];
- TypeMembers[TypeIdIndex].insert(GlobalIndex);
+ auto I = TypeIdIndices.find(Type->getOperand(1));
+ if (I != TypeIdIndices.end())
+ TypeMembers[I->second].insert(GlobalIndex);
}
+ GlobalIndices[GTM] = GlobalIndex;
GlobalIndex++;
}
+ for (ICallBranchFunnel *JT : ICallBranchFunnels) {
+ TypeMembers.emplace_back();
+ std::set<uint64_t> &TMSet = TypeMembers.back();
+ for (GlobalTypeMember *T : JT->targets())
+ TMSet.insert(GlobalIndices[T]);
+ }
+
// Order the sets of indices by size. The GlobalLayoutBuilder works best
// when given small index sets first.
std::stable_sort(
@@ -1567,8 +1607,11 @@ bool LowerTypeTestsModule::runForTesting(Module &M) {
bool LowerTypeTestsModule::lower() {
Function *TypeTestFunc =
M.getFunction(Intrinsic::getName(Intrinsic::type_test));
- if ((!TypeTestFunc || TypeTestFunc->use_empty()) && !ExportSummary &&
- !ImportSummary)
+ Function *ICallBranchFunnelFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::icall_branch_funnel));
+ if ((!TypeTestFunc || TypeTestFunc->use_empty()) &&
+ (!ICallBranchFunnelFunc || ICallBranchFunnelFunc->use_empty()) &&
+ !ExportSummary && !ImportSummary)
return false;
if (ImportSummary) {
@@ -1580,6 +1623,10 @@ bool LowerTypeTestsModule::lower() {
}
}
+ if (ICallBranchFunnelFunc && !ICallBranchFunnelFunc->use_empty())
+ report_fatal_error(
+ "unexpected call to llvm.icall.branch.funnel during import phase");
+
SmallVector<Function *, 8> Defs;
SmallVector<Function *, 8> Decls;
for (auto &F : M) {
@@ -1604,8 +1651,8 @@ bool LowerTypeTestsModule::lower() {
// Equivalence class set containing type identifiers and the globals that
// reference them. This is used to partition the set of type identifiers in
// the module into disjoint sets.
- using GlobalClassesTy =
- EquivalenceClasses<PointerUnion<GlobalTypeMember *, Metadata *>>;
+ using GlobalClassesTy = EquivalenceClasses<
+ PointerUnion3<GlobalTypeMember *, Metadata *, ICallBranchFunnel *>>;
GlobalClassesTy GlobalClasses;
// Verify the type metadata and build a few data structures to let us
@@ -1688,14 +1735,13 @@ bool LowerTypeTestsModule::lower() {
}
}
+ DenseMap<GlobalObject *, GlobalTypeMember *> GlobalTypeMembers;
for (GlobalObject &GO : M.global_objects()) {
if (isa<GlobalVariable>(GO) && GO.isDeclarationForLinker())
continue;
Types.clear();
GO.getMetadata(LLVMContext::MD_type, Types);
- if (Types.empty())
- continue;
bool IsDefinition = !GO.isDeclarationForLinker();
bool IsExported = false;
@@ -1706,6 +1752,7 @@ bool LowerTypeTestsModule::lower() {
auto *GTM =
GlobalTypeMember::create(Alloc, &GO, IsDefinition, IsExported, Types);
+ GlobalTypeMembers[&GO] = GTM;
for (MDNode *Type : Types) {
verifyTypeMDNode(&GO, Type);
auto &Info = TypeIdInfo[Type->getOperand(1)];
@@ -1746,6 +1793,43 @@ bool LowerTypeTestsModule::lower() {
}
}
+ if (ICallBranchFunnelFunc) {
+ for (const Use &U : ICallBranchFunnelFunc->uses()) {
+ if (Arch != Triple::x86_64)
+ report_fatal_error(
+ "llvm.icall.branch.funnel not supported on this target");
+
+ auto CI = cast<CallInst>(U.getUser());
+
+ std::vector<GlobalTypeMember *> Targets;
+ if (CI->getNumArgOperands() % 2 != 1)
+ report_fatal_error("number of arguments should be odd");
+
+ GlobalClassesTy::member_iterator CurSet;
+ for (unsigned I = 1; I != CI->getNumArgOperands(); I += 2) {
+ int64_t Offset;
+ auto *Base = dyn_cast<GlobalObject>(GetPointerBaseWithConstantOffset(
+ CI->getOperand(I), Offset, M.getDataLayout()));
+ if (!Base)
+ report_fatal_error(
+ "Expected branch funnel operand to be global value");
+
+ GlobalTypeMember *GTM = GlobalTypeMembers[Base];
+ Targets.push_back(GTM);
+ GlobalClassesTy::member_iterator NewSet =
+ GlobalClasses.findLeader(GlobalClasses.insert(GTM));
+ if (I == 1)
+ CurSet = NewSet;
+ else
+ CurSet = GlobalClasses.unionSets(CurSet, NewSet);
+ }
+
+ GlobalClasses.unionSets(
+ CurSet, GlobalClasses.findLeader(GlobalClasses.insert(
+ ICallBranchFunnel::create(Alloc, CI, Targets))));
+ }
+ }
+
if (ExportSummary) {
DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
for (auto &P : TypeIdInfo) {
@@ -1798,13 +1882,16 @@ bool LowerTypeTestsModule::lower() {
// Build the list of type identifiers in this disjoint set.
std::vector<Metadata *> TypeIds;
std::vector<GlobalTypeMember *> Globals;
+ std::vector<ICallBranchFunnel *> ICallBranchFunnels;
for (GlobalClassesTy::member_iterator MI =
GlobalClasses.member_begin(S.first);
MI != GlobalClasses.member_end(); ++MI) {
- if ((*MI).is<Metadata *>())
+ if (MI->is<Metadata *>())
TypeIds.push_back(MI->get<Metadata *>());
- else
+ else if (MI->is<GlobalTypeMember *>())
Globals.push_back(MI->get<GlobalTypeMember *>());
+ else
+ ICallBranchFunnels.push_back(MI->get<ICallBranchFunnel *>());
}
// Order type identifiers by global index for determinism. This ordering is
@@ -1814,7 +1901,7 @@ bool LowerTypeTestsModule::lower() {
});
// Build bitsets for this disjoint set.
- buildBitSetsFromDisjointSet(TypeIds, Globals);
+ buildBitSetsFromDisjointSet(TypeIds, Globals, ICallBranchFunnels);
}
allocateByteArrays();
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index aa1755bb097..a3aa7c42608 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -316,12 +316,17 @@ struct CallSiteInfo {
/// cases we are directly operating on the call sites at the IR level.
std::vector<VirtualCallSite> CallSites;
+ /// Whether all call sites represented by this CallSiteInfo, including those
+ /// in summaries, have been devirtualized. This starts off as true because a
+ /// default constructed CallSiteInfo represents no call sites.
+ bool AllCallSitesDevirted = true;
+
// These fields are used during the export phase of ThinLTO and reflect
// information collected from function summaries.
/// Whether any function summary contains an llvm.assume(llvm.type.test) for
/// this slot.
- bool SummaryHasTypeTestAssumeUsers;
+ bool SummaryHasTypeTestAssumeUsers = false;
/// CFI-specific: a vector containing the list of function summaries that use
/// the llvm.type.checked.load intrinsic and therefore will require
@@ -337,8 +342,22 @@ struct CallSiteInfo {
!SummaryTypeCheckedLoadUsers.empty();
}
- /// As explained in the comment for SummaryTypeCheckedLoadUsers.
- void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); }
+ void markSummaryHasTypeTestAssumeUsers() {
+ SummaryHasTypeTestAssumeUsers = true;
+ AllCallSitesDevirted = false;
+ }
+
+ void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
+ SummaryTypeCheckedLoadUsers.push_back(FS);
+ AllCallSitesDevirted = false;
+ }
+
+ void markDevirt() {
+ AllCallSitesDevirted = true;
+
+ // As explained in the comment for SummaryTypeCheckedLoadUsers.
+ SummaryTypeCheckedLoadUsers.clear();
+ }
};
// Call site information collected for a specific VTableSlot.
@@ -373,7 +392,9 @@ CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
unsigned *NumUnsafeUses) {
- findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
+ auto &CSI = findCallSiteInfo(CS);
+ CSI.AllCallSitesDevirted = false;
+ CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
}
struct DevirtModule {
@@ -438,6 +459,12 @@ struct DevirtModule {
VTableSlotInfo &SlotInfo,
WholeProgramDevirtResolution *Res);
+ void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
+ bool &IsExported);
+ void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
+ VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res, VTableSlot Slot);
+
bool tryEvaluateFunctionsWithArgs(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
ArrayRef<uint64_t> Args);
@@ -471,6 +498,8 @@ struct DevirtModule {
StringRef Name, IntegerType *IntTy,
uint32_t Storage);
+ Constant *getMemberAddr(const TypeMemberInfo *M);
+
void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
Constant *UniqueMemberAddr);
bool tryUniqueRetValOpt(unsigned BitWidth,
@@ -726,10 +755,9 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
if (VCallSite.NumUnsafeUses)
--*VCallSite.NumUnsafeUses;
}
- if (CSInfo.isExported()) {
+ if (CSInfo.isExported())
IsExported = true;
- CSInfo.markDevirt();
- }
+ CSInfo.markDevirt();
};
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
@@ -785,6 +813,134 @@ bool DevirtModule::trySingleImplDevirt(
return true;
}
+void DevirtModule::tryICallBranchFunnel(
+ MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res, VTableSlot Slot) {
+ Triple T(M.getTargetTriple());
+ if (T.getArch() != Triple::x86_64)
+ return;
+
+ const unsigned kBranchFunnelThreshold = 10;
+ if (TargetsForSlot.size() > kBranchFunnelThreshold)
+ return;
+
+ bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted;
+ if (!HasNonDevirt)
+ for (auto &P : SlotInfo.ConstCSInfo)
+ if (!P.second.AllCallSitesDevirted) {
+ HasNonDevirt = true;
+ break;
+ }
+
+ if (!HasNonDevirt)
+ return;
+
+ FunctionType *FT =
+ FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
+ Function *JT;
+ if (isa<MDString>(Slot.TypeID)) {
+ JT = Function::Create(FT, Function::ExternalLinkage,
+ getGlobalName(Slot, {}, "branch_funnel"), &M);
+ JT->setVisibility(GlobalValue::HiddenVisibility);
+ } else {
+ JT = Function::Create(FT, Function::InternalLinkage, "branch_funnel", &M);
+ }
+ JT->addAttribute(1, Attribute::Nest);
+
+ std::vector<Value *> JTArgs;
+ JTArgs.push_back(JT->arg_begin());
+ for (auto &T : TargetsForSlot) {
+ JTArgs.push_back(getMemberAddr(T.TM));
+ JTArgs.push_back(T.Fn);
+ }
+
+ BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr);
+ Constant *Intr =
+ Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {});
+
+ auto *CI = CallInst::Create(Intr, JTArgs, "", BB);
+ CI->setTailCallKind(CallInst::TCK_MustTail);
+ ReturnInst::Create(M.getContext(), nullptr, BB);
+
+ bool IsExported = false;
+ applyICallBranchFunnel(SlotInfo, JT, IsExported);
+ if (IsExported)
+ Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
+}
+
+void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
+ Constant *JT, bool &IsExported) {
+ auto Apply = [&](CallSiteInfo &CSInfo) {
+ if (CSInfo.isExported())
+ IsExported = true;
+ if (CSInfo.AllCallSitesDevirted)
+ return;
+ for (auto &&VCallSite : CSInfo.CallSites) {
+ CallSite CS = VCallSite.CS;
+
+ // Jump tables are only profitable if the retpoline mitigation is enabled.
+ Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
+ if (FSAttr.hasAttribute(Attribute::None) ||
+ !FSAttr.getValueAsString().contains("+retpoline"))
+ continue;
+
+ if (RemarksEnabled)
+ VCallSite.emitRemark("branch-funnel", JT->getName(), OREGetter);
+
+ // Pass the address of the vtable in the nest register, which is r10 on
+ // x86_64.
+ std::vector<Type *> NewArgs;
+ NewArgs.push_back(Int8PtrTy);
+ for (Type *T : CS.getFunctionType()->params())
+ NewArgs.push_back(T);
+ PointerType *NewFT = PointerType::getUnqual(
+ FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
+ CS.getFunctionType()->isVarArg()));
+
+ IRBuilder<> IRB(CS.getInstruction());
+ std::vector<Value *> Args;
+ Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
+ for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
+ Args.push_back(CS.getArgOperand(I));
+
+ CallSite NewCS;
+ if (CS.isCall())
+ NewCS = IRB.CreateCall(IRB.CreateBitCast(JT, NewFT), Args);
+ else
+ NewCS = IRB.CreateInvoke(
+ IRB.CreateBitCast(JT, NewFT),
+ cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
+ cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
+ NewCS.setCallingConv(CS.getCallingConv());
+
+ AttributeList Attrs = CS.getAttributes();
+ std::vector<AttributeSet> NewArgAttrs;
+ NewArgAttrs.push_back(AttributeSet::get(
+ M.getContext(), ArrayRef<Attribute>{Attribute::get(
+ M.getContext(), Attribute::Nest)}));
+ for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I)
+ NewArgAttrs.push_back(Attrs.getParamAttributes(I));
+ NewCS.setAttributes(
+ AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
+ Attrs.getRetAttributes(), NewArgAttrs));
+
+ CS->replaceAllUsesWith(NewCS.getInstruction());
+ CS->eraseFromParent();
+
+ // This use is no longer unsafe.
+ if (VCallSite.NumUnsafeUses)
+ --*VCallSite.NumUnsafeUses;
+ }
+ // Don't mark as devirtualized because there may be callers compiled without
+ // retpoline mitigation, which would mean that they are lowered to
+ // llvm.type.test and therefore require an llvm.type.test resolution for the
+ // type identifier.
+ };
+ Apply(SlotInfo.CSInfo);
+ for (auto &P : SlotInfo.ConstCSInfo)
+ Apply(P.second);
+}
+
bool DevirtModule::tryEvaluateFunctionsWithArgs(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
ArrayRef<uint64_t> Args) {
@@ -937,6 +1093,12 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
CSInfo.markDevirt();
}
+Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
+ Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy);
+ return ConstantExpr::getGetElementPtr(Int8Ty, C,
+ ConstantInt::get(Int64Ty, M->Offset));
+}
+
bool DevirtModule::tryUniqueRetValOpt(
unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
@@ -956,12 +1118,7 @@ bool DevirtModule::tryUniqueRetValOpt(
// checked for a uniform return value in tryUniformRetValOpt.
assert(UniqueMember);
- Constant *UniqueMemberAddr =
- ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy);
- UniqueMemberAddr = ConstantExpr::getGetElementPtr(
- Int8Ty, UniqueMemberAddr,
- ConstantInt::get(Int64Ty, UniqueMember->Offset));
-
+ Constant *UniqueMemberAddr = getMemberAddr(UniqueMember);
if (CSInfo.isExported()) {
Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
Res->Info = IsOne;
@@ -1348,6 +1505,14 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
break;
}
}
+
+ if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
+ auto *JT = M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
+ Type::getVoidTy(M.getContext()));
+ bool IsExported = false;
+ applyICallBranchFunnel(SlotInfo, JT, IsExported);
+ assert(!IsExported);
+ }
}
void DevirtModule::removeRedundantTypeTests() {
@@ -1417,14 +1582,13 @@ bool DevirtModule::run() {
// FIXME: Only add live functions.
for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
for (Metadata *MD : MetadataByGUID[VF.GUID]) {
- CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers =
- true;
+ CallSlots[{MD, VF.Offset}]
+ .CSInfo.markSummaryHasTypeTestAssumeUsers();
}
}
for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
for (Metadata *MD : MetadataByGUID[VF.GUID]) {
- CallSlots[{MD, VF.Offset}]
- .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS);
+ CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
}
}
for (const FunctionSummary::ConstVCall &VC :
@@ -1432,7 +1596,7 @@ bool DevirtModule::run() {
for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
CallSlots[{MD, VC.VFunc.Offset}]
.ConstCSInfo[VC.Args]
- .SummaryHasTypeTestAssumeUsers = true;
+ .markSummaryHasTypeTestAssumeUsers();
}
}
for (const FunctionSummary::ConstVCall &VC :
@@ -1440,7 +1604,7 @@ bool DevirtModule::run() {
for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
CallSlots[{MD, VC.VFunc.Offset}]
.ConstCSInfo[VC.Args]
- .SummaryTypeCheckedLoadUsers.push_back(FS);
+ .addSummaryTypeCheckedLoadUser(FS);
}
}
}
@@ -1464,9 +1628,12 @@ bool DevirtModule::run() {
cast<MDString>(S.first.TypeID)->getString())
.WPDRes[S.first.ByteOffset];
- if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) &&
- tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first))
- DidVirtualConstProp = true;
+ if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) {
+ DidVirtualConstProp |=
+ tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
+
+ tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first);
+ }
// Collect functions devirtualized at least for one call site for stats.
if (RemarksEnabled)
OpenPOWER on IntegriCloud