summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
diff options
context:
space:
mode:
authorPeter Collingbourne <peter@pcc.me.uk>2018-03-09 19:11:44 +0000
committerPeter Collingbourne <peter@pcc.me.uk>2018-03-09 19:11:44 +0000
commit2974856ad4326989052f04299affaa516985e77a (patch)
tree9be1d673e3f12d12980c7a093f7c12554c44ace4 /llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
parentdee18b82c23c1637bdfad001bfe8c62bdf8c5955 (diff)
downloadbcm5719-llvm-2974856ad4326989052f04299affaa516985e77a.tar.gz
bcm5719-llvm-2974856ad4326989052f04299affaa516985e77a.zip
Use branch funnels for virtual calls when retpoline mitigation is enabled.
The retpoline mitigation for variant 2 of CVE-2017-5715 inhibits the branch predictor, and as a result it can lead to a measurable loss of performance. We can reduce the performance impact of retpolined virtual calls by replacing them with a special construct known as a branch funnel, which is an instruction sequence that implements virtual calls to a set of known targets using a binary tree of direct branches. This allows the processor to speculately execute valid implementations of the virtual function without allowing for speculative execution of of calls to arbitrary addresses. This patch extends the whole-program devirtualization pass to replace certain virtual calls with calls to branch funnels, which are represented using a new llvm.icall.jumptable intrinsic. It also extends the LowerTypeTests pass to recognize the new intrinsic, generate code for the branch funnels (x86_64 only for now) and lay out virtual tables as required for each branch funnel. The implementation supports full LTO as well as ThinLTO, and extends the ThinLTO summary format used for whole-program devirtualization to support branch funnels. For more details see RFC: http://lists.llvm.org/pipermail/llvm-dev/2018-January/120672.html Differential Revision: https://reviews.llvm.org/D42453 llvm-svn: 327163
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp211
1 files changed, 189 insertions, 22 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index aa1755bb097..a3aa7c42608 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -316,12 +316,17 @@ struct CallSiteInfo {
/// cases we are directly operating on the call sites at the IR level.
std::vector<VirtualCallSite> CallSites;
+ /// Whether all call sites represented by this CallSiteInfo, including those
+ /// in summaries, have been devirtualized. This starts off as true because a
+ /// default constructed CallSiteInfo represents no call sites.
+ bool AllCallSitesDevirted = true;
+
// These fields are used during the export phase of ThinLTO and reflect
// information collected from function summaries.
/// Whether any function summary contains an llvm.assume(llvm.type.test) for
/// this slot.
- bool SummaryHasTypeTestAssumeUsers;
+ bool SummaryHasTypeTestAssumeUsers = false;
/// CFI-specific: a vector containing the list of function summaries that use
/// the llvm.type.checked.load intrinsic and therefore will require
@@ -337,8 +342,22 @@ struct CallSiteInfo {
!SummaryTypeCheckedLoadUsers.empty();
}
- /// As explained in the comment for SummaryTypeCheckedLoadUsers.
- void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); }
+ void markSummaryHasTypeTestAssumeUsers() {
+ SummaryHasTypeTestAssumeUsers = true;
+ AllCallSitesDevirted = false;
+ }
+
+ void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
+ SummaryTypeCheckedLoadUsers.push_back(FS);
+ AllCallSitesDevirted = false;
+ }
+
+ void markDevirt() {
+ AllCallSitesDevirted = true;
+
+ // As explained in the comment for SummaryTypeCheckedLoadUsers.
+ SummaryTypeCheckedLoadUsers.clear();
+ }
};
// Call site information collected for a specific VTableSlot.
@@ -373,7 +392,9 @@ CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
unsigned *NumUnsafeUses) {
- findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
+ auto &CSI = findCallSiteInfo(CS);
+ CSI.AllCallSitesDevirted = false;
+ CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
}
struct DevirtModule {
@@ -438,6 +459,12 @@ struct DevirtModule {
VTableSlotInfo &SlotInfo,
WholeProgramDevirtResolution *Res);
+ void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
+ bool &IsExported);
+ void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
+ VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res, VTableSlot Slot);
+
bool tryEvaluateFunctionsWithArgs(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
ArrayRef<uint64_t> Args);
@@ -471,6 +498,8 @@ struct DevirtModule {
StringRef Name, IntegerType *IntTy,
uint32_t Storage);
+ Constant *getMemberAddr(const TypeMemberInfo *M);
+
void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
Constant *UniqueMemberAddr);
bool tryUniqueRetValOpt(unsigned BitWidth,
@@ -726,10 +755,9 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
if (VCallSite.NumUnsafeUses)
--*VCallSite.NumUnsafeUses;
}
- if (CSInfo.isExported()) {
+ if (CSInfo.isExported())
IsExported = true;
- CSInfo.markDevirt();
- }
+ CSInfo.markDevirt();
};
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
@@ -785,6 +813,134 @@ bool DevirtModule::trySingleImplDevirt(
return true;
}
+void DevirtModule::tryICallBranchFunnel(
+ MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res, VTableSlot Slot) {
+ Triple T(M.getTargetTriple());
+ if (T.getArch() != Triple::x86_64)
+ return;
+
+ const unsigned kBranchFunnelThreshold = 10;
+ if (TargetsForSlot.size() > kBranchFunnelThreshold)
+ return;
+
+ bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted;
+ if (!HasNonDevirt)
+ for (auto &P : SlotInfo.ConstCSInfo)
+ if (!P.second.AllCallSitesDevirted) {
+ HasNonDevirt = true;
+ break;
+ }
+
+ if (!HasNonDevirt)
+ return;
+
+ FunctionType *FT =
+ FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
+ Function *JT;
+ if (isa<MDString>(Slot.TypeID)) {
+ JT = Function::Create(FT, Function::ExternalLinkage,
+ getGlobalName(Slot, {}, "branch_funnel"), &M);
+ JT->setVisibility(GlobalValue::HiddenVisibility);
+ } else {
+ JT = Function::Create(FT, Function::InternalLinkage, "branch_funnel", &M);
+ }
+ JT->addAttribute(1, Attribute::Nest);
+
+ std::vector<Value *> JTArgs;
+ JTArgs.push_back(JT->arg_begin());
+ for (auto &T : TargetsForSlot) {
+ JTArgs.push_back(getMemberAddr(T.TM));
+ JTArgs.push_back(T.Fn);
+ }
+
+ BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr);
+ Constant *Intr =
+ Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {});
+
+ auto *CI = CallInst::Create(Intr, JTArgs, "", BB);
+ CI->setTailCallKind(CallInst::TCK_MustTail);
+ ReturnInst::Create(M.getContext(), nullptr, BB);
+
+ bool IsExported = false;
+ applyICallBranchFunnel(SlotInfo, JT, IsExported);
+ if (IsExported)
+ Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
+}
+
+void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
+ Constant *JT, bool &IsExported) {
+ auto Apply = [&](CallSiteInfo &CSInfo) {
+ if (CSInfo.isExported())
+ IsExported = true;
+ if (CSInfo.AllCallSitesDevirted)
+ return;
+ for (auto &&VCallSite : CSInfo.CallSites) {
+ CallSite CS = VCallSite.CS;
+
+ // Jump tables are only profitable if the retpoline mitigation is enabled.
+ Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
+ if (FSAttr.hasAttribute(Attribute::None) ||
+ !FSAttr.getValueAsString().contains("+retpoline"))
+ continue;
+
+ if (RemarksEnabled)
+ VCallSite.emitRemark("branch-funnel", JT->getName(), OREGetter);
+
+ // Pass the address of the vtable in the nest register, which is r10 on
+ // x86_64.
+ std::vector<Type *> NewArgs;
+ NewArgs.push_back(Int8PtrTy);
+ for (Type *T : CS.getFunctionType()->params())
+ NewArgs.push_back(T);
+ PointerType *NewFT = PointerType::getUnqual(
+ FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
+ CS.getFunctionType()->isVarArg()));
+
+ IRBuilder<> IRB(CS.getInstruction());
+ std::vector<Value *> Args;
+ Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
+ for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
+ Args.push_back(CS.getArgOperand(I));
+
+ CallSite NewCS;
+ if (CS.isCall())
+ NewCS = IRB.CreateCall(IRB.CreateBitCast(JT, NewFT), Args);
+ else
+ NewCS = IRB.CreateInvoke(
+ IRB.CreateBitCast(JT, NewFT),
+ cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
+ cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
+ NewCS.setCallingConv(CS.getCallingConv());
+
+ AttributeList Attrs = CS.getAttributes();
+ std::vector<AttributeSet> NewArgAttrs;
+ NewArgAttrs.push_back(AttributeSet::get(
+ M.getContext(), ArrayRef<Attribute>{Attribute::get(
+ M.getContext(), Attribute::Nest)}));
+ for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I)
+ NewArgAttrs.push_back(Attrs.getParamAttributes(I));
+ NewCS.setAttributes(
+ AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
+ Attrs.getRetAttributes(), NewArgAttrs));
+
+ CS->replaceAllUsesWith(NewCS.getInstruction());
+ CS->eraseFromParent();
+
+ // This use is no longer unsafe.
+ if (VCallSite.NumUnsafeUses)
+ --*VCallSite.NumUnsafeUses;
+ }
+ // Don't mark as devirtualized because there may be callers compiled without
+ // retpoline mitigation, which would mean that they are lowered to
+ // llvm.type.test and therefore require an llvm.type.test resolution for the
+ // type identifier.
+ };
+ Apply(SlotInfo.CSInfo);
+ for (auto &P : SlotInfo.ConstCSInfo)
+ Apply(P.second);
+}
+
bool DevirtModule::tryEvaluateFunctionsWithArgs(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
ArrayRef<uint64_t> Args) {
@@ -937,6 +1093,12 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
CSInfo.markDevirt();
}
+Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
+ Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy);
+ return ConstantExpr::getGetElementPtr(Int8Ty, C,
+ ConstantInt::get(Int64Ty, M->Offset));
+}
+
bool DevirtModule::tryUniqueRetValOpt(
unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
@@ -956,12 +1118,7 @@ bool DevirtModule::tryUniqueRetValOpt(
// checked for a uniform return value in tryUniformRetValOpt.
assert(UniqueMember);
- Constant *UniqueMemberAddr =
- ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy);
- UniqueMemberAddr = ConstantExpr::getGetElementPtr(
- Int8Ty, UniqueMemberAddr,
- ConstantInt::get(Int64Ty, UniqueMember->Offset));
-
+ Constant *UniqueMemberAddr = getMemberAddr(UniqueMember);
if (CSInfo.isExported()) {
Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
Res->Info = IsOne;
@@ -1348,6 +1505,14 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
break;
}
}
+
+ if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
+ auto *JT = M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
+ Type::getVoidTy(M.getContext()));
+ bool IsExported = false;
+ applyICallBranchFunnel(SlotInfo, JT, IsExported);
+ assert(!IsExported);
+ }
}
void DevirtModule::removeRedundantTypeTests() {
@@ -1417,14 +1582,13 @@ bool DevirtModule::run() {
// FIXME: Only add live functions.
for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
for (Metadata *MD : MetadataByGUID[VF.GUID]) {
- CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers =
- true;
+ CallSlots[{MD, VF.Offset}]
+ .CSInfo.markSummaryHasTypeTestAssumeUsers();
}
}
for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
for (Metadata *MD : MetadataByGUID[VF.GUID]) {
- CallSlots[{MD, VF.Offset}]
- .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS);
+ CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
}
}
for (const FunctionSummary::ConstVCall &VC :
@@ -1432,7 +1596,7 @@ bool DevirtModule::run() {
for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
CallSlots[{MD, VC.VFunc.Offset}]
.ConstCSInfo[VC.Args]
- .SummaryHasTypeTestAssumeUsers = true;
+ .markSummaryHasTypeTestAssumeUsers();
}
}
for (const FunctionSummary::ConstVCall &VC :
@@ -1440,7 +1604,7 @@ bool DevirtModule::run() {
for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
CallSlots[{MD, VC.VFunc.Offset}]
.ConstCSInfo[VC.Args]
- .SummaryTypeCheckedLoadUsers.push_back(FS);
+ .addSummaryTypeCheckedLoadUser(FS);
}
}
}
@@ -1464,9 +1628,12 @@ bool DevirtModule::run() {
cast<MDString>(S.first.TypeID)->getString())
.WPDRes[S.first.ByteOffset];
- if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) &&
- tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first))
- DidVirtualConstProp = true;
+ if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) {
+ DidVirtualConstProp |=
+ tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
+
+ tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first);
+ }
// Collect functions devirtualized at least for one call site for stats.
if (RemarksEnabled)
OpenPOWER on IntegriCloud