diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp | 55 |
1 files changed, 37 insertions, 18 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 4a6a35c0ab1..9051b7ceb3a 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -780,6 +780,41 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, return SE->getMinusSCEV(Start, Index); } +/// Compute the number of bytes as a SCEV from the backedge taken count. +/// +/// This also maps the SCEV into the provided type and tries to handle the +/// computation in a way that will fold cleanly. +static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, + unsigned StoreSize, Loop *CurLoop, + const DataLayout *DL, ScalarEvolution *SE) { + const SCEV *NumBytesS; + // The # stored bytes is (BECount+1)*Size. Expand the trip count out to + // pointer size if it isn't already. + // + // If we're going to need to zero extend the BE count, check if we can add + // one to it prior to zero extending without overflow. Provided this is safe, + // it allows better simplification of the +1. + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && + SE->isLoopEntryGuardedByCond( + CurLoop, ICmpInst::ICMP_NE, BECount, + SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { + NumBytesS = SE->getZeroExtendExpr( + SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), + IntPtr); + } else { + NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); + } + + // And scale it based on the store size. + if (StoreSize != 1) { + NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), + SCEV::FlagNUW); + } + return NumBytesS; +} + /// processLoopStridedStore - We see a strided store of some value. If we can /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( @@ -837,16 +872,8 @@ bool LoopIdiomRecognize::processLoopStridedStore( // Okay, everything looks good, insert the memset. - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - BECount = SE->getTruncateOrZeroExtend(BECount, IntPtr); - const SCEV *NumBytesS = - SE->getAddExpr(BECount, SE->getOne(IntPtr), SCEV::FlagNUW); - if (StoreSize != 1) { - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), - SCEV::FlagNUW); - } + getNumBytes(BECount, IntPtr, StoreSize, CurLoop, DL, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -976,16 +1003,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // Okay, everything is safe, we can transform this! - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy); - const SCEV *NumBytesS = - SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); - - if (StoreSize != 1) - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), - SCEV::FlagNUW); + getNumBytes(BECount, IntPtrTy, StoreSize, CurLoop, DL, SE); Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator()); |