diff options
Diffstat (limited to 'clang/lib')
-rw-r--r-- | clang/lib/CodeGen/CGClass.cpp | 4 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGVTables.cpp | 43 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenFunction.h | 2 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenModule.cpp | 89 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenModule.h | 19 | ||||
-rw-r--r-- | clang/lib/CodeGen/ItaniumCXXABI.cpp | 87 | ||||
-rw-r--r-- | clang/lib/Driver/SanitizerArgs.cpp | 29 | ||||
-rw-r--r-- | clang/lib/Driver/ToolChains/MSVC.cpp | 1 |
8 files changed, 224 insertions, 50 deletions
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp index 11a327d2d2f..0b9311f7771 100644 --- a/clang/lib/CodeGen/CGClass.cpp +++ b/clang/lib/CodeGen/CGClass.cpp @@ -2688,7 +2688,9 @@ void CodeGenFunction::EmitVTablePtrCheck(const CXXRecordDecl *RD, SSK = llvm::SanStat_CFI_UnrelatedCast; break; case CFITCK_ICall: - llvm_unreachable("not expecting CFITCK_ICall"); + case CFITCK_NVMFCall: + case CFITCK_VMFCall: + llvm_unreachable("unexpected sanitizer kind"); } std::string TypeName = RD->getQualifiedNameAsString(); diff --git a/clang/lib/CodeGen/CGVTables.cpp b/clang/lib/CodeGen/CGVTables.cpp index 86bceb6ee9f..5a2ec65f776 100644 --- a/clang/lib/CodeGen/CGVTables.cpp +++ b/clang/lib/CodeGen/CGVTables.cpp @@ -1012,30 +1012,29 @@ void CodeGenModule::EmitVTableTypeMetadata(llvm::GlobalVariable *VTable, CharUnits PointerWidth = Context.toCharUnitsFromBits(Context.getTargetInfo().getPointerWidth(0)); - typedef std::pair<const CXXRecordDecl *, unsigned> TypeMetadata; - std::vector<TypeMetadata> TypeMetadatas; - // Create type metadata for each address point. + typedef std::pair<const CXXRecordDecl *, unsigned> AddressPoint; + std::vector<AddressPoint> AddressPoints; for (auto &&AP : VTLayout.getAddressPoints()) - TypeMetadatas.push_back(std::make_pair( + AddressPoints.push_back(std::make_pair( AP.first.getBase(), VTLayout.getVTableOffset(AP.second.VTableIndex) + AP.second.AddressPointIndex)); - // Sort the type metadata for determinism. - llvm::sort(TypeMetadatas.begin(), TypeMetadatas.end(), - [this](const TypeMetadata &M1, const TypeMetadata &M2) { - if (&M1 == &M2) + // Sort the address points for determinism. + llvm::sort(AddressPoints.begin(), AddressPoints.end(), + [this](const AddressPoint &AP1, const AddressPoint &AP2) { + if (&AP1 == &AP2) return false; std::string S1; llvm::raw_string_ostream O1(S1); getCXXABI().getMangleContext().mangleTypeName( - QualType(M1.first->getTypeForDecl(), 0), O1); + QualType(AP1.first->getTypeForDecl(), 0), O1); O1.flush(); std::string S2; llvm::raw_string_ostream O2(S2); getCXXABI().getMangleContext().mangleTypeName( - QualType(M2.first->getTypeForDecl(), 0), O2); + QualType(AP2.first->getTypeForDecl(), 0), O2); O2.flush(); if (S1 < S2) @@ -1043,10 +1042,26 @@ void CodeGenModule::EmitVTableTypeMetadata(llvm::GlobalVariable *VTable, if (S1 != S2) return false; - return M1.second < M2.second; + return AP1.second < AP2.second; }); - for (auto TypeMetadata : TypeMetadatas) - AddVTableTypeMetadata(VTable, PointerWidth * TypeMetadata.second, - TypeMetadata.first); + ArrayRef<VTableComponent> Comps = VTLayout.vtable_components(); + for (auto AP : AddressPoints) { + // Create type metadata for the address point. + AddVTableTypeMetadata(VTable, PointerWidth * AP.second, AP.first); + + // The class associated with each address point could also potentially be + // used for indirect calls via a member function pointer, so we need to + // annotate the address of each function pointer with the appropriate member + // function pointer type. + for (unsigned I = 0; I != Comps.size(); ++I) { + if (Comps[I].getKind() != VTableComponent::CK_FunctionPointer) + continue; + llvm::Metadata *MD = CreateMetadataIdentifierForVirtualMemPtrType( + Context.getMemberPointerType( + Comps[I].getFunctionDecl()->getType(), + Context.getRecordType(AP.first).getTypePtr())); + VTable->addTypeMetadata((PointerWidth * I).getQuantity(), MD); + } + } } diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index afe199c4eb5..548a4178ef2 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -1765,6 +1765,8 @@ public: CFITCK_DerivedCast, CFITCK_UnrelatedCast, CFITCK_ICall, + CFITCK_NVMFCall, + CFITCK_VMFCall, }; /// Derived is the presumed address of an object of type T after a diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp index 5a2f2a01d39..35e9dea37a9 100644 --- a/clang/lib/CodeGen/CodeGenModule.cpp +++ b/clang/lib/CodeGen/CodeGenModule.cpp @@ -1132,6 +1132,34 @@ static bool hasUnwindExceptions(const LangOptions &LangOpts) { return true; } +static bool requiresMemberFunctionPointerTypeMetadata(CodeGenModule &CGM, + const CXXMethodDecl *MD) { + // Check that the type metadata can ever actually be used by a call. + if (!CGM.getCodeGenOpts().LTOUnit || + !CGM.HasHiddenLTOVisibility(MD->getParent())) + return false; + + // Only functions whose address can be taken with a member function pointer + // need this sort of type metadata. + return !MD->isStatic() && !MD->isVirtual() && !isa<CXXConstructorDecl>(MD) && + !isa<CXXDestructorDecl>(MD); +} + +std::vector<const CXXRecordDecl *> +CodeGenModule::getMostBaseClasses(const CXXRecordDecl *RD) { + llvm::SetVector<const CXXRecordDecl *> MostBases; + + std::function<void (const CXXRecordDecl *)> CollectMostBases; + CollectMostBases = [&](const CXXRecordDecl *RD) { + if (RD->getNumBases() == 0) + MostBases.insert(RD); + for (const CXXBaseSpecifier &B : RD->bases()) + CollectMostBases(B.getType()->getAsCXXRecordDecl()); + }; + CollectMostBases(RD); + return MostBases.takeVector(); +} + void CodeGenModule::SetLLVMFunctionAttributesForDefinition(const Decl *D, llvm::Function *F) { llvm::AttrBuilder B; @@ -1257,7 +1285,20 @@ void CodeGenModule::SetLLVMFunctionAttributesForDefinition(const Decl *D, // In the cross-dso CFI mode, we want !type attributes on definitions only. if (CodeGenOpts.SanitizeCfiCrossDso) if (auto *FD = dyn_cast<FunctionDecl>(D)) - CreateFunctionTypeMetadata(FD, F); + CreateFunctionTypeMetadataForIcall(FD, F); + + // Emit type metadata on member functions for member function pointer checks. + // These are only ever necessary on definitions; we're guaranteed that the + // definition will be present in the LTO unit as a result of LTO visibility. + auto *MD = dyn_cast<CXXMethodDecl>(D); + if (MD && requiresMemberFunctionPointerTypeMetadata(*this, MD)) { + for (const CXXRecordDecl *Base : getMostBaseClasses(MD->getParent())) { + llvm::Metadata *Id = + CreateMetadataIdentifierForType(Context.getMemberPointerType( + MD->getType(), Context.getRecordType(Base).getTypePtr())); + F->addTypeMetadata(0, Id); + } + } } void CodeGenModule::SetCommonAttributes(GlobalDecl GD, llvm::GlobalValue *GV) { @@ -1378,13 +1419,14 @@ static void setLinkageForGV(llvm::GlobalValue *GV, const NamedDecl *ND) { GV->setLinkage(llvm::GlobalValue::ExternalWeakLinkage); } -void CodeGenModule::CreateFunctionTypeMetadata(const FunctionDecl *FD, - llvm::Function *F) { +void CodeGenModule::CreateFunctionTypeMetadataForIcall(const FunctionDecl *FD, + llvm::Function *F) { // Only if we are checking indirect calls. if (!LangOpts.Sanitize.has(SanitizerKind::CFIICall)) return; - // Non-static class methods are handled via vtable pointer checks elsewhere. + // Non-static class methods are handled via vtable or member function pointer + // checks elsewhere. if (isa<CXXMethodDecl>(FD) && !cast<CXXMethodDecl>(FD)->isStatic()) return; @@ -1476,7 +1518,7 @@ void CodeGenModule::SetFunctionAttributes(GlobalDecl GD, llvm::Function *F, // Don't emit entries for function declarations in the cross-DSO mode. This // is handled with better precision by the receiving DSO. if (!CodeGenOpts.SanitizeCfiCrossDso) - CreateFunctionTypeMetadata(FD, F); + CreateFunctionTypeMetadataForIcall(FD, F); if (getLangOpts().OpenMP && FD->hasAttr<OMPDeclareSimdDeclAttr>()) getOpenMPRuntime().emitDeclareSimdFunction(FD, F); @@ -4925,8 +4967,10 @@ void CodeGenModule::EmitOMPThreadPrivateDecl(const OMPThreadPrivateDecl *D) { } } -llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { - llvm::Metadata *&InternalId = MetadataIdMap[T.getCanonicalType()]; +llvm::Metadata * +CodeGenModule::CreateMetadataIdentifierImpl(QualType T, MetadataTypeMap &Map, + StringRef Suffix) { + llvm::Metadata *&InternalId = Map[T.getCanonicalType()]; if (InternalId) return InternalId; @@ -4934,6 +4978,7 @@ llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { std::string OutName; llvm::raw_string_ostream Out(OutName); getCXXABI().getMangleContext().mangleTypeName(T, Out); + Out << Suffix; InternalId = llvm::MDString::get(getLLVMContext(), Out.str()); } else { @@ -4944,6 +4989,15 @@ llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { return InternalId; } +llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { + return CreateMetadataIdentifierImpl(T, MetadataIdMap, ""); +} + +llvm::Metadata * +CodeGenModule::CreateMetadataIdentifierForVirtualMemPtrType(QualType T) { + return CreateMetadataIdentifierImpl(T, VirtualMetadataIdMap, ".virtual"); +} + // Generalize pointer types to a void pointer with the qualifiers of the // originally pointed-to type, e.g. 'const char *' and 'char * const *' // generalize to 'const void *' while 'char *' and 'const char **' generalize to @@ -4977,25 +5031,8 @@ static QualType GeneralizeFunctionType(ASTContext &Ctx, QualType Ty) { } llvm::Metadata *CodeGenModule::CreateMetadataIdentifierGeneralized(QualType T) { - T = GeneralizeFunctionType(getContext(), T); - - llvm::Metadata *&InternalId = GeneralizedMetadataIdMap[T.getCanonicalType()]; - if (InternalId) - return InternalId; - - if (isExternallyVisible(T->getLinkage())) { - std::string OutName; - llvm::raw_string_ostream Out(OutName); - getCXXABI().getMangleContext().mangleTypeName(T, Out); - Out << ".generalized"; - - InternalId = llvm::MDString::get(getLLVMContext(), Out.str()); - } else { - InternalId = llvm::MDNode::getDistinct(getLLVMContext(), - llvm::ArrayRef<llvm::Metadata *>()); - } - - return InternalId; + return CreateMetadataIdentifierImpl(GeneralizeFunctionType(getContext(), T), + GeneralizedMetadataIdMap, ".generalized"); } /// Returns whether this module needs the "all-vtables" type identifier. diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h index bf22ad246c8..d2c7b327f98 100644 --- a/clang/lib/CodeGen/CodeGenModule.h +++ b/clang/lib/CodeGen/CodeGenModule.h @@ -503,6 +503,7 @@ private: /// MDNodes. typedef llvm::DenseMap<QualType, llvm::Metadata *> MetadataTypeMap; MetadataTypeMap MetadataIdMap; + MetadataTypeMap VirtualMetadataIdMap; MetadataTypeMap GeneralizedMetadataIdMap; public: @@ -1232,13 +1233,18 @@ public: /// internal identifiers). llvm::Metadata *CreateMetadataIdentifierForType(QualType T); + /// Create a metadata identifier that is intended to be used to check virtual + /// calls via a member function pointer. + llvm::Metadata *CreateMetadataIdentifierForVirtualMemPtrType(QualType T); + /// Create a metadata identifier for the generalization of the given type. /// This may either be an MDString (for external identifiers) or a distinct /// unnamed MDNode (for internal identifiers). llvm::Metadata *CreateMetadataIdentifierGeneralized(QualType T); /// Create and attach type metadata to the given function. - void CreateFunctionTypeMetadata(const FunctionDecl *FD, llvm::Function *F); + void CreateFunctionTypeMetadataForIcall(const FunctionDecl *FD, + llvm::Function *F); /// Returns whether this module needs the "all-vtables" type identifier. bool NeedAllVtablesTypeId() const; @@ -1247,6 +1253,14 @@ public: void AddVTableTypeMetadata(llvm::GlobalVariable *VTable, CharUnits Offset, const CXXRecordDecl *RD); + /// Return a vector of most-base classes for RD. This is used to implement + /// control flow integrity checks for member function pointers. + /// + /// A most-base class of a class C is defined as a recursive base class of C, + /// including C itself, that does not have any bases. + std::vector<const CXXRecordDecl *> + getMostBaseClasses(const CXXRecordDecl *RD); + /// Get the declaration of std::terminate for the platform. llvm::Constant *getTerminateFn(); @@ -1408,6 +1422,9 @@ private: void ConstructDefaultFnAttrList(StringRef Name, bool HasOptnone, bool AttrOnCallSite, llvm::AttrBuilder &FuncAttrs); + + llvm::Metadata *CreateMetadataIdentifierImpl(QualType T, MetadataTypeMap &Map, + StringRef Suffix); }; } // end namespace CodeGen diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index d855afab22e..8dd94e49556 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -622,13 +622,53 @@ CGCallee ItaniumCXXABI::EmitLoadOfMemberFunctionPointer( VTableOffset = Builder.CreateTrunc(VTableOffset, CGF.Int32Ty); VTableOffset = Builder.CreateZExt(VTableOffset, CGM.PtrDiffTy); } - VTable = Builder.CreateGEP(VTable, VTableOffset); + // Compute the address of the virtual function pointer. + llvm::Value *VFPAddr = Builder.CreateGEP(VTable, VTableOffset); + + // Check the address of the function pointer if CFI on member function + // pointers is enabled. + llvm::Constant *CheckSourceLocation; + llvm::Constant *CheckTypeDesc; + bool ShouldEmitCFICheck = CGF.SanOpts.has(SanitizerKind::CFIMFCall) && + CGM.HasHiddenLTOVisibility(RD); + if (ShouldEmitCFICheck) { + CodeGenFunction::SanitizerScope SanScope(&CGF); + + CheckSourceLocation = CGF.EmitCheckSourceLocation(E->getLocStart()); + CheckTypeDesc = CGF.EmitCheckTypeDescriptor(QualType(MPT, 0)); + llvm::Constant *StaticData[] = { + llvm::ConstantInt::get(CGF.Int8Ty, CodeGenFunction::CFITCK_VMFCall), + CheckSourceLocation, + CheckTypeDesc, + }; + + llvm::Metadata *MD = + CGM.CreateMetadataIdentifierForVirtualMemPtrType(QualType(MPT, 0)); + llvm::Value *TypeId = llvm::MetadataAsValue::get(CGF.getLLVMContext(), MD); + + llvm::Value *TypeTest = Builder.CreateCall( + CGM.getIntrinsic(llvm::Intrinsic::type_test), {VFPAddr, TypeId}); + + if (CGM.getCodeGenOpts().SanitizeTrap.has(SanitizerKind::CFIMFCall)) { + CGF.EmitTrapCheck(TypeTest); + } else { + llvm::Value *AllVtables = llvm::MetadataAsValue::get( + CGM.getLLVMContext(), + llvm::MDString::get(CGM.getLLVMContext(), "all-vtables")); + llvm::Value *ValidVtable = Builder.CreateCall( + CGM.getIntrinsic(llvm::Intrinsic::type_test), {VTable, AllVtables}); + CGF.EmitCheck(std::make_pair(TypeTest, SanitizerKind::CFIMFCall), + SanitizerHandler::CFICheckFail, StaticData, + {VTable, ValidVtable}); + } + + FnVirtual = Builder.GetInsertBlock(); + } // Load the virtual function to call. - VTable = Builder.CreateBitCast(VTable, FTy->getPointerTo()->getPointerTo()); - llvm::Value *VirtualFn = - Builder.CreateAlignedLoad(VTable, CGF.getPointerAlign(), - "memptr.virtualfn"); + VFPAddr = Builder.CreateBitCast(VFPAddr, FTy->getPointerTo()->getPointerTo()); + llvm::Value *VirtualFn = Builder.CreateAlignedLoad( + VFPAddr, CGF.getPointerAlign(), "memptr.virtualfn"); CGF.EmitBranch(FnEnd); // In the non-virtual path, the function pointer is actually a @@ -637,6 +677,43 @@ CGCallee ItaniumCXXABI::EmitLoadOfMemberFunctionPointer( llvm::Value *NonVirtualFn = Builder.CreateIntToPtr(FnAsInt, FTy->getPointerTo(), "memptr.nonvirtualfn"); + // Check the function pointer if CFI on member function pointers is enabled. + if (ShouldEmitCFICheck) { + CXXRecordDecl *RD = MPT->getClass()->getAsCXXRecordDecl(); + if (RD->hasDefinition()) { + CodeGenFunction::SanitizerScope SanScope(&CGF); + + llvm::Constant *StaticData[] = { + llvm::ConstantInt::get(CGF.Int8Ty, CodeGenFunction::CFITCK_NVMFCall), + CheckSourceLocation, + CheckTypeDesc, + }; + + llvm::Value *Bit = Builder.getFalse(); + llvm::Value *CastedNonVirtualFn = + Builder.CreateBitCast(NonVirtualFn, CGF.Int8PtrTy); + for (const CXXRecordDecl *Base : CGM.getMostBaseClasses(RD)) { + llvm::Metadata *MD = CGM.CreateMetadataIdentifierForType( + getContext().getMemberPointerType( + MPT->getPointeeType(), + getContext().getRecordType(Base).getTypePtr())); + llvm::Value *TypeId = + llvm::MetadataAsValue::get(CGF.getLLVMContext(), MD); + + llvm::Value *TypeTest = + Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::type_test), + {CastedNonVirtualFn, TypeId}); + Bit = Builder.CreateOr(Bit, TypeTest); + } + + CGF.EmitCheck(std::make_pair(Bit, SanitizerKind::CFIMFCall), + SanitizerHandler::CFICheckFail, StaticData, + {CastedNonVirtualFn, llvm::UndefValue::get(CGF.IntPtrTy)}); + + FnNonVirtual = Builder.GetInsertBlock(); + } + } + // We're done. CGF.EmitBlock(FnEnd); llvm::PHINode *CalleePtr = Builder.CreatePHI(FTy->getPointerTo(), 2); diff --git a/clang/lib/Driver/SanitizerArgs.cpp b/clang/lib/Driver/SanitizerArgs.cpp index 743e042a79d..bdc17d11c92 100644 --- a/clang/lib/Driver/SanitizerArgs.cpp +++ b/clang/lib/Driver/SanitizerArgs.cpp @@ -44,7 +44,8 @@ enum : SanitizerMask { TrappingSupported = (Undefined & ~Vptr) | UnsignedIntegerOverflow | Nullability | LocalBounds | CFI, TrappingDefault = CFI, - CFIClasses = CFIVCall | CFINVCall | CFIDerivedCast | CFIUnrelatedCast, + CFIClasses = + CFIVCall | CFINVCall | CFIMFCall | CFIDerivedCast | CFIUnrelatedCast, CompatibleWithMinimalRuntime = TrappingSupported | Scudo, }; @@ -219,6 +220,10 @@ SanitizerArgs::SanitizerArgs(const ToolChain &TC, // Used to deduplicate diagnostics. SanitizerMask Kinds = 0; const SanitizerMask Supported = setGroupBits(TC.getSupportedSanitizers()); + + CfiCrossDso = Args.hasFlag(options::OPT_fsanitize_cfi_cross_dso, + options::OPT_fno_sanitize_cfi_cross_dso, false); + ToolChain::RTTIMode RTTIMode = TC.getRTTIMode(); const Driver &D = TC.getDriver(); @@ -278,6 +283,24 @@ SanitizerArgs::SanitizerArgs(const ToolChain &TC, Add &= ~NotAllowedWithMinimalRuntime; } + // FIXME: Make CFI on member function calls compatible with cross-DSO CFI. + // There are currently two problems: + // - Virtual function call checks need to pass a pointer to the function + // address to llvm.type.test and a pointer to the address point to the + // diagnostic function. Currently we pass the same pointer to both + // places. + // - Non-virtual function call checks may need to check multiple type + // identifiers. + // Fixing both of those may require changes to the cross-DSO CFI + // interface. + if (CfiCrossDso && (Add & CFIMFCall & ~DiagnosedKinds)) { + D.Diag(diag::err_drv_argument_not_allowed_with) + << "-fsanitize=cfi-mfcall" + << "-fsanitize-cfi-cross-dso"; + Add &= ~CFIMFCall; + DiagnosedKinds |= CFIMFCall; + } + if (SanitizerMask KindsToDiagnose = Add & ~Supported & ~DiagnosedKinds) { std::string Desc = describeSanitizeArg(*I, KindsToDiagnose); D.Diag(diag::err_drv_unsupported_opt_for_target) @@ -316,6 +339,8 @@ SanitizerArgs::SanitizerArgs(const ToolChain &TC, if (MinimalRuntime) { Add &= ~NotAllowedWithMinimalRuntime; } + if (CfiCrossDso) + Add &= ~CFIMFCall; Add &= Supported; if (Add & Fuzzer) @@ -554,8 +579,6 @@ SanitizerArgs::SanitizerArgs(const ToolChain &TC, } if (AllAddedKinds & CFI) { - CfiCrossDso = Args.hasFlag(options::OPT_fsanitize_cfi_cross_dso, - options::OPT_fno_sanitize_cfi_cross_dso, false); // Without PIE, external function address may resolve to a PLT record, which // can not be verified by the target module. NeedPIE |= CfiCrossDso; diff --git a/clang/lib/Driver/ToolChains/MSVC.cpp b/clang/lib/Driver/ToolChains/MSVC.cpp index 22257092bb6..3ce1bfa0aa6 100644 --- a/clang/lib/Driver/ToolChains/MSVC.cpp +++ b/clang/lib/Driver/ToolChains/MSVC.cpp @@ -1297,6 +1297,7 @@ MSVCToolChain::ComputeEffectiveClangTriple(const ArgList &Args, SanitizerMask MSVCToolChain::getSupportedSanitizers() const { SanitizerMask Res = ToolChain::getSupportedSanitizers(); Res |= SanitizerKind::Address; + Res &= ~SanitizerKind::CFIMFCall; return Res; } |