diff options
| author | Chris Lattner <sabre@nondot.org> | 2006-03-04 09:31:13 +0000 | 
|---|---|---|
| committer | Chris Lattner <sabre@nondot.org> | 2006-03-04 09:31:13 +0000 | 
| commit | 4c065091d850e52803d8312859490418cc041fe6 (patch) | |
| tree | 9dacdb100a0f4428828d1b8fda63f51aa4291fa8 | |
| parent | c9a318d8fa9b0e155b1ea13cf48148600d4cb42a (diff) | |
| download | bcm5719-llvm-4c065091d850e52803d8312859490418cc041fe6.tar.gz bcm5719-llvm-4c065091d850e52803d8312859490418cc041fe6.zip  | |
Add factoring of multiplications, e.g. turning A*A+A*B into A*(A+B).
Testcase here: Transforms/Reassociate/mulfactor.ll
llvm-svn: 26524
| -rw-r--r-- | llvm/lib/Transforms/Scalar/Reassociate.cpp | 235 | 
1 files changed, 186 insertions, 49 deletions
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index 41faae74962..61c5c4953c1 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -41,6 +41,7 @@ namespace {    Statistic<> NumChanged("reassociate","Number of insts reassociated");    Statistic<> NumSwapped("reassociate","Number of insts with operands swapped");    Statistic<> NumAnnihil("reassociate","Number of expr tree annihilated"); +  Statistic<> NumFactor ("reassociate","Number of multiplies factored");    struct ValueEntry {      unsigned Rank; @@ -50,7 +51,20 @@ namespace {    inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) {      return LHS.Rank > RHS.Rank;   // Sort so that highest rank goes to start.    } +} +/// PrintOps - Print out the expression identified in the Ops list. +/// +static void PrintOps(Instruction *I, const std::vector<ValueEntry> &Ops) { +  Module *M = I->getParent()->getParent()->getParent(); +  std::cerr << Instruction::getOpcodeName(I->getOpcode()) << " " +  << *Ops[0].Op->getType(); +  for (unsigned i = 0, e = Ops.size(); i != e; ++i) +    WriteAsOperand(std::cerr << " ", Ops[i].Op, false, true, M) +      << "," << Ops[i].Rank; +} +   +namespace {      class Reassociate : public FunctionPass {      std::map<BasicBlock*, unsigned> RankMap;      std::map<Value*, unsigned> ValueRankMap; @@ -66,10 +80,13 @@ namespace {      unsigned getRank(Value *V);      void RewriteExprTree(BinaryOperator *I, unsigned Idx,                           std::vector<ValueEntry> &Ops); -    void OptimizeExpression(unsigned Opcode, std::vector<ValueEntry> &Ops); +    Value *OptimizeExpression(BinaryOperator *I, std::vector<ValueEntry> &Ops);      void LinearizeExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops);      void LinearizeExpr(BinaryOperator *I); +    Value *RemoveFactorFromExpression(Value *V, Value *Factor);      void ReassociateBB(BasicBlock *BB); +     +    void RemoveDeadBinaryOp(Value *V);    };    RegisterOpt<Reassociate> X("reassociate", "Reassociate expressions"); @@ -78,6 +95,15 @@ namespace {  // Public interface to the Reassociate pass  FunctionPass *llvm::createReassociatePass() { return new Reassociate(); } +void Reassociate::RemoveDeadBinaryOp(Value *V) { +  BinaryOperator *BOp = dyn_cast<BinaryOperator>(V); +  if (!BOp || !BOp->use_empty()) return; +   +  Value *LHS = BOp->getOperand(0), *RHS = BOp->getOperand(1); +  RemoveDeadBinaryOp(LHS); +  RemoveDeadBinaryOp(RHS); +} +  static bool isUnmovableInstruction(Instruction *I) {    if (I->getOpcode() == Instruction::PHI || @@ -207,9 +233,6 @@ void Reassociate::LinearizeExpr(BinaryOperator *I) {  /// form of the the expression (((a+b)+c)+d), and collects information about the  /// rank of the non-tree operands.  /// -/// This returns the rank of the RHS operand, which is known to be the highest -/// rank value in the expression tree. -///  void Reassociate::LinearizeExprTree(BinaryOperator *I,                                      std::vector<ValueEntry> &Ops) {    Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); @@ -279,12 +302,17 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, unsigned i,    if (i+2 == Ops.size()) {      if (I->getOperand(0) != Ops[i].Op ||          I->getOperand(1) != Ops[i+1].Op) { +      Value *OldLHS = I->getOperand(0);        DEBUG(std::cerr << "RA: " << *I);        I->setOperand(0, Ops[i].Op);        I->setOperand(1, Ops[i+1].Op);        DEBUG(std::cerr << "TO: " << *I);        MadeChange = true;        ++NumChanged; +       +      // If we reassociated a tree to fewer operands (e.g. (1+a+2) -> (a+3) +      // delete the extra, now dead, nodes. +      RemoveDeadBinaryOp(OldLHS);      }      return;    } @@ -297,7 +325,15 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, unsigned i,      MadeChange = true;      ++NumChanged;    } -  RewriteExprTree(cast<BinaryOperator>(I->getOperand(0)), i+1, Ops); +   +  BinaryOperator *LHS = cast<BinaryOperator>(I->getOperand(0)); +  assert(LHS->getOpcode() == I->getOpcode() && +         "Improper expression tree!"); +   +  // Compactify the tree instructions together with each other to guarantee +  // that the expression tree is dominated by all of Ops. +  LHS->moveBefore(I); +  RewriteExprTree(LHS, i+1, Ops);  } @@ -405,19 +441,57 @@ static unsigned FindInOperandList(std::vector<ValueEntry> &Ops, unsigned i,    return i;  } -void Reassociate::OptimizeExpression(unsigned Opcode, -                                     std::vector<ValueEntry> &Ops) { +/// EmitAddTreeOfValues - Emit a tree of add instructions, summing Ops together +/// and returning the result.  Insert the tree before I. +static Value *EmitAddTreeOfValues(Instruction *I, std::vector<Value*> &Ops) { +  if (Ops.size() == 1) return Ops.back(); +   +  Value *V1 = Ops.back(); +  Ops.pop_back(); +  Value *V2 = EmitAddTreeOfValues(I, Ops); +  return BinaryOperator::createAdd(V2, V1, "tmp", I); +} + +/// RemoveFactorFromExpression - If V is an expression tree that is a  +/// multiplication sequence, and if this sequence contains a multiply by Factor, +/// remove Factor from the tree and return the new tree. +Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { +  BinaryOperator *BO = isReassociableOp(V, Instruction::Mul); +  if (!BO) return 0; +   +  std::vector<ValueEntry> Factors; +  LinearizeExprTree(BO, Factors); + +  bool FoundFactor = false; +  for (unsigned i = 0, e = Factors.size(); i != e; ++i) +    if (Factors[i].Op == Factor) { +      FoundFactor = true; +      Factors.erase(Factors.begin()+i); +      break; +    } +  if (!FoundFactor) return 0; +   +  if (Factors.size() == 1) return Factors[0].Op; +   +  RewriteExprTree(BO, 0, Factors); +  return BO; +} + + +Value *Reassociate::OptimizeExpression(BinaryOperator *I, +                                       std::vector<ValueEntry> &Ops) {    // Now that we have the linearized expression tree, try to optimize it.    // Start by folding any constants that we found.    bool IterateOptimization = false; -  if (Ops.size() == 1) return; +  if (Ops.size() == 1) return Ops[0].Op; +  unsigned Opcode = I->getOpcode(); +      if (Constant *V1 = dyn_cast<Constant>(Ops[Ops.size()-2].Op))      if (Constant *V2 = dyn_cast<Constant>(Ops.back().Op)) {        Ops.pop_back();        Ops.back().Op = ConstantExpr::get(Opcode, V1, V2); -      OptimizeExpression(Opcode, Ops); -      return; +      return OptimizeExpression(I, Ops);      }    // Check for destructive annihilation due to a constant being used. @@ -426,30 +500,24 @@ void Reassociate::OptimizeExpression(unsigned Opcode,      default: break;      case Instruction::And:        if (CstVal->isNullValue()) {           // ... & 0 -> 0 -        Ops[0].Op = CstVal; -        Ops.erase(Ops.begin()+1, Ops.end());          ++NumAnnihil; -        return; +        return CstVal;        } else if (CstVal->isAllOnesValue()) { // ... & -1 -> ...          Ops.pop_back();        }        break;      case Instruction::Mul:        if (CstVal->isNullValue()) {           // ... * 0 -> 0 -        Ops[0].Op = CstVal; -        Ops.erase(Ops.begin()+1, Ops.end());          ++NumAnnihil; -        return; +        return CstVal;        } else if (cast<ConstantInt>(CstVal)->getRawValue() == 1) {          Ops.pop_back();                      // ... * 1 -> ...        }        break;      case Instruction::Or:        if (CstVal->isAllOnesValue()) {        // ... | -1 -> -1 -        Ops[0].Op = CstVal; -        Ops.erase(Ops.begin()+1, Ops.end());          ++NumAnnihil; -        return; +        return CstVal;        }        // FALLTHROUGH!      case Instruction::Add: @@ -458,7 +526,7 @@ void Reassociate::OptimizeExpression(unsigned Opcode,          Ops.pop_back();        break;      } -  if (Ops.size() == 1) return; +  if (Ops.size() == 1) return Ops[0].Op;    // Handle destructive annihilation do to identities between elements in the    // argument list here. @@ -477,15 +545,11 @@ void Reassociate::OptimizeExpression(unsigned Opcode,          unsigned FoundX = FindInOperandList(Ops, i, X);          if (FoundX != i) {            if (Opcode == Instruction::And) {   // ...&X&~X = 0 -            Ops[0].Op = Constant::getNullValue(X->getType()); -            Ops.erase(Ops.begin()+1, Ops.end());              ++NumAnnihil; -            return; +            return Constant::getNullValue(X->getType());            } else if (Opcode == Instruction::Or) {   // ...|X|~X = -1 -            Ops[0].Op = ConstantIntegral::getAllOnesValue(X->getType()); -            Ops.erase(Ops.begin()+1, Ops.end());              ++NumAnnihil; -            return; +            return ConstantIntegral::getAllOnesValue(X->getType());            }          }        } @@ -503,10 +567,8 @@ void Reassociate::OptimizeExpression(unsigned Opcode,          } else {            assert(Opcode == Instruction::Xor);            if (e == 2) { -            Ops[0].Op = Constant::getNullValue(Ops[0].Op->getType()); -            Ops.erase(Ops.begin()+1, Ops.end());              ++NumAnnihil; -            return; +            return Constant::getNullValue(Ops[0].Op->getType());            }            // ... X^X -> ...            Ops.erase(Ops.begin()+i, Ops.begin()+i+2); @@ -520,7 +582,7 @@ void Reassociate::OptimizeExpression(unsigned Opcode,    case Instruction::Add:      // Scan the operand lists looking for X and -X pairs.  If we find any, we -    // can simplify the expression. X+-X == 0 +    // can simplify the expression. X+-X == 0.      for (unsigned i = 0, e = Ops.size(); i != e; ++i) {        assert(i < Ops.size());        // Check for X and -X in the operand list. @@ -530,10 +592,8 @@ void Reassociate::OptimizeExpression(unsigned Opcode,          if (FoundX != i) {            // Remove X and -X from the operand list.            if (Ops.size() == 2) { -            Ops[0].Op = Constant::getNullValue(X->getType()); -            Ops.pop_back();              ++NumAnnihil; -            return; +            return Constant::getNullValue(X->getType());            } else {              Ops.erase(Ops.begin()+i);              if (i < FoundX) @@ -549,30 +609,99 @@ void Reassociate::OptimizeExpression(unsigned Opcode,          }        }      } +     + +    // Scan the operand list, checking to see if there are any common factors +    // between operands.  Consider something like A*A+A*B*C+D.  We would like to +    // reassociate this to A*(A+B*C)+D, which reduces the number of multiplies. +    // To efficiently find this, we count the number of times a factor occurs +    // for any ADD operands that are MULs. +    std::map<Value*, unsigned> FactorOccurrences; +    unsigned MaxOcc = 0; +    Value *MaxOccVal = 0; +    if (!I->getType()->isFloatingPoint()) { +      for (unsigned i = 0, e = Ops.size(); i != e; ++i) { +        if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(Ops[i].Op)) +          if (BOp->getOpcode() == Instruction::Mul && BOp->hasOneUse()) { +            // Compute all of the factors of this added value. +            std::vector<ValueEntry> Factors; +            LinearizeExprTree(BOp, Factors); +            assert(Factors.size() > 1 && "Bad linearize!"); +             +            // Add one to FactorOccurrences for each unique factor in this op. +            if (Factors.size() == 2) { +              unsigned Occ = ++FactorOccurrences[Factors[0].Op]; +              if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0].Op; } +              if (Factors[0].Op != Factors[1].Op) {   // Don't double count A*A. +                Occ = ++FactorOccurrences[Factors[1].Op]; +                if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1].Op; } +              } +            } else { +              std::set<Value*> Duplicates; +              for (unsigned i = 0, e = Factors.size(); i != e; ++i) +                if (Duplicates.insert(Factors[i].Op).second) { +                  unsigned Occ = ++FactorOccurrences[Factors[i].Op]; +                  if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i].Op; } +                } +            } +          } +      } +    } + +    // If any factor occurred more than one time, we can pull it out. +    if (MaxOcc > 1) { +      DEBUG(std::cerr << "\nFACTORING [" << MaxOcc << "]: " +                      << *MaxOccVal << "\n"); +       +      // Create a new instruction that uses the MaxOccVal twice.  If we don't do +      // this, we could otherwise run into situations where removing a factor +      // from an expression will drop a use of maxocc, and this can cause  +      // RemoveFactorFromExpression on successive values to behave differently. +      Instruction *DummyInst = BinaryOperator::createAdd(MaxOccVal, MaxOccVal); +      std::vector<Value*> NewMulOps; +      for (unsigned i = 0, e = Ops.size(); i != e; ++i) { +        if (Value *V = RemoveFactorFromExpression(Ops[i].Op, MaxOccVal)) { +          NewMulOps.push_back(V); +          Ops.erase(Ops.begin()+i); +          --i; --e; +        } +      } +       +      // No need for extra uses anymore. +      delete DummyInst; + +      Value *V = EmitAddTreeOfValues(I, NewMulOps); +      // FIXME: Must optimize V now, to handle this case: +      // A*A*B + A*A*C -> A*(A*B+A*C)   -> A*(A*(B+C)) +      V = BinaryOperator::createMul(V, MaxOccVal, "tmp", I); + +      ++NumFactor; +       +      if (Ops.size() == 0) +        return V; + +      // Add the new value to the list of things being added. +      Ops.insert(Ops.begin(), ValueEntry(getRank(V), V)); +       +      // Rewrite the tree so that there is now a use of V. +      RewriteExprTree(I, 0, Ops); +      return OptimizeExpression(I, Ops); +    }      break;    //case Instruction::Mul:    }    if (IterateOptimization) -    OptimizeExpression(Opcode, Ops); +    return OptimizeExpression(I, Ops); +  return 0;  } -/// PrintOps - Print out the expression identified in the Ops list. -/// -static void PrintOps(unsigned Opcode, const std::vector<ValueEntry> &Ops, -                     BasicBlock *BB) { -  Module *M = BB->getParent()->getParent(); -  std::cerr << Instruction::getOpcodeName(Opcode) << " " -            << *Ops[0].Op->getType(); -  for (unsigned i = 0, e = Ops.size(); i != e; ++i) -    WriteAsOperand(std::cerr << " ", Ops[i].Op, false, true, M) -      << "," << Ops[i].Rank; -}  /// ReassociateBB - Inspect all of the instructions in this basic block,  /// reassociating them as we go.  void Reassociate::ReassociateBB(BasicBlock *BB) { -  for (BasicBlock::iterator BI = BB->begin(); BI != BB->end(); ++BI) { +  for (BasicBlock::iterator BBI = BB->begin(); BBI != BB->end(); ) { +    Instruction *BI = BBI++;      if (BI->getOpcode() == Instruction::Shl &&          isa<ConstantInt>(BI->getOperand(1)))        if (Instruction *NI = ConvertShiftToMul(BI)) { @@ -623,7 +752,7 @@ void Reassociate::ReassociateBB(BasicBlock *BB) {      std::vector<ValueEntry> Ops;      LinearizeExprTree(I, Ops); -    DEBUG(std::cerr << "RAIn:\t"; PrintOps(I->getOpcode(), Ops, BB); +    DEBUG(std::cerr << "RAIn:\t"; PrintOps(I, Ops);            std::cerr << "\n");      // Now that we have linearized the tree to a list and have gathered all of @@ -636,7 +765,14 @@ void Reassociate::ReassociateBB(BasicBlock *BB) {      // OptimizeExpression - Now that we have the expression tree in a convenient      // sorted form, optimize it globally if possible. -    OptimizeExpression(I->getOpcode(), Ops); +    if (Value *V = OptimizeExpression(I, Ops)) { +      // This expression tree simplified to something that isn't a tree, +      // eliminate it. +      DEBUG(std::cerr << "Reassoc to scalar: " << *V << "\n"); +      I->replaceAllUsesWith(V); +      RemoveDeadBinaryOp(I); +      continue; +    }      // We want to sink immediates as deeply as possible except in the case where      // this is a multiply tree used only by an add, and the immediate is a -1. @@ -650,13 +786,14 @@ void Reassociate::ReassociateBB(BasicBlock *BB) {        Ops.pop_back();      } -    DEBUG(std::cerr << "RAOut:\t"; PrintOps(I->getOpcode(), Ops, BB); +    DEBUG(std::cerr << "RAOut:\t"; PrintOps(I, Ops);            std::cerr << "\n");      if (Ops.size() == 1) {        // This expression tree simplified to something that isn't a tree,        // eliminate it.        I->replaceAllUsesWith(Ops[0].Op); +      RemoveDeadBinaryOp(I);      } else {        // Now that we ordered and optimized the expressions, splat them back into        // the expression tree, removing any unneeded nodes.  | 

