diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 6e7e11a15ae..9fd65001d14 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -132,18 +132,30 @@ static bool IsMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, /// \brief A helper routine of InstCombiner::visitMul(). /// -/// If C is a vector of known powers of 2, then this function returns -/// a new vector obtained from C replacing each element with its logBase2. +/// If C is a scalar/vector of known powers of 2, then this function returns +/// a new scalar/vector obtained from logBase2 of C. /// Return a null pointer otherwise. -static Constant *getLogBase2Vector(ConstantDataVector *CV) { +static Constant *getLogBase2(Type *Ty, Constant *C) { const APInt *IVal; - SmallVector<Constant *, 4> Elts; + if (const auto *CI = dyn_cast<ConstantInt>(C)) + if (match(C, m_APInt(IVal)) && IVal->isPowerOf2()) + return ConstantInt::get(Ty, IVal->logBase2()); + + if (!Ty->isVectorTy()) + return nullptr; - for (unsigned I = 0, E = CV->getNumElements(); I != E; ++I) { - Constant *Elt = CV->getElementAsConstant(I); + SmallVector<Constant *, 4> Elts; + for (unsigned I = 0, E = Ty->getVectorNumElements(); I != E; ++I) { + Constant *Elt = C->getAggregateElement(I); + if (!Elt) + return nullptr; + if (isa<UndefValue>(Elt)) { + Elts.push_back(UndefValue::get(Ty->getScalarType())); + continue; + } if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2()) return nullptr; - Elts.push_back(ConstantInt::get(Elt->getType(), IVal->logBase2())); + Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2())); } return ConstantVector::get(Elts); @@ -232,16 +244,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { - Constant *NewCst = nullptr; - if (match(C1, m_APInt(IVal)) && IVal->isPowerOf2()) - // Replace X*(2^C) with X << C, where C is either a scalar or a splat. - NewCst = ConstantInt::get(NewOp->getType(), IVal->logBase2()); - else if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(C1)) - // Replace X*(2^C) with X << C, where C is a vector of known - // constant powers of 2. - NewCst = getLogBase2Vector(CV); - - if (NewCst) { + // Replace X*(2^C) with X << C, where C is either a scalar or a vector. + if (Constant *NewCst = getLogBase2(NewOp->getType(), C1)) { unsigned Width = NewCst->getType()->getPrimitiveSizeInBits(); BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); |