diff options
author | Philip Reames <listmail@philipreames.com> | 2019-03-20 18:44:58 +0000 |
---|---|---|
committer | Philip Reames <listmail@philipreames.com> | 2019-03-20 18:44:58 +0000 |
commit | e4588bbf80ab24dc593ef8af88ed93010acbc90a (patch) | |
tree | 0132ee38a540dbc3e8ab21a7ad762e6a3a5d1d73 | |
parent | af8817570456f3b7599d3f6327d00ff88efbe877 (diff) | |
download | bcm5719-llvm-e4588bbf80ab24dc593ef8af88ed93010acbc90a.tar.gz bcm5719-llvm-e4588bbf80ab24dc593ef8af88ed93010acbc90a.zip |
Simplify operands of masked stores and scatters based on demanded elements
If we know we're not storing a lane, we don't need to compute the lane. This could be improved by using the undef element result to further prune the mask, but I want to separate that into its own change since it's relatively likely to expose other problems.
Differential Revision: https://reviews.llvm.org/D57247
llvm-svn: 356590
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 57 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineInternal.h | 3 | ||||
-rw-r--r-- | llvm/test/Transforms/InstCombine/masked_intrinsics.ll | 8 |
3 files changed, 55 insertions, 13 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 359f617d550..1646c0fd39a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -25,6 +25,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -1175,6 +1176,20 @@ static bool maskIsAllOneOrUndef(Value *Mask) { return true; } +/// Given a mask vector <Y x i1>, return an APInt (of bitwidth Y) for each lane +/// which may be active. TODO: This is a lot like known bits, but for +/// vectors. Is there something we can common this with? +static APInt possiblyDemandedEltsInMask(Value *Mask) { + + const unsigned VWidth = cast<VectorType>(Mask->getType())->getNumElements(); + APInt DemandedElts = APInt::getAllOnesValue(VWidth); + if (auto *CV = dyn_cast<ConstantVector>(Mask)) + for (unsigned i = 0; i < VWidth; i++) + if (CV->getAggregateElement(i)->isNullValue()) + DemandedElts.clearBit(i); + return DemandedElts; +} + // TODO, Obvious Missing Transforms: // * Dereferenceable address -> speculative load/select // * Narrow width by halfs excluding zero/undef lanes @@ -1196,14 +1211,14 @@ static Value *simplifyMaskedLoad(const IntrinsicInst &II, // * SimplifyDemandedVectorElts // * Single constant active lane -> store // * Narrow width by halfs excluding zero/undef lanes -static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) { +Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); if (!ConstMask) return nullptr; // If the mask is all zeros, this instruction does nothing. if (ConstMask->isNullValue()) - return IC.eraseInstFromFunction(II); + return eraseInstFromFunction(II); // If the mask is all ones, this is a plain vector store of the 1st argument. if (ConstMask->isAllOnesValue()) { @@ -1212,6 +1227,15 @@ static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) { return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); } + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts + APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); + APInt UndefElts(DemandedElts.getBitWidth(), 0); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), + DemandedElts, UndefElts)) { + II.setOperand(0, V); + return &II; + } + return nullptr; } @@ -1268,11 +1292,28 @@ static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, // * Single constant active lane -> store // * Adjacent vector addresses -> masked.store // * Narrow store width by halfs excluding zero/undef lanes -static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { - // If the mask is all zeros, a scatter does nothing. +Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) { auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); - if (ConstMask && ConstMask->isNullValue()) - return IC.eraseInstFromFunction(II); + if (!ConstMask) + return nullptr; + + // If the mask is all zeros, a scatter does nothing. + if (ConstMask->isNullValue()) + return eraseInstFromFunction(II); + + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts + APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); + APInt UndefElts(DemandedElts.getBitWidth(), 0); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), + DemandedElts, UndefElts)) { + II.setOperand(0, V); + return &II; + } + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), + DemandedElts, UndefElts)) { + II.setOperand(1, V); + return &II; + } return nullptr; } @@ -1972,11 +2013,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(CI, SimplifiedMaskedOp); break; case Intrinsic::masked_store: - return simplifyMaskedStore(*II, *this); + return simplifyMaskedStore(*II); case Intrinsic::masked_gather: return simplifyMaskedGather(*II, *this); case Intrinsic::masked_scatter: - return simplifyMaskedScatter(*II, *this); + return simplifyMaskedScatter(*II); case Intrinsic::launder_invariant_group: case Intrinsic::strip_invariant_group: if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index e4e6228a047..ee1853613bc 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -474,6 +474,9 @@ private: Instruction *transformCallThroughTrampoline(CallBase &Call, IntrinsicInst &Tramp); + Instruction *simplifyMaskedStore(IntrinsicInst &II); + Instruction *simplifyMaskedScatter(IntrinsicInst &II); + /// Transform (zext icmp) to bitwise / integer operations in order to /// eliminate it. /// diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll index e685e03726c..4417ced6906 100644 --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll @@ -80,8 +80,7 @@ define void @store_onemask(<2 x double>* %ptr, <2 x double> %val) { define void @store_demandedelts(<2 x double>* %ptr, double %val) { ; CHECK-LABEL: @store_demandedelts( -; CHECK-NEXT: [[VALVEC1:%.*]] = insertelement <2 x double> undef, double [[VAL:%.*]], i32 0 -; CHECK-NEXT: [[VALVEC2:%.*]] = shufflevector <2 x double> [[VALVEC1]], <2 x double> undef, <2 x i32> zeroinitializer +; CHECK-NEXT: [[VALVEC2:%.*]] = insertelement <2 x double> undef, double [[VAL:%.*]], i32 0 ; CHECK-NEXT: call void @llvm.masked.store.v2f64.p0v2f64(<2 x double> [[VALVEC2]], <2 x double>* [[PTR:%.*]], i32 4, <2 x i1> <i1 true, i1 false>) ; CHECK-NEXT: ret void ; @@ -137,9 +136,8 @@ define void @scatter_zeromask(<2 x double*> %ptrs, <2 x double> %val) { define void @scatter_demandedelts(double* %ptr, double %val) { ; CHECK-LABEL: @scatter_demandedelts( -; CHECK-NEXT: [[PTRS:%.*]] = getelementptr double, double* [[PTR:%.*]], <2 x i64> <i64 0, i64 1> -; CHECK-NEXT: [[VALVEC1:%.*]] = insertelement <2 x double> undef, double [[VAL:%.*]], i32 0 -; CHECK-NEXT: [[VALVEC2:%.*]] = shufflevector <2 x double> [[VALVEC1]], <2 x double> undef, <2 x i32> zeroinitializer +; CHECK-NEXT: [[PTRS:%.*]] = getelementptr double, double* [[PTR:%.*]], <2 x i64> <i64 0, i64 undef> +; CHECK-NEXT: [[VALVEC2:%.*]] = insertelement <2 x double> undef, double [[VAL:%.*]], i32 0 ; CHECK-NEXT: call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> [[VALVEC2]], <2 x double*> [[PTRS]], i32 8, <2 x i1> <i1 true, i1 false>) ; CHECK-NEXT: ret void ; |