diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 102 | 
1 files changed, 47 insertions, 55 deletions
| diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index b90e4d846bc..2be3cdbee6b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -430,6 +430,50 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,    return false;  } +/// Given a vector that is bitcast to an integer, optionally logically +/// right-shifted, and truncated, convert it to an extractelement. +/// Example (big endian): +///   trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32 +///   ---> +///   extractelement <4 x i32> %X, 1 +static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC, +                                         const DataLayout &DL) { +  Value *TruncOp = Trunc.getOperand(0); +  Type *DestType = Trunc.getType(); +  if (!TruncOp->hasOneUse() || !isa<IntegerType>(DestType)) +    return nullptr; + +  Value *VecInput = nullptr; +  ConstantInt *ShiftVal = nullptr; +  if (!match(TruncOp, m_CombineOr(m_BitCast(m_Value(VecInput)), +                                  m_LShr(m_BitCast(m_Value(VecInput)), +                                         m_ConstantInt(ShiftVal)))) || +      !isa<VectorType>(VecInput->getType())) +    return nullptr; + +  VectorType *VecType = cast<VectorType>(VecInput->getType()); +  unsigned VecWidth = VecType->getPrimitiveSizeInBits(); +  unsigned DestWidth = DestType->getPrimitiveSizeInBits(); +  unsigned ShiftAmount = ShiftVal ? ShiftVal->getZExtValue() : 0; + +  if ((VecWidth % DestWidth != 0) || (ShiftAmount % DestWidth != 0)) +    return nullptr; + +  // If the element type of the vector doesn't match the result type, +  // bitcast it to a vector type that we can extract from. +  unsigned NumVecElts = VecWidth / DestWidth; +  if (VecType->getElementType() != DestType) { +    VecType = VectorType::get(DestType, NumVecElts); +    VecInput = IC.Builder->CreateBitCast(VecInput, VecType, "bc"); +  } + +  unsigned Elt = ShiftAmount / DestWidth; +  if (DL.isBigEndian()) +    Elt = NumVecElts - 1 - Elt; + +  return ExtractElementInst::Create(VecInput, IC.Builder->getInt32(Elt)); +} +  Instruction *InstCombiner::visitTrunc(TruncInst &CI) {    if (Instruction *Result = commonCastTransforms(CI))      return Result; @@ -528,6 +572,9 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {                                       ConstantExpr::getTrunc(Cst, DestTy));    } +  if (Instruction *I = foldVecTruncToExtElt(CI, *this, DL)) +    return I; +    return nullptr;  } @@ -1740,56 +1787,6 @@ static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast,    return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand());  } -static Instruction *foldVecTruncToExtElt(Value *VecInput, Type *DestTy, -                                         unsigned ShiftAmt, InstCombiner &IC, -                                         const DataLayout &DL) { -  VectorType *VecTy = cast<VectorType>(VecInput->getType()); -  unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); -  unsigned VecWidth = VecTy->getPrimitiveSizeInBits(); - -  if ((VecWidth % DestWidth != 0) || (ShiftAmt % DestWidth != 0)) -    return nullptr; - -  // If the element type of the vector doesn't match the result type, -  // bitcast it to be a vector type we can extract from. -  unsigned NumVecElts = VecWidth / DestWidth; -  if (VecTy->getElementType() != DestTy) { -    VecTy = VectorType::get(DestTy, NumVecElts); -    VecInput = IC.Builder->CreateBitCast(VecInput, VecTy); -  } - -  unsigned Elt = ShiftAmt / DestWidth; -  if (DL.isBigEndian()) -    Elt = NumVecElts - 1 - Elt; - -  return ExtractElementInst::Create(VecInput, IC.Builder->getInt32(Elt)); -} - -/// See if we can optimize an integer->float/double bitcast. -/// The various long double bitcasts can't get in here. -static Instruction *optimizeIntToFloatBitCast(BitCastInst &CI, InstCombiner &IC, -                                              const DataLayout &DL) { -  Value *Src = CI.getOperand(0); -  Type *DstTy = CI.getType(); - -  // If this is a bitcast from int to float, check to see if the int is an -  // extraction from a vector. -  Value *VecInput = nullptr; -  // bitcast(trunc(bitcast(somevector))) -  if (match(Src, m_Trunc(m_BitCast(m_Value(VecInput)))) && -      isa<VectorType>(VecInput->getType())) -    return foldVecTruncToExtElt(VecInput, DstTy, 0, IC, DL); - -  // bitcast(trunc(lshr(bitcast(somevector), cst)) -  ConstantInt *ShAmt = nullptr; -  if (match(Src, m_Trunc(m_LShr(m_BitCast(m_Value(VecInput)), -                                m_ConstantInt(ShAmt)))) && -      isa<VectorType>(VecInput->getType())) -    return foldVecTruncToExtElt(VecInput, DstTy, ShAmt->getZExtValue(), IC, DL); - -  return nullptr; -} -  Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {    // If the operands are integer typed then apply the integer transforms,    // otherwise just apply the common ones. @@ -1833,11 +1830,6 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {      }    } -  // Try to optimize int -> float bitcasts. -  if ((DestTy->isFloatTy() || DestTy->isDoubleTy()) && isa<IntegerType>(SrcTy)) -    if (Instruction *I = optimizeIntToFloatBitCast(CI, *this, DL)) -      return I; -    if (VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) {      if (DestVTy->getNumElements() == 1 && !SrcTy->isVectorTy()) {        Value *Elem = Builder->CreateBitCast(Src, DestVTy->getElementType()); | 

