diff options
| author | Johannes Doerfert <jdoerfert@codeaurora.org> | 2014-08-01 08:17:19 +0000 |
|---|---|---|
| committer | Johannes Doerfert <jdoerfert@codeaurora.org> | 2014-08-01 08:17:19 +0000 |
| commit | dc6ad99aada4a2f32a5303ab1d169bf45afea5c2 (patch) | |
| tree | 69e43ebe75da8a2b60dbf3c64a1f4161d82025b5 | |
| parent | ed67f8baf6847712aab6437050341e00975e1221 (diff) | |
| download | bcm5719-llvm-dc6ad99aada4a2f32a5303ab1d169bf45afea5c2.tar.gz bcm5719-llvm-dc6ad99aada4a2f32a5303ab1d169bf45afea5c2.zip | |
Annotate the IslAst with broken reductions
+ Split all reduction dependences and map them to the causing memory accesses.
+ Print the types & base addresses of broken reductions for each "reduction
parallel" marked loop (OpenMP style).
+ 3 test cases to show how reductions are now represented in the isl ast.
The mapping "(ast) loops -> broken reductions" is also needed to find the
memory accesses we need to privatize in a loop.
llvm-svn: 214489
| -rw-r--r-- | polly/include/polly/CodeGen/IslAst.h | 12 | ||||
| -rwxr-xr-x | polly/include/polly/Dependences.h | 18 | ||||
| -rw-r--r-- | polly/lib/Analysis/Dependences.cpp | 50 | ||||
| -rw-r--r-- | polly/lib/CodeGen/IslAst.cpp | 79 |
4 files changed, 140 insertions, 19 deletions
diff --git a/polly/include/polly/CodeGen/IslAst.h b/polly/include/polly/CodeGen/IslAst.h index c30d0c2d03b..d7c2379197a 100644 --- a/polly/include/polly/CodeGen/IslAst.h +++ b/polly/include/polly/CodeGen/IslAst.h @@ -34,15 +34,19 @@ class raw_ostream; struct isl_ast_node; struct isl_ast_expr; struct isl_ast_build; +struct isl_union_map; struct isl_pw_multi_aff; namespace polly { class Scop; class IslAst; +class MemoryAccess; class IslAstInfo : public ScopPass { public: - /// @brief Payload information used to annoate an ast node. + using MemoryAccessSet = SmallPtrSet<MemoryAccess *, 4>; + + /// @brief Payload information used to annotate an AST node. struct IslAstUserPayload { /// @brief Construct and initialize the payload. IslAstUserPayload() @@ -67,6 +71,9 @@ public: /// @brief The build environment at the time this node was constructed. isl_ast_build *Build; + + /// @brief Set of accesses which break reduction dependences. + MemoryAccessSet BrokenReductions; }; private: @@ -119,6 +126,9 @@ public: /// @brief Get the nodes schedule or a nullptr if not available. static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node); + /// @brief Get the nodes broken reductions or a nullptr if not available. + static MemoryAccessSet *getBrokenReductions(__isl_keep isl_ast_node *Node); + ///} virtual void getAnalysisUsage(AnalysisUsage &AU) const; diff --git a/polly/include/polly/Dependences.h b/polly/include/polly/Dependences.h index 56f864dc708..3eb836247f5 100755 --- a/polly/include/polly/Dependences.h +++ b/polly/include/polly/Dependences.h @@ -40,6 +40,7 @@ namespace polly { class Scop; class ScopStmt; +class MemoryAccess; class Dependences : public ScopPass { public: @@ -105,6 +106,16 @@ public: /// @brief Report if valid dependences are available. bool hasValidDependences(); + /// @brief Return the reduction dependences caused by @p MA. + /// + /// @return The reduction dependences caused by @p MA or nullptr if None. + __isl_give isl_map *getReductionDependences(MemoryAccess *MA); + + /// @brief Return the reduction dependences mapped by the causing @p MA. + const DenseMap<MemoryAccess *, isl_map *> &getReductionDependences() const { + return ReductionDependences; + } + bool runOnScop(Scop &S); void printScop(raw_ostream &OS) const; virtual void releaseMemory(); @@ -122,6 +133,9 @@ private: /// @brief The (reverse) transitive closure of reduction dependences isl_union_map *TC_RED = nullptr; + /// @brief Map from memory accesses to their reduction dependences. + DenseMap<MemoryAccess *, isl_map *> ReductionDependences; + /// @brief Collect information about the SCoP. void collectInfo(Scop &S, isl_union_map **Read, isl_union_map **Write, isl_union_map **MayWrite, isl_union_map **AccessSchedule, @@ -132,6 +146,10 @@ private: /// @brief Calculate the dependences for a certain SCoP. void calculateDependences(Scop &S); + + /// @brief Set the reduction dependences for @p MA to @p Deps. + void setReductionDependences(MemoryAccess *MA, __isl_take isl_map *Deps); + }; } // End polly namespace. diff --git a/polly/lib/Analysis/Dependences.cpp b/polly/lib/Analysis/Dependences.cpp index 715b88b2ec9..58bd71c8e0d 100644 --- a/polly/lib/Analysis/Dependences.cpp +++ b/polly/lib/Analysis/Dependences.cpp @@ -355,6 +355,42 @@ void Dependences::calculateDependences(Scop &S) { DEBUG(dbgs() << "Final Wrapped Dependences:\n"; printScop(dbgs()); dbgs() << "\n"); + // RED_SIN is used to collect all reduction dependences again after we + // split them according to the causing memory accesses. The current assumption + // is that our method of splitting will not have any leftovers. In the end + // we validate this assumption until we have more confidence in this method. + isl_union_map *RED_SIN = isl_union_map_empty(isl_union_map_get_space(RAW)); + + // For each reduction like memory access, check if there are reduction + // dependences with the access relation of the memory access as a domain + // (wrapped space!). If so these dependences are caused by this memory access. + // We then move this portion of reduction dependences back to the statement -> + // statement space and add a mapping from the memory access to these + // dependences. + for (ScopStmt *Stmt : S) { + for (MemoryAccess *MA : *Stmt) { + if (!MA->isReductionLike()) + continue; + + isl_set *AccDomW = isl_map_wrap(MA->getAccessRelation()); + isl_union_map *AccRedDepU = isl_union_map_intersect_domain( + isl_union_map_copy(TC_RED), isl_union_set_from_set(AccDomW)); + if (isl_union_map_is_empty(AccRedDepU) && !isl_union_map_free(AccRedDepU)) + continue; + + isl_map *AccRedDep = isl_map_from_union_map(AccRedDepU); + RED_SIN = isl_union_map_add_map(RED_SIN, isl_map_copy(AccRedDep)); + AccRedDep = isl_map_zip(AccRedDep); + AccRedDep = isl_set_unwrap(isl_map_domain(AccRedDep)); + setReductionDependences(MA, AccRedDep); + } + } + + assert(isl_union_map_is_equal(RED_SIN, TC_RED) && + "Intersecting the reduction dependence domain with the wrapped access " + "relation is not enough, we need to loosen the access relation also"); + isl_union_map_free(RED_SIN); + RAW = isl_union_map_zip(RAW); WAW = isl_union_map_zip(WAW); WAR = isl_union_map_zip(WAR); @@ -506,6 +542,10 @@ void Dependences::releaseMemory() { isl_union_map_free(TC_RED); RED = RAW = WAR = WAW = TC_RED = nullptr; + + for (auto &ReductionDeps : ReductionDependences) + isl_map_free(ReductionDeps.second); + ReductionDependences.clear(); } isl_union_map *Dependences::getDependences(int Kinds) { @@ -537,6 +577,16 @@ bool Dependences::hasValidDependences() { return (RAW != nullptr) && (WAR != nullptr) && (WAW != nullptr); } +isl_map *Dependences::getReductionDependences(MemoryAccess *MA) { + return isl_map_copy(ReductionDependences[MA]); +} + +void Dependences::setReductionDependences(MemoryAccess *MA, isl_map *D) { + assert(ReductionDependences.count(MA) == 0 && + "Reduction dependences set twice!"); + ReductionDependences[MA] = D; +} + void Dependences::getAnalysisUsage(AnalysisUsage &AU) const { ScopPass::getAnalysisUsage(AU); } diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp index b9da4479463..203c4dc709c 100644 --- a/polly/lib/CodeGen/IslAst.cpp +++ b/polly/lib/CodeGen/IslAst.cpp @@ -108,23 +108,48 @@ static isl_printer *printLine(__isl_take isl_printer *Printer, return isl_printer_end_line(Printer); } +/// @brief Return all broken reductions as a string of clauses (OpenMP style). +static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) { + IslAstInfo::MemoryAccessSet *BrokenReductions; + std::string str; + + BrokenReductions = IslAstInfo::getBrokenReductions(Node); + if (!BrokenReductions || BrokenReductions->empty()) + return ""; + + // Map each type of reduction to a comma separated list of the base addresses. + std::map<MemoryAccess::ReductionType, std::string> Clauses; + for (MemoryAccess *MA : *BrokenReductions) + if (MA->isWrite()) + Clauses[MA->getReductionType()] += + ", " + MA->getBaseAddr()->getName().str(); + + // Now print the reductions sorted by type. Each type will cause a clause + // like: reduction (+ : sum0, sum1, sum2) + for (const auto &ReductionClause : Clauses) { + str += " reduction ("; + str += MemoryAccess::getReductionOperatorStr(ReductionClause.first); + // Remove the first two symbols (", ") to make the output look pretty. + str += " : " + ReductionClause.second.substr(2) + ")"; + } + + return str; +} + /// @brief Callback executed for each for node in the ast in order to print it. static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, __isl_take isl_ast_print_options *Options, __isl_keep isl_ast_node *Node, void *) { - if (IslAstInfo::isInnermostParallel(Node) && - !IslAstInfo::isReductionParallel(Node)) - Printer = printLine(Printer, "#pragma simd"); - if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node)) - Printer = printLine(Printer, "#pragma simd reduction"); + const std::string BrokenReductionsStr = getBrokenReductionsStr(Node); + const std::string SimdPragmaStr = "#pragma simd"; + const std::string OmpPragmaStr = "#pragma omp parallel for"; - if (IslAstInfo::isOutermostParallel(Node) && - !IslAstInfo::isReductionParallel(Node)) - Printer = printLine(Printer, "#pragma omp parallel for"); + if (IslAstInfo::isInnermostParallel(Node)) + Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr); - if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node)) - Printer = printLine(Printer, "#pragma omp parallel for reduction"); + if (IslAstInfo::isOutermostParallel(Node)) + Printer = printLine(Printer, OmpPragmaStr + BrokenReductionsStr); return isl_ast_node_for_print(Node, Printer, Options); } @@ -141,7 +166,7 @@ static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, /// (or non-zero) dependence distance on the dimension in question. static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, Dependences *D, - bool &IsReductionParallel) { + IslAstUserPayload *NodeInfo) { if (!D->hasValidDependences()) return false; @@ -153,7 +178,20 @@ static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED); if (!D->isParallel(Schedule, RedDeps)) - IsReductionParallel = true; + NodeInfo->IsReductionParallel = true; + + if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule)) + return true; + + // Annotate reduction parallel nodes with the memory accesses which caused the + // reduction dependences parallel execution of the node conflicts with. + for (const auto &MaRedPair : D->getReductionDependences()) { + if (!MaRedPair.second) + continue; + RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second)); + if (!D->isParallel(Schedule, RedDeps)) + NodeInfo->BrokenReductions.insert(MaRedPair.first); + } isl_union_map_free(Schedule); return true; @@ -177,8 +215,7 @@ static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, // Test for parallelism only if we are not already inside a parallel loop if (!BuildInfo->InParallelFor) BuildInfo->InParallelFor = Payload->IsOutermostParallel = - astScheduleDimIsParallel(Build, BuildInfo->Deps, - Payload->IsReductionParallel); + astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); return Id; } @@ -206,13 +243,13 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, // Innermost loops that are surrounded by parallel loops have not yet been // tested for parallelism. Test them here to ensure we check all innermost // loops for parallelism. - if (Payload->IsInnermost && BuildInfo->InParallelFor) + if (Payload->IsInnermost && BuildInfo->InParallelFor) { if (Payload->IsOutermostParallel) Payload->IsInnermostParallel = true; else - Payload->IsInnermostParallel = astScheduleDimIsParallel( - Build, BuildInfo->Deps, Payload->IsReductionParallel); - else if (Payload->IsOutermostParallel) + Payload->IsInnermostParallel = + astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); + } else if (Payload->IsOutermostParallel) BuildInfo->InParallelFor = false; isl_id_free(Id); @@ -370,6 +407,12 @@ isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) { return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr; } +IslAstInfo::MemoryAccessSet * +IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + return Payload ? &Payload->BrokenReductions : nullptr; +} + void IslAstInfo::printScop(raw_ostream &OS) const { isl_ast_print_options *Options; isl_ast_node *RootNode = getAst(); |

