diff options
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopInterchange.cpp | 70 |
1 files changed, 40 insertions, 30 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 9addb04e81b..2978165ed8a 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -1278,8 +1278,7 @@ bool LoopInterchangeTransform::transform() { else InnerIndexVar = dyn_cast<Instruction>(InductionPHI->getIncomingValue(0)); - // Ensure that InductionPHI is the first Phi node as required by - // splitInnerLoopHeader + // Ensure that InductionPHI is the first Phi node. if (&InductionPHI->getParent()->front() != InductionPHI) InductionPHI->moveBefore(&InductionPHI->getParent()->front()); @@ -1290,8 +1289,9 @@ bool LoopInterchangeTransform::transform() { LLVM_DEBUG(dbgs() << "splitInnerLoopLatch done\n"); // Splits the inner loops phi nodes out into a separate basic block. - splitInnerLoopHeader(); - LLVM_DEBUG(dbgs() << "splitInnerLoopHeader done\n"); + BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); + SplitBlock(InnerLoopHeader, InnerLoopHeader->getFirstNonPHI(), DT, LI); + LLVM_DEBUG(dbgs() << "splitting InnerLoopHeader done\n"); } Transformed |= adjustLoopLinks(); @@ -1309,32 +1309,7 @@ void LoopInterchangeTransform::splitInnerLoopLatch(Instruction *Inc) { InnerLoopLatch = SplitBlock(InnerLoopLatchPred, Inc, DT, LI); } -void LoopInterchangeTransform::splitInnerLoopHeader() { - // Split the inner loop header out. Here make sure that the reduction PHI's - // stay in the innerloop body. - BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - SplitBlock(InnerLoopHeader, InnerLoopHeader->getFirstNonPHI(), DT, LI); - if (InnerLoopHasReduction) { - // Adjust Reduction PHI's in the block. The induction PHI must be the first - // PHI in InnerLoopHeader for this to work. - SmallVector<PHINode *, 8> PHIVec; - for (auto I = std::next(InnerLoopHeader->begin()); isa<PHINode>(I); ++I) { - PHINode *PHI = dyn_cast<PHINode>(I); - Value *V = PHI->getIncomingValueForBlock(InnerLoopPreHeader); - PHI->replaceAllUsesWith(V); - PHIVec.push_back((PHI)); - } - for (PHINode *P : PHIVec) { - P->eraseFromParent(); - } - } - - LLVM_DEBUG(dbgs() << "Output of splitInnerLoopHeader InnerLoopHeaderSucc & " - "InnerLoopHeader\n"); -} - -/// Move all instructions except the terminator from FromBB right before +/// \brief Move all instructions except the terminator from FromBB right before /// InsertBefore static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { auto &ToList = InsertBefore->getParent()->getInstList(); @@ -1470,6 +1445,41 @@ bool LoopInterchangeTransform::adjustLoopBranches() { restructureLoops(OuterLoop, InnerLoop, InnerLoopPreHeader, OuterLoopPreHeader); + // Now update the reduction PHIs in the inner and outer loop headers. + SmallVector<PHINode *, 4> InnerLoopPHIs, OuterLoopPHIs; + for (PHINode &PHI : drop_begin(InnerLoopHeader->phis(), 1)) + InnerLoopPHIs.push_back(cast<PHINode>(&PHI)); + for (PHINode &PHI : drop_begin(OuterLoopHeader->phis(), 1)) + OuterLoopPHIs.push_back(cast<PHINode>(&PHI)); + + for (PHINode *PHI : OuterLoopPHIs) + PHI->moveBefore(InnerLoopHeader->getFirstNonPHI()); + + // Move the PHI nodes from the inner loop header to the outer loop header. + // We have to deal with one kind of PHI nodes: + // 1) PHI nodes that are part of inner loop-only reductions. + // We only have to move the PHI node and update the incoming blocks. + for (PHINode *PHI : InnerLoopPHIs) { + PHI->moveBefore(OuterLoopHeader->getFirstNonPHI()); + for (BasicBlock *InBB : PHI->blocks()) { + if (InnerLoop->contains(InBB)) + continue; + + assert(!isa<PHINode>(PHI->getIncomingValueForBlock(InBB)) && + "Unexpected incoming PHI node, reductions in outer loop are not " + "supported yet"); + PHI->replaceAllUsesWith(PHI->getIncomingValueForBlock(InBB)); + PHI->eraseFromParent(); + break; + } + } + + // Update the incoming blocks for moved PHI nodes. + updateIncomingBlock(OuterLoopHeader, InnerLoopPreHeader, OuterLoopPreHeader); + updateIncomingBlock(OuterLoopHeader, InnerLoopLatch, OuterLoopLatch); + updateIncomingBlock(InnerLoopHeader, OuterLoopPreHeader, InnerLoopPreHeader); + updateIncomingBlock(InnerLoopHeader, OuterLoopLatch, InnerLoopLatch); + return true; } |