summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp29
1 files changed, 20 insertions, 9 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index a258ac56568..25f692c6fb9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1038,13 +1038,20 @@ static Value *simplifyX86vpcom(const IntrinsicInst &II,
// masked intrinsics.
static Value *emitX86MaskSelect(Value *Mask, Value *Op0, Value *Op1,
InstCombiner::BuilderTy &Builder) {
+ unsigned VWidth = Op0->getType()->getVectorNumElements();
+
+ // If the mask is all ones we don't need the select. But we need to check
+ // only the bit thats will be used in case VWidth is less than 8.
+ if (auto *C = dyn_cast<ConstantInt>(Mask))
+ if (C->getValue().zextOrTrunc(VWidth).isAllOnesValue())
+ return Op0;
+
auto *MaskTy = VectorType::get(Builder.getInt1Ty(),
cast<IntegerType>(Mask->getType())->getBitWidth());
Mask = Builder.CreateBitCast(Mask, MaskTy);
// If we have less than 8 elements, then the starting mask was an i8 and
// we need to extract down to the right number of elements.
- unsigned VWidth = Op0->getType()->getVectorNumElements();
if (VWidth < 8) {
uint32_t Indices[4];
for (unsigned i = 0; i != VWidth; ++i)
@@ -1873,16 +1880,20 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
}
// Handle the masking aspect of the intrinsic.
- // Cast the mask to an i1 vector and then extract the lowest element.
Value *Mask = II->getArgOperand(3);
- auto *MaskTy = VectorType::get(Builder->getInt1Ty(),
+ auto *C = dyn_cast<ConstantInt>(Mask);
+ // We don't need a select if we know the mask bit is a 1.
+ if (!C || !C->getValue()[0]) {
+ // Cast the mask to an i1 vector and then extract the lowest element.
+ auto *MaskTy = VectorType::get(Builder->getInt1Ty(),
cast<IntegerType>(Mask->getType())->getBitWidth());
- Mask = Builder->CreateBitCast(Mask, MaskTy);
- Mask = Builder->CreateExtractElement(Mask, (uint64_t)0);
- // Extract the lowest element from the passthru operand.
- Value *Passthru = Builder->CreateExtractElement(II->getArgOperand(2),
- (uint64_t)0);
- V = Builder->CreateSelect(Mask, V, Passthru);
+ Mask = Builder->CreateBitCast(Mask, MaskTy);
+ Mask = Builder->CreateExtractElement(Mask, (uint64_t)0);
+ // Extract the lowest element from the passthru operand.
+ Value *Passthru = Builder->CreateExtractElement(II->getArgOperand(2),
+ (uint64_t)0);
+ V = Builder->CreateSelect(Mask, V, Passthru);
+ }
// Insert the result back into the original argument 0.
V = Builder->CreateInsertElement(Arg0, V, (uint64_t)0);
OpenPOWER on IntegriCloud