diff options
author | Peter Collingbourne <peter@pcc.me.uk> | 2018-06-26 02:15:47 +0000 |
---|---|---|
committer | Peter Collingbourne <peter@pcc.me.uk> | 2018-06-26 02:15:47 +0000 |
commit | e44acadf6ab0ec228dcba2e02b495c1dfe7338dd (patch) | |
tree | 9bee8e35f720ae18caf51815d1b2f0a1afffa08e /clang/lib/CodeGen | |
parent | 689e363ff2294f43d2ee35d08777f7a7ea0ce7dd (diff) | |
download | bcm5719-llvm-e44acadf6ab0ec228dcba2e02b495c1dfe7338dd.tar.gz bcm5719-llvm-e44acadf6ab0ec228dcba2e02b495c1dfe7338dd.zip |
Implement CFI for indirect calls via a member function pointer.
Similarly to CFI on virtual and indirect calls, this implementation
tries to use program type information to make the checks as precise
as possible. The basic way that it works is as follows, where `C`
is the name of the class being defined or the target of a call and
the function type is assumed to be `void()`.
For virtual calls:
- Attach type metadata to the addresses of function pointers in vtables
(not the functions themselves) of type `void (B::*)()` for each `B`
that is a recursive dynamic base class of `C`, including `C` itself.
This type metadata has an annotation that the type is for virtual
calls (to distinguish it from the non-virtual case).
- At the call site, check that the computed address of the function
pointer in the vtable has type `void (C::*)()`.
For non-virtual calls:
- Attach type metadata to each non-virtual member function whose address
can be taken with a member function pointer. The type of a function
in class `C` of type `void()` is each of the types `void (B::*)()`
where `B` is a most-base class of `C`. A most-base class of `C`
is defined as a recursive base class of `C`, including `C` itself,
that does not have any bases.
- At the call site, check that the function pointer has one of the types
`void (B::*)()` where `B` is a most-base class of `C`.
Differential Revision: https://reviews.llvm.org/D47567
llvm-svn: 335569
Diffstat (limited to 'clang/lib/CodeGen')
-rw-r--r-- | clang/lib/CodeGen/CGClass.cpp | 4 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGVTables.cpp | 43 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenFunction.h | 2 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenModule.cpp | 89 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenModule.h | 19 | ||||
-rw-r--r-- | clang/lib/CodeGen/ItaniumCXXABI.cpp | 87 |
6 files changed, 197 insertions, 47 deletions
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp index 11a327d2d2f..0b9311f7771 100644 --- a/clang/lib/CodeGen/CGClass.cpp +++ b/clang/lib/CodeGen/CGClass.cpp @@ -2688,7 +2688,9 @@ void CodeGenFunction::EmitVTablePtrCheck(const CXXRecordDecl *RD, SSK = llvm::SanStat_CFI_UnrelatedCast; break; case CFITCK_ICall: - llvm_unreachable("not expecting CFITCK_ICall"); + case CFITCK_NVMFCall: + case CFITCK_VMFCall: + llvm_unreachable("unexpected sanitizer kind"); } std::string TypeName = RD->getQualifiedNameAsString(); diff --git a/clang/lib/CodeGen/CGVTables.cpp b/clang/lib/CodeGen/CGVTables.cpp index 86bceb6ee9f..5a2ec65f776 100644 --- a/clang/lib/CodeGen/CGVTables.cpp +++ b/clang/lib/CodeGen/CGVTables.cpp @@ -1012,30 +1012,29 @@ void CodeGenModule::EmitVTableTypeMetadata(llvm::GlobalVariable *VTable, CharUnits PointerWidth = Context.toCharUnitsFromBits(Context.getTargetInfo().getPointerWidth(0)); - typedef std::pair<const CXXRecordDecl *, unsigned> TypeMetadata; - std::vector<TypeMetadata> TypeMetadatas; - // Create type metadata for each address point. + typedef std::pair<const CXXRecordDecl *, unsigned> AddressPoint; + std::vector<AddressPoint> AddressPoints; for (auto &&AP : VTLayout.getAddressPoints()) - TypeMetadatas.push_back(std::make_pair( + AddressPoints.push_back(std::make_pair( AP.first.getBase(), VTLayout.getVTableOffset(AP.second.VTableIndex) + AP.second.AddressPointIndex)); - // Sort the type metadata for determinism. - llvm::sort(TypeMetadatas.begin(), TypeMetadatas.end(), - [this](const TypeMetadata &M1, const TypeMetadata &M2) { - if (&M1 == &M2) + // Sort the address points for determinism. + llvm::sort(AddressPoints.begin(), AddressPoints.end(), + [this](const AddressPoint &AP1, const AddressPoint &AP2) { + if (&AP1 == &AP2) return false; std::string S1; llvm::raw_string_ostream O1(S1); getCXXABI().getMangleContext().mangleTypeName( - QualType(M1.first->getTypeForDecl(), 0), O1); + QualType(AP1.first->getTypeForDecl(), 0), O1); O1.flush(); std::string S2; llvm::raw_string_ostream O2(S2); getCXXABI().getMangleContext().mangleTypeName( - QualType(M2.first->getTypeForDecl(), 0), O2); + QualType(AP2.first->getTypeForDecl(), 0), O2); O2.flush(); if (S1 < S2) @@ -1043,10 +1042,26 @@ void CodeGenModule::EmitVTableTypeMetadata(llvm::GlobalVariable *VTable, if (S1 != S2) return false; - return M1.second < M2.second; + return AP1.second < AP2.second; }); - for (auto TypeMetadata : TypeMetadatas) - AddVTableTypeMetadata(VTable, PointerWidth * TypeMetadata.second, - TypeMetadata.first); + ArrayRef<VTableComponent> Comps = VTLayout.vtable_components(); + for (auto AP : AddressPoints) { + // Create type metadata for the address point. + AddVTableTypeMetadata(VTable, PointerWidth * AP.second, AP.first); + + // The class associated with each address point could also potentially be + // used for indirect calls via a member function pointer, so we need to + // annotate the address of each function pointer with the appropriate member + // function pointer type. + for (unsigned I = 0; I != Comps.size(); ++I) { + if (Comps[I].getKind() != VTableComponent::CK_FunctionPointer) + continue; + llvm::Metadata *MD = CreateMetadataIdentifierForVirtualMemPtrType( + Context.getMemberPointerType( + Comps[I].getFunctionDecl()->getType(), + Context.getRecordType(AP.first).getTypePtr())); + VTable->addTypeMetadata((PointerWidth * I).getQuantity(), MD); + } + } } diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index afe199c4eb5..548a4178ef2 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -1765,6 +1765,8 @@ public: CFITCK_DerivedCast, CFITCK_UnrelatedCast, CFITCK_ICall, + CFITCK_NVMFCall, + CFITCK_VMFCall, }; /// Derived is the presumed address of an object of type T after a diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp index 5a2f2a01d39..35e9dea37a9 100644 --- a/clang/lib/CodeGen/CodeGenModule.cpp +++ b/clang/lib/CodeGen/CodeGenModule.cpp @@ -1132,6 +1132,34 @@ static bool hasUnwindExceptions(const LangOptions &LangOpts) { return true; } +static bool requiresMemberFunctionPointerTypeMetadata(CodeGenModule &CGM, + const CXXMethodDecl *MD) { + // Check that the type metadata can ever actually be used by a call. + if (!CGM.getCodeGenOpts().LTOUnit || + !CGM.HasHiddenLTOVisibility(MD->getParent())) + return false; + + // Only functions whose address can be taken with a member function pointer + // need this sort of type metadata. + return !MD->isStatic() && !MD->isVirtual() && !isa<CXXConstructorDecl>(MD) && + !isa<CXXDestructorDecl>(MD); +} + +std::vector<const CXXRecordDecl *> +CodeGenModule::getMostBaseClasses(const CXXRecordDecl *RD) { + llvm::SetVector<const CXXRecordDecl *> MostBases; + + std::function<void (const CXXRecordDecl *)> CollectMostBases; + CollectMostBases = [&](const CXXRecordDecl *RD) { + if (RD->getNumBases() == 0) + MostBases.insert(RD); + for (const CXXBaseSpecifier &B : RD->bases()) + CollectMostBases(B.getType()->getAsCXXRecordDecl()); + }; + CollectMostBases(RD); + return MostBases.takeVector(); +} + void CodeGenModule::SetLLVMFunctionAttributesForDefinition(const Decl *D, llvm::Function *F) { llvm::AttrBuilder B; @@ -1257,7 +1285,20 @@ void CodeGenModule::SetLLVMFunctionAttributesForDefinition(const Decl *D, // In the cross-dso CFI mode, we want !type attributes on definitions only. if (CodeGenOpts.SanitizeCfiCrossDso) if (auto *FD = dyn_cast<FunctionDecl>(D)) - CreateFunctionTypeMetadata(FD, F); + CreateFunctionTypeMetadataForIcall(FD, F); + + // Emit type metadata on member functions for member function pointer checks. + // These are only ever necessary on definitions; we're guaranteed that the + // definition will be present in the LTO unit as a result of LTO visibility. + auto *MD = dyn_cast<CXXMethodDecl>(D); + if (MD && requiresMemberFunctionPointerTypeMetadata(*this, MD)) { + for (const CXXRecordDecl *Base : getMostBaseClasses(MD->getParent())) { + llvm::Metadata *Id = + CreateMetadataIdentifierForType(Context.getMemberPointerType( + MD->getType(), Context.getRecordType(Base).getTypePtr())); + F->addTypeMetadata(0, Id); + } + } } void CodeGenModule::SetCommonAttributes(GlobalDecl GD, llvm::GlobalValue *GV) { @@ -1378,13 +1419,14 @@ static void setLinkageForGV(llvm::GlobalValue *GV, const NamedDecl *ND) { GV->setLinkage(llvm::GlobalValue::ExternalWeakLinkage); } -void CodeGenModule::CreateFunctionTypeMetadata(const FunctionDecl *FD, - llvm::Function *F) { +void CodeGenModule::CreateFunctionTypeMetadataForIcall(const FunctionDecl *FD, + llvm::Function *F) { // Only if we are checking indirect calls. if (!LangOpts.Sanitize.has(SanitizerKind::CFIICall)) return; - // Non-static class methods are handled via vtable pointer checks elsewhere. + // Non-static class methods are handled via vtable or member function pointer + // checks elsewhere. if (isa<CXXMethodDecl>(FD) && !cast<CXXMethodDecl>(FD)->isStatic()) return; @@ -1476,7 +1518,7 @@ void CodeGenModule::SetFunctionAttributes(GlobalDecl GD, llvm::Function *F, // Don't emit entries for function declarations in the cross-DSO mode. This // is handled with better precision by the receiving DSO. if (!CodeGenOpts.SanitizeCfiCrossDso) - CreateFunctionTypeMetadata(FD, F); + CreateFunctionTypeMetadataForIcall(FD, F); if (getLangOpts().OpenMP && FD->hasAttr<OMPDeclareSimdDeclAttr>()) getOpenMPRuntime().emitDeclareSimdFunction(FD, F); @@ -4925,8 +4967,10 @@ void CodeGenModule::EmitOMPThreadPrivateDecl(const OMPThreadPrivateDecl *D) { } } -llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { - llvm::Metadata *&InternalId = MetadataIdMap[T.getCanonicalType()]; +llvm::Metadata * +CodeGenModule::CreateMetadataIdentifierImpl(QualType T, MetadataTypeMap &Map, + StringRef Suffix) { + llvm::Metadata *&InternalId = Map[T.getCanonicalType()]; if (InternalId) return InternalId; @@ -4934,6 +4978,7 @@ llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { std::string OutName; llvm::raw_string_ostream Out(OutName); getCXXABI().getMangleContext().mangleTypeName(T, Out); + Out << Suffix; InternalId = llvm::MDString::get(getLLVMContext(), Out.str()); } else { @@ -4944,6 +4989,15 @@ llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { return InternalId; } +llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { + return CreateMetadataIdentifierImpl(T, MetadataIdMap, ""); +} + +llvm::Metadata * +CodeGenModule::CreateMetadataIdentifierForVirtualMemPtrType(QualType T) { + return CreateMetadataIdentifierImpl(T, VirtualMetadataIdMap, ".virtual"); +} + // Generalize pointer types to a void pointer with the qualifiers of the // originally pointed-to type, e.g. 'const char *' and 'char * const *' // generalize to 'const void *' while 'char *' and 'const char **' generalize to @@ -4977,25 +5031,8 @@ static QualType GeneralizeFunctionType(ASTContext &Ctx, QualType Ty) { } llvm::Metadata *CodeGenModule::CreateMetadataIdentifierGeneralized(QualType T) { - T = GeneralizeFunctionType(getContext(), T); - - llvm::Metadata *&InternalId = GeneralizedMetadataIdMap[T.getCanonicalType()]; - if (InternalId) - return InternalId; - - if (isExternallyVisible(T->getLinkage())) { - std::string OutName; - llvm::raw_string_ostream Out(OutName); - getCXXABI().getMangleContext().mangleTypeName(T, Out); - Out << ".generalized"; - - InternalId = llvm::MDString::get(getLLVMContext(), Out.str()); - } else { - InternalId = llvm::MDNode::getDistinct(getLLVMContext(), - llvm::ArrayRef<llvm::Metadata *>()); - } - - return InternalId; + return CreateMetadataIdentifierImpl(GeneralizeFunctionType(getContext(), T), + GeneralizedMetadataIdMap, ".generalized"); } /// Returns whether this module needs the "all-vtables" type identifier. diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h index bf22ad246c8..d2c7b327f98 100644 --- a/clang/lib/CodeGen/CodeGenModule.h +++ b/clang/lib/CodeGen/CodeGenModule.h @@ -503,6 +503,7 @@ private: /// MDNodes. typedef llvm::DenseMap<QualType, llvm::Metadata *> MetadataTypeMap; MetadataTypeMap MetadataIdMap; + MetadataTypeMap VirtualMetadataIdMap; MetadataTypeMap GeneralizedMetadataIdMap; public: @@ -1232,13 +1233,18 @@ public: /// internal identifiers). llvm::Metadata *CreateMetadataIdentifierForType(QualType T); + /// Create a metadata identifier that is intended to be used to check virtual + /// calls via a member function pointer. + llvm::Metadata *CreateMetadataIdentifierForVirtualMemPtrType(QualType T); + /// Create a metadata identifier for the generalization of the given type. /// This may either be an MDString (for external identifiers) or a distinct /// unnamed MDNode (for internal identifiers). llvm::Metadata *CreateMetadataIdentifierGeneralized(QualType T); /// Create and attach type metadata to the given function. - void CreateFunctionTypeMetadata(const FunctionDecl *FD, llvm::Function *F); + void CreateFunctionTypeMetadataForIcall(const FunctionDecl *FD, + llvm::Function *F); /// Returns whether this module needs the "all-vtables" type identifier. bool NeedAllVtablesTypeId() const; @@ -1247,6 +1253,14 @@ public: void AddVTableTypeMetadata(llvm::GlobalVariable *VTable, CharUnits Offset, const CXXRecordDecl *RD); + /// Return a vector of most-base classes for RD. This is used to implement + /// control flow integrity checks for member function pointers. + /// + /// A most-base class of a class C is defined as a recursive base class of C, + /// including C itself, that does not have any bases. + std::vector<const CXXRecordDecl *> + getMostBaseClasses(const CXXRecordDecl *RD); + /// Get the declaration of std::terminate for the platform. llvm::Constant *getTerminateFn(); @@ -1408,6 +1422,9 @@ private: void ConstructDefaultFnAttrList(StringRef Name, bool HasOptnone, bool AttrOnCallSite, llvm::AttrBuilder &FuncAttrs); + + llvm::Metadata *CreateMetadataIdentifierImpl(QualType T, MetadataTypeMap &Map, + StringRef Suffix); }; } // end namespace CodeGen diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index d855afab22e..8dd94e49556 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -622,13 +622,53 @@ CGCallee ItaniumCXXABI::EmitLoadOfMemberFunctionPointer( VTableOffset = Builder.CreateTrunc(VTableOffset, CGF.Int32Ty); VTableOffset = Builder.CreateZExt(VTableOffset, CGM.PtrDiffTy); } - VTable = Builder.CreateGEP(VTable, VTableOffset); + // Compute the address of the virtual function pointer. + llvm::Value *VFPAddr = Builder.CreateGEP(VTable, VTableOffset); + + // Check the address of the function pointer if CFI on member function + // pointers is enabled. + llvm::Constant *CheckSourceLocation; + llvm::Constant *CheckTypeDesc; + bool ShouldEmitCFICheck = CGF.SanOpts.has(SanitizerKind::CFIMFCall) && + CGM.HasHiddenLTOVisibility(RD); + if (ShouldEmitCFICheck) { + CodeGenFunction::SanitizerScope SanScope(&CGF); + + CheckSourceLocation = CGF.EmitCheckSourceLocation(E->getLocStart()); + CheckTypeDesc = CGF.EmitCheckTypeDescriptor(QualType(MPT, 0)); + llvm::Constant *StaticData[] = { + llvm::ConstantInt::get(CGF.Int8Ty, CodeGenFunction::CFITCK_VMFCall), + CheckSourceLocation, + CheckTypeDesc, + }; + + llvm::Metadata *MD = + CGM.CreateMetadataIdentifierForVirtualMemPtrType(QualType(MPT, 0)); + llvm::Value *TypeId = llvm::MetadataAsValue::get(CGF.getLLVMContext(), MD); + + llvm::Value *TypeTest = Builder.CreateCall( + CGM.getIntrinsic(llvm::Intrinsic::type_test), {VFPAddr, TypeId}); + + if (CGM.getCodeGenOpts().SanitizeTrap.has(SanitizerKind::CFIMFCall)) { + CGF.EmitTrapCheck(TypeTest); + } else { + llvm::Value *AllVtables = llvm::MetadataAsValue::get( + CGM.getLLVMContext(), + llvm::MDString::get(CGM.getLLVMContext(), "all-vtables")); + llvm::Value *ValidVtable = Builder.CreateCall( + CGM.getIntrinsic(llvm::Intrinsic::type_test), {VTable, AllVtables}); + CGF.EmitCheck(std::make_pair(TypeTest, SanitizerKind::CFIMFCall), + SanitizerHandler::CFICheckFail, StaticData, + {VTable, ValidVtable}); + } + + FnVirtual = Builder.GetInsertBlock(); + } // Load the virtual function to call. - VTable = Builder.CreateBitCast(VTable, FTy->getPointerTo()->getPointerTo()); - llvm::Value *VirtualFn = - Builder.CreateAlignedLoad(VTable, CGF.getPointerAlign(), - "memptr.virtualfn"); + VFPAddr = Builder.CreateBitCast(VFPAddr, FTy->getPointerTo()->getPointerTo()); + llvm::Value *VirtualFn = Builder.CreateAlignedLoad( + VFPAddr, CGF.getPointerAlign(), "memptr.virtualfn"); CGF.EmitBranch(FnEnd); // In the non-virtual path, the function pointer is actually a @@ -637,6 +677,43 @@ CGCallee ItaniumCXXABI::EmitLoadOfMemberFunctionPointer( llvm::Value *NonVirtualFn = Builder.CreateIntToPtr(FnAsInt, FTy->getPointerTo(), "memptr.nonvirtualfn"); + // Check the function pointer if CFI on member function pointers is enabled. + if (ShouldEmitCFICheck) { + CXXRecordDecl *RD = MPT->getClass()->getAsCXXRecordDecl(); + if (RD->hasDefinition()) { + CodeGenFunction::SanitizerScope SanScope(&CGF); + + llvm::Constant *StaticData[] = { + llvm::ConstantInt::get(CGF.Int8Ty, CodeGenFunction::CFITCK_NVMFCall), + CheckSourceLocation, + CheckTypeDesc, + }; + + llvm::Value *Bit = Builder.getFalse(); + llvm::Value *CastedNonVirtualFn = + Builder.CreateBitCast(NonVirtualFn, CGF.Int8PtrTy); + for (const CXXRecordDecl *Base : CGM.getMostBaseClasses(RD)) { + llvm::Metadata *MD = CGM.CreateMetadataIdentifierForType( + getContext().getMemberPointerType( + MPT->getPointeeType(), + getContext().getRecordType(Base).getTypePtr())); + llvm::Value *TypeId = + llvm::MetadataAsValue::get(CGF.getLLVMContext(), MD); + + llvm::Value *TypeTest = + Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::type_test), + {CastedNonVirtualFn, TypeId}); + Bit = Builder.CreateOr(Bit, TypeTest); + } + + CGF.EmitCheck(std::make_pair(Bit, SanitizerKind::CFIMFCall), + SanitizerHandler::CFICheckFail, StaticData, + {CastedNonVirtualFn, llvm::UndefValue::get(CGF.IntPtrTy)}); + + FnNonVirtual = Builder.GetInsertBlock(); + } + } + // We're done. CGF.EmitBlock(FnEnd); llvm::PHINode *CalleePtr = Builder.CreatePHI(FTy->getPointerTo(), 2); |