diff options
Diffstat (limited to 'llvm/lib/Transforms')
| -rw-r--r-- | llvm/lib/Transforms/Utils/CodeExtractor.cpp | 96 | 
1 files changed, 39 insertions, 57 deletions
| diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 1ad610d80af..68cf7310c75 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -26,16 +26,11 @@  #include "Support/Debug.h"  #include "Support/StringExtras.h"  #include <algorithm> -#include <map> -#include <vector> +#include <set>  using namespace llvm;  namespace { -  inline bool contains(const std::vector<BasicBlock*> &V, const BasicBlock *BB){ -    return std::find(V.begin(), V.end(), BB) != V.end(); -  } -    /// getFunctionArg - Return a pointer to F's ARGNOth argument.    ///    Argument *getFunctionArg(Function *F, unsigned argno) { @@ -49,19 +44,16 @@ namespace {      typedef std::vector<std::pair<unsigned, unsigned> > PhiValChangesTy;      typedef std::map<PHINode*, PhiValChangesTy> PhiVal2ArgTy;      PhiVal2ArgTy PhiVal2Arg; - +    std::set<BasicBlock*> BlocksToExtract;    public:      Function *ExtractCodeRegion(const std::vector<BasicBlock*> &code);    private: -    void findInputsOutputs(const std::vector<BasicBlock*> &code, -                           Values &inputs, -                           Values &outputs, +    void findInputsOutputs(Values &inputs, Values &outputs,                             BasicBlock *newHeader,                             BasicBlock *newRootNode);      void processPhiNodeInputs(PHINode *Phi, -                              const std::vector<BasicBlock*> &code,                                Values &inputs,                                BasicBlock *newHeader,                                BasicBlock *newRootNode); @@ -71,15 +63,12 @@ namespace {      Function *constructFunction(const Values &inputs,                                  const Values &outputs,                                  BasicBlock *newRootNode, BasicBlock *newHeader, -                                const std::vector<BasicBlock*> &code,                                  Function *oldFunction, Module *M); -    void moveCodeToFunction(const std::vector<BasicBlock*> &code, -                            Function *newFunction); +    void moveCodeToFunction(Function *newFunction);      void emitCallAndSwitchStatement(Function *newFunction,                                      BasicBlock *newHeader, -                                    const std::vector<BasicBlock*> &code,                                      Values &inputs,                                      Values &outputs); @@ -87,7 +76,6 @@ namespace {  }  void CodeExtractor::processPhiNodeInputs(PHINode *Phi, -                                         const std::vector<BasicBlock*> &code,                                           Values &inputs,                                           BasicBlock *codeReplacer,                                           BasicBlock *newFuncRoot) @@ -102,11 +90,11 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi,    for (unsigned i = 0, e = Phi->getNumIncomingValues(); i != e; ++i) {      Value *phiVal = Phi->getIncomingValue(i);      if (Instruction *Inst = dyn_cast<Instruction>(phiVal)) { -      if (contains(code, Inst->getParent())) { -        if (!contains(code, Phi->getIncomingBlock(i))) +      if (BlocksToExtract.count(Inst->getParent())) { +        if (!BlocksToExtract.count(Phi->getIncomingBlock(i)))            IValEBB.push_back(i);        } else { -        if (contains(code, Phi->getIncomingBlock(i))) +        if (BlocksToExtract.count(Phi->getIncomingBlock(i)))            EValIBB.push_back(i);          else            EValEBB.push_back(i); @@ -114,11 +102,11 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi,      } else if (Constant *Const = dyn_cast<Constant>(phiVal)) {        // Constants are internal, but considered `external' if they are coming        // from an external block. -      if (!contains(code, Phi->getIncomingBlock(i))) +      if (!BlocksToExtract.count(Phi->getIncomingBlock(i)))          EValEBB.push_back(i);      } else if (Argument *Arg = dyn_cast<Argument>(phiVal)) {        // arguments are external -      if (contains(code, Phi->getIncomingBlock(i))) +      if (BlocksToExtract.count(Phi->getIncomingBlock(i)))          EValIBB.push_back(i);        else          EValEBB.push_back(i); @@ -184,14 +172,13 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi,  } -void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code, -                                      Values &inputs, +void CodeExtractor::findInputsOutputs(Values &inputs,                                        Values &outputs,                                        BasicBlock *newHeader,                                        BasicBlock *newRootNode)  { -  for (std::vector<BasicBlock*>::const_iterator ci = code.begin(),  -       ce = code.end(); ci != ce; ++ci) { +  for (std::set<BasicBlock*>::const_iterator ci = BlocksToExtract.begin(),  +       ce = BlocksToExtract.end(); ci != ce; ++ci) {      BasicBlock *BB = *ci;      for (BasicBlock::iterator BBi = BB->begin(), BBe = BB->end();           BBi != BBe; ++BBi) { @@ -200,7 +187,7 @@ void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code,        if (Instruction *I = dyn_cast<Instruction>(&*BBi)) {          // If it's a phi node          if (PHINode *Phi = dyn_cast<PHINode>(I)) { -          processPhiNodeInputs(Phi, code, inputs, newHeader, newRootNode); +          processPhiNodeInputs(Phi, inputs, newHeader, newRootNode);          } else {            // All other instructions go through the generic input finder            // Loop over the operands of each instruction (inputs) @@ -208,7 +195,7 @@ void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code,                 op != opE; ++op) {              if (Instruction *opI = dyn_cast<Instruction>(op->get())) {                // Check if definition of this operand is within the loop -              if (!contains(code, opI->getParent())) { +              if (!BlocksToExtract.count(opI->getParent())) {                  // add this operand to the inputs                  inputs.push_back(opI);                } @@ -220,7 +207,7 @@ void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code,          for (Value::use_iterator use = I->use_begin(), useE = I->use_end();               use != useE; ++use) {            if (Instruction* inst = dyn_cast<Instruction>(*use)) { -            if (!contains(code, inst->getParent())) { +            if (!BlocksToExtract.count(inst->getParent())) {                // add this op to the outputs                outputs.push_back(I);              } @@ -276,11 +263,10 @@ Function *CodeExtractor::constructFunction(const Values &inputs,                                             const Values &outputs,                                             BasicBlock *newRootNode,                                             BasicBlock *newHeader, -                                           const std::vector<BasicBlock*> &code,                                             Function *oldFunction, Module *M) {    DEBUG(std::cerr << "inputs: " << inputs.size() << "\n");    DEBUG(std::cerr << "outputs: " << outputs.size() << "\n"); -  BasicBlock *header = code[0]; +  BasicBlock *header = *BlocksToExtract.begin();    // This function returns unsigned, outputs will go back by reference.    Type *retTy = Type::UShortTy; @@ -327,7 +313,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs,      for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end();           use != useE; ++use)        if (Instruction* inst = dyn_cast<Instruction>(*use)) -        if (contains(code, inst->getParent())) +        if (BlocksToExtract.count(inst->getParent()))            inst->replaceUsesOfWith(inputs[i], getFunctionArg(newFunction, i));    } @@ -339,7 +325,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs,         i != e; ++i) {      if (BranchInst *inst = dyn_cast<BranchInst>(*i)) {        BasicBlock *BB = inst->getParent(); -      if (!contains(code, BB) && BB->getParent() == oldFunction) { +      if (!BlocksToExtract.count(BB) && BB->getParent() == oldFunction) {          // The BasicBlock which contains the branch is not in the region          // modify the branch target to a new block          inst->replaceUsesOfWith(header, newHeader); @@ -350,29 +336,25 @@ Function *CodeExtractor::constructFunction(const Values &inputs,    return newFunction;  } -void CodeExtractor::moveCodeToFunction(const std::vector<BasicBlock*> &code, -                                       Function *newFunction) +void CodeExtractor::moveCodeToFunction(Function *newFunction)  { -  Function *oldFunc = code[0]->getParent(); +  Function *oldFunc = (*BlocksToExtract.begin())->getParent();    Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();      Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); -  for (std::vector<BasicBlock*>::const_iterator i = code.begin(), e =code.end(); -       i != e; ++i) { -    BasicBlock *BB = *i; - +  for (std::set<BasicBlock*>::const_iterator i = BlocksToExtract.begin(), +         e = BlocksToExtract.end(); i != e; ++i) {      // Delete the basic block from the old function, and the list of blocks -    oldBlocks.remove(BB); +    oldBlocks.remove(*i);      // Insert this basic block into the new function -    newBlocks.push_back(BB); +    newBlocks.push_back(*i);    }  }  void  CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,                                            BasicBlock *codeReplacer, -                                          const std::vector<BasicBlock*> &code,                                            Values &inputs,                                            Values &outputs)  { @@ -399,7 +381,7 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,        for (std::vector<User*>::iterator use = Users.begin(), useE =Users.end();             use != useE; ++use) {          if (Instruction* inst = dyn_cast<Instruction>(*use)) { -          if (!contains(code, inst->getParent())) { +          if (!BlocksToExtract.count(inst->getParent())) {              inst->replaceUsesOfWith(*i, load);            }          } @@ -425,8 +407,8 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,    // Since there may be multiple exits from the original region, make the new    // function return an unsigned, switch on that number    unsigned switchVal = 0; -  for (std::vector<BasicBlock*>::const_iterator i =code.begin(), e = code.end(); -       i != e; ++i) { +  for (std::set<BasicBlock*>::const_iterator i = BlocksToExtract.begin(), +         e = BlocksToExtract.end(); i != e; ++i) {      BasicBlock *BB = *i;      // rewrite the terminator of the original BasicBlock @@ -436,16 +418,14 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,        // Restore values just before we exit        // FIXME: Use a GetElementPtr to bunch the outputs in a struct        for (unsigned outIdx = 0, outE = outputs.size(); outIdx != outE; ++outIdx) -      {          new StoreInst(outputs[outIdx],                        getFunctionArg(newFunction, outIdx),                        brInst); -      }        // Rewrite branches into exits which return a value based on which        // exit we take from this function        if (brInst->isUnconditional()) { -        if (!contains(code, brInst->getSuccessor(0))) { +        if (!BlocksToExtract.count(brInst->getSuccessor(0))) {            ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal);            ReturnInst *newRet = new ReturnInst(brVal);            // add a new target to the switch @@ -461,7 +441,7 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,          // to two new blocks, each of which returns a different code.          for (unsigned idx = 0; idx < 2; ++idx) {            BasicBlock *oldTarget = brInst->getSuccessor(idx); -          if (!contains(code, oldTarget)) { +          if (!BlocksToExtract.count(oldTarget)) {              // add a new basic block which returns the appropriate value              BasicBlock *newTarget = new BasicBlock("newTarget", newFunction);              ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal); @@ -475,13 +455,15 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,            }          }        } +    } else if (SwitchInst *swTerm = dyn_cast<SwitchInst>(term)) { + +      assert(0 && "Cannot handle switch instructions just yet."); +      } else if (ReturnInst *retTerm = dyn_cast<ReturnInst>(term)) {        assert(0 && "Cannot handle return instructions just yet.");        // FIXME: what if the terminator is a return!??!        // Need to rewrite: add new basic block, move the return there        // treat the original as an unconditional branch to that basicblock -    } else if (SwitchInst *swTerm = dyn_cast<SwitchInst>(term)) { -      assert(0 && "Cannot handle switch instructions just yet.");      } else if (InvokeInst *invInst = dyn_cast<InvokeInst>(term)) {        assert(0 && "Cannot handle invoke instructions just yet.");      } else { @@ -514,7 +496,8 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector<BasicBlock*> &code)    //  * Add allocas for defs, pass as args by reference    //  * Pass in uses as args    // 3) Move code region, add call instr to func -  //  +  // +  BlocksToExtract.insert(code.begin(), code.end());    Values inputs, outputs; @@ -548,19 +531,18 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector<BasicBlock*> &code)    // blocks moving to a new function.    // SOLUTION: move Phi nodes out of the loop header into the codeReplacer, pass    // the values as parameters to the function -  findInputsOutputs(code, inputs, outputs, codeReplacer, newFuncRoot); +  findInputsOutputs(inputs, outputs, codeReplacer, newFuncRoot);    // Step 2: Construct new function based on inputs/outputs,    // Add allocas for all defs    Function *newFunction = constructFunction(inputs, outputs, newFuncRoot,  -                                            codeReplacer, code,  -                                            oldFunction, module); +                                            codeReplacer, oldFunction, module);    rewritePhiNodes(newFunction, newFuncRoot); -  emitCallAndSwitchStatement(newFunction, codeReplacer, code, inputs, outputs); +  emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); -  moveCodeToFunction(code, newFunction); +  moveCodeToFunction(newFunction);    DEBUG(if (verifyFunction(*newFunction)) abort());    return newFunction; | 

