diff options
Diffstat (limited to 'polly/lib/Transform/ScheduleOptimizer.cpp')
| -rw-r--r-- | polly/lib/Transform/ScheduleOptimizer.cpp | 40 |
1 files changed, 39 insertions, 1 deletions
diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp index ba1373e7be2..f0a54e0ed2c 100644 --- a/polly/lib/Transform/ScheduleOptimizer.cpp +++ b/polly/lib/Transform/ScheduleOptimizer.cpp @@ -1268,6 +1268,38 @@ static isl_schedule_node *markInterIterationAliasFree(isl_schedule_node *Node, return isl_schedule_node_child(isl_schedule_node_insert_mark(Node, Id), 0); } +/// Restore the initial ordering of dimensions of the band node +/// +/// In case the band node represents all the dimensions of the iteration +/// domain, recreate the band node to restore the initial ordering of the +/// dimensions. +/// +/// @param Node The band node to be modified. +/// @return The modified schedule node. +namespace { +isl::schedule_node getBandNodeWithOriginDimOrder(isl::schedule_node Node) { + assert(isl_schedule_node_get_type(Node.keep()) == isl_schedule_node_band); + if (isl_schedule_node_get_type(Node.child(0).keep()) != + isl_schedule_node_leaf) + return Node; + auto Domain = isl::manage(isl_schedule_node_get_universe_domain(Node.keep())); + assert(isl_union_set_n_set(Domain.keep()) == 1); + if (isl_schedule_node_get_schedule_depth(Node.keep()) != 0 || + (isl::set(isl::manage(Domain.copy())).dim(isl::dim::set) != + isl_schedule_node_band_n_member(Node.keep()))) + return Node; + Node = isl::manage(isl_schedule_node_delete(Node.take())); + auto PartialSchedulePwAff = + isl::manage(isl_union_set_identity_union_pw_multi_aff(Domain.take())); + auto PartialScheduleMultiPwAff = + isl::multi_union_pw_aff(PartialSchedulePwAff); + PartialScheduleMultiPwAff = isl::manage(isl_multi_union_pw_aff_reset_tuple_id( + PartialScheduleMultiPwAff.take(), isl_dim_set)); + return isl::manage(isl_schedule_node_insert_partial_schedule( + Node.take(), PartialScheduleMultiPwAff.take())); +} +} // namespace + __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern( __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI, MatMulInfoTy &MMI) { @@ -1277,6 +1309,7 @@ __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern( assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest " "and, consequently, the corresponding scheduling " "functions have at least three dimensions."); + Node = getBandNodeWithOriginDimOrder(isl::manage(Node)).take(); Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3); int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j; int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k; @@ -1304,7 +1337,12 @@ bool ScheduleTreeOptimizer::isMatrMultPattern( MatMulInfoTy &MMI) { auto *PartialSchedule = isl_schedule_node_band_get_partial_schedule_union_map(Node); - if (isl_schedule_node_band_n_member(Node) < 3 || + Node = isl_schedule_node_child(Node, 0); + auto LeafType = isl_schedule_node_get_type(Node); + Node = isl_schedule_node_parent(Node); + if (LeafType != isl_schedule_node_leaf || + isl_schedule_node_band_n_member(Node) < 3 || + isl_schedule_node_get_schedule_depth(Node) != 0 || isl_union_map_n_map(PartialSchedule) != 1) { isl_union_map_free(PartialSchedule); return false; |

