diff options
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/ExpandMemCmp.cpp | 233 |
1 files changed, 96 insertions, 137 deletions
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp index ee7683adbcd..d7562cbf1e9 100644 --- a/llvm/lib/CodeGen/ExpandMemCmp.cpp +++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp @@ -66,18 +66,23 @@ class MemCmpExpansion { // Represents the decomposition in blocks of the expansion. For example, // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. + // TODO(courbet): Involve the target more in this computation. On X86, 7 + // bytes can be done more efficiently with two overlaping 4-byte loads than + // covering the interval with [{4, 0},{2, 4},{1, 6}}. struct LoadEntry { LoadEntry(unsigned LoadSize, uint64_t Offset) : LoadSize(LoadSize), Offset(Offset) { + assert(Offset % LoadSize == 0 && "invalid load entry"); } + uint64_t getGEPIndex() const { return Offset / LoadSize; } + // The size of the load for this block, in bytes. - unsigned LoadSize; - // The offset of this load from the base pointer, in bytes. - uint64_t Offset; + const unsigned LoadSize; + // The offset of this load WRT the base pointer, in bytes. + const uint64_t Offset; }; - using LoadEntryVector = SmallVector<LoadEntry, 8>; - LoadEntryVector LoadSequence; + SmallVector<LoadEntry, 8> LoadSequence; void createLoadCmpBlocks(); void createResultBlock(); @@ -87,23 +92,13 @@ class MemCmpExpansion { void emitLoadCompareBlock(unsigned BlockIndex); void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, unsigned &LoadIndex); - void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes); + void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex); void emitMemCmpResultBlock(); Value *getMemCmpExpansionZeroCase(); Value *getMemCmpEqZeroOneBlock(); Value *getMemCmpOneBlock(); - Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType, - uint64_t OffsetBytes); - - static LoadEntryVector - computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, - unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); - static LoadEntryVector - computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize, - unsigned MaxNumLoads, - unsigned &NumLoadsNonOneByte); -public: + public: MemCmpExpansion(CallInst *CI, uint64_t Size, const TargetTransformInfo::MemCmpExpansionOptions &Options, unsigned MaxNumLoads, const bool IsUsedForZeroCmp, @@ -115,76 +110,6 @@ public: Value *getMemCmpExpansion(); }; -MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence( - uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, - const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) { - NumLoadsNonOneByte = 0; - LoadEntryVector LoadSequence; - uint64_t Offset = 0; - while (Size && !LoadSizes.empty()) { - const unsigned LoadSize = LoadSizes.front(); - const uint64_t NumLoadsForThisSize = Size / LoadSize; - if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { - // Do not expand if the total number of loads is larger than what the - // target allows. Note that it's important that we exit before completing - // the expansion to avoid using a ton of memory to store the expansion for - // large sizes. - return {}; - } - if (NumLoadsForThisSize > 0) { - for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { - LoadSequence.push_back({LoadSize, Offset}); - Offset += LoadSize; - } - if (LoadSize > 1) - ++NumLoadsNonOneByte; - Size = Size % LoadSize; - } - LoadSizes = LoadSizes.drop_front(); - } - return LoadSequence; -} - -MemCmpExpansion::LoadEntryVector -MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, - const unsigned MaxLoadSize, - const unsigned MaxNumLoads, - unsigned &NumLoadsNonOneByte) { - // These are already handled by the greedy approach. - if (Size < 2 || MaxLoadSize < 2) - return {}; - - // We try to do as many non-overlapping loads as possible starting from the - // beginning. - const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize; - assert(NumNonOverlappingLoads && "there must be at least one load"); - // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with - // an overlapping load. - Size = Size - NumNonOverlappingLoads * MaxLoadSize; - // Bail if we do not need an overloapping store, this is already handled by - // the greedy approach. - if (Size == 0) - return {}; - // Bail if the number of loads (non-overlapping + potential overlapping one) - // is larger than the max allowed. - if ((NumNonOverlappingLoads + 1) > MaxNumLoads) - return {}; - - // Add non-overlapping loads. - LoadEntryVector LoadSequence; - uint64_t Offset = 0; - for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) { - LoadSequence.push_back({MaxLoadSize, Offset}); - Offset += MaxLoadSize; - } - - // Add the last overlapping load. - assert(Size > 0 && Size < MaxLoadSize && "broken invariant"); - LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)}); - NumLoadsNonOneByte = 1; - return LoadSequence; -} - // Initialize the basic block structure required for expansion of memcmp call // with given maximum load size and memcmp size parameter. // This structure includes: @@ -208,31 +133,38 @@ MemCmpExpansion::MemCmpExpansion( Builder(CI) { assert(Size > 0 && "zero blocks"); // Scale the max size down if the target can load more bytes than we need. - llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes); - while (!LoadSizes.empty() && LoadSizes.front() > Size) { - LoadSizes = LoadSizes.drop_front(); + size_t LoadSizeIndex = 0; + while (LoadSizeIndex < Options.LoadSizes.size() && + Options.LoadSizes[LoadSizeIndex] > Size) { + ++LoadSizeIndex; } - assert(!LoadSizes.empty() && "cannot load Size bytes"); - MaxLoadSize = LoadSizes.front(); + this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex]; // Compute the decomposition. - unsigned GreedyNumLoadsNonOneByte = 0; - LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, MaxNumLoads, - GreedyNumLoadsNonOneByte); - NumLoadsNonOneByte = GreedyNumLoadsNonOneByte; - assert(LoadSequence.size() <= MaxNumLoads && "broken invariant"); - // If we allow overlapping loads and the load sequence is not already optimal, - // use overlapping loads. - if (Options.AllowOverlappingLoads && - (LoadSequence.empty() || LoadSequence.size() > 2)) { - unsigned OverlappingNumLoadsNonOneByte = 0; - auto OverlappingLoads = computeOverlappingLoadSequence( - Size, MaxLoadSize, MaxNumLoads, OverlappingNumLoadsNonOneByte); - if (!OverlappingLoads.empty() && - (LoadSequence.empty() || - OverlappingLoads.size() < LoadSequence.size())) { - LoadSequence = OverlappingLoads; - NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte; + uint64_t CurSize = Size; + uint64_t Offset = 0; + while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) { + const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex]; + assert(LoadSize > 0 && "zero load size"); + const uint64_t NumLoadsForThisSize = CurSize / LoadSize; + if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { + // Do not expand if the total number of loads is larger than what the + // target allows. Note that it's important that we exit before completing + // the expansion to avoid using a ton of memory to store the expansion for + // large sizes. + LoadSequence.clear(); + return; } + if (NumLoadsForThisSize > 0) { + for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { + LoadSequence.push_back({LoadSize, Offset}); + Offset += LoadSize; + } + if (LoadSize > 1) { + ++NumLoadsNonOneByte; + } + CurSize = CurSize % LoadSize; + } + ++LoadSizeIndex; } assert(LoadSequence.size() <= MaxNumLoads && "broken invariant"); } @@ -257,32 +189,30 @@ void MemCmpExpansion::createResultBlock() { EndBlock->getParent(), EndBlock); } -/// Return a pointer to an element of type `LoadSizeType` at offset -/// `OffsetBytes`. -Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source, - Type *LoadSizeType, - uint64_t OffsetBytes) { - if (OffsetBytes > 0) { - auto *ByteType = Type::getInt8Ty(CI->getContext()); - Source = Builder.CreateGEP( - ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()), - ConstantInt::get(ByteType, OffsetBytes)); - } - return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo()); -} - // This function creates the IR instructions for loading and comparing 1 byte. // It loads 1 byte from each source of the memcmp parameters with the given // GEPIndex. It then subtracts the two loaded values and adds this result to the // final phi node for selecting the memcmp result. void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, - unsigned OffsetBytes) { + unsigned GEPIndex) { + Value *Source1 = CI->getArgOperand(0); + Value *Source2 = CI->getArgOperand(1); + Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); - Value *Source1 = - getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes); - Value *Source2 = - getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes); + // Cast source to LoadSizeType*. + if (Source1->getType() != LoadSizeType) + Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); + if (Source2->getType() != LoadSizeType) + Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); + + // Get the base address using the GEPIndex. + if (GEPIndex != 0) { + Source1 = Builder.CreateGEP(LoadSizeType, Source1, + ConstantInt::get(LoadSizeType, GEPIndex)); + Source2 = Builder.CreateGEP(LoadSizeType, Source2, + ConstantInt::get(LoadSizeType, GEPIndex)); + } Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); @@ -340,10 +270,24 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, IntegerType *LoadSizeType = IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); - Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, - CurLoadEntry.Offset); - Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, - CurLoadEntry.Offset); + Value *Source1 = CI->getArgOperand(0); + Value *Source2 = CI->getArgOperand(1); + + // Cast source to LoadSizeType*. + if (Source1->getType() != LoadSizeType) + Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); + if (Source2->getType() != LoadSizeType) + Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); + + // Get the base address using a GEP. + if (CurLoadEntry.Offset != 0) { + Source1 = Builder.CreateGEP( + LoadSizeType, Source1, + ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); + Source2 = Builder.CreateGEP( + LoadSizeType, Source2, + ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); + } // Get a constant or load a value for each source address. Value *LoadSrc1 = nullptr; @@ -434,7 +378,8 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; if (CurLoadEntry.LoadSize == 1) { - MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset); + MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, + CurLoadEntry.getGEPIndex()); return; } @@ -443,12 +388,25 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); + Value *Source1 = CI->getArgOperand(0); + Value *Source2 = CI->getArgOperand(1); + Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); + // Cast source to LoadSizeType*. + if (Source1->getType() != LoadSizeType) + Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); + if (Source2->getType() != LoadSizeType) + Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, - CurLoadEntry.Offset); - Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, - CurLoadEntry.Offset); + // Get the base address using a GEP. + if (CurLoadEntry.Offset != 0) { + Source1 = Builder.CreateGEP( + LoadSizeType, Source1, + ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); + Source2 = Builder.CreateGEP( + LoadSizeType, Source2, + ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); + } // Load LoadSizeType from the base address. Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); @@ -736,6 +694,7 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, if (SizeVal == 0) { return false; } + // TTI call to check if target would like to expand memcmp. Also, get the // available load sizes. const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); |