summaryrefslogtreecommitdiffstats
path: root/polly/lib/Transform/ScheduleOptimizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'polly/lib/Transform/ScheduleOptimizer.cpp')
-rw-r--r--polly/lib/Transform/ScheduleOptimizer.cpp40
1 files changed, 39 insertions, 1 deletions
diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp
index ba1373e7be2..f0a54e0ed2c 100644
--- a/polly/lib/Transform/ScheduleOptimizer.cpp
+++ b/polly/lib/Transform/ScheduleOptimizer.cpp
@@ -1268,6 +1268,38 @@ static isl_schedule_node *markInterIterationAliasFree(isl_schedule_node *Node,
return isl_schedule_node_child(isl_schedule_node_insert_mark(Node, Id), 0);
}
+/// Restore the initial ordering of dimensions of the band node
+///
+/// In case the band node represents all the dimensions of the iteration
+/// domain, recreate the band node to restore the initial ordering of the
+/// dimensions.
+///
+/// @param Node The band node to be modified.
+/// @return The modified schedule node.
+namespace {
+isl::schedule_node getBandNodeWithOriginDimOrder(isl::schedule_node Node) {
+ assert(isl_schedule_node_get_type(Node.keep()) == isl_schedule_node_band);
+ if (isl_schedule_node_get_type(Node.child(0).keep()) !=
+ isl_schedule_node_leaf)
+ return Node;
+ auto Domain = isl::manage(isl_schedule_node_get_universe_domain(Node.keep()));
+ assert(isl_union_set_n_set(Domain.keep()) == 1);
+ if (isl_schedule_node_get_schedule_depth(Node.keep()) != 0 ||
+ (isl::set(isl::manage(Domain.copy())).dim(isl::dim::set) !=
+ isl_schedule_node_band_n_member(Node.keep())))
+ return Node;
+ Node = isl::manage(isl_schedule_node_delete(Node.take()));
+ auto PartialSchedulePwAff =
+ isl::manage(isl_union_set_identity_union_pw_multi_aff(Domain.take()));
+ auto PartialScheduleMultiPwAff =
+ isl::multi_union_pw_aff(PartialSchedulePwAff);
+ PartialScheduleMultiPwAff = isl::manage(isl_multi_union_pw_aff_reset_tuple_id(
+ PartialScheduleMultiPwAff.take(), isl_dim_set));
+ return isl::manage(isl_schedule_node_insert_partial_schedule(
+ Node.take(), PartialScheduleMultiPwAff.take()));
+}
+} // namespace
+
__isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern(
__isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI,
MatMulInfoTy &MMI) {
@@ -1277,6 +1309,7 @@ __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern(
assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest "
"and, consequently, the corresponding scheduling "
"functions have at least three dimensions.");
+ Node = getBandNodeWithOriginDimOrder(isl::manage(Node)).take();
Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3);
int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j;
int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k;
@@ -1304,7 +1337,12 @@ bool ScheduleTreeOptimizer::isMatrMultPattern(
MatMulInfoTy &MMI) {
auto *PartialSchedule =
isl_schedule_node_band_get_partial_schedule_union_map(Node);
- if (isl_schedule_node_band_n_member(Node) < 3 ||
+ Node = isl_schedule_node_child(Node, 0);
+ auto LeafType = isl_schedule_node_get_type(Node);
+ Node = isl_schedule_node_parent(Node);
+ if (LeafType != isl_schedule_node_leaf ||
+ isl_schedule_node_band_n_member(Node) < 3 ||
+ isl_schedule_node_get_schedule_depth(Node) != 0 ||
isl_union_map_n_map(PartialSchedule) != 1) {
isl_union_map_free(PartialSchedule);
return false;
OpenPOWER on IntegriCloud