summaryrefslogtreecommitdiffstats
path: root/polly/lib/CodeGen/IslAst.cpp
diff options
context:
space:
mode:
authorJohannes Doerfert <jdoerfert@codeaurora.org>2014-08-01 08:17:19 +0000
committerJohannes Doerfert <jdoerfert@codeaurora.org>2014-08-01 08:17:19 +0000
commitdc6ad99aada4a2f32a5303ab1d169bf45afea5c2 (patch)
tree69e43ebe75da8a2b60dbf3c64a1f4161d82025b5 /polly/lib/CodeGen/IslAst.cpp
parented67f8baf6847712aab6437050341e00975e1221 (diff)
downloadbcm5719-llvm-dc6ad99aada4a2f32a5303ab1d169bf45afea5c2.tar.gz
bcm5719-llvm-dc6ad99aada4a2f32a5303ab1d169bf45afea5c2.zip
Annotate the IslAst with broken reductions
+ Split all reduction dependences and map them to the causing memory accesses. + Print the types & base addresses of broken reductions for each "reduction parallel" marked loop (OpenMP style). + 3 test cases to show how reductions are now represented in the isl ast. The mapping "(ast) loops -> broken reductions" is also needed to find the memory accesses we need to privatize in a loop. llvm-svn: 214489
Diffstat (limited to 'polly/lib/CodeGen/IslAst.cpp')
-rw-r--r--polly/lib/CodeGen/IslAst.cpp79
1 files changed, 61 insertions, 18 deletions
diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp
index b9da4479463..203c4dc709c 100644
--- a/polly/lib/CodeGen/IslAst.cpp
+++ b/polly/lib/CodeGen/IslAst.cpp
@@ -108,23 +108,48 @@ static isl_printer *printLine(__isl_take isl_printer *Printer,
return isl_printer_end_line(Printer);
}
+/// @brief Return all broken reductions as a string of clauses (OpenMP style).
+static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) {
+ IslAstInfo::MemoryAccessSet *BrokenReductions;
+ std::string str;
+
+ BrokenReductions = IslAstInfo::getBrokenReductions(Node);
+ if (!BrokenReductions || BrokenReductions->empty())
+ return "";
+
+ // Map each type of reduction to a comma separated list of the base addresses.
+ std::map<MemoryAccess::ReductionType, std::string> Clauses;
+ for (MemoryAccess *MA : *BrokenReductions)
+ if (MA->isWrite())
+ Clauses[MA->getReductionType()] +=
+ ", " + MA->getBaseAddr()->getName().str();
+
+ // Now print the reductions sorted by type. Each type will cause a clause
+ // like: reduction (+ : sum0, sum1, sum2)
+ for (const auto &ReductionClause : Clauses) {
+ str += " reduction (";
+ str += MemoryAccess::getReductionOperatorStr(ReductionClause.first);
+ // Remove the first two symbols (", ") to make the output look pretty.
+ str += " : " + ReductionClause.second.substr(2) + ")";
+ }
+
+ return str;
+}
+
/// @brief Callback executed for each for node in the ast in order to print it.
static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
__isl_take isl_ast_print_options *Options,
__isl_keep isl_ast_node *Node, void *) {
- if (IslAstInfo::isInnermostParallel(Node) &&
- !IslAstInfo::isReductionParallel(Node))
- Printer = printLine(Printer, "#pragma simd");
- if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
- Printer = printLine(Printer, "#pragma simd reduction");
+ const std::string BrokenReductionsStr = getBrokenReductionsStr(Node);
+ const std::string SimdPragmaStr = "#pragma simd";
+ const std::string OmpPragmaStr = "#pragma omp parallel for";
- if (IslAstInfo::isOutermostParallel(Node) &&
- !IslAstInfo::isReductionParallel(Node))
- Printer = printLine(Printer, "#pragma omp parallel for");
+ if (IslAstInfo::isInnermostParallel(Node))
+ Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr);
- if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
- Printer = printLine(Printer, "#pragma omp parallel for reduction");
+ if (IslAstInfo::isOutermostParallel(Node))
+ Printer = printLine(Printer, OmpPragmaStr + BrokenReductionsStr);
return isl_ast_node_for_print(Node, Printer, Options);
}
@@ -141,7 +166,7 @@ static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
/// (or non-zero) dependence distance on the dimension in question.
static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
Dependences *D,
- bool &IsReductionParallel) {
+ IslAstUserPayload *NodeInfo) {
if (!D->hasValidDependences())
return false;
@@ -153,7 +178,20 @@ static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED);
if (!D->isParallel(Schedule, RedDeps))
- IsReductionParallel = true;
+ NodeInfo->IsReductionParallel = true;
+
+ if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule))
+ return true;
+
+ // Annotate reduction parallel nodes with the memory accesses which caused the
+ // reduction dependences parallel execution of the node conflicts with.
+ for (const auto &MaRedPair : D->getReductionDependences()) {
+ if (!MaRedPair.second)
+ continue;
+ RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second));
+ if (!D->isParallel(Schedule, RedDeps))
+ NodeInfo->BrokenReductions.insert(MaRedPair.first);
+ }
isl_union_map_free(Schedule);
return true;
@@ -177,8 +215,7 @@ static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build,
// Test for parallelism only if we are not already inside a parallel loop
if (!BuildInfo->InParallelFor)
BuildInfo->InParallelFor = Payload->IsOutermostParallel =
- astScheduleDimIsParallel(Build, BuildInfo->Deps,
- Payload->IsReductionParallel);
+ astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload);
return Id;
}
@@ -206,13 +243,13 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
// Innermost loops that are surrounded by parallel loops have not yet been
// tested for parallelism. Test them here to ensure we check all innermost
// loops for parallelism.
- if (Payload->IsInnermost && BuildInfo->InParallelFor)
+ if (Payload->IsInnermost && BuildInfo->InParallelFor) {
if (Payload->IsOutermostParallel)
Payload->IsInnermostParallel = true;
else
- Payload->IsInnermostParallel = astScheduleDimIsParallel(
- Build, BuildInfo->Deps, Payload->IsReductionParallel);
- else if (Payload->IsOutermostParallel)
+ Payload->IsInnermostParallel =
+ astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload);
+ } else if (Payload->IsOutermostParallel)
BuildInfo->InParallelFor = false;
isl_id_free(Id);
@@ -370,6 +407,12 @@ isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) {
return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr;
}
+IslAstInfo::MemoryAccessSet *
+IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) {
+ IslAstUserPayload *Payload = getNodePayload(Node);
+ return Payload ? &Payload->BrokenReductions : nullptr;
+}
+
void IslAstInfo::printScop(raw_ostream &OS) const {
isl_ast_print_options *Options;
isl_ast_node *RootNode = getAst();
OpenPOWER on IntegriCloud