diff options
-rw-r--r-- | llvm/include/llvm/Module.h | 9 | ||||
-rw-r--r-- | llvm/lib/CodeGen/StackProtector.cpp | 25 | ||||
-rw-r--r-- | llvm/lib/VMCore/Module.cpp | 22 |
3 files changed, 42 insertions, 14 deletions
diff --git a/llvm/include/llvm/Module.h b/llvm/include/llvm/Module.h index de121578750..09eba81b4eb 100644 --- a/llvm/include/llvm/Module.h +++ b/llvm/include/llvm/Module.h @@ -229,6 +229,15 @@ public: return getGlobalVariable(Name, true); } + /// getOrInsertGlobal - Look up the specified global in the module symbol + /// table. + /// 1. If it does not exist, add a declaration of the global and return it. + /// 2. Else, the global exists but has the wrong type: return the function + /// with a constantexpr cast to the right type. + /// 3. Finally, if the existing global is the correct delclaration, return + /// the existing global. + Constant *getOrInsertGlobal(const std::string &Name, const Type *Ty); + /// @} /// @name Global Alias Accessors /// @{ diff --git a/llvm/lib/CodeGen/StackProtector.cpp b/llvm/lib/CodeGen/StackProtector.cpp index 4bbb357bc55..30c3db5b526 100644 --- a/llvm/lib/CodeGen/StackProtector.cpp +++ b/llvm/lib/CodeGen/StackProtector.cpp @@ -52,7 +52,7 @@ namespace { AllocaInst *StackProtFrameSlot; /// StackGuardVar - The global variable for the stack guard. - GlobalVariable *StackGuardVar; + Constant *StackGuardVar; Function *F; Module *M; @@ -115,14 +115,8 @@ void StackProtector::InsertStackProtectorPrologue() { BasicBlock &Entry = F->getEntryBlock(); Instruction &InsertPt = Entry.front(); - const char *StackGuardStr = "__stack_chk_guard"; - StackGuardVar = M->getNamedGlobal(StackGuardStr); - - if (!StackGuardVar) - StackGuardVar = new GlobalVariable(PointerType::getUnqual(Type::Int8Ty), - false, GlobalValue::ExternalLinkage, - 0, StackGuardStr, M); - + StackGuardVar = M->getOrInsertGlobal("__stack_chk_guard", + PointerType::getUnqual(Type::Int8Ty)); StackProtFrameSlot = new AllocaInst(PointerType::getUnqual(Type::Int8Ty), "StackProt_Frame", &InsertPt); LoadInst *LI = new LoadInst(StackGuardVar, "StackGuard", false, &InsertPt); @@ -161,7 +155,7 @@ void StackProtector::InsertStackProtectorEpilogue() { // %3 = cmp i1 %1, %2 // br i1 %3, label %SPRet, label %CallStackCheckFailBlk // - // SPRet: + // SP_return: // ret ... // // CallStackCheckFailBlk: @@ -174,12 +168,15 @@ void StackProtector::InsertStackProtectorEpilogue() { ReturnInst *RI = cast<ReturnInst>(BB->getTerminator()); Function::iterator InsPt = BB; ++InsPt; // Insertion point for new BB. - BasicBlock *NewBB = BasicBlock::Create("SPRet", F, InsPt); + // Split the basic block before the return instruction. + BasicBlock *NewBB = BB->splitBasicBlock(RI, "SP_return"); - // Move the return instruction into the new basic block. - RI->removeFromParent(); - NewBB->getInstList().insert(NewBB->begin(), RI); + // Move the newly created basic block to the point right after the old basic + // block. + NewBB->removeFromParent(); + F->getBasicBlockList().insert(InsPt, NewBB); + // Generate the stack protector instructions in the old basic block. LoadInst *LI2 = new LoadInst(StackGuardVar, "", false, BB); LoadInst *LI1 = new LoadInst(StackProtFrameSlot, "", true, BB); ICmpInst *Cmp = new ICmpInst(CmpInst::ICMP_EQ, LI1, LI2, "", BB); diff --git a/llvm/lib/VMCore/Module.cpp b/llvm/lib/VMCore/Module.cpp index b95f6e3fa86..d4432c6ed95 100644 --- a/llvm/lib/VMCore/Module.cpp +++ b/llvm/lib/VMCore/Module.cpp @@ -224,6 +224,28 @@ GlobalVariable *Module::getGlobalVariable(const std::string &Name, return 0; } +Constant *Module::getOrInsertGlobal(const std::string &Name, const Type *Ty) { + ValueSymbolTable &SymTab = getValueSymbolTable(); + + // See if we have a definition for the specified global already. + GlobalVariable *GV = dyn_cast_or_null<GlobalVariable>(SymTab.lookup(Name)); + if (GV == 0) { + // Nope, add it + GlobalVariable *New = + new GlobalVariable(Ty, false, GlobalVariable::ExternalLinkage, 0, Name); + GlobalList.push_back(New); + return New; // Return the new declaration. + } + + // If the variable exists but has the wrong type, return a bitcast to the + // right type. + if (GV->getType() != PointerType::getUnqual(Ty)) + return ConstantExpr::getBitCast(GV, PointerType::getUnqual(Ty)); + + // Otherwise, we just found the existing function or a prototype. + return GV; +} + //===----------------------------------------------------------------------===// // Methods for easy access to the global variables in the module. // |