diff options
Diffstat (limited to 'clang/lib/CodeGen/CodeGenPGO.cpp')
-rw-r--r-- | clang/lib/CodeGen/CodeGenPGO.cpp | 241 |
1 files changed, 111 insertions, 130 deletions
diff --git a/clang/lib/CodeGen/CodeGenPGO.cpp b/clang/lib/CodeGen/CodeGenPGO.cpp index c90b025e551..d9016774fa1 100644 --- a/clang/lib/CodeGen/CodeGenPGO.cpp +++ b/clang/lib/CodeGen/CodeGenPGO.cpp @@ -264,6 +264,12 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { } } + /// Set and return the current count. + uint64_t setCount(uint64_t Count) { + PGO.setCurrentRegionCount(Count); + return Count; + } + void VisitStmt(const Stmt *S) { RecordStmtCount(S); for (Stmt::const_child_range I = S->children(); I; ++I) { @@ -274,9 +280,8 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { void VisitFunctionDecl(const FunctionDecl *D) { // Counter tracks entry to the function body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; Visit(D->getBody()); } @@ -287,25 +292,22 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { void VisitCapturedDecl(const CapturedDecl *D) { // Counter tracks entry to the capture body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; Visit(D->getBody()); } void VisitObjCMethodDecl(const ObjCMethodDecl *D) { // Counter tracks entry to the method body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; Visit(D->getBody()); } void VisitBlockDecl(const BlockDecl *D) { // Counter tracks entry to the block body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; Visit(D->getBody()); } @@ -334,9 +336,8 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { void VisitLabelStmt(const LabelStmt *S) { RecordNextStmtCount = false; // Counter tracks the block following the label. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); - CountMap[S] = PGO.getCurrentRegionCount(); + uint64_t BlockCount = setCount(PGO.getRegionCount(S)); + CountMap[S] = BlockCount; Visit(S->getSubStmt()); } @@ -358,52 +359,47 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { void VisitWhileStmt(const WhileStmt *S) { RecordStmtCount(S); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + uint64_t ParentCount = PGO.getCurrentRegionCount(); + BreakContinueStack.push_back(BreakContinue()); // Visit the body region first so the break/continue adjustments can be // included when visiting the condition. - Cnt.beginRegion(); + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); CountMap[S->getBody()] = PGO.getCurrentRegionCount(); Visit(S->getBody()); - Cnt.adjustForControlFlow(); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); // ...then go back and propagate counts through the condition. The count // at the start of the condition is the sum of the incoming edges, // the backedge from the end of the loop body, and the edges from // continue statements. BreakContinue BC = BreakContinueStack.pop_back_val(); - Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + setCount(BC.BreakCount + CondCount - BodyCount); RecordNextStmtCount = true; } void VisitDoStmt(const DoStmt *S) { RecordStmtCount(S); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + uint64_t LoopCount = PGO.getRegionCount(S); + BreakContinueStack.push_back(BreakContinue()); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + // The count doesn't include the fallthrough from the parent scope. Add it. + uint64_t BodyCount = setCount(LoopCount + PGO.getCurrentRegionCount()); + CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - Cnt.adjustForControlFlow(); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); BreakContinue BC = BreakContinueStack.pop_back_val(); // The count at the start of the condition is equal to the count at the - // end of the body. The adjusted count does not include either the - // fall-through count coming into the loop or the continue count, so add - // both of those separately. This is coincidentally the same equation as - // with while loops but for different reasons. - Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); + // end of the body, plus any continues. + uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + setCount(BC.BreakCount + CondCount - LoopCount); RecordNextStmtCount = true; } @@ -411,37 +407,34 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { RecordStmtCount(S); if (S->getInit()) Visit(S->getInit()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + + uint64_t ParentCount = PGO.getCurrentRegionCount(); + BreakContinueStack.push_back(BreakContinue()); // Visit the body region first. (This is basically the same as a while // loop; see further comments in VisitWhileStmt.) - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - Cnt.adjustForControlFlow(); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + BreakContinue BC = BreakContinueStack.pop_back_val(); // The increment is essentially part of the body but it needs to include // the count for all the continue statements. if (S->getInc()) { - Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + - BreakContinueStack.back().ContinueCount); - CountMap[S->getInc()] = PGO.getCurrentRegionCount(); + uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getInc()] = IncCount; Visit(S->getInc()); - Cnt.adjustForControlFlow(); } - BreakContinue BC = BreakContinueStack.pop_back_val(); - // ...then go back and propagate counts through the condition. + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); if (S->getCond()) { - Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - Cnt.adjustForControlFlow(); } - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + setCount(BC.BreakCount + CondCount - BodyCount); RecordNextStmtCount = true; } @@ -450,47 +443,47 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { Visit(S->getLoopVarStmt()); Visit(S->getRangeStmt()); Visit(S->getBeginEndStmt()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + + uint64_t ParentCount = PGO.getCurrentRegionCount(); + BreakContinueStack.push_back(BreakContinue()); // Visit the body region first. (This is basically the same as a while // loop; see further comments in VisitWhileStmt.) - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - Cnt.adjustForControlFlow(); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + BreakContinue BC = BreakContinueStack.pop_back_val(); // The increment is essentially part of the body but it needs to include // the count for all the continue statements. - Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + - BreakContinueStack.back().ContinueCount); - CountMap[S->getInc()] = PGO.getCurrentRegionCount(); + uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getInc()] = IncCount; Visit(S->getInc()); - Cnt.adjustForControlFlow(); - - BreakContinue BC = BreakContinueStack.pop_back_val(); // ...then go back and propagate counts through the condition. - Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + setCount(BC.BreakCount + CondCount - BodyCount); RecordNextStmtCount = true; } void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { RecordStmtCount(S); Visit(S->getElement()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + uint64_t ParentCount = PGO.getCurrentRegionCount(); BreakContinueStack.push_back(BreakContinue()); - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + // Counter tracks the body of the loop. + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); BreakContinue BC = BreakContinueStack.pop_back_val(); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + + setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - + BodyCount); RecordNextStmtCount = true; } @@ -505,53 +498,45 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { if (!BreakContinueStack.empty()) BreakContinueStack.back().ContinueCount += BC.ContinueCount; // Counter tracks the exit block of the switch. - RegionCounter ExitCnt(PGO, S); - ExitCnt.beginRegion(); + setCount(PGO.getRegionCount(S)); RecordNextStmtCount = true; } - void VisitCaseStmt(const CaseStmt *S) { + void VisitSwitchCase(const SwitchCase *S) { RecordNextStmtCount = false; // Counter for this particular case. This counts only jumps from the // switch header and does not include fallthrough from the case before // this one. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S] = Cnt.getCount(); - RecordNextStmtCount = true; - Visit(S->getSubStmt()); - } - - void VisitDefaultStmt(const DefaultStmt *S) { - RecordNextStmtCount = false; - // Counter for this default case. This does not include fallthrough from - // the previous case. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S] = Cnt.getCount(); + uint64_t CaseCount = PGO.getRegionCount(S); + setCount(PGO.getCurrentRegionCount() + CaseCount); + // We need the count without fallthrough in the mapping, so it's more useful + // for branch probabilities. + CountMap[S] = CaseCount; RecordNextStmtCount = true; Visit(S->getSubStmt()); } void VisitIfStmt(const IfStmt *S) { RecordStmtCount(S); - // Counter tracks the "then" part of an if statement. The count for - // the "else" part, if it exists, will be calculated from this counter. - RegionCounter Cnt(PGO, S); + uint64_t ParentCount = PGO.getCurrentRegionCount(); Visit(S->getCond()); - Cnt.beginRegion(); - CountMap[S->getThen()] = PGO.getCurrentRegionCount(); + // Counter tracks the "then" part of an if statement. The count for + // the "else" part, if it exists, will be calculated from this counter. + uint64_t ThenCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getThen()] = ThenCount; Visit(S->getThen()); - Cnt.adjustForControlFlow(); + uint64_t OutCount = PGO.getCurrentRegionCount(); + uint64_t ElseCount = ParentCount - ThenCount; if (S->getElse()) { - Cnt.beginElseRegion(); - CountMap[S->getElse()] = PGO.getCurrentRegionCount(); + setCount(ElseCount); + CountMap[S->getElse()] = ElseCount; Visit(S->getElse()); - Cnt.adjustForControlFlow(); - } - Cnt.applyAdjustmentsToRegion(0); + OutCount += PGO.getCurrentRegionCount(); + } else + OutCount += ElseCount; + setCount(OutCount); RecordNextStmtCount = true; } @@ -561,64 +546,60 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) Visit(S->getHandler(I)); // Counter tracks the continuation block of the try statement. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); + setCount(PGO.getRegionCount(S)); RecordNextStmtCount = true; } void VisitCXXCatchStmt(const CXXCatchStmt *S) { RecordNextStmtCount = false; // Counter tracks the catch statement's handler block. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); - CountMap[S] = PGO.getCurrentRegionCount(); + uint64_t CatchCount = setCount(PGO.getRegionCount(S)); + CountMap[S] = CatchCount; Visit(S->getHandlerBlock()); } void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { RecordStmtCount(E); - // Counter tracks the "true" part of a conditional operator. The - // count in the "false" part will be calculated from this counter. - RegionCounter Cnt(PGO, E); + uint64_t ParentCount = PGO.getCurrentRegionCount(); Visit(E->getCond()); - Cnt.beginRegion(); - CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount(); + // Counter tracks the "true" part of a conditional operator. The + // count in the "false" part will be calculated from this counter. + uint64_t TrueCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getTrueExpr()] = TrueCount; Visit(E->getTrueExpr()); - Cnt.adjustForControlFlow(); + uint64_t OutCount = PGO.getCurrentRegionCount(); - Cnt.beginElseRegion(); - CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount(); + uint64_t FalseCount = setCount(ParentCount - TrueCount); + CountMap[E->getFalseExpr()] = FalseCount; Visit(E->getFalseExpr()); - Cnt.adjustForControlFlow(); + OutCount += PGO.getCurrentRegionCount(); - Cnt.applyAdjustmentsToRegion(0); + setCount(OutCount); RecordNextStmtCount = true; } void VisitBinLAnd(const BinaryOperator *E) { RecordStmtCount(E); - // Counter tracks the right hand side of a logical and operator. - RegionCounter Cnt(PGO, E); + uint64_t ParentCount = PGO.getCurrentRegionCount(); Visit(E->getLHS()); - Cnt.beginRegion(); - CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); + // Counter tracks the right hand side of a logical and operator. + uint64_t RHSCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getRHS()] = RHSCount; Visit(E->getRHS()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(0); + setCount(ParentCount + RHSCount - PGO.getCurrentRegionCount()); RecordNextStmtCount = true; } void VisitBinLOr(const BinaryOperator *E) { RecordStmtCount(E); - // Counter tracks the right hand side of a logical or operator. - RegionCounter Cnt(PGO, E); + uint64_t ParentCount = PGO.getCurrentRegionCount(); Visit(E->getLHS()); - Cnt.beginRegion(); - CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); + // Counter tracks the right hand side of a logical or operator. + uint64_t RHSCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getRHS()] = RHSCount; Visit(E->getRHS()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(0); + setCount(ParentCount + RHSCount - PGO.getCurrentRegionCount()); RecordNextStmtCount = true; } }; |