diff options
Diffstat (limited to 'polly/lib/CodeGen/IslAst.cpp')
| -rw-r--r-- | polly/lib/CodeGen/IslAst.cpp | 85 |
1 files changed, 38 insertions, 47 deletions
diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp index 091083534ac..c36dea96f91 100644 --- a/polly/lib/CodeGen/IslAst.cpp +++ b/polly/lib/CodeGen/IslAst.cpp @@ -100,43 +100,31 @@ struct AstBuildUserInfo { isl_id *LastForNodeId; }; -// Print a loop annotated with OpenMP or vector pragmas. -static __isl_give isl_printer * -printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer, - __isl_take isl_ast_print_options *PrintOptions, - IslAstUserPayload *Info) { - if (Info) { - if (Info->IsInnermostParallel) { - Printer = isl_printer_start_line(Printer); - Printer = isl_printer_print_str(Printer, "#pragma simd"); - if (Info->IsReductionParallel) - Printer = isl_printer_print_str(Printer, " reduction"); - Printer = isl_printer_end_line(Printer); - } - if (Info->IsOutermostParallel) { - Printer = isl_printer_start_line(Printer); - Printer = isl_printer_print_str(Printer, "#pragma omp parallel for"); - if (Info->IsReductionParallel) - Printer = isl_printer_print_str(Printer, " reduction"); - Printer = isl_printer_end_line(Printer); - } - } - return isl_ast_node_for_print(Node, Printer, PrintOptions); +/// @brief Print a string @p str in a single line using @p Printer. +static isl_printer *printLine(__isl_take isl_printer *Printer, + const std::string &str) { + Printer = isl_printer_start_line(Printer); + Printer = isl_printer_print_str(Printer, str.c_str()); + return isl_printer_end_line(Printer); } -// Print an isl_ast_for. -static __isl_give isl_printer * -printFor(__isl_take isl_printer *Printer, - __isl_take isl_ast_print_options *PrintOptions, - __isl_keep isl_ast_node *Node, void *User) { - isl_id *Id = isl_ast_node_get_annotation(Node); - if (!Id) - return isl_ast_node_for_print(Node, Printer, PrintOptions); +/// @brief Callback executed for each for node in the ast in order to print it. +static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, + __isl_take isl_ast_print_options *Options, + __isl_keep isl_ast_node *Node, void *) { + if (IslAstInfo::isInnermostParallel(Node)) + Printer = printLine(Printer, "#pragma simd"); - IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id); - Printer = printParallelFor(Node, Printer, PrintOptions, Info); - isl_id_free(Id); - return Printer; + if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node)) + Printer = printLine(Printer, "#pragma simd reduction"); + + if (IslAstInfo::isOuterParallel(Node)) + Printer = printLine(Printer, "#pragma omp parallel for"); + + if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node)) + Printer = printLine(Printer, "#pragma omp parallel for reduction"); + + return isl_ast_node_for_print(Node, Printer, Options); } /// @brief Check if the current scheduling dimension is parallel @@ -219,18 +207,16 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id); AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; - bool IsInnermost = (Id == BuildInfo->LastForNodeId); - - if (Info) { - if (Info->IsOutermostParallel) - BuildInfo->InParallelFor = 0; - if (IsInnermost) - if (astScheduleDimIsParallel(Build, BuildInfo->Deps, - Info->IsReductionParallel)) - Info->IsInnermostParallel = 1; - if (!Info->Build) - Info->Build = isl_ast_build_copy(Build); - } + Info->IsInnermost = (Id == BuildInfo->LastForNodeId); + + if (Info->IsOutermostParallel) + BuildInfo->InParallelFor = 0; + if (Info->IsInnermost) + if (astScheduleDimIsParallel(Build, BuildInfo->Deps, + Info->IsReductionParallel)) + Info->IsInnermostParallel = 1; + if (!Info->Build) + Info->Build = isl_ast_build_copy(Build); isl_id_free(Id); return Node; @@ -356,6 +342,11 @@ IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) { return Payload; } +bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + return Payload && Payload->IsInnermost; +} + bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) { return (isInnermostParallel(Node) || isOuterParallel(Node)) && !isReductionParallel(Node); @@ -391,7 +382,7 @@ void IslAstInfo::printScop(raw_ostream &OS) const { Scop &S = getCurScop(); Options = isl_ast_print_options_alloc(S.getIslCtx()); - Options = isl_ast_print_options_set_print_for(Options, printFor, nullptr); + Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr); isl_printer *P = isl_printer_to_str(S.getIslCtx()); P = isl_printer_print_ast_expr(P, RunCondition); |

