diff options
| -rw-r--r-- | llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp | 181 | 
1 files changed, 172 insertions, 9 deletions
diff --git a/llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp b/llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp index e6978100be9..4c1fe29c2a3 100644 --- a/llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp @@ -257,6 +257,21 @@ public:    }    /// @brief Return a Function* for the strlen libcall +  Function* get_strcpy() +  { +    if (!strcpy_func) +    { +      std::vector<const Type*> args; +      args.push_back(PointerType::get(Type::SByteTy)); +      args.push_back(PointerType::get(Type::SByteTy)); +      FunctionType* strcpy_type =  +        FunctionType::get(PointerType::get(Type::SByteTy), args, false); +      strcpy_func = M->getOrInsertFunction("strcpy",strcpy_type); +    } +    return strcpy_func; +  } + +  /// @brief Return a Function* for the strlen libcall    Function* get_strlen()    {      if (!strlen_func) @@ -295,8 +310,8 @@ public:        std::vector<const Type*> args;        args.push_back(PointerType::get(Type::SByteTy));        args.push_back(PointerType::get(Type::SByteTy)); -      args.push_back(Type::IntTy); -      args.push_back(Type::IntTy); +      args.push_back(Type::UIntTy); +      args.push_back(Type::UIntTy);        FunctionType* memcpy_type = FunctionType::get(Type::VoidTy, args, false);        memcpy_func = M->getOrInsertFunction("llvm.memcpy",memcpy_type);      } @@ -314,6 +329,7 @@ private:      memcpy_func = 0;      memchr_func = 0;      sqrt_func   = 0; +    strcpy_func = 0;      strlen_func = 0;    } @@ -323,6 +339,7 @@ private:    Function* memcpy_func; ///< Cached llvm.memcpy function    Function* memchr_func; ///< Cached memchr function    Function* sqrt_func;   ///< Cached sqrt function +  Function* strcpy_func; ///< Cached strcpy function    Function* strlen_func; ///< Cached strlen function    Module* M;             ///< Cached Module    TargetData* TD;        ///< Cached TargetData @@ -493,8 +510,8 @@ public:      std::vector<Value*> vals;      vals.push_back(gep); // destination      vals.push_back(ci->getOperand(2)); // source -    vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length -    vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment +    vals.push_back(ConstantUInt::get(Type::UIntTy,len)); // length +    vals.push_back(ConstantUInt::get(Type::UIntTy,1)); // alignment      new CallInst(SLC.get_memcpy(), vals, "", ci);      // Finally, substitute the first operand of the strcat call for the  @@ -862,8 +879,8 @@ public:      std::vector<Value*> vals;      vals.push_back(dest); // destination      vals.push_back(src); // source -    vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length -    vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment +    vals.push_back(ConstantUInt::get(Type::UIntTy,len)); // length +    vals.push_back(ConstantUInt::get(Type::UIntTy,1)); // alignment      new CallInst(SLC.get_memcpy(), vals, "", ci);      // Finally, substitute the first operand of the strcat call for the  @@ -1255,7 +1272,8 @@ public:        args.push_back(ConstantUInt::get(SLC.getIntPtrType(),len));        args.push_back(ConstantUInt::get(SLC.getIntPtrType(),1));        args.push_back(ci->getOperand(1)); -      new CallInst(fwrite_func,args,"",ci); +      new CallInst(fwrite_func,args,ci->getName(),ci); +      ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len));        ci->eraseFromParent();        return true;      } @@ -1281,7 +1299,7 @@ public:          if (!getConstantStringLength(ci->getOperand(3), len, &CA))            return false; -        // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),1,file)  +        // fprintf(file,"%s",str) -> fwrite(fmt,strlen(fmt),1,file)           const Type* FILEptr_type = ci->getOperand(1)->getType();          Function* fwrite_func = SLC.get_fwrite(FILEptr_type);          if (!fwrite_func) @@ -1291,7 +1309,8 @@ public:          args.push_back(ConstantUInt::get(SLC.getIntPtrType(),len));          args.push_back(ConstantUInt::get(SLC.getIntPtrType(),1));          args.push_back(ci->getOperand(1)); -        new CallInst(fwrite_func,args,"",ci); +        new CallInst(fwrite_func,args,ci->getName(),ci); +        ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len));          break;        }        case 'c': @@ -1306,6 +1325,7 @@ public:            return false;          CastInst* cast = new CastInst(CI,Type::IntTy,CI->getName()+".int",ci);          new CallInst(fputc_func,cast,ci->getOperand(1),"",ci); +        ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,1));          break;        }        default: @@ -1317,6 +1337,149 @@ public:  } FPrintFOptimizer; +/// This LibCallOptimization will simplify calls to the "sprintf" library  +/// function. It looks for cases where the result of sprintf is not used and the +/// operation can be reduced to something simpler. +/// @brief Simplify the pow library function. +struct SPrintFOptimization : public LibCallOptimization +{ +public: +  /// @brief Default Constructor +  SPrintFOptimization() : LibCallOptimization("sprintf", +      "simplify-libcalls:sprintf", "Number of 'sprintf' calls simplified") {} + +  /// @brief Destructor +  virtual ~SPrintFOptimization() {} + +  /// @brief Make sure that the "fprintf" function has the right prototype +  virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC) +  { +    // Just make sure this has at least 2 arguments +    return (f->getReturnType() == Type::IntTy && f->arg_size() >= 2); +  } + +  /// @brief Perform the sprintf optimization. +  virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& SLC) +  { +    // If the call has more than 3 operands, we can't optimize it +    if (ci->getNumOperands() > 4 || ci->getNumOperands() < 3) +      return false; + +    // All the optimizations depend on the length of the second argument and the +    // fact that it is a constant string array. Check that now +    uint64_t len = 0;  +    ConstantArray* CA = 0; +    if (!getConstantStringLength(ci->getOperand(2), len, &CA)) +      return false; + +    if (ci->getNumOperands() == 3) +    { +      if (len == 0) +      { +        // If the length is 0, we just need to store a null byte +        new StoreInst(ConstantInt::get(Type::SByteTy,0),ci->getOperand(1),ci); +        ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,0)); +        ci->eraseFromParent(); +        return true; +      } + +      // Make sure there's no % in the constant array +      for (unsigned i = 0; i < len; ++i) +      { +        if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(i))) +        { +          // Check for the null terminator +          if (CI->getRawValue() == '%') +            return false; // we found a %, can't optimize +        } +        else  +          return false; // initializer is not constant int, can't optimize +      } + +      // Increment length because we want to copy the null byte too +      len++; + +      // sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1)  +      Function* memcpy_func = SLC.get_memcpy(); +      if (!memcpy_func) +        return false; +      std::vector<Value*> args; +      args.push_back(ci->getOperand(1)); +      args.push_back(ci->getOperand(2)); +      args.push_back(ConstantUInt::get(Type::UIntTy,len)); +      args.push_back(ConstantUInt::get(Type::UIntTy,1)); +      new CallInst(memcpy_func,args,"",ci); +      ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len)); +      ci->eraseFromParent(); +      return true; +    } + +    // The remaining optimizations require the format string to be length 2 +    // "%s" or "%c". +    if (len != 2) +      return false; + +    // The first character has to be a % +    if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(0))) +      if (CI->getRawValue() != '%') +        return false; + +    // Get the second character and switch on its value +    ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(1)); +    switch (CI->getRawValue()) +    { +      case 's': +      { +        uint64_t len = 0; +        if (ci->hasNUses(0)) +        { +          // sprintf(dest,"%s",str) -> strcpy(dest,str)  +          Function* strcpy_func = SLC.get_strcpy(); +          if (!strcpy_func) +            return false; +          std::vector<Value*> args; +          args.push_back(ci->getOperand(1)); +          args.push_back(ci->getOperand(3)); +          new CallInst(strcpy_func,args,"",ci); +        } +        else if (getConstantStringLength(ci->getOperand(3),len)) +        { +          // sprintf(dest,"%s",cstr) -> llvm.memcpy(dest,str,strlen(str),1) +          len++; // get the null-terminator +          Function* memcpy_func = SLC.get_memcpy(); +          if (!memcpy_func) +            return false; +          std::vector<Value*> args; +          args.push_back(ci->getOperand(1)); +          args.push_back(ci->getOperand(3)); +          args.push_back(ConstantUInt::get(Type::UIntTy,len)); +          args.push_back(ConstantUInt::get(Type::UIntTy,1)); +          new CallInst(memcpy_func,args,"",ci); +          ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len)); +        } +        break; +      } +      case 'c': +      { +        // sprintf(dest,"%c",chr) -> store chr, dest +        CastInst* cast =  +          new CastInst(ci->getOperand(3),Type::SByteTy,"char",ci); +        new StoreInst(cast, ci->getOperand(1), ci); +        GetElementPtrInst* gep = new GetElementPtrInst(ci->getOperand(1), +          ConstantUInt::get(Type::UIntTy,1),ci->getOperand(1)->getName()+".end", +          ci); +        new StoreInst(ConstantInt::get(Type::SByteTy,0),gep,ci); +        ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,1)); +        break; +      } +      default: +        return false; +    } +    ci->eraseFromParent(); +    return true; +  } +} SPrintFOptimizer; +  /// This LibCallOptimization will simplify calls to the "fputs" library   /// function. It looks for cases where the result of fputs is not used and the  /// operation can be reduced to something simpler.  | 

