diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/BoundsChecking.cpp | 115 | 
1 files changed, 84 insertions, 31 deletions
diff --git a/llvm/lib/Transforms/Scalar/BoundsChecking.cpp b/llvm/lib/Transforms/Scalar/BoundsChecking.cpp index c92ae2697ea..004f34f3bd6 100644 --- a/llvm/lib/Transforms/Scalar/BoundsChecking.cpp +++ b/llvm/lib/Transforms/Scalar/BoundsChecking.cpp @@ -19,6 +19,7 @@  #include "llvm/Support/Debug.h"  #include "llvm/Support/InstIterator.h"  #include "llvm/Support/IRBuilder.h" +#include "llvm/Support/raw_ostream.h"  #include "llvm/Support/TargetFolder.h"  #include "llvm/Target/TargetData.h"  #include "llvm/Transforms/Utils/Local.h" @@ -41,26 +42,29 @@ namespace {    };    struct BoundsChecking : public FunctionPass { -    const TargetData *TD; -    BuilderTy *Builder; -    Function *Fn; -    BasicBlock *TrapBB; -    unsigned Penalty;      static char ID;      BoundsChecking(unsigned _Penalty = 5) : FunctionPass(ID), Penalty(_Penalty){        initializeBoundsCheckingPass(*PassRegistry::getPassRegistry());      } -    BasicBlock *getTrapBB(); -    ConstTriState computeAllocSize(Value *Alloc, uint64_t &Size, Value* &SizeValue); -    bool instrument(Value *Ptr, Value *Val); -      virtual bool runOnFunction(Function &F);      virtual void getAnalysisUsage(AnalysisUsage &AU) const {        AU.addRequired<TargetData>();      } + +  private: +    const TargetData *TD; +    BuilderTy *Builder; +    Function *Fn; +    BasicBlock *TrapBB; +    unsigned Penalty; + +    BasicBlock *getTrapBB(); +    ConstTriState computeAllocSize(Value *Alloc, uint64_t &Size, +                                   Value* &SizeValue); +    bool instrument(Value *Ptr, Value *Val);   };  } @@ -126,29 +130,73 @@ ConstTriState BoundsChecking::computeAllocSize(Value *Alloc, uint64_t &Size,      SizeValue = Builder->CreateMul(SizeValue, ArraySize);      return NotConst; -  } else if (CallInst *MI = extractMallocCall(Alloc)) { -    SizeValue = MI->getArgOperand(0); -    if (ConstantInt *CI = dyn_cast<ConstantInt>(SizeValue)) { -      Size = CI->getZExtValue(); -      return Const; -    } -    return Penalty >= 2 ? NotConst : Dunno; - -  } else if (CallInst *MI = extractCallocCall(Alloc)) { -    Value *Arg1 = MI->getArgOperand(0); -    Value *Arg2 = MI->getArgOperand(1); -    if (ConstantInt *CI1 = dyn_cast<ConstantInt>(Arg1)) { -      if (ConstantInt *CI2 = dyn_cast<ConstantInt>(Arg2)) { -        Size = (CI1->getValue() * CI2->getValue()).getZExtValue(); -        return Const; +  } else if (CallInst *CI = dyn_cast<CallInst>(Alloc)) { +    Function *Callee = CI->getCalledFunction(); +    if (!Callee || !Callee->isDeclaration()) +      return Dunno; + +    FunctionType *FTy = Callee->getFunctionType(); +    if (FTy->getNumParams() == 1) { +      // alloc(size) +      if ((FTy->getParamType(0)->isIntegerTy(32) || +           FTy->getParamType(0)->isIntegerTy(64)) && +          (Callee->getName() == "malloc" || +           Callee->getName() == "valloc" || +           Callee->getName() == "_Znwj"  || // operator new(unsigned int) +           Callee->getName() == "_Znwm"  || // operator new(unsigned long) +           Callee->getName() == "_Znaj"  || // operator new[](unsigned int) +           Callee->getName() == "_Znam")) { // operator new[](unsigned long) +        SizeValue = CI->getArgOperand(0); +        if (ConstantInt *Arg = dyn_cast<ConstantInt>(SizeValue)) { +          Size = Arg->getZExtValue(); +          return Const; +        } +        return Penalty >= 2 ? NotConst : Dunno;        } +      return Dunno;      } -    if (Penalty < 2) -      return Dunno; +    if (FTy->getNumParams() == 2) { +      // alloc(x, y) and return buffer of size x * y +      if (((FTy->getParamType(0)->isIntegerTy(32) && +            FTy->getParamType(1)->isIntegerTy(32)) || +           (FTy->getParamType(0)->isIntegerTy(64) && +            FTy->getParamType(1)->isIntegerTy(64))) && +          Callee->getName() == "calloc") { +        Value *Arg1 = CI->getArgOperand(0); +        Value *Arg2 = CI->getArgOperand(1); +        if (ConstantInt *CI1 = dyn_cast<ConstantInt>(Arg1)) { +          if (ConstantInt *CI2 = dyn_cast<ConstantInt>(Arg2)) { +            Size = (CI1->getValue() * CI2->getValue()).getZExtValue(); +            return Const; +          } +        } + +        if (Penalty < 2) +          return Dunno; + +        SizeValue = Builder->CreateMul(Arg1, Arg2); +        return NotConst; +      } -    SizeValue = Builder->CreateMul(Arg1, Arg2); -    return NotConst; +      // realloc(ptr, size) +      if ((FTy->getParamType(1)->isIntegerTy(32) || +           FTy->getParamType(1)->isIntegerTy(64)) && +          (Callee->getName() == "realloc" || +           Callee->getName() == "reallocf")) { +        SizeValue = CI->getArgOperand(1); +        if (ConstantInt *Arg = dyn_cast<ConstantInt>(SizeValue)) { +          Size = Arg->getZExtValue(); +          return Const; +        } +        return Penalty >= 2 ? NotConst : Dunno; +      } +    } +    // TODO: handle more standard functions: +    // - strdup / strndup +    // - strcpy / strncpy +    // - memcpy / memmove +    // - strcat / strncat    }    DEBUG(dbgs() << "computeAllocSize failed:\n" << *Alloc); @@ -156,6 +204,11 @@ ConstTriState BoundsChecking::computeAllocSize(Value *Alloc, uint64_t &Size,  } +/// instrument - adds run-time bounds checks to memory accessing instructions. +/// Ptr is the pointer that will be read/written, and InstVal is either the +/// result from the load or the value being stored. It is used to determine the +/// size of memory block that is touched. +/// Returns true if any change was made to the IR, false otherwise.  bool BoundsChecking::instrument(Value *Ptr, Value *InstVal) {    uint64_t NeededSize = TD->getTypeStoreSize(InstVal->getType());    DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize) @@ -255,9 +308,9 @@ bool BoundsChecking::runOnFunction(Function &F) {    }    bool MadeChange = false; -  while (!WorkList.empty()) { -    Instruction *I = WorkList.back(); -    WorkList.pop_back(); +  for (std::vector<Instruction*>::iterator i = WorkList.begin(), +       e = WorkList.end(); i != e; ++i) { +    Instruction *I = *i;      Builder->SetInsertPoint(I);      if (LoadInst *LI = dyn_cast<LoadInst>(I)) {  | 

