diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index bff46597266..efe41b49274 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2949,14 +2949,27 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Convert to a vector select if we can bypass casts and find a boolean // vector condition value. Value *BoolVec; - if (match(peekThroughBitcast(Mask), m_SExt(m_Value(BoolVec)))) { - auto *VTy = dyn_cast<VectorType>(BoolVec->getType()); - if (VTy && VTy->getScalarSizeInBits() == 1 && - VTy->getVectorNumElements() == II->getType()->getVectorNumElements()) + Mask = peekThroughBitcast(Mask); + if (match(Mask, m_SExt(m_Value(BoolVec))) && + BoolVec->getType()->isVectorTy() && + BoolVec->getType()->getScalarSizeInBits() == 1) { + assert(Mask->getType()->getPrimitiveSizeInBits() == + II->getType()->getPrimitiveSizeInBits() && + "Not expecting mask and operands with different sizes"); + + unsigned NumMaskElts = Mask->getType()->getVectorNumElements(); + unsigned NumOperandElts = II->getType()->getVectorNumElements(); + if (NumMaskElts == NumOperandElts) return SelectInst::Create(BoolVec, Op1, Op0); - // TODO: If we can find a boolean vector condition with less elements, - // then we can form a vector select by bitcasting Op0/Op1 to a - // vector type with wider elements and bitcasting the result. + + // If the mask has less elements than the operands, each mask bit maps to + // multiple elements of the operands. Bitcast back and forth. + if (NumMaskElts < NumOperandElts) { + Value *CastOp0 = Builder.CreateBitCast(Op0, Mask->getType()); + Value *CastOp1 = Builder.CreateBitCast(Op1, Mask->getType()); + Value *Sel = Builder.CreateSelect(BoolVec, CastOp1, CastOp0); + return new BitCastInst(Sel, II->getType()); + } } break; |