diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils')
| -rw-r--r-- | llvm/lib/Transforms/Utils/CodeExtractor.cpp | 126 | 
1 files changed, 110 insertions, 16 deletions
| diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 8d0bc036d72..c514c9c9cd4 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -17,6 +17,9 @@  #include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/SetVector.h"  #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h"  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/RegionInfo.h"  #include "llvm/Analysis/RegionIterator.h" @@ -26,9 +29,11 @@  #include "llvm/IR/Instructions.h"  #include "llvm/IR/Intrinsics.h"  #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h"  #include "llvm/IR/Module.h"  #include "llvm/IR/Verifier.h"  #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/ErrorHandling.h" @@ -119,23 +124,30 @@ buildExtractionBlockSet(const RegionNode &RN) {    return buildExtractionBlockSet(R.block_begin(), R.block_end());  } -CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) -  : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), -    Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} +CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, +                             BlockFrequencyInfo *BFI, +                             BranchProbabilityInfo *BPI) +    : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), +      BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}  CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, -                             bool AggregateArgs) -  : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), -    Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} - -CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) -  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), -    Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} +                             bool AggregateArgs, BlockFrequencyInfo *BFI, +                             BranchProbabilityInfo *BPI) +    : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), +      BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} + +CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, +                             BlockFrequencyInfo *BFI, +                             BranchProbabilityInfo *BPI) +    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), +      BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())), +      NumExitBlocks(~0U) {}  CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, -                             bool AggregateArgs) -  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), -    Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} +                             bool AggregateArgs, BlockFrequencyInfo *BFI, +                             BranchProbabilityInfo *BPI) +    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), +      BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}  /// definedInRegion - Return true if the specified value is defined in the  /// extracted region. @@ -687,6 +699,51 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) {    }  } +void CodeExtractor::calculateNewCallTerminatorWeights( +    BasicBlock *CodeReplacer, +    DenseMap<BasicBlock *, BlockFrequency> &ExitWeights, +    BranchProbabilityInfo *BPI) { +  typedef BlockFrequencyInfoImplBase::Distribution Distribution; +  typedef BlockFrequencyInfoImplBase::BlockNode BlockNode; + +  // Update the branch weights for the exit block. +  TerminatorInst *TI = CodeReplacer->getTerminator(); +  SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0); + +  // Block Frequency distribution with dummy node. +  Distribution BranchDist; + +  // Add each of the frequencies of the successors. +  for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) { +    BlockNode ExitNode(i); +    uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency(); +    if (ExitFreq != 0) +      BranchDist.addExit(ExitNode, ExitFreq); +    else +      BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero()); +  } + +  // Check for no total weight. +  if (BranchDist.Total == 0) +    return; + +  // Normalize the distribution so that they can fit in unsigned. +  BranchDist.normalize(); + +  // Create normalized branch weights and set the metadata. +  for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { +    const auto &Weight = BranchDist.Weights[I]; + +    // Get the weight and update the current BFI. +    BranchWeights[Weight.TargetNode.Index] = Weight.Amount; +    BranchProbability BP(Weight.Amount, BranchDist.Total); +    BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP); +  } +  TI->setMetadata( +      LLVMContext::MD_prof, +      MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); +} +  Function *CodeExtractor::extractCodeRegion() {    if (!isEligible())      return nullptr; @@ -697,6 +754,19 @@ Function *CodeExtractor::extractCodeRegion() {    // block in the region.    BasicBlock *header = *Blocks.begin(); +  // Calculate the entry frequency of the new function before we change the root +  //   block. +  BlockFrequency EntryFreq; +  if (BFI) { +    assert(BPI && "Both BPI and BFI are required to preserve profile info"); +    for (BasicBlock *Pred : predecessors(header)) { +      if (Blocks.count(Pred)) +        continue; +      EntryFreq += +          BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); +    } +  } +    // If we have to split PHI nodes or the entry block, do so now.    severSplitPHINodes(header); @@ -720,12 +790,23 @@ Function *CodeExtractor::extractCodeRegion() {    // Find inputs to, outputs from the code region.    findInputsOutputs(inputs, outputs); +  // Calculate the exit blocks for the extracted region and the total exit +  //  weights for each of those blocks. +  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;    SmallPtrSet<BasicBlock *, 1> ExitBlocks; -  for (BasicBlock *Block : Blocks) +  for (BasicBlock *Block : Blocks) {      for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; -         ++SI) -      if (!Blocks.count(*SI)) +         ++SI) { +      if (!Blocks.count(*SI)) { +        // Update the branch weight for this successor. +        if (BFI) { +          BlockFrequency &BF = ExitWeights[*SI]; +          BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); +        }          ExitBlocks.insert(*SI); +      } +    } +  }    NumExitBlocks = ExitBlocks.size();    // Construct new function based on inputs/outputs & add allocas for all defs. @@ -734,10 +815,23 @@ Function *CodeExtractor::extractCodeRegion() {                                              codeReplacer, oldFunction,                                              oldFunction->getParent()); +  // Update the entry count of the function. +  if (BFI) { +    Optional<uint64_t> EntryCount = +        BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); +    if (EntryCount.hasValue()) +      newFunction->setEntryCount(EntryCount.getValue()); +    BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); +  } +    emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);    moveCodeToFunction(newFunction); +  // Update the branch weights for the exit block. +  if (BFI && NumExitBlocks > 1) +    calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); +    // Loop over all of the PHI nodes in the header block, and change any    // references to the old incoming edge to be the new incoming edge.    for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) { | 

