summaryrefslogtreecommitdiffstats
path: root/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp
diff options
context:
space:
mode:
authorTom Roeder <tmroeder@google.com>2014-11-11 21:08:02 +0000
committerTom Roeder <tmroeder@google.com>2014-11-11 21:08:02 +0000
commiteb7a303d1beb57484d8e559801552fd9745a0d78 (patch)
tree2e4e605f3d054b9b1cb7d363d073410581729127 /llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp
parenteb4675fb29bd689a1ecd5709bbd39d8ae2426feb (diff)
downloadbcm5719-llvm-eb7a303d1beb57484d8e559801552fd9745a0d78.tar.gz
bcm5719-llvm-eb7a303d1beb57484d8e559801552fd9745a0d78.zip
Add Forward Control-Flow Integrity.
This commit adds a new pass that can inject checks before indirect calls to make sure that these calls target known locations. It supports three types of checks and, at compile time, it can take the name of a custom function to call when an indirect call check fails. The default failure function ignores the error and continues. This pass incidentally moves the function JumpInstrTables::transformType from private to public and makes it static (with a new argument that specifies the table type to use); this is so that the CFI code can transform function types at call sites to determine which jump-instruction table to use for the check at that site. Also, this removes support for jumptables in ARM, pending further performance analysis and discussion. Review: http://reviews.llvm.org/D4167 llvm-svn: 221708
Diffstat (limited to 'llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp')
-rw-r--r--llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp375
1 files changed, 375 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp b/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp
new file mode 100644
index 00000000000..679faeffefc
--- /dev/null
+++ b/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp
@@ -0,0 +1,375 @@
+//===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===//
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// \brief A pass that instruments code with fast checks for indirect calls and
+/// hooks for a function to check violations.
+///
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "cfi"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/JumpInstrTableInfo.h"
+#include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
+#include "llvm/CodeGen/JumpInstrTables.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+STATISTIC(NumCFIIndirectCalls,
+ "Number of indirect call sites rewritten by the CFI pass");
+
+char ForwardControlFlowIntegrity::ID = 0;
+INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi",
+ "Control-Flow Integrity", true, true)
+INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
+INITIALIZE_PASS_DEPENDENCY(JumpInstrTables);
+INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi",
+ "Control-Flow Integrity", true, true)
+
+ModulePass *llvm::createForwardControlFlowIntegrityPass() {
+ return new ForwardControlFlowIntegrity();
+}
+
+ModulePass *llvm::createForwardControlFlowIntegrityPass(
+ JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
+ StringRef CFIFuncName) {
+ return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing,
+ CFIFuncName);
+}
+
+// Checks to see if a given CallSite is making an indirect call, including
+// cases where the indirect call is made through a bitcast.
+static bool isIndirectCall(CallSite &CS) {
+ if (CS.getCalledFunction())
+ return false;
+
+ // Check the value to see if it is merely a bitcast of a function. In
+ // this case, it will translate to a direct function call in the resulting
+ // assembly, so we won't treat it as an indirect call here.
+ const Value *V = CS.getCalledValue();
+ if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
+ return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
+ }
+
+ // Otherwise, since we know it's a call, it must be an indirect call
+ return true;
+}
+
+static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning";
+static const char cfi_func_name_prefix[] = "__llvm_cfi_function_";
+
+ForwardControlFlowIntegrity::ForwardControlFlowIntegrity()
+ : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single),
+ CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") {
+ initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
+}
+
+ForwardControlFlowIntegrity::ForwardControlFlowIntegrity(
+ JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
+ std::string CFIFuncName)
+ : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType),
+ CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) {
+ initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
+}
+
+ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {}
+
+void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<JumpInstrTableInfo>();
+ AU.addRequired<JumpInstrTables>();
+}
+
+void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) {
+ // To get the indirect calls, we iterate over all functions and iterate over
+ // the list of basic blocks in each. We extract a total list of indirect calls
+ // before modifying any of them, since our modifications will modify the list
+ // of basic blocks.
+ for (Function &F : M) {
+ for (BasicBlock &BB : F) {
+ for (Instruction &I : BB) {
+ CallSite CS(&I);
+ if (!(CS && isIndirectCall(CS)))
+ continue;
+
+ Value *CalledValue = CS.getCalledValue();
+
+ // Don't rewrite this instruction if the indirect call is actually just
+ // inline assembly, since our transformation will generate an invalid
+ // module in that case.
+ if (isa<InlineAsm>(CalledValue))
+ continue;
+
+ IndirectCalls.push_back(&I);
+ }
+ }
+ }
+}
+
+void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M,
+ CFITables &CFIT) {
+ Type *Int64Ty = Type::getInt64Ty(M.getContext());
+ for (Instruction *I : IndirectCalls) {
+ CallSite CS(I);
+ Value *CalledValue = CS.getCalledValue();
+
+ // Get the function type for this call and look it up in the tables.
+ Type *VTy = CalledValue->getType();
+ PointerType *PTy = dyn_cast<PointerType>(VTy);
+ Type *EltTy = PTy->getElementType();
+ FunctionType *FunTy = dyn_cast<FunctionType>(EltTy);
+ FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy);
+ ++NumCFIIndirectCalls;
+ Constant *JumpTableStart = nullptr;
+ Constant *JumpTableMask = nullptr;
+ Constant *JumpTableSize = nullptr;
+
+ // Some call sites have function types that don't correspond to any
+ // address-taken function in the module. This happens when function pointers
+ // are passed in from external code.
+ auto it = CFIT.find(TransformedTy);
+ if (it == CFIT.end()) {
+ // In this case, make sure that the function pointer will change by
+ // setting the mask and the start to be 0 so that the transformed
+ // function is 0.
+ JumpTableStart = ConstantInt::get(Int64Ty, 0);
+ JumpTableMask = ConstantInt::get(Int64Ty, 0);
+ JumpTableSize = ConstantInt::get(Int64Ty, 0);
+ } else {
+ JumpTableStart = it->second.StartValue;
+ JumpTableMask = it->second.MaskValue;
+ JumpTableSize = it->second.Size;
+ }
+
+ rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask,
+ JumpTableSize);
+ }
+
+ return;
+}
+
+bool ForwardControlFlowIntegrity::runOnModule(Module &M) {
+ JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>();
+ Type *Int64Ty = Type::getInt64Ty(M.getContext());
+ Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
+
+ // JumpInstrTableInfo stores information about the alignment of each entry.
+ // The alignment returned by JumpInstrTableInfo is alignment in bytes, not
+ // in the exponent.
+ ByteAlignment = JITI->entryByteAlignment();
+ LogByteAlignment = llvm::Log2_64(ByteAlignment);
+
+ // Set up tables for control-flow integrity based on information about the
+ // jump-instruction tables.
+ CFITables CFIT;
+ for (const auto &KV : JITI->getTables()) {
+ uint64_t Size = static_cast<uint64_t>(KV.second.size());
+ uint64_t TableSize = NextPowerOf2(Size);
+
+ int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment;
+ Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue);
+ Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size);
+
+ // The base of the table is defined to be the first jumptable function in
+ // the table.
+ Function *First = KV.second.begin()->second;
+ Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy);
+ CFIT[KV.first].StartValue = JumpTableStartValue;
+ CFIT[KV.first].MaskValue = JumpTableMaskValue;
+ CFIT[KV.first].Size = JumpTableSize;
+ }
+
+ if (CFIT.empty())
+ return false;
+
+ getIndirectCalls(M);
+
+ if (!CFIEnforcing) {
+ addWarningFunction(M);
+ }
+
+ // Update the instructions with the check and the indirect jump through our
+ // table.
+ updateIndirectCalls(M, CFIT);
+
+ return true;
+}
+
+void ForwardControlFlowIntegrity::addWarningFunction(Module &M) {
+ PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext());
+
+ // Get the type of the Warning Function: void (i8*, i8*),
+ // where the first argument is the name of the function in which the violation
+ // occurs, and the second is the function pointer that violates CFI.
+ SmallVector<Type *, 2> WarningFunArgs;
+ WarningFunArgs.push_back(CharPtrTy);
+ WarningFunArgs.push_back(CharPtrTy);
+ FunctionType *WarningFunTy =
+ FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false);
+
+ if (!CFIFuncName.empty()) {
+ Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy);
+ if (!FailureFun)
+ report_fatal_error("Could not get or insert the function specified by"
+ " -cfi-func-name");
+ } else {
+ // The default warning function swallows the warning and lets the call
+ // continue, since there's no generic way for it to print out this
+ // information.
+ Function *WarningFun = M.getFunction(cfi_failure_func_name);
+ if (!WarningFun) {
+ WarningFun =
+ Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage,
+ cfi_failure_func_name, &M);
+ }
+
+ BasicBlock *Entry =
+ BasicBlock::Create(M.getContext(), "entry", WarningFun, 0);
+ ReturnInst::Create(M.getContext(), Entry);
+ }
+}
+
+void ForwardControlFlowIntegrity::rewriteFunctionPointer(
+ Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart,
+ Constant *JumpTableMask, Constant *JumpTableSize) {
+ IRBuilder<> TempBuilder(I);
+
+ Type *OrigFunType = FunPtr->getType();
+
+ BasicBlock *CurBB = cast<BasicBlock>(I->getParent());
+ Function *CurF = cast<Function>(CurBB->getParent());
+ Type *Int64Ty = Type::getInt64Ty(M.getContext());
+
+ Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty);
+ Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty);
+
+ Value *NewFunPtr = nullptr;
+ Value *Check = nullptr;
+ switch (CFIType) {
+ case CFIntegrity::Sub: {
+ // This is the subtract, mask, and add version.
+ // Subtract from the base.
+ Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
+
+ // Mask the difference to force this to be a table offset.
+ Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask);
+
+ // Add it back to the base.
+ Value *Result = TempBuilder.CreateAdd(And, TStartInt);
+
+ // Convert it back into a function pointer that we can call.
+ NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
+ break;
+ }
+ case CFIntegrity::Ror: {
+ // This is the subtract and rotate version.
+ // Rotate right by the alignment value. The optimizer should recognize
+ // this sequence as a rotation.
+
+ // This cast is safe, since unsigned is always a subset of uint64_t.
+ uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment);
+ Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64);
+ Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64);
+
+ // Subtract from the base.
+ Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
+
+ // Create the equivalent of a rotate-right instruction.
+ Value *Shr = TempBuilder.CreateLShr(Sub, RightShift);
+ Value *Shl = TempBuilder.CreateShl(Sub, LeftShift);
+ Value *Or = TempBuilder.CreateOr(Shr, Shl);
+
+ // Perform unsigned comparison to check for inclusion in the table.
+ Check = TempBuilder.CreateICmpULT(Or, JumpTableSize);
+ NewFunPtr = FunPtr;
+ break;
+ }
+ case CFIntegrity::Add: {
+ // This is the mask and add version.
+ // Mask the function pointer to turn it into an offset into the table.
+ Value *And = TempBuilder.CreateAnd(TI, JumpTableMask);
+
+ // Then or this offset to the base and get the pointer value.
+ Value *Result = TempBuilder.CreateAdd(And, TStartInt);
+
+ // Convert it back into a function pointer that we can call.
+ NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
+ break;
+ }
+ }
+
+ if (!CFIEnforcing) {
+ // If a check hasn't been added (in the rotation version), then check to see
+ // if it's the same as the original function. This check determines whether
+ // or not we call the CFI failure function.
+ if (!Check)
+ Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr);
+ BasicBlock *InvalidPtrBlock =
+ BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0);
+ BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I);
+
+ // Remove the unconditional branch that connects the two blocks.
+ TerminatorInst *TermInst = CurBB->getTerminator();
+ TermInst->eraseFromParent();
+
+ // Add a conditional branch that depends on the Check above.
+ BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB);
+
+ // Call the warning function for this pointer, then continue.
+ Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock);
+ insertWarning(M, InvalidPtrBlock, BI, FunPtr);
+ } else {
+ // Modify the instruction to call this value.
+ CallSite CS(I);
+ CS.setCalledFunction(NewFunPtr);
+ }
+}
+
+void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block,
+ Instruction *I, Value *FunPtr) {
+ Function *ParentFun = cast<Function>(Block->getParent());
+
+ // Get the function to call right before the instruction.
+ Function *WarningFun = nullptr;
+ if (CFIFuncName.empty()) {
+ WarningFun = M.getFunction(cfi_failure_func_name);
+ } else {
+ WarningFun = M.getFunction(CFIFuncName);
+ }
+
+ assert(WarningFun && "Could not find the CFI failure function");
+
+ Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
+
+ IRBuilder<> WarningInserter(I);
+ // Create a mergeable GlobalVariable containing the name of the function.
+ Value *ParentNameGV =
+ WarningInserter.CreateGlobalString(ParentFun->getName());
+ Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy);
+ Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy);
+ WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr);
+}
OpenPOWER on IntegriCloud