diff options
Diffstat (limited to 'llvm/lib/CodeGen/ExpandMemCmp.cpp')
-rw-r--r-- | llvm/lib/CodeGen/ExpandMemCmp.cpp | 233 |
1 files changed, 137 insertions, 96 deletions
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp index d7562cbf1e9..ee7683adbcd 100644 --- a/llvm/lib/CodeGen/ExpandMemCmp.cpp +++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp @@ -66,23 +66,18 @@ class MemCmpExpansion { // Represents the decomposition in blocks of the expansion. For example, // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. - // TODO(courbet): Involve the target more in this computation. On X86, 7 - // bytes can be done more efficiently with two overlaping 4-byte loads than - // covering the interval with [{4, 0},{2, 4},{1, 6}}. struct LoadEntry { LoadEntry(unsigned LoadSize, uint64_t Offset) : LoadSize(LoadSize), Offset(Offset) { - assert(Offset % LoadSize == 0 && "invalid load entry"); } - uint64_t getGEPIndex() const { return Offset / LoadSize; } - // The size of the load for this block, in bytes. - const unsigned LoadSize; - // The offset of this load WRT the base pointer, in bytes. - const uint64_t Offset; + unsigned LoadSize; + // The offset of this load from the base pointer, in bytes. + uint64_t Offset; }; - SmallVector<LoadEntry, 8> LoadSequence; + using LoadEntryVector = SmallVector<LoadEntry, 8>; + LoadEntryVector LoadSequence; void createLoadCmpBlocks(); void createResultBlock(); @@ -92,13 +87,23 @@ class MemCmpExpansion { void emitLoadCompareBlock(unsigned BlockIndex); void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, unsigned &LoadIndex); - void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex); + void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes); void emitMemCmpResultBlock(); Value *getMemCmpExpansionZeroCase(); Value *getMemCmpEqZeroOneBlock(); Value *getMemCmpOneBlock(); + Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType, + uint64_t OffsetBytes); + + static LoadEntryVector + computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, + unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); + static LoadEntryVector + computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize, + unsigned MaxNumLoads, + unsigned &NumLoadsNonOneByte); - public: +public: MemCmpExpansion(CallInst *CI, uint64_t Size, const TargetTransformInfo::MemCmpExpansionOptions &Options, unsigned MaxNumLoads, const bool IsUsedForZeroCmp, @@ -110,6 +115,76 @@ class MemCmpExpansion { Value *getMemCmpExpansion(); }; +MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence( + uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, + const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) { + NumLoadsNonOneByte = 0; + LoadEntryVector LoadSequence; + uint64_t Offset = 0; + while (Size && !LoadSizes.empty()) { + const unsigned LoadSize = LoadSizes.front(); + const uint64_t NumLoadsForThisSize = Size / LoadSize; + if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { + // Do not expand if the total number of loads is larger than what the + // target allows. Note that it's important that we exit before completing + // the expansion to avoid using a ton of memory to store the expansion for + // large sizes. + return {}; + } + if (NumLoadsForThisSize > 0) { + for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { + LoadSequence.push_back({LoadSize, Offset}); + Offset += LoadSize; + } + if (LoadSize > 1) + ++NumLoadsNonOneByte; + Size = Size % LoadSize; + } + LoadSizes = LoadSizes.drop_front(); + } + return LoadSequence; +} + +MemCmpExpansion::LoadEntryVector +MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, + const unsigned MaxLoadSize, + const unsigned MaxNumLoads, + unsigned &NumLoadsNonOneByte) { + // These are already handled by the greedy approach. + if (Size < 2 || MaxLoadSize < 2) + return {}; + + // We try to do as many non-overlapping loads as possible starting from the + // beginning. + const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize; + assert(NumNonOverlappingLoads && "there must be at least one load"); + // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with + // an overlapping load. + Size = Size - NumNonOverlappingLoads * MaxLoadSize; + // Bail if we do not need an overloapping store, this is already handled by + // the greedy approach. + if (Size == 0) + return {}; + // Bail if the number of loads (non-overlapping + potential overlapping one) + // is larger than the max allowed. + if ((NumNonOverlappingLoads + 1) > MaxNumLoads) + return {}; + + // Add non-overlapping loads. + LoadEntryVector LoadSequence; + uint64_t Offset = 0; + for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) { + LoadSequence.push_back({MaxLoadSize, Offset}); + Offset += MaxLoadSize; + } + + // Add the last overlapping load. + assert(Size > 0 && Size < MaxLoadSize && "broken invariant"); + LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)}); + NumLoadsNonOneByte = 1; + return LoadSequence; +} + // Initialize the basic block structure required for expansion of memcmp call // with given maximum load size and memcmp size parameter. // This structure includes: @@ -133,38 +208,31 @@ MemCmpExpansion::MemCmpExpansion( Builder(CI) { assert(Size > 0 && "zero blocks"); // Scale the max size down if the target can load more bytes than we need. - size_t LoadSizeIndex = 0; - while (LoadSizeIndex < Options.LoadSizes.size() && - Options.LoadSizes[LoadSizeIndex] > Size) { - ++LoadSizeIndex; + llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes); + while (!LoadSizes.empty() && LoadSizes.front() > Size) { + LoadSizes = LoadSizes.drop_front(); } - this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex]; + assert(!LoadSizes.empty() && "cannot load Size bytes"); + MaxLoadSize = LoadSizes.front(); // Compute the decomposition. - uint64_t CurSize = Size; - uint64_t Offset = 0; - while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) { - const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex]; - assert(LoadSize > 0 && "zero load size"); - const uint64_t NumLoadsForThisSize = CurSize / LoadSize; - if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { - // Do not expand if the total number of loads is larger than what the - // target allows. Note that it's important that we exit before completing - // the expansion to avoid using a ton of memory to store the expansion for - // large sizes. - LoadSequence.clear(); - return; - } - if (NumLoadsForThisSize > 0) { - for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { - LoadSequence.push_back({LoadSize, Offset}); - Offset += LoadSize; - } - if (LoadSize > 1) { - ++NumLoadsNonOneByte; - } - CurSize = CurSize % LoadSize; + unsigned GreedyNumLoadsNonOneByte = 0; + LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, MaxNumLoads, + GreedyNumLoadsNonOneByte); + NumLoadsNonOneByte = GreedyNumLoadsNonOneByte; + assert(LoadSequence.size() <= MaxNumLoads && "broken invariant"); + // If we allow overlapping loads and the load sequence is not already optimal, + // use overlapping loads. + if (Options.AllowOverlappingLoads && + (LoadSequence.empty() || LoadSequence.size() > 2)) { + unsigned OverlappingNumLoadsNonOneByte = 0; + auto OverlappingLoads = computeOverlappingLoadSequence( + Size, MaxLoadSize, MaxNumLoads, OverlappingNumLoadsNonOneByte); + if (!OverlappingLoads.empty() && + (LoadSequence.empty() || + OverlappingLoads.size() < LoadSequence.size())) { + LoadSequence = OverlappingLoads; + NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte; } - ++LoadSizeIndex; } assert(LoadSequence.size() <= MaxNumLoads && "broken invariant"); } @@ -189,30 +257,32 @@ void MemCmpExpansion::createResultBlock() { EndBlock->getParent(), EndBlock); } +/// Return a pointer to an element of type `LoadSizeType` at offset +/// `OffsetBytes`. +Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source, + Type *LoadSizeType, + uint64_t OffsetBytes) { + if (OffsetBytes > 0) { + auto *ByteType = Type::getInt8Ty(CI->getContext()); + Source = Builder.CreateGEP( + ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()), + ConstantInt::get(ByteType, OffsetBytes)); + } + return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo()); +} + // This function creates the IR instructions for loading and comparing 1 byte. // It loads 1 byte from each source of the memcmp parameters with the given // GEPIndex. It then subtracts the two loaded values and adds this result to the // final phi node for selecting the memcmp result. void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, - unsigned GEPIndex) { - Value *Source1 = CI->getArgOperand(0); - Value *Source2 = CI->getArgOperand(1); - + unsigned OffsetBytes) { Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); - // Cast source to LoadSizeType*. - if (Source1->getType() != LoadSizeType) - Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); - if (Source2->getType() != LoadSizeType) - Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - - // Get the base address using the GEPIndex. - if (GEPIndex != 0) { - Source1 = Builder.CreateGEP(LoadSizeType, Source1, - ConstantInt::get(LoadSizeType, GEPIndex)); - Source2 = Builder.CreateGEP(LoadSizeType, Source2, - ConstantInt::get(LoadSizeType, GEPIndex)); - } + Value *Source1 = + getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes); + Value *Source2 = + getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes); Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); @@ -270,24 +340,10 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, IntegerType *LoadSizeType = IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); - Value *Source1 = CI->getArgOperand(0); - Value *Source2 = CI->getArgOperand(1); - - // Cast source to LoadSizeType*. - if (Source1->getType() != LoadSizeType) - Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); - if (Source2->getType() != LoadSizeType) - Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - - // Get the base address using a GEP. - if (CurLoadEntry.Offset != 0) { - Source1 = Builder.CreateGEP( - LoadSizeType, Source1, - ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); - Source2 = Builder.CreateGEP( - LoadSizeType, Source2, - ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); - } + Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, + CurLoadEntry.Offset); + Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, + CurLoadEntry.Offset); // Get a constant or load a value for each source address. Value *LoadSrc1 = nullptr; @@ -378,8 +434,7 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; if (CurLoadEntry.LoadSize == 1) { - MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, - CurLoadEntry.getGEPIndex()); + MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset); return; } @@ -388,25 +443,12 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); - Value *Source1 = CI->getArgOperand(0); - Value *Source2 = CI->getArgOperand(1); - Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); - // Cast source to LoadSizeType*. - if (Source1->getType() != LoadSizeType) - Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); - if (Source2->getType() != LoadSizeType) - Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - // Get the base address using a GEP. - if (CurLoadEntry.Offset != 0) { - Source1 = Builder.CreateGEP( - LoadSizeType, Source1, - ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); - Source2 = Builder.CreateGEP( - LoadSizeType, Source2, - ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); - } + Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, + CurLoadEntry.Offset); + Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, + CurLoadEntry.Offset); // Load LoadSizeType from the base address. Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); @@ -694,7 +736,6 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, if (SizeVal == 0) { return false; } - // TTI call to check if target would like to expand memcmp. Also, get the // available load sizes. const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); |