diff options
Diffstat (limited to 'llvm/lib/VMCore/ConstantFolding.cpp')
| -rw-r--r-- | llvm/lib/VMCore/ConstantFolding.cpp | 337 | 
1 files changed, 177 insertions, 160 deletions
| diff --git a/llvm/lib/VMCore/ConstantFolding.cpp b/llvm/lib/VMCore/ConstantFolding.cpp index ee74280c62c..f07d4c3708f 100644 --- a/llvm/lib/VMCore/ConstantFolding.cpp +++ b/llvm/lib/VMCore/ConstantFolding.cpp @@ -11,6 +11,11 @@  // (internal) ConstantFolding.h interface, which is used by the  // ConstantExpr::get* methods to automatically fold constants when possible.  // +// The current constant folding implementation is implemented in two pieces: the +// template-based folder for simple primitive constants like ConstantInt, and +// the special case hackery that we use to symbolically evaluate expressions +// that use ConstantExprs. +//  //===----------------------------------------------------------------------===//  #include "ConstantFolding.h" @@ -22,11 +27,6 @@  #include <cmath>  using namespace llvm; -static unsigned getSize(const Type *Ty) { -  unsigned S = Ty->getPrimitiveSize(); -  return S ? S : 8;  // Treat pointers at 8 bytes -} -  namespace {    struct ConstRules {      ConstRules() {} @@ -71,158 +71,6 @@ namespace {  } -Constant *llvm::ConstantFoldCastInstruction(const Constant *V, -                                            const Type *DestTy) { -  if (V->getType() == DestTy) return (Constant*)V; - -  if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) -    if (CE->getOpcode() == Instruction::Cast) { -      Constant *Op = const_cast<Constant*>(CE->getOperand(0)); -      // Try to not produce a cast of a cast, which is almost always redundant. -      if (!Op->getType()->isFloatingPoint() && -          !CE->getType()->isFloatingPoint() && -          !DestTy->getType()->isFloatingPoint()) { -        unsigned S1 = getSize(Op->getType()), S2 = getSize(CE->getType()); -        unsigned S3 = getSize(DestTy); -        if (Op->getType() == DestTy && S3 >= S2) -          return Op; -        if (S1 >= S2 && S2 >= S3) -          return ConstantExpr::getCast(Op, DestTy); -        if (S1 <= S2 && S2 >= S3 && S1 <= S3) -          return ConstantExpr::getCast(Op, DestTy); -      } -    } else if (CE->getOpcode() == Instruction::GetElementPtr) { -      // If all of the indexes in the GEP are null values, there is no pointer -      // adjustment going on.  We might as well cast the source pointer. -      bool isAllNull = true; -      for (unsigned i = 1, e = CE->getNumOperands(); i != e; ++i) -        if (!CE->getOperand(i)->isNullValue()) { -          isAllNull = false; -          break; -        } -      if (isAllNull) -        return ConstantExpr::getCast(CE->getOperand(0), DestTy); -    } - -  ConstRules &Rules = ConstRules::get(V, V); - -  switch (DestTy->getPrimitiveID()) { -  case Type::BoolTyID:    return Rules.castToBool(V); -  case Type::UByteTyID:   return Rules.castToUByte(V); -  case Type::SByteTyID:   return Rules.castToSByte(V); -  case Type::UShortTyID:  return Rules.castToUShort(V); -  case Type::ShortTyID:   return Rules.castToShort(V); -  case Type::UIntTyID:    return Rules.castToUInt(V); -  case Type::IntTyID:     return Rules.castToInt(V); -  case Type::ULongTyID:   return Rules.castToULong(V); -  case Type::LongTyID:    return Rules.castToLong(V); -  case Type::FloatTyID:   return Rules.castToFloat(V); -  case Type::DoubleTyID:  return Rules.castToDouble(V); -  case Type::PointerTyID: -    return Rules.castToPointer(V, cast<PointerType>(DestTy)); -  default: return 0; -  } -} - -Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, -                                              const Constant *V1, -                                              const Constant *V2) { -  Constant *C; -  switch (Opcode) { -  default:                   return 0; -  case Instruction::Add:     return ConstRules::get(V1, V2).add(V1, V2); -  case Instruction::Sub:     return ConstRules::get(V1, V2).sub(V1, V2); -  case Instruction::Mul:     return ConstRules::get(V1, V2).mul(V1, V2); -  case Instruction::Div:     return ConstRules::get(V1, V2).div(V1, V2); -  case Instruction::Rem:     return ConstRules::get(V1, V2).rem(V1, V2); -  case Instruction::And:     return ConstRules::get(V1, V2).op_and(V1, V2); -  case Instruction::Or:      return ConstRules::get(V1, V2).op_or (V1, V2); -  case Instruction::Xor:     return ConstRules::get(V1, V2).op_xor(V1, V2); - -  case Instruction::Shl:     return ConstRules::get(V1, V2).shl(V1, V2); -  case Instruction::Shr:     return ConstRules::get(V1, V2).shr(V1, V2); - -  case Instruction::SetEQ:   return ConstRules::get(V1, V2).equalto(V1, V2); -  case Instruction::SetLT:   return ConstRules::get(V1, V2).lessthan(V1, V2); -  case Instruction::SetGT:   return ConstRules::get(V1, V2).lessthan(V2, V1); -  case Instruction::SetNE:   // V1 != V2  ===  !(V1 == V2) -    C = ConstRules::get(V1, V2).equalto(V1, V2); -    break; -  case Instruction::SetLE:   // V1 <= V2  ===  !(V2 < V1) -    C = ConstRules::get(V1, V2).lessthan(V2, V1); -    break; -  case Instruction::SetGE:   // V1 >= V2  ===  !(V1 < V2) -    C = ConstRules::get(V1, V2).lessthan(V1, V2); -    break; -  } - -  // If the folder broke out of the switch statement, invert the boolean -  // constant value, if it exists, and return it. -  if (!C) return 0; -  return ConstantExpr::get(Instruction::Xor, ConstantBool::True, C); -} - -Constant *llvm::ConstantFoldGetElementPtr(const Constant *C, -                                        const std::vector<Constant*> &IdxList) { -  if (IdxList.size() == 0 || -      (IdxList.size() == 1 && IdxList[0]->isNullValue())) -    return const_cast<Constant*>(C); - -  // TODO If C is null and all idx's are null, return null of the right type. - - -  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(const_cast<Constant*>(C))) { -    // Combine Indices - If the source pointer to this getelementptr instruction -    // is a getelementptr instruction, combine the indices of the two -    // getelementptr instructions into a single instruction. -    // -    if (CE->getOpcode() == Instruction::GetElementPtr) { -      const Type *LastTy = 0; -      for (gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE); -           I != E; ++I) -        LastTy = *I; - -      if ((LastTy && isa<ArrayType>(LastTy)) || IdxList[0]->isNullValue()) { -        std::vector<Constant*> NewIndices; -        NewIndices.reserve(IdxList.size() + CE->getNumOperands()); -        for (unsigned i = 1, e = CE->getNumOperands()-1; i != e; ++i) -          NewIndices.push_back(cast<Constant>(CE->getOperand(i))); - -        // Add the last index of the source with the first index of the new GEP. -        // Make sure to handle the case when they are actually different types. -        Constant *Combined = CE->getOperand(CE->getNumOperands()-1); -        if (!IdxList[0]->isNullValue())   // Otherwise it must be an array -          Combined =  -            ConstantExpr::get(Instruction::Add, -                              ConstantExpr::getCast(IdxList[0], Type::LongTy), -                              ConstantExpr::getCast(Combined, Type::LongTy)); -         -        NewIndices.push_back(Combined); -        NewIndices.insert(NewIndices.end(), IdxList.begin()+1, IdxList.end()); -        return ConstantExpr::getGetElementPtr(CE->getOperand(0), NewIndices); -      } -    } - -    // Implement folding of: -    //    int* getelementptr ([2 x int]* cast ([3 x int]* %X to [2 x int]*), -    //                        long 0, long 0) -    // To: int* getelementptr ([3 x int]* %X, long 0, long 0) -    // -    if (CE->getOpcode() == Instruction::Cast && IdxList.size() > 1 && -        IdxList[0]->isNullValue()) -      if (const PointerType *SPT =  -          dyn_cast<PointerType>(CE->getOperand(0)->getType())) -        if (const ArrayType *SAT = dyn_cast<ArrayType>(SPT->getElementType())) -          if (const ArrayType *CAT = -              dyn_cast<ArrayType>(cast<PointerType>(C->getType())->getElementType())) -            if (CAT->getElementType() == SAT->getElementType()) -              return ConstantExpr::getGetElementPtr( -                      (Constant*)CE->getOperand(0), IdxList); -  } -  return 0; -} - -  //===----------------------------------------------------------------------===//  //                             TemplateRules Class  //===----------------------------------------------------------------------===// @@ -604,9 +452,9 @@ struct DirectIntRules  //                           DirectFPRules Class  //===----------------------------------------------------------------------===//  // -// DirectFPRules provides implementations of functions that are valid on -// floating point types, but not all types in general. -// +/// DirectFPRules provides implementations of functions that are valid on +/// floating point types, but not all types in general. +///  template <class ConstantClass, class BuiltinType, Type **Ty>  struct DirectFPRules    : public DirectRules<ConstantClass, BuiltinType, Ty, @@ -619,6 +467,9 @@ struct DirectFPRules    }  }; + +/// ConstRules::get - This method returns the constant rules implementation that +/// implements the semantics of the two specified constants.  ConstRules &ConstRules::get(const Constant *V1, const Constant *V2) {    static EmptyRules       EmptyR;    static BoolRules        BoolR; @@ -654,3 +505,169 @@ ConstRules &ConstRules::get(const Constant *V1, const Constant *V2) {    case Type::DoubleTyID:  return DoubleR;    }  } + + +//===----------------------------------------------------------------------===// +//                ConstantFold*Instruction Implementations +//===----------------------------------------------------------------------===// +// +// These methods contain the special case hackery required to symbolically +// evaluate some constant expression cases, and use the ConstantRules class to +// evaluate normal constants. +// +static unsigned getSize(const Type *Ty) { +  unsigned S = Ty->getPrimitiveSize(); +  return S ? S : 8;  // Treat pointers at 8 bytes +} + +Constant *llvm::ConstantFoldCastInstruction(const Constant *V, +                                            const Type *DestTy) { +  if (V->getType() == DestTy) return (Constant*)V; + +  if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) +    if (CE->getOpcode() == Instruction::Cast) { +      Constant *Op = const_cast<Constant*>(CE->getOperand(0)); +      // Try to not produce a cast of a cast, which is almost always redundant. +      if (!Op->getType()->isFloatingPoint() && +          !CE->getType()->isFloatingPoint() && +          !DestTy->getType()->isFloatingPoint()) { +        unsigned S1 = getSize(Op->getType()), S2 = getSize(CE->getType()); +        unsigned S3 = getSize(DestTy); +        if (Op->getType() == DestTy && S3 >= S2) +          return Op; +        if (S1 >= S2 && S2 >= S3) +          return ConstantExpr::getCast(Op, DestTy); +        if (S1 <= S2 && S2 >= S3 && S1 <= S3) +          return ConstantExpr::getCast(Op, DestTy); +      } +    } else if (CE->getOpcode() == Instruction::GetElementPtr) { +      // If all of the indexes in the GEP are null values, there is no pointer +      // adjustment going on.  We might as well cast the source pointer. +      bool isAllNull = true; +      for (unsigned i = 1, e = CE->getNumOperands(); i != e; ++i) +        if (!CE->getOperand(i)->isNullValue()) { +          isAllNull = false; +          break; +        } +      if (isAllNull) +        return ConstantExpr::getCast(CE->getOperand(0), DestTy); +    } + +  ConstRules &Rules = ConstRules::get(V, V); + +  switch (DestTy->getPrimitiveID()) { +  case Type::BoolTyID:    return Rules.castToBool(V); +  case Type::UByteTyID:   return Rules.castToUByte(V); +  case Type::SByteTyID:   return Rules.castToSByte(V); +  case Type::UShortTyID:  return Rules.castToUShort(V); +  case Type::ShortTyID:   return Rules.castToShort(V); +  case Type::UIntTyID:    return Rules.castToUInt(V); +  case Type::IntTyID:     return Rules.castToInt(V); +  case Type::ULongTyID:   return Rules.castToULong(V); +  case Type::LongTyID:    return Rules.castToLong(V); +  case Type::FloatTyID:   return Rules.castToFloat(V); +  case Type::DoubleTyID:  return Rules.castToDouble(V); +  case Type::PointerTyID: +    return Rules.castToPointer(V, cast<PointerType>(DestTy)); +  default: return 0; +  } +} + +Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, +                                              const Constant *V1, +                                              const Constant *V2) { +  Constant *C; +  switch (Opcode) { +  default:                   return 0; +  case Instruction::Add:     return ConstRules::get(V1, V2).add(V1, V2); +  case Instruction::Sub:     return ConstRules::get(V1, V2).sub(V1, V2); +  case Instruction::Mul:     return ConstRules::get(V1, V2).mul(V1, V2); +  case Instruction::Div:     return ConstRules::get(V1, V2).div(V1, V2); +  case Instruction::Rem:     return ConstRules::get(V1, V2).rem(V1, V2); +  case Instruction::And:     return ConstRules::get(V1, V2).op_and(V1, V2); +  case Instruction::Or:      return ConstRules::get(V1, V2).op_or (V1, V2); +  case Instruction::Xor:     return ConstRules::get(V1, V2).op_xor(V1, V2); + +  case Instruction::Shl:     return ConstRules::get(V1, V2).shl(V1, V2); +  case Instruction::Shr:     return ConstRules::get(V1, V2).shr(V1, V2); + +  case Instruction::SetEQ:   return ConstRules::get(V1, V2).equalto(V1, V2); +  case Instruction::SetLT:   return ConstRules::get(V1, V2).lessthan(V1, V2); +  case Instruction::SetGT:   return ConstRules::get(V1, V2).lessthan(V2, V1); +  case Instruction::SetNE:   // V1 != V2  ===  !(V1 == V2) +    C = ConstRules::get(V1, V2).equalto(V1, V2); +    break; +  case Instruction::SetLE:   // V1 <= V2  ===  !(V2 < V1) +    C = ConstRules::get(V1, V2).lessthan(V2, V1); +    break; +  case Instruction::SetGE:   // V1 >= V2  ===  !(V1 < V2) +    C = ConstRules::get(V1, V2).lessthan(V1, V2); +    break; +  } + +  // If the folder broke out of the switch statement, invert the boolean +  // constant value, if it exists, and return it. +  if (!C) return 0; +  return ConstantExpr::get(Instruction::Xor, ConstantBool::True, C); +} + +Constant *llvm::ConstantFoldGetElementPtr(const Constant *C, +                                        const std::vector<Constant*> &IdxList) { +  if (IdxList.size() == 0 || +      (IdxList.size() == 1 && IdxList[0]->isNullValue())) +    return const_cast<Constant*>(C); + +  // TODO If C is null and all idx's are null, return null of the right type. + + +  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(const_cast<Constant*>(C))) { +    // Combine Indices - If the source pointer to this getelementptr instruction +    // is a getelementptr instruction, combine the indices of the two +    // getelementptr instructions into a single instruction. +    // +    if (CE->getOpcode() == Instruction::GetElementPtr) { +      const Type *LastTy = 0; +      for (gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE); +           I != E; ++I) +        LastTy = *I; + +      if ((LastTy && isa<ArrayType>(LastTy)) || IdxList[0]->isNullValue()) { +        std::vector<Constant*> NewIndices; +        NewIndices.reserve(IdxList.size() + CE->getNumOperands()); +        for (unsigned i = 1, e = CE->getNumOperands()-1; i != e; ++i) +          NewIndices.push_back(cast<Constant>(CE->getOperand(i))); + +        // Add the last index of the source with the first index of the new GEP. +        // Make sure to handle the case when they are actually different types. +        Constant *Combined = CE->getOperand(CE->getNumOperands()-1); +        if (!IdxList[0]->isNullValue())   // Otherwise it must be an array +          Combined =  +            ConstantExpr::get(Instruction::Add, +                              ConstantExpr::getCast(IdxList[0], Type::LongTy), +                              ConstantExpr::getCast(Combined, Type::LongTy)); +         +        NewIndices.push_back(Combined); +        NewIndices.insert(NewIndices.end(), IdxList.begin()+1, IdxList.end()); +        return ConstantExpr::getGetElementPtr(CE->getOperand(0), NewIndices); +      } +    } + +    // Implement folding of: +    //    int* getelementptr ([2 x int]* cast ([3 x int]* %X to [2 x int]*), +    //                        long 0, long 0) +    // To: int* getelementptr ([3 x int]* %X, long 0, long 0) +    // +    if (CE->getOpcode() == Instruction::Cast && IdxList.size() > 1 && +        IdxList[0]->isNullValue()) +      if (const PointerType *SPT =  +          dyn_cast<PointerType>(CE->getOperand(0)->getType())) +        if (const ArrayType *SAT = dyn_cast<ArrayType>(SPT->getElementType())) +          if (const ArrayType *CAT = +              dyn_cast<ArrayType>(cast<PointerType>(C->getType())->getElementType())) +            if (CAT->getElementType() == SAT->getElementType()) +              return ConstantExpr::getGetElementPtr( +                      (Constant*)CE->getOperand(0), IdxList); +  } +  return 0; +} + | 

