diff options
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 12 | ||||
-rw-r--r-- | llvm/lib/CodeGen/CMakeLists.txt | 1 | ||||
-rw-r--r-- | llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp | 375 | ||||
-rw-r--r-- | llvm/lib/CodeGen/JumpInstrTables.cpp | 13 | ||||
-rw-r--r-- | llvm/lib/CodeGen/LLVMTargetMachine.cpp | 9 | ||||
-rw-r--r-- | llvm/lib/CodeGen/TargetOptionsImpl.cpp | 7 |
6 files changed, 404 insertions, 13 deletions
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index 61f1d1b653e..087836901a7 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -879,16 +879,17 @@ bool AsmPrinter::doFinalization(Module &M) { bool IsThumb = (Arch == Triple::thumb || Arch == Triple::thumbeb); MCInst TrapInst; TM.getSubtargetImpl()->getInstrInfo()->getTrap(TrapInst); + unsigned LogAlignment = llvm::Log2_64(JITI->entryByteAlignment()); + + // Emit the right section for these functions. + OutStreamer.SwitchSection(OutContext.getObjectFileInfo()->getTextSection()); for (const auto &KV : JITI->getTables()) { uint64_t Count = 0; for (const auto &FunPair : KV.second) { // Emit the function labels to make this be a function entry point. MCSymbol *FunSym = OutContext.GetOrCreateSymbol(FunPair.second->getName()); - OutStreamer.EmitSymbolAttribute(FunSym, MCSA_Global); - // FIXME: JumpTableInstrInfo should store information about the required - // alignment of table entries and the size of the padding instruction. - EmitAlignment(3); + EmitAlignment(LogAlignment); if (IsThumb) OutStreamer.EmitThumbFunc(FunSym); if (MAI->hasDotTypeDotSizeDirective()) @@ -910,10 +911,9 @@ bool AsmPrinter::doFinalization(Module &M) { } // Emit enough padding instructions to fill up to the next power of two. - // This assumes that the trap instruction takes 8 bytes or fewer. uint64_t Remaining = NextPowerOf2(Count) - Count; for (uint64_t C = 0; C < Remaining; ++C) { - EmitAlignment(3); + EmitAlignment(LogAlignment); OutStreamer.EmitInstruction(TrapInst, getSubtargetInfo()); } diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt index 0c84a90d4a8..092346bacd7 100644 --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -19,6 +19,7 @@ add_llvm_library(LLVMCodeGen ExecutionDepsFix.cpp ExpandISelPseudos.cpp ExpandPostRAPseudos.cpp + ForwardControlFlowIntegrity.cpp GCMetadata.cpp GCMetadataPrinter.cpp GCStrategy.cpp diff --git a/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp b/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp new file mode 100644 index 00000000000..679faeffefc --- /dev/null +++ b/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp @@ -0,0 +1,375 @@ +//===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===// +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// \brief A pass that instruments code with fast checks for indirect calls and +/// hooks for a function to check violations. +/// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "cfi" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/JumpInstrTableInfo.h" +#include "llvm/CodeGen/ForwardControlFlowIntegrity.h" +#include "llvm/CodeGen/JumpInstrTables.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +STATISTIC(NumCFIIndirectCalls, + "Number of indirect call sites rewritten by the CFI pass"); + +char ForwardControlFlowIntegrity::ID = 0; +INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi", + "Control-Flow Integrity", true, true) +INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo); +INITIALIZE_PASS_DEPENDENCY(JumpInstrTables); +INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi", + "Control-Flow Integrity", true, true) + +ModulePass *llvm::createForwardControlFlowIntegrityPass() { + return new ForwardControlFlowIntegrity(); +} + +ModulePass *llvm::createForwardControlFlowIntegrityPass( + JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing, + StringRef CFIFuncName) { + return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing, + CFIFuncName); +} + +// Checks to see if a given CallSite is making an indirect call, including +// cases where the indirect call is made through a bitcast. +static bool isIndirectCall(CallSite &CS) { + if (CS.getCalledFunction()) + return false; + + // Check the value to see if it is merely a bitcast of a function. In + // this case, it will translate to a direct function call in the resulting + // assembly, so we won't treat it as an indirect call here. + const Value *V = CS.getCalledValue(); + if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + return !(CE->isCast() && isa<Function>(CE->getOperand(0))); + } + + // Otherwise, since we know it's a call, it must be an indirect call + return true; +} + +static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning"; +static const char cfi_func_name_prefix[] = "__llvm_cfi_function_"; + +ForwardControlFlowIntegrity::ForwardControlFlowIntegrity() + : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single), + CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") { + initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry()); +} + +ForwardControlFlowIntegrity::ForwardControlFlowIntegrity( + JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing, + std::string CFIFuncName) + : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType), + CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) { + initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry()); +} + +ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {} + +void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<JumpInstrTableInfo>(); + AU.addRequired<JumpInstrTables>(); +} + +void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) { + // To get the indirect calls, we iterate over all functions and iterate over + // the list of basic blocks in each. We extract a total list of indirect calls + // before modifying any of them, since our modifications will modify the list + // of basic blocks. + for (Function &F : M) { + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + CallSite CS(&I); + if (!(CS && isIndirectCall(CS))) + continue; + + Value *CalledValue = CS.getCalledValue(); + + // Don't rewrite this instruction if the indirect call is actually just + // inline assembly, since our transformation will generate an invalid + // module in that case. + if (isa<InlineAsm>(CalledValue)) + continue; + + IndirectCalls.push_back(&I); + } + } + } +} + +void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M, + CFITables &CFIT) { + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + for (Instruction *I : IndirectCalls) { + CallSite CS(I); + Value *CalledValue = CS.getCalledValue(); + + // Get the function type for this call and look it up in the tables. + Type *VTy = CalledValue->getType(); + PointerType *PTy = dyn_cast<PointerType>(VTy); + Type *EltTy = PTy->getElementType(); + FunctionType *FunTy = dyn_cast<FunctionType>(EltTy); + FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy); + ++NumCFIIndirectCalls; + Constant *JumpTableStart = nullptr; + Constant *JumpTableMask = nullptr; + Constant *JumpTableSize = nullptr; + + // Some call sites have function types that don't correspond to any + // address-taken function in the module. This happens when function pointers + // are passed in from external code. + auto it = CFIT.find(TransformedTy); + if (it == CFIT.end()) { + // In this case, make sure that the function pointer will change by + // setting the mask and the start to be 0 so that the transformed + // function is 0. + JumpTableStart = ConstantInt::get(Int64Ty, 0); + JumpTableMask = ConstantInt::get(Int64Ty, 0); + JumpTableSize = ConstantInt::get(Int64Ty, 0); + } else { + JumpTableStart = it->second.StartValue; + JumpTableMask = it->second.MaskValue; + JumpTableSize = it->second.Size; + } + + rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask, + JumpTableSize); + } + + return; +} + +bool ForwardControlFlowIntegrity::runOnModule(Module &M) { + JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>(); + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext()); + + // JumpInstrTableInfo stores information about the alignment of each entry. + // The alignment returned by JumpInstrTableInfo is alignment in bytes, not + // in the exponent. + ByteAlignment = JITI->entryByteAlignment(); + LogByteAlignment = llvm::Log2_64(ByteAlignment); + + // Set up tables for control-flow integrity based on information about the + // jump-instruction tables. + CFITables CFIT; + for (const auto &KV : JITI->getTables()) { + uint64_t Size = static_cast<uint64_t>(KV.second.size()); + uint64_t TableSize = NextPowerOf2(Size); + + int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment; + Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue); + Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size); + + // The base of the table is defined to be the first jumptable function in + // the table. + Function *First = KV.second.begin()->second; + Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy); + CFIT[KV.first].StartValue = JumpTableStartValue; + CFIT[KV.first].MaskValue = JumpTableMaskValue; + CFIT[KV.first].Size = JumpTableSize; + } + + if (CFIT.empty()) + return false; + + getIndirectCalls(M); + + if (!CFIEnforcing) { + addWarningFunction(M); + } + + // Update the instructions with the check and the indirect jump through our + // table. + updateIndirectCalls(M, CFIT); + + return true; +} + +void ForwardControlFlowIntegrity::addWarningFunction(Module &M) { + PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext()); + + // Get the type of the Warning Function: void (i8*, i8*), + // where the first argument is the name of the function in which the violation + // occurs, and the second is the function pointer that violates CFI. + SmallVector<Type *, 2> WarningFunArgs; + WarningFunArgs.push_back(CharPtrTy); + WarningFunArgs.push_back(CharPtrTy); + FunctionType *WarningFunTy = + FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false); + + if (!CFIFuncName.empty()) { + Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy); + if (!FailureFun) + report_fatal_error("Could not get or insert the function specified by" + " -cfi-func-name"); + } else { + // The default warning function swallows the warning and lets the call + // continue, since there's no generic way for it to print out this + // information. + Function *WarningFun = M.getFunction(cfi_failure_func_name); + if (!WarningFun) { + WarningFun = + Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage, + cfi_failure_func_name, &M); + } + + BasicBlock *Entry = + BasicBlock::Create(M.getContext(), "entry", WarningFun, 0); + ReturnInst::Create(M.getContext(), Entry); + } +} + +void ForwardControlFlowIntegrity::rewriteFunctionPointer( + Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart, + Constant *JumpTableMask, Constant *JumpTableSize) { + IRBuilder<> TempBuilder(I); + + Type *OrigFunType = FunPtr->getType(); + + BasicBlock *CurBB = cast<BasicBlock>(I->getParent()); + Function *CurF = cast<Function>(CurBB->getParent()); + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + + Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty); + Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty); + + Value *NewFunPtr = nullptr; + Value *Check = nullptr; + switch (CFIType) { + case CFIntegrity::Sub: { + // This is the subtract, mask, and add version. + // Subtract from the base. + Value *Sub = TempBuilder.CreateSub(TI, TStartInt); + + // Mask the difference to force this to be a table offset. + Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask); + + // Add it back to the base. + Value *Result = TempBuilder.CreateAdd(And, TStartInt); + + // Convert it back into a function pointer that we can call. + NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType); + break; + } + case CFIntegrity::Ror: { + // This is the subtract and rotate version. + // Rotate right by the alignment value. The optimizer should recognize + // this sequence as a rotation. + + // This cast is safe, since unsigned is always a subset of uint64_t. + uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment); + Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64); + Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64); + + // Subtract from the base. + Value *Sub = TempBuilder.CreateSub(TI, TStartInt); + + // Create the equivalent of a rotate-right instruction. + Value *Shr = TempBuilder.CreateLShr(Sub, RightShift); + Value *Shl = TempBuilder.CreateShl(Sub, LeftShift); + Value *Or = TempBuilder.CreateOr(Shr, Shl); + + // Perform unsigned comparison to check for inclusion in the table. + Check = TempBuilder.CreateICmpULT(Or, JumpTableSize); + NewFunPtr = FunPtr; + break; + } + case CFIntegrity::Add: { + // This is the mask and add version. + // Mask the function pointer to turn it into an offset into the table. + Value *And = TempBuilder.CreateAnd(TI, JumpTableMask); + + // Then or this offset to the base and get the pointer value. + Value *Result = TempBuilder.CreateAdd(And, TStartInt); + + // Convert it back into a function pointer that we can call. + NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType); + break; + } + } + + if (!CFIEnforcing) { + // If a check hasn't been added (in the rotation version), then check to see + // if it's the same as the original function. This check determines whether + // or not we call the CFI failure function. + if (!Check) + Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr); + BasicBlock *InvalidPtrBlock = + BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0); + BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I); + + // Remove the unconditional branch that connects the two blocks. + TerminatorInst *TermInst = CurBB->getTerminator(); + TermInst->eraseFromParent(); + + // Add a conditional branch that depends on the Check above. + BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB); + + // Call the warning function for this pointer, then continue. + Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock); + insertWarning(M, InvalidPtrBlock, BI, FunPtr); + } else { + // Modify the instruction to call this value. + CallSite CS(I); + CS.setCalledFunction(NewFunPtr); + } +} + +void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block, + Instruction *I, Value *FunPtr) { + Function *ParentFun = cast<Function>(Block->getParent()); + + // Get the function to call right before the instruction. + Function *WarningFun = nullptr; + if (CFIFuncName.empty()) { + WarningFun = M.getFunction(cfi_failure_func_name); + } else { + WarningFun = M.getFunction(CFIFuncName); + } + + assert(WarningFun && "Could not find the CFI failure function"); + + Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext()); + + IRBuilder<> WarningInserter(I); + // Create a mergeable GlobalVariable containing the name of the function. + Value *ParentNameGV = + WarningInserter.CreateGlobalString(ParentFun->getName()); + Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy); + Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy); + WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr); +} diff --git a/llvm/lib/CodeGen/JumpInstrTables.cpp b/llvm/lib/CodeGen/JumpInstrTables.cpp index 750f71f6022..20f775c1245 100644 --- a/llvm/lib/CodeGen/JumpInstrTables.cpp +++ b/llvm/lib/CodeGen/JumpInstrTables.cpp @@ -163,7 +163,7 @@ void JumpInstrTables::getAnalysisUsage(AnalysisUsage &AU) const { Function *JumpInstrTables::insertEntry(Module &M, Function *Target) { FunctionType *OrigFunTy = Target->getFunctionType(); - FunctionType *FunTy = transformType(OrigFunTy); + FunctionType *FunTy = transformType(JTType, OrigFunTy); JumpMap::iterator it = Metadata.find(FunTy); if (Metadata.end() == it) { @@ -191,11 +191,12 @@ Function *JumpInstrTables::insertEntry(Module &M, Function *Target) { } bool JumpInstrTables::hasTable(FunctionType *FunTy) { - FunctionType *TransTy = transformType(FunTy); + FunctionType *TransTy = transformType(JTType, FunTy); return Metadata.end() != Metadata.find(TransTy); } -FunctionType *JumpInstrTables::transformType(FunctionType *FunTy) { +FunctionType *JumpInstrTables::transformType(JumpTable::JumpTableType JTT, + FunctionType *FunTy) { // Returning nullptr forces all types into the same table, since all types map // to the same type Type *VoidPtrTy = Type::getInt8PtrTy(FunTy->getContext()); @@ -211,7 +212,7 @@ FunctionType *JumpInstrTables::transformType(FunctionType *FunTy) { Type *Int32Ty = Type::getInt32Ty(FunTy->getContext()); FunctionType *VoidFnTy = FunctionType::get( Type::getVoidTy(FunTy->getContext()), EmptyParams, false); - switch (JTType) { + switch (JTT) { case JumpTable::Single: return FunctionType::get(RetTy, EmptyParams, false); @@ -253,10 +254,10 @@ FunctionType *JumpInstrTables::transformType(FunctionType *FunTy) { bool JumpInstrTables::runOnModule(Module &M) { JITI = &getAnalysis<JumpInstrTableInfo>(); - // Get the set of jumptable-annotated functions. + // Get the set of jumptable-annotated functions that have their address taken. DenseMap<Function *, Function *> Functions; for (Function &F : M) { - if (F.hasFnAttribute(Attribute::JumpTable)) { + if (F.hasFnAttribute(Attribute::JumpTable) && F.hasAddressTaken()) { assert(F.hasUnnamedAddr() && "Attribute 'jumptable' requires 'unnamed_addr'"); Functions[&F] = nullptr; diff --git a/llvm/lib/CodeGen/LLVMTargetMachine.cpp b/llvm/lib/CodeGen/LLVMTargetMachine.cpp index 8afdf5dc467..61face27f14 100644 --- a/llvm/lib/CodeGen/LLVMTargetMachine.cpp +++ b/llvm/lib/CodeGen/LLVMTargetMachine.cpp @@ -13,8 +13,10 @@ #include "llvm/Target/TargetMachine.h" +#include "llvm/Analysis/JumpInstrTableInfo.h" #include "llvm/Analysis/Passes.h" #include "llvm/CodeGen/AsmPrinter.h" +#include "llvm/CodeGen/ForwardControlFlowIntegrity.h" #include "llvm/CodeGen/JumpInstrTables.h" #include "llvm/CodeGen/MachineFunctionAnalysis.h" #include "llvm/CodeGen/MachineModuleInfo.h" @@ -143,8 +145,13 @@ bool LLVMTargetMachine::addPassesToEmitFile(PassManagerBase &PM, AnalysisID StopAfter) { // Passes to handle jumptable function annotations. These can't be handled at // JIT time, so we don't add them directly to addPassesToGenerateCode. - PM.add(createJumpInstrTableInfoPass()); + PM.add(createJumpInstrTableInfoPass( + getSubtargetImpl()->getInstrInfo()->getJumpInstrTableEntryBound())); PM.add(createJumpInstrTablesPass(Options.JTType)); + if (Options.FCFI) + PM.add(createForwardControlFlowIntegrityPass( + Options.JTType, Options.CFIType, Options.CFIEnforcing, + Options.getCFIFuncName())); // Add common CodeGen passes. MCContext *Context = addPassesToGenerateCode(this, PM, DisableVerify, diff --git a/llvm/lib/CodeGen/TargetOptionsImpl.cpp b/llvm/lib/CodeGen/TargetOptionsImpl.cpp index 3ca2017550c..618d903a090 100644 --- a/llvm/lib/CodeGen/TargetOptionsImpl.cpp +++ b/llvm/lib/CodeGen/TargetOptionsImpl.cpp @@ -51,3 +51,10 @@ bool TargetOptions::HonorSignDependentRoundingFPMath() const { StringRef TargetOptions::getTrapFunctionName() const { return TrapFuncName; } + +/// getCFIFuncName - If this returns a non-empty string, then it is the name of +/// the function that gets called on CFI violations in CFI non-enforcing mode +/// (!TargetOptions::CFIEnforcing). +StringRef TargetOptions::getCFIFuncName() const { + return CFIFuncName; +} |