diff options
| author | Johannes Doerfert <jdoerfert@codeaurora.org> | 2014-08-01 08:17:19 +0000 |
|---|---|---|
| committer | Johannes Doerfert <jdoerfert@codeaurora.org> | 2014-08-01 08:17:19 +0000 |
| commit | dc6ad99aada4a2f32a5303ab1d169bf45afea5c2 (patch) | |
| tree | 69e43ebe75da8a2b60dbf3c64a1f4161d82025b5 /polly/lib/CodeGen/IslAst.cpp | |
| parent | ed67f8baf6847712aab6437050341e00975e1221 (diff) | |
| download | bcm5719-llvm-dc6ad99aada4a2f32a5303ab1d169bf45afea5c2.tar.gz bcm5719-llvm-dc6ad99aada4a2f32a5303ab1d169bf45afea5c2.zip | |
Annotate the IslAst with broken reductions
+ Split all reduction dependences and map them to the causing memory accesses.
+ Print the types & base addresses of broken reductions for each "reduction
parallel" marked loop (OpenMP style).
+ 3 test cases to show how reductions are now represented in the isl ast.
The mapping "(ast) loops -> broken reductions" is also needed to find the
memory accesses we need to privatize in a loop.
llvm-svn: 214489
Diffstat (limited to 'polly/lib/CodeGen/IslAst.cpp')
| -rw-r--r-- | polly/lib/CodeGen/IslAst.cpp | 79 |
1 files changed, 61 insertions, 18 deletions
diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp index b9da4479463..203c4dc709c 100644 --- a/polly/lib/CodeGen/IslAst.cpp +++ b/polly/lib/CodeGen/IslAst.cpp @@ -108,23 +108,48 @@ static isl_printer *printLine(__isl_take isl_printer *Printer, return isl_printer_end_line(Printer); } +/// @brief Return all broken reductions as a string of clauses (OpenMP style). +static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) { + IslAstInfo::MemoryAccessSet *BrokenReductions; + std::string str; + + BrokenReductions = IslAstInfo::getBrokenReductions(Node); + if (!BrokenReductions || BrokenReductions->empty()) + return ""; + + // Map each type of reduction to a comma separated list of the base addresses. + std::map<MemoryAccess::ReductionType, std::string> Clauses; + for (MemoryAccess *MA : *BrokenReductions) + if (MA->isWrite()) + Clauses[MA->getReductionType()] += + ", " + MA->getBaseAddr()->getName().str(); + + // Now print the reductions sorted by type. Each type will cause a clause + // like: reduction (+ : sum0, sum1, sum2) + for (const auto &ReductionClause : Clauses) { + str += " reduction ("; + str += MemoryAccess::getReductionOperatorStr(ReductionClause.first); + // Remove the first two symbols (", ") to make the output look pretty. + str += " : " + ReductionClause.second.substr(2) + ")"; + } + + return str; +} + /// @brief Callback executed for each for node in the ast in order to print it. static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, __isl_take isl_ast_print_options *Options, __isl_keep isl_ast_node *Node, void *) { - if (IslAstInfo::isInnermostParallel(Node) && - !IslAstInfo::isReductionParallel(Node)) - Printer = printLine(Printer, "#pragma simd"); - if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node)) - Printer = printLine(Printer, "#pragma simd reduction"); + const std::string BrokenReductionsStr = getBrokenReductionsStr(Node); + const std::string SimdPragmaStr = "#pragma simd"; + const std::string OmpPragmaStr = "#pragma omp parallel for"; - if (IslAstInfo::isOutermostParallel(Node) && - !IslAstInfo::isReductionParallel(Node)) - Printer = printLine(Printer, "#pragma omp parallel for"); + if (IslAstInfo::isInnermostParallel(Node)) + Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr); - if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node)) - Printer = printLine(Printer, "#pragma omp parallel for reduction"); + if (IslAstInfo::isOutermostParallel(Node)) + Printer = printLine(Printer, OmpPragmaStr + BrokenReductionsStr); return isl_ast_node_for_print(Node, Printer, Options); } @@ -141,7 +166,7 @@ static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, /// (or non-zero) dependence distance on the dimension in question. static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, Dependences *D, - bool &IsReductionParallel) { + IslAstUserPayload *NodeInfo) { if (!D->hasValidDependences()) return false; @@ -153,7 +178,20 @@ static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED); if (!D->isParallel(Schedule, RedDeps)) - IsReductionParallel = true; + NodeInfo->IsReductionParallel = true; + + if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule)) + return true; + + // Annotate reduction parallel nodes with the memory accesses which caused the + // reduction dependences parallel execution of the node conflicts with. + for (const auto &MaRedPair : D->getReductionDependences()) { + if (!MaRedPair.second) + continue; + RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second)); + if (!D->isParallel(Schedule, RedDeps)) + NodeInfo->BrokenReductions.insert(MaRedPair.first); + } isl_union_map_free(Schedule); return true; @@ -177,8 +215,7 @@ static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, // Test for parallelism only if we are not already inside a parallel loop if (!BuildInfo->InParallelFor) BuildInfo->InParallelFor = Payload->IsOutermostParallel = - astScheduleDimIsParallel(Build, BuildInfo->Deps, - Payload->IsReductionParallel); + astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); return Id; } @@ -206,13 +243,13 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, // Innermost loops that are surrounded by parallel loops have not yet been // tested for parallelism. Test them here to ensure we check all innermost // loops for parallelism. - if (Payload->IsInnermost && BuildInfo->InParallelFor) + if (Payload->IsInnermost && BuildInfo->InParallelFor) { if (Payload->IsOutermostParallel) Payload->IsInnermostParallel = true; else - Payload->IsInnermostParallel = astScheduleDimIsParallel( - Build, BuildInfo->Deps, Payload->IsReductionParallel); - else if (Payload->IsOutermostParallel) + Payload->IsInnermostParallel = + astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); + } else if (Payload->IsOutermostParallel) BuildInfo->InParallelFor = false; isl_id_free(Id); @@ -370,6 +407,12 @@ isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) { return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr; } +IslAstInfo::MemoryAccessSet * +IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + return Payload ? &Payload->BrokenReductions : nullptr; +} + void IslAstInfo::printScop(raw_ostream &OS) const { isl_ast_print_options *Options; isl_ast_node *RootNode = getAst(); |

