summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp57
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h3
2 files changed, 52 insertions, 8 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 359f617d550..1646c0fd39a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -25,6 +25,7 @@
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
@@ -1175,6 +1176,20 @@ static bool maskIsAllOneOrUndef(Value *Mask) {
return true;
}
+/// Given a mask vector <Y x i1>, return an APInt (of bitwidth Y) for each lane
+/// which may be active. TODO: This is a lot like known bits, but for
+/// vectors. Is there something we can common this with?
+static APInt possiblyDemandedEltsInMask(Value *Mask) {
+
+ const unsigned VWidth = cast<VectorType>(Mask->getType())->getNumElements();
+ APInt DemandedElts = APInt::getAllOnesValue(VWidth);
+ if (auto *CV = dyn_cast<ConstantVector>(Mask))
+ for (unsigned i = 0; i < VWidth; i++)
+ if (CV->getAggregateElement(i)->isNullValue())
+ DemandedElts.clearBit(i);
+ return DemandedElts;
+}
+
// TODO, Obvious Missing Transforms:
// * Dereferenceable address -> speculative load/select
// * Narrow width by halfs excluding zero/undef lanes
@@ -1196,14 +1211,14 @@ static Value *simplifyMaskedLoad(const IntrinsicInst &II,
// * SimplifyDemandedVectorElts
// * Single constant active lane -> store
// * Narrow width by halfs excluding zero/undef lanes
-static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) {
+Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) {
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
if (!ConstMask)
return nullptr;
// If the mask is all zeros, this instruction does nothing.
if (ConstMask->isNullValue())
- return IC.eraseInstFromFunction(II);
+ return eraseInstFromFunction(II);
// If the mask is all ones, this is a plain vector store of the 1st argument.
if (ConstMask->isAllOnesValue()) {
@@ -1212,6 +1227,15 @@ static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) {
return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment);
}
+ // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
+ APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
+ APInt UndefElts(DemandedElts.getBitWidth(), 0);
+ if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0),
+ DemandedElts, UndefElts)) {
+ II.setOperand(0, V);
+ return &II;
+ }
+
return nullptr;
}
@@ -1268,11 +1292,28 @@ static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II,
// * Single constant active lane -> store
// * Adjacent vector addresses -> masked.store
// * Narrow store width by halfs excluding zero/undef lanes
-static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) {
- // If the mask is all zeros, a scatter does nothing.
+Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) {
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
- if (ConstMask && ConstMask->isNullValue())
- return IC.eraseInstFromFunction(II);
+ if (!ConstMask)
+ return nullptr;
+
+ // If the mask is all zeros, a scatter does nothing.
+ if (ConstMask->isNullValue())
+ return eraseInstFromFunction(II);
+
+ // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
+ APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
+ APInt UndefElts(DemandedElts.getBitWidth(), 0);
+ if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0),
+ DemandedElts, UndefElts)) {
+ II.setOperand(0, V);
+ return &II;
+ }
+ if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1),
+ DemandedElts, UndefElts)) {
+ II.setOperand(1, V);
+ return &II;
+ }
return nullptr;
}
@@ -1972,11 +2013,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(CI, SimplifiedMaskedOp);
break;
case Intrinsic::masked_store:
- return simplifyMaskedStore(*II, *this);
+ return simplifyMaskedStore(*II);
case Intrinsic::masked_gather:
return simplifyMaskedGather(*II, *this);
case Intrinsic::masked_scatter:
- return simplifyMaskedScatter(*II, *this);
+ return simplifyMaskedScatter(*II);
case Intrinsic::launder_invariant_group:
case Intrinsic::strip_invariant_group:
if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index e4e6228a047..ee1853613bc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -474,6 +474,9 @@ private:
Instruction *transformCallThroughTrampoline(CallBase &Call,
IntrinsicInst &Tramp);
+ Instruction *simplifyMaskedStore(IntrinsicInst &II);
+ Instruction *simplifyMaskedScatter(IntrinsicInst &II);
+
/// Transform (zext icmp) to bitwise / integer operations in order to
/// eliminate it.
///
OpenPOWER on IntegriCloud