summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
authorTom Roeder <tmroeder@google.com>2014-11-11 21:08:02 +0000
committerTom Roeder <tmroeder@google.com>2014-11-11 21:08:02 +0000
commiteb7a303d1beb57484d8e559801552fd9745a0d78 (patch)
tree2e4e605f3d054b9b1cb7d363d073410581729127 /llvm/lib
parenteb4675fb29bd689a1ecd5709bbd39d8ae2426feb (diff)
downloadbcm5719-llvm-eb7a303d1beb57484d8e559801552fd9745a0d78.tar.gz
bcm5719-llvm-eb7a303d1beb57484d8e559801552fd9745a0d78.zip
Add Forward Control-Flow Integrity.
This commit adds a new pass that can inject checks before indirect calls to make sure that these calls target known locations. It supports three types of checks and, at compile time, it can take the name of a custom function to call when an indirect call check fails. The default failure function ignores the error and continues. This pass incidentally moves the function JumpInstrTables::transformType from private to public and makes it static (with a new argument that specifies the table type to use); this is so that the CFI code can transform function types at call sites to determine which jump-instruction table to use for the check at that site. Also, this removes support for jumptables in ARM, pending further performance analysis and discussion. Review: http://reviews.llvm.org/D4167 llvm-svn: 221708
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/JumpInstrTableInfo.cpp17
-rw-r--r--llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp12
-rw-r--r--llvm/lib/CodeGen/CMakeLists.txt1
-rw-r--r--llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp375
-rw-r--r--llvm/lib/CodeGen/JumpInstrTables.cpp13
-rw-r--r--llvm/lib/CodeGen/LLVMTargetMachine.cpp9
-rw-r--r--llvm/lib/CodeGen/TargetOptionsImpl.cpp7
-rw-r--r--llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp23
-rw-r--r--llvm/lib/Target/ARM/ARMBaseInstrInfo.h6
-rw-r--r--llvm/lib/Target/X86/X86InstrInfo.cpp16
-rw-r--r--llvm/lib/Target/X86/X86InstrInfo.h2
11 files changed, 438 insertions, 43 deletions
diff --git a/llvm/lib/Analysis/JumpInstrTableInfo.cpp b/llvm/lib/Analysis/JumpInstrTableInfo.cpp
index b5b426533ff..7aae2a5592e 100644
--- a/llvm/lib/Analysis/JumpInstrTableInfo.cpp
+++ b/llvm/lib/Analysis/JumpInstrTableInfo.cpp
@@ -17,6 +17,7 @@
#include "llvm/Analysis/Passes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
+#include "llvm/Support/MathExtras.h"
using namespace llvm;
@@ -28,7 +29,21 @@ ImmutablePass *llvm::createJumpInstrTableInfoPass() {
return new JumpInstrTableInfo();
}
-JumpInstrTableInfo::JumpInstrTableInfo() : ImmutablePass(ID), Tables() {
+ModulePass *llvm::createJumpInstrTableInfoPass(unsigned Bound) {
+ // This cast is always safe, since Bound is always in a subset of uint64_t.
+ uint64_t B = static_cast<uint64_t>(Bound);
+ return new JumpInstrTableInfo(B);
+}
+
+JumpInstrTableInfo::JumpInstrTableInfo(uint64_t ByteAlign)
+ : ImmutablePass(ID), Tables(), ByteAlignment(ByteAlign) {
+ if (!llvm::isPowerOf2_64(ByteAlign)) {
+ // Note that we don't explicitly handle overflow here, since we handle the 0
+ // case explicitly when a caller actually tries to create jumptable entries,
+ // and this is the return value on overflow.
+ ByteAlignment = llvm::NextPowerOf2(ByteAlign);
+ }
+
initializeJumpInstrTableInfoPass(*PassRegistry::getPassRegistry());
}
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index 61f1d1b653e..087836901a7 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -879,16 +879,17 @@ bool AsmPrinter::doFinalization(Module &M) {
bool IsThumb = (Arch == Triple::thumb || Arch == Triple::thumbeb);
MCInst TrapInst;
TM.getSubtargetImpl()->getInstrInfo()->getTrap(TrapInst);
+ unsigned LogAlignment = llvm::Log2_64(JITI->entryByteAlignment());
+
+ // Emit the right section for these functions.
+ OutStreamer.SwitchSection(OutContext.getObjectFileInfo()->getTextSection());
for (const auto &KV : JITI->getTables()) {
uint64_t Count = 0;
for (const auto &FunPair : KV.second) {
// Emit the function labels to make this be a function entry point.
MCSymbol *FunSym =
OutContext.GetOrCreateSymbol(FunPair.second->getName());
- OutStreamer.EmitSymbolAttribute(FunSym, MCSA_Global);
- // FIXME: JumpTableInstrInfo should store information about the required
- // alignment of table entries and the size of the padding instruction.
- EmitAlignment(3);
+ EmitAlignment(LogAlignment);
if (IsThumb)
OutStreamer.EmitThumbFunc(FunSym);
if (MAI->hasDotTypeDotSizeDirective())
@@ -910,10 +911,9 @@ bool AsmPrinter::doFinalization(Module &M) {
}
// Emit enough padding instructions to fill up to the next power of two.
- // This assumes that the trap instruction takes 8 bytes or fewer.
uint64_t Remaining = NextPowerOf2(Count) - Count;
for (uint64_t C = 0; C < Remaining; ++C) {
- EmitAlignment(3);
+ EmitAlignment(LogAlignment);
OutStreamer.EmitInstruction(TrapInst, getSubtargetInfo());
}
diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt
index 0c84a90d4a8..092346bacd7 100644
--- a/llvm/lib/CodeGen/CMakeLists.txt
+++ b/llvm/lib/CodeGen/CMakeLists.txt
@@ -19,6 +19,7 @@ add_llvm_library(LLVMCodeGen
ExecutionDepsFix.cpp
ExpandISelPseudos.cpp
ExpandPostRAPseudos.cpp
+ ForwardControlFlowIntegrity.cpp
GCMetadata.cpp
GCMetadataPrinter.cpp
GCStrategy.cpp
diff --git a/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp b/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp
new file mode 100644
index 00000000000..679faeffefc
--- /dev/null
+++ b/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp
@@ -0,0 +1,375 @@
+//===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===//
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// \brief A pass that instruments code with fast checks for indirect calls and
+/// hooks for a function to check violations.
+///
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "cfi"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/JumpInstrTableInfo.h"
+#include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
+#include "llvm/CodeGen/JumpInstrTables.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+STATISTIC(NumCFIIndirectCalls,
+ "Number of indirect call sites rewritten by the CFI pass");
+
+char ForwardControlFlowIntegrity::ID = 0;
+INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi",
+ "Control-Flow Integrity", true, true)
+INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
+INITIALIZE_PASS_DEPENDENCY(JumpInstrTables);
+INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi",
+ "Control-Flow Integrity", true, true)
+
+ModulePass *llvm::createForwardControlFlowIntegrityPass() {
+ return new ForwardControlFlowIntegrity();
+}
+
+ModulePass *llvm::createForwardControlFlowIntegrityPass(
+ JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
+ StringRef CFIFuncName) {
+ return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing,
+ CFIFuncName);
+}
+
+// Checks to see if a given CallSite is making an indirect call, including
+// cases where the indirect call is made through a bitcast.
+static bool isIndirectCall(CallSite &CS) {
+ if (CS.getCalledFunction())
+ return false;
+
+ // Check the value to see if it is merely a bitcast of a function. In
+ // this case, it will translate to a direct function call in the resulting
+ // assembly, so we won't treat it as an indirect call here.
+ const Value *V = CS.getCalledValue();
+ if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
+ return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
+ }
+
+ // Otherwise, since we know it's a call, it must be an indirect call
+ return true;
+}
+
+static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning";
+static const char cfi_func_name_prefix[] = "__llvm_cfi_function_";
+
+ForwardControlFlowIntegrity::ForwardControlFlowIntegrity()
+ : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single),
+ CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") {
+ initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
+}
+
+ForwardControlFlowIntegrity::ForwardControlFlowIntegrity(
+ JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
+ std::string CFIFuncName)
+ : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType),
+ CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) {
+ initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
+}
+
+ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {}
+
+void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<JumpInstrTableInfo>();
+ AU.addRequired<JumpInstrTables>();
+}
+
+void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) {
+ // To get the indirect calls, we iterate over all functions and iterate over
+ // the list of basic blocks in each. We extract a total list of indirect calls
+ // before modifying any of them, since our modifications will modify the list
+ // of basic blocks.
+ for (Function &F : M) {
+ for (BasicBlock &BB : F) {
+ for (Instruction &I : BB) {
+ CallSite CS(&I);
+ if (!(CS && isIndirectCall(CS)))
+ continue;
+
+ Value *CalledValue = CS.getCalledValue();
+
+ // Don't rewrite this instruction if the indirect call is actually just
+ // inline assembly, since our transformation will generate an invalid
+ // module in that case.
+ if (isa<InlineAsm>(CalledValue))
+ continue;
+
+ IndirectCalls.push_back(&I);
+ }
+ }
+ }
+}
+
+void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M,
+ CFITables &CFIT) {
+ Type *Int64Ty = Type::getInt64Ty(M.getContext());
+ for (Instruction *I : IndirectCalls) {
+ CallSite CS(I);
+ Value *CalledValue = CS.getCalledValue();
+
+ // Get the function type for this call and look it up in the tables.
+ Type *VTy = CalledValue->getType();
+ PointerType *PTy = dyn_cast<PointerType>(VTy);
+ Type *EltTy = PTy->getElementType();
+ FunctionType *FunTy = dyn_cast<FunctionType>(EltTy);
+ FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy);
+ ++NumCFIIndirectCalls;
+ Constant *JumpTableStart = nullptr;
+ Constant *JumpTableMask = nullptr;
+ Constant *JumpTableSize = nullptr;
+
+ // Some call sites have function types that don't correspond to any
+ // address-taken function in the module. This happens when function pointers
+ // are passed in from external code.
+ auto it = CFIT.find(TransformedTy);
+ if (it == CFIT.end()) {
+ // In this case, make sure that the function pointer will change by
+ // setting the mask and the start to be 0 so that the transformed
+ // function is 0.
+ JumpTableStart = ConstantInt::get(Int64Ty, 0);
+ JumpTableMask = ConstantInt::get(Int64Ty, 0);
+ JumpTableSize = ConstantInt::get(Int64Ty, 0);
+ } else {
+ JumpTableStart = it->second.StartValue;
+ JumpTableMask = it->second.MaskValue;
+ JumpTableSize = it->second.Size;
+ }
+
+ rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask,
+ JumpTableSize);
+ }
+
+ return;
+}
+
+bool ForwardControlFlowIntegrity::runOnModule(Module &M) {
+ JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>();
+ Type *Int64Ty = Type::getInt64Ty(M.getContext());
+ Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
+
+ // JumpInstrTableInfo stores information about the alignment of each entry.
+ // The alignment returned by JumpInstrTableInfo is alignment in bytes, not
+ // in the exponent.
+ ByteAlignment = JITI->entryByteAlignment();
+ LogByteAlignment = llvm::Log2_64(ByteAlignment);
+
+ // Set up tables for control-flow integrity based on information about the
+ // jump-instruction tables.
+ CFITables CFIT;
+ for (const auto &KV : JITI->getTables()) {
+ uint64_t Size = static_cast<uint64_t>(KV.second.size());
+ uint64_t TableSize = NextPowerOf2(Size);
+
+ int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment;
+ Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue);
+ Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size);
+
+ // The base of the table is defined to be the first jumptable function in
+ // the table.
+ Function *First = KV.second.begin()->second;
+ Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy);
+ CFIT[KV.first].StartValue = JumpTableStartValue;
+ CFIT[KV.first].MaskValue = JumpTableMaskValue;
+ CFIT[KV.first].Size = JumpTableSize;
+ }
+
+ if (CFIT.empty())
+ return false;
+
+ getIndirectCalls(M);
+
+ if (!CFIEnforcing) {
+ addWarningFunction(M);
+ }
+
+ // Update the instructions with the check and the indirect jump through our
+ // table.
+ updateIndirectCalls(M, CFIT);
+
+ return true;
+}
+
+void ForwardControlFlowIntegrity::addWarningFunction(Module &M) {
+ PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext());
+
+ // Get the type of the Warning Function: void (i8*, i8*),
+ // where the first argument is the name of the function in which the violation
+ // occurs, and the second is the function pointer that violates CFI.
+ SmallVector<Type *, 2> WarningFunArgs;
+ WarningFunArgs.push_back(CharPtrTy);
+ WarningFunArgs.push_back(CharPtrTy);
+ FunctionType *WarningFunTy =
+ FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false);
+
+ if (!CFIFuncName.empty()) {
+ Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy);
+ if (!FailureFun)
+ report_fatal_error("Could not get or insert the function specified by"
+ " -cfi-func-name");
+ } else {
+ // The default warning function swallows the warning and lets the call
+ // continue, since there's no generic way for it to print out this
+ // information.
+ Function *WarningFun = M.getFunction(cfi_failure_func_name);
+ if (!WarningFun) {
+ WarningFun =
+ Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage,
+ cfi_failure_func_name, &M);
+ }
+
+ BasicBlock *Entry =
+ BasicBlock::Create(M.getContext(), "entry", WarningFun, 0);
+ ReturnInst::Create(M.getContext(), Entry);
+ }
+}
+
+void ForwardControlFlowIntegrity::rewriteFunctionPointer(
+ Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart,
+ Constant *JumpTableMask, Constant *JumpTableSize) {
+ IRBuilder<> TempBuilder(I);
+
+ Type *OrigFunType = FunPtr->getType();
+
+ BasicBlock *CurBB = cast<BasicBlock>(I->getParent());
+ Function *CurF = cast<Function>(CurBB->getParent());
+ Type *Int64Ty = Type::getInt64Ty(M.getContext());
+
+ Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty);
+ Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty);
+
+ Value *NewFunPtr = nullptr;
+ Value *Check = nullptr;
+ switch (CFIType) {
+ case CFIntegrity::Sub: {
+ // This is the subtract, mask, and add version.
+ // Subtract from the base.
+ Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
+
+ // Mask the difference to force this to be a table offset.
+ Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask);
+
+ // Add it back to the base.
+ Value *Result = TempBuilder.CreateAdd(And, TStartInt);
+
+ // Convert it back into a function pointer that we can call.
+ NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
+ break;
+ }
+ case CFIntegrity::Ror: {
+ // This is the subtract and rotate version.
+ // Rotate right by the alignment value. The optimizer should recognize
+ // this sequence as a rotation.
+
+ // This cast is safe, since unsigned is always a subset of uint64_t.
+ uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment);
+ Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64);
+ Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64);
+
+ // Subtract from the base.
+ Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
+
+ // Create the equivalent of a rotate-right instruction.
+ Value *Shr = TempBuilder.CreateLShr(Sub, RightShift);
+ Value *Shl = TempBuilder.CreateShl(Sub, LeftShift);
+ Value *Or = TempBuilder.CreateOr(Shr, Shl);
+
+ // Perform unsigned comparison to check for inclusion in the table.
+ Check = TempBuilder.CreateICmpULT(Or, JumpTableSize);
+ NewFunPtr = FunPtr;
+ break;
+ }
+ case CFIntegrity::Add: {
+ // This is the mask and add version.
+ // Mask the function pointer to turn it into an offset into the table.
+ Value *And = TempBuilder.CreateAnd(TI, JumpTableMask);
+
+ // Then or this offset to the base and get the pointer value.
+ Value *Result = TempBuilder.CreateAdd(And, TStartInt);
+
+ // Convert it back into a function pointer that we can call.
+ NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
+ break;
+ }
+ }
+
+ if (!CFIEnforcing) {
+ // If a check hasn't been added (in the rotation version), then check to see
+ // if it's the same as the original function. This check determines whether
+ // or not we call the CFI failure function.
+ if (!Check)
+ Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr);
+ BasicBlock *InvalidPtrBlock =
+ BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0);
+ BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I);
+
+ // Remove the unconditional branch that connects the two blocks.
+ TerminatorInst *TermInst = CurBB->getTerminator();
+ TermInst->eraseFromParent();
+
+ // Add a conditional branch that depends on the Check above.
+ BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB);
+
+ // Call the warning function for this pointer, then continue.
+ Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock);
+ insertWarning(M, InvalidPtrBlock, BI, FunPtr);
+ } else {
+ // Modify the instruction to call this value.
+ CallSite CS(I);
+ CS.setCalledFunction(NewFunPtr);
+ }
+}
+
+void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block,
+ Instruction *I, Value *FunPtr) {
+ Function *ParentFun = cast<Function>(Block->getParent());
+
+ // Get the function to call right before the instruction.
+ Function *WarningFun = nullptr;
+ if (CFIFuncName.empty()) {
+ WarningFun = M.getFunction(cfi_failure_func_name);
+ } else {
+ WarningFun = M.getFunction(CFIFuncName);
+ }
+
+ assert(WarningFun && "Could not find the CFI failure function");
+
+ Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
+
+ IRBuilder<> WarningInserter(I);
+ // Create a mergeable GlobalVariable containing the name of the function.
+ Value *ParentNameGV =
+ WarningInserter.CreateGlobalString(ParentFun->getName());
+ Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy);
+ Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy);
+ WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr);
+}
diff --git a/llvm/lib/CodeGen/JumpInstrTables.cpp b/llvm/lib/CodeGen/JumpInstrTables.cpp
index 750f71f6022..20f775c1245 100644
--- a/llvm/lib/CodeGen/JumpInstrTables.cpp
+++ b/llvm/lib/CodeGen/JumpInstrTables.cpp
@@ -163,7 +163,7 @@ void JumpInstrTables::getAnalysisUsage(AnalysisUsage &AU) const {
Function *JumpInstrTables::insertEntry(Module &M, Function *Target) {
FunctionType *OrigFunTy = Target->getFunctionType();
- FunctionType *FunTy = transformType(OrigFunTy);
+ FunctionType *FunTy = transformType(JTType, OrigFunTy);
JumpMap::iterator it = Metadata.find(FunTy);
if (Metadata.end() == it) {
@@ -191,11 +191,12 @@ Function *JumpInstrTables::insertEntry(Module &M, Function *Target) {
}
bool JumpInstrTables::hasTable(FunctionType *FunTy) {
- FunctionType *TransTy = transformType(FunTy);
+ FunctionType *TransTy = transformType(JTType, FunTy);
return Metadata.end() != Metadata.find(TransTy);
}
-FunctionType *JumpInstrTables::transformType(FunctionType *FunTy) {
+FunctionType *JumpInstrTables::transformType(JumpTable::JumpTableType JTT,
+ FunctionType *FunTy) {
// Returning nullptr forces all types into the same table, since all types map
// to the same type
Type *VoidPtrTy = Type::getInt8PtrTy(FunTy->getContext());
@@ -211,7 +212,7 @@ FunctionType *JumpInstrTables::transformType(FunctionType *FunTy) {
Type *Int32Ty = Type::getInt32Ty(FunTy->getContext());
FunctionType *VoidFnTy = FunctionType::get(
Type::getVoidTy(FunTy->getContext()), EmptyParams, false);
- switch (JTType) {
+ switch (JTT) {
case JumpTable::Single:
return FunctionType::get(RetTy, EmptyParams, false);
@@ -253,10 +254,10 @@ FunctionType *JumpInstrTables::transformType(FunctionType *FunTy) {
bool JumpInstrTables::runOnModule(Module &M) {
JITI = &getAnalysis<JumpInstrTableInfo>();
- // Get the set of jumptable-annotated functions.
+ // Get the set of jumptable-annotated functions that have their address taken.
DenseMap<Function *, Function *> Functions;
for (Function &F : M) {
- if (F.hasFnAttribute(Attribute::JumpTable)) {
+ if (F.hasFnAttribute(Attribute::JumpTable) && F.hasAddressTaken()) {
assert(F.hasUnnamedAddr() &&
"Attribute 'jumptable' requires 'unnamed_addr'");
Functions[&F] = nullptr;
diff --git a/llvm/lib/CodeGen/LLVMTargetMachine.cpp b/llvm/lib/CodeGen/LLVMTargetMachine.cpp
index 8afdf5dc467..61face27f14 100644
--- a/llvm/lib/CodeGen/LLVMTargetMachine.cpp
+++ b/llvm/lib/CodeGen/LLVMTargetMachine.cpp
@@ -13,8 +13,10 @@
#include "llvm/Target/TargetMachine.h"
+#include "llvm/Analysis/JumpInstrTableInfo.h"
#include "llvm/Analysis/Passes.h"
#include "llvm/CodeGen/AsmPrinter.h"
+#include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
#include "llvm/CodeGen/JumpInstrTables.h"
#include "llvm/CodeGen/MachineFunctionAnalysis.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
@@ -143,8 +145,13 @@ bool LLVMTargetMachine::addPassesToEmitFile(PassManagerBase &PM,
AnalysisID StopAfter) {
// Passes to handle jumptable function annotations. These can't be handled at
// JIT time, so we don't add them directly to addPassesToGenerateCode.
- PM.add(createJumpInstrTableInfoPass());
+ PM.add(createJumpInstrTableInfoPass(
+ getSubtargetImpl()->getInstrInfo()->getJumpInstrTableEntryBound()));
PM.add(createJumpInstrTablesPass(Options.JTType));
+ if (Options.FCFI)
+ PM.add(createForwardControlFlowIntegrityPass(
+ Options.JTType, Options.CFIType, Options.CFIEnforcing,
+ Options.getCFIFuncName()));
// Add common CodeGen passes.
MCContext *Context = addPassesToGenerateCode(this, PM, DisableVerify,
diff --git a/llvm/lib/CodeGen/TargetOptionsImpl.cpp b/llvm/lib/CodeGen/TargetOptionsImpl.cpp
index 3ca2017550c..618d903a090 100644
--- a/llvm/lib/CodeGen/TargetOptionsImpl.cpp
+++ b/llvm/lib/CodeGen/TargetOptionsImpl.cpp
@@ -51,3 +51,10 @@ bool TargetOptions::HonorSignDependentRoundingFPMath() const {
StringRef TargetOptions::getTrapFunctionName() const {
return TrapFuncName;
}
+
+/// getCFIFuncName - If this returns a non-empty string, then it is the name of
+/// the function that gets called on CFI violations in CFI non-enforcing mode
+/// (!TargetOptions::CFIEnforcing).
+StringRef TargetOptions::getCFIFuncName() const {
+ return CFIFuncName;
+}
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
index fdb31c88cce..4ab05f910ff 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
@@ -4489,29 +4489,6 @@ breakPartialRegDependency(MachineBasicBlock::iterator MI,
MI->addRegisterKilled(DReg, TRI, true);
}
-void ARMBaseInstrInfo::getUnconditionalBranch(
- MCInst &Branch, const MCSymbolRefExpr *BranchTarget) const {
- if (Subtarget.isThumb())
- Branch.setOpcode(ARM::tB);
- else if (Subtarget.isThumb2())
- Branch.setOpcode(ARM::t2B);
- else
- Branch.setOpcode(ARM::Bcc);
-
- Branch.addOperand(MCOperand::CreateExpr(BranchTarget));
- Branch.addOperand(MCOperand::CreateImm(ARMCC::AL));
- Branch.addOperand(MCOperand::CreateReg(0));
-}
-
-void ARMBaseInstrInfo::getTrap(MCInst &MI) const {
- if (Subtarget.isThumb())
- MI.setOpcode(ARM::tTRAP);
- else if (Subtarget.useNaClTrap())
- MI.setOpcode(ARM::TRAPNaCl);
- else
- MI.setOpcode(ARM::TRAP);
-}
-
bool ARMBaseInstrInfo::hasNOP() const {
return (Subtarget.getFeatureBits() & ARM::HasV6T2Ops) != 0;
}
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
index ab5dc661faf..0ae291bccf9 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
@@ -289,12 +289,6 @@ public:
void breakPartialRegDependency(MachineBasicBlock::iterator, unsigned,
const TargetRegisterInfo *TRI) const override;
- void
- getUnconditionalBranch(MCInst &Branch,
- const MCSymbolRefExpr *BranchTarget) const override;
-
- void getTrap(MCInst &MI) const override;
-
/// Get the number of addresses by LDM or VLDM or zero for unknown.
unsigned getNumLDMAddresses(const MachineInstr *MI) const;
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 68c5ff44ca9..12de514a3d3 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -5477,16 +5477,32 @@ void X86InstrInfo::getNoopForMachoTarget(MCInst &NopInst) const {
NopInst.setOpcode(X86::NOOP);
}
+// This code must remain in sync with getJumpInstrTableEntryBound in this class!
+// In particular, getJumpInstrTableEntryBound must always return an upper bound
+// on the encoding lengths of the instructions generated by
+// getUnconditionalBranch and getTrap.
void X86InstrInfo::getUnconditionalBranch(
MCInst &Branch, const MCSymbolRefExpr *BranchTarget) const {
Branch.setOpcode(X86::JMP_4);
Branch.addOperand(MCOperand::CreateExpr(BranchTarget));
}
+// This code must remain in sync with getJumpInstrTableEntryBound in this class!
+// In particular, getJumpInstrTableEntryBound must always return an upper bound
+// on the encoding lengths of the instructions generated by
+// getUnconditionalBranch and getTrap.
void X86InstrInfo::getTrap(MCInst &MI) const {
MI.setOpcode(X86::TRAP);
}
+// See getTrap and getUnconditionalBranch for conditions on the value returned
+// by this function.
+unsigned X86InstrInfo::getJumpInstrTableEntryBound() const {
+ // 5 bytes suffice: JMP_4 Symbol@PLT is uses 1 byte (E9) for the JMP_4 and 4
+ // bytes for the symbol offset. And TRAP is ud2, which is two bytes (0F 0B).
+ return 5;
+}
+
bool X86InstrInfo::isHighLatencyDef(int opc) const {
switch (opc) {
default: return false;
diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h
index f3f54ae5664..57b19589545 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.h
+++ b/llvm/lib/Target/X86/X86InstrInfo.h
@@ -413,6 +413,8 @@ public:
void getTrap(MCInst &MI) const override;
+ unsigned getJumpInstrTableEntryBound() const override;
+
bool isHighLatencyDef(int opc) const override;
bool hasHighOperandLatency(const InstrItineraryData *ItinData,
OpenPOWER on IntegriCloud