diff options
Diffstat (limited to 'polly/lib/CodeGen/IslCodeGeneration.cpp')
-rw-r--r-- | polly/lib/CodeGen/IslCodeGeneration.cpp | 336 |
1 files changed, 329 insertions, 7 deletions
diff --git a/polly/lib/CodeGen/IslCodeGeneration.cpp b/polly/lib/CodeGen/IslCodeGeneration.cpp index bb1991511b3..5308da3c921 100644 --- a/polly/lib/CodeGen/IslCodeGeneration.cpp +++ b/polly/lib/CodeGen/IslCodeGeneration.cpp @@ -30,7 +30,11 @@ #include "polly/ScopInfo.h" #include "polly/Support/GICHelper.h" #include "polly/Support/ScopHelper.h" +#include "polly/Support/SCEVValidator.h" #include "polly/TempScopInfo.h" + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" @@ -56,11 +60,12 @@ using namespace llvm; class IslNodeBuilder { public: IslNodeBuilder(PollyIRBuilder &Builder, ScopAnnotator &Annotator, Pass *P, - LoopInfo &LI, ScalarEvolution &SE, DominatorTree &DT) - : Builder(Builder), Annotator(Annotator), + const DataLayout &DL, LoopInfo &LI, ScalarEvolution &SE, + DominatorTree &DT, Scop &S) + : S(S), Builder(Builder), Annotator(Annotator), Rewriter(new SCEVExpander(SE, "polly")), - ExprBuilder(Builder, IDToValue, *Rewriter), P(P), LI(LI), SE(SE), - DT(DT) {} + ExprBuilder(Builder, IDToValue, *Rewriter), P(P), DL(DL), LI(LI), + SE(SE), DT(DT) {} ~IslNodeBuilder() { delete Rewriter; } @@ -69,6 +74,7 @@ public: IslExprBuilder &getExprBuilder() { return ExprBuilder; } private: + Scop &S; PollyIRBuilder &Builder; ScopAnnotator &Annotator; @@ -77,10 +83,17 @@ private: IslExprBuilder ExprBuilder; Pass *P; + const DataLayout &DL; LoopInfo &LI; ScalarEvolution &SE; DominatorTree &DT; + /// @brief The current iteration of out-of-scop loops + /// + /// This map provides for a given loop a llvm::Value that contains the current + /// loop iteration. + LoopToScevMapT OutsideLoopIterations; + // This maps an isl_id* to the Value* it has in the generated program. For now // on, the only isl_ids that are stored here are the newly calculated loop // ivs. @@ -95,6 +108,12 @@ private: /// @param Expr The expression to code generate. Value *generateSCEV(const SCEV *Expr); + /// A set of Value -> Value remappings to apply when generating new code. + /// + /// When generating new code for a ScopStmt this map is used to map certain + /// llvm::Values to new llvm::Values. + ValueMapT ValueMap; + // Extract the upper bound of this loop // // The isl code generation can generate arbitrary expressions to check if the @@ -119,10 +138,49 @@ private: unsigned getNumberOfIterations(__isl_keep isl_ast_node *For); + /// Compute the values and loops referenced in this subtree. + /// + /// This function looks at all ScopStmts scheduled below the provided For node + /// and finds the llvm::Value[s] and llvm::Loops[s] which are referenced but + /// not locally defined. + /// + /// Values that can be synthesized or that are available as globals are + /// considered locally defined. + /// + /// Loops that contain the scop or that are part of the scop are considered + /// locally defined. Loops that are before the scop, but do not contain the + /// scop itself are considered not locally defined. + /// + /// @param For The node defining the subtree. + /// @param Values A vector that will be filled with the Values referenced in + /// this subtree. + /// @param Loops A vector that will be filled with the Loops referenced in + /// this subtree. + void getReferencesInSubtree(__isl_keep isl_ast_node *For, + SetVector<Value *> &Values, + SetVector<const Loop *> &Loops); + + /// Change the llvm::Value(s) used for code generation. + /// + /// When generating code certain values (e.g., references to induction + /// variables or array base pointers) in the original code may be replaced by + /// new values. This function allows to (partially) update the set of values + /// used. A typical use case for this function is the case when we continue + /// code generation in a subfunction/kernel function and need to explicitly + /// pass down certain values. + /// + /// @param NewValues A map that maps certain llvm::Values to new llvm::Values. + void updateValues(ParallelLoopGenerator::ValueToValueMapTy &NewValues); + void createFor(__isl_take isl_ast_node *For); void createForVector(__isl_take isl_ast_node *For, int VectorWidth); void createForSequential(__isl_take isl_ast_node *For); + /// Create LLVM-IR that executes a for node thread parallel. + /// + /// @param For The FOR isl_ast_node for which code is generated. + void createForParallel(__isl_take isl_ast_node *For); + /// Generate LLVM-IR that computes the values of the original induction /// variables in function of the newly generated loop induction variables. /// @@ -238,6 +296,98 @@ unsigned IslNodeBuilder::getNumberOfIterations(__isl_keep isl_ast_node *For) { return NumberOfIterations + 1; } +struct FindValuesUser { + LoopInfo &LI; + ScalarEvolution &SE; + Region &R; + SetVector<Value *> &Values; + SetVector<const SCEV *> &SCEVs; +}; + +/// Extract the values and SCEVs needed to generate code for a ScopStmt. +/// +/// This function extracts a ScopStmt from a given isl_set and computes the +/// Values this statement depends on as well as a set of SCEV expressions that +/// need to be synthesized when generating code for this statment. +static int findValuesInStmt(isl_set *Set, void *UserPtr) { + isl_id *Id = isl_set_get_tuple_id(Set); + struct FindValuesUser &User = *static_cast<struct FindValuesUser *>(UserPtr); + const ScopStmt *Stmt = static_cast<const ScopStmt *>(isl_id_get_user(Id)); + const BasicBlock *BB = Stmt->getBasicBlock(); + + // Check all the operands of instructions in the basic block. + for (const Instruction &Inst : *BB) { + for (Value *SrcVal : Inst.operands()) { + if (Instruction *OpInst = dyn_cast<Instruction>(SrcVal)) + if (canSynthesize(OpInst, &User.LI, &User.SE, &User.R)) { + User.SCEVs.insert( + User.SE.getSCEVAtScope(OpInst, User.LI.getLoopFor(BB))); + continue; + } + if (Instruction *OpInst = dyn_cast<Instruction>(SrcVal)) + if (Stmt->getParent()->getRegion().contains(OpInst)) + continue; + + if (isa<Instruction>(SrcVal) || isa<Argument>(SrcVal)) + User.Values.insert(SrcVal); + } + } + isl_id_free(Id); + isl_set_free(Set); + return 0; +} + +void IslNodeBuilder::getReferencesInSubtree(__isl_keep isl_ast_node *For, + SetVector<Value *> &Values, + SetVector<const Loop *> &Loops) { + + SetVector<const SCEV *> SCEVs; + struct FindValuesUser FindValues = {LI, SE, S.getRegion(), Values, SCEVs}; + + for (const auto &I : IDToValue) + Values.insert(I.second); + + for (const auto &I : OutsideLoopIterations) + Values.insert(cast<SCEVUnknown>(I.second)->getValue()); + + isl_union_set *Schedule = isl_union_map_domain(IslAstInfo::getSchedule(For)); + + isl_union_set_foreach_set(Schedule, findValuesInStmt, &FindValues); + isl_union_set_free(Schedule); + + for (const SCEV *Expr : SCEVs) { + findValues(Expr, Values); + findLoops(Expr, Loops); + } + + Values.remove_if([](const Value *V) { return isa<GlobalValue>(V); }); + + /// Remove loops that contain the scop or that are part of the scop, as they + /// are considered local. This leaves only loops that are before the scop, but + /// do not contain the scop itself. + Loops.remove_if([this](const Loop *L) { + return this->S.getRegion().contains(L) || + L->contains(S.getRegion().getEntry()); + }); +} + +void IslNodeBuilder::updateValues( + ParallelLoopGenerator::ValueToValueMapTy &NewValues) { + SmallPtrSet<Value *, 5> Inserted; + + for (const auto &I : IDToValue) { + IDToValue[I.first] = NewValues[I.second]; + Inserted.insert(I.second); + } + + for (const auto &I : NewValues) { + if (Inserted.count(I.first)) + continue; + + ValueMap[I.first] = I.second; + } +} + void IslNodeBuilder::createUserVector(__isl_take isl_ast_node *User, std::vector<Value *> &IVS, __isl_take isl_id *IteratorID, @@ -315,7 +465,7 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For, llvm_unreachable("Unhandled isl_ast_node in vectorizer"); } - IDToValue.erase(IteratorID); + IDToValue.erase(IDToValue.find(IteratorID)); isl_id_free(IteratorID); isl_union_map_free(Schedule); @@ -379,7 +529,7 @@ void IslNodeBuilder::createForSequential(__isl_take isl_ast_node *For) { Annotator.popLoop(Parallel); - IDToValue.erase(IteratorID); + IDToValue.erase(IDToValue.find(IteratorID)); Builder.SetInsertPoint(ExitBlock->begin()); @@ -388,6 +538,139 @@ void IslNodeBuilder::createForSequential(__isl_take isl_ast_node *For) { isl_id_free(IteratorID); } +/// @brief Remove the BBs contained in a (sub)function from the dominator tree. +/// +/// This function removes the basic blocks that are part of a subfunction from +/// the dominator tree. Specifically, when generating code it may happen that at +/// some point the code generation continues in a new sub-function (e.g., when +/// generating OpenMP code). The basic blocks that are created in this +/// sub-function are then still part of the dominator tree of the original +/// function, such that the dominator tree reaches over function boundaries. +/// This is not only incorrect, but also causes crashes. This function now +/// removes from the dominator tree all basic blocks that are dominated (and +/// consequently reachable) from the entry block of this (sub)function. +/// +/// FIXME: A LLVM (function or region) pass should not touch anything outside of +/// the function/region it runs on. Hence, the pure need for this function shows +/// that we do not comply to this rule. At the moment, this does not cause any +/// issues, but we should be aware that such issues may appear. Unfortunately +/// the current LLVM pass infrastructure does not allow to make Polly a module +/// or call-graph pass to solve this issue, as such a pass would not have access +/// to the per-function analyses passes needed by Polly. A future pass manager +/// infrastructure is supposed to enable such kind of access possibly allowing +/// us to create a cleaner solution here. +/// +/// FIXME: Instead of adding the dominance information and then dropping it +/// later on, we should try to just not add it in the first place. This requires +/// some careful testing to make sure this does not break in interaction with +/// the SCEVBuilder and SplitBlock which may rely on the dominator tree or +/// which may try to update it. +/// +/// @param F The function which contains the BBs to removed. +/// @param DT The dominator tree from which to remove the BBs. +static void removeSubFuncFromDomTree(Function *F, DominatorTree &DT) { + DomTreeNode *N = DT.getNode(&F->getEntryBlock()); + std::vector<BasicBlock *> Nodes; + + // We can only remove an element from the dominator tree, if all its children + // have been removed. To ensure this we obtain the list of nodes to remove + // using a post-order tree traversal. + for (po_iterator<DomTreeNode *> I = po_begin(N), E = po_end(N); I != E; ++I) + Nodes.push_back(I->getBlock()); + + for (BasicBlock *BB : Nodes) + DT.eraseNode(BB); +} + +void IslNodeBuilder::createForParallel(__isl_take isl_ast_node *For) { + isl_ast_node *Body; + isl_ast_expr *Init, *Inc, *Iterator, *UB; + isl_id *IteratorID; + Value *ValueLB, *ValueUB, *ValueInc; + Type *MaxType; + Value *IV; + CmpInst::Predicate Predicate; + + Body = isl_ast_node_for_get_body(For); + Init = isl_ast_node_for_get_init(For); + Inc = isl_ast_node_for_get_inc(For); + Iterator = isl_ast_node_for_get_iterator(For); + IteratorID = isl_ast_expr_get_id(Iterator); + UB = getUpperBound(For, Predicate); + + ValueLB = ExprBuilder.create(Init); + ValueUB = ExprBuilder.create(UB); + ValueInc = ExprBuilder.create(Inc); + + // OpenMP always uses SLE. In case the isl generated AST uses a SLT + // expression, we need to adjust the loop blound by one. + if (Predicate == CmpInst::ICMP_SLT) + ValueUB = Builder.CreateAdd( + ValueUB, Builder.CreateSExt(Builder.getTrue(), ValueUB->getType())); + + MaxType = ExprBuilder.getType(Iterator); + MaxType = ExprBuilder.getWidestType(MaxType, ValueLB->getType()); + MaxType = ExprBuilder.getWidestType(MaxType, ValueUB->getType()); + MaxType = ExprBuilder.getWidestType(MaxType, ValueInc->getType()); + + if (MaxType != ValueLB->getType()) + ValueLB = Builder.CreateSExt(ValueLB, MaxType); + if (MaxType != ValueUB->getType()) + ValueUB = Builder.CreateSExt(ValueUB, MaxType); + if (MaxType != ValueInc->getType()) + ValueInc = Builder.CreateSExt(ValueInc, MaxType); + + BasicBlock::iterator LoopBody; + + SetVector<Value *> SubtreeValues; + SetVector<const Loop *> Loops; + + getReferencesInSubtree(For, SubtreeValues, Loops); + + // Create for all loops we depend on values that contain the current loop + // iteration. These values are necessary to generate code for SCEVs that + // depend on such loops. As a result we need to pass them to the subfunction. + for (const Loop *L : Loops) { + const SCEV *OuterLIV = SE.getAddRecExpr(SE.getUnknown(Builder.getInt64(0)), + SE.getUnknown(Builder.getInt64(1)), + L, SCEV::FlagAnyWrap); + Value *V = generateSCEV(OuterLIV); + OutsideLoopIterations[L] = SE.getUnknown(V); + SubtreeValues.insert(V); + } + + ParallelLoopGenerator::ValueToValueMapTy NewValues; + ParallelLoopGenerator ParallelLoopGen(Builder, P, LI, DT, DL); + + IV = ParallelLoopGen.createParallelLoop(ValueLB, ValueUB, ValueInc, + SubtreeValues, NewValues, &LoopBody); + BasicBlock::iterator AfterLoop = Builder.GetInsertPoint(); + Builder.SetInsertPoint(LoopBody); + + // Save the current values. + ValueMapT ValueMapCopy = ValueMap; + IslExprBuilder::IDToValueTy IDToValueCopy = IDToValue; + + updateValues(NewValues); + IDToValue[IteratorID] = IV; + + create(Body); + + // Restore the original values. + ValueMap = ValueMapCopy; + IDToValue = IDToValueCopy; + + Builder.SetInsertPoint(AfterLoop); + removeSubFuncFromDomTree((*LoopBody).getParent()->getParent(), DT); + + for (const Loop *L : Loops) + OutsideLoopIterations.erase(L); + + isl_ast_node_free(For); + isl_ast_expr_free(Iterator); + isl_id_free(IteratorID); +} + void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) { bool Vector = PollyVectorizerChoice != VECTORIZER_NONE; @@ -399,6 +682,11 @@ void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) { return; } } + + if (IslAstInfo::isExecutedInParallel(For)) { + createForParallel(For); + return; + } createForSequential(For); } @@ -474,6 +762,12 @@ void IslNodeBuilder::createSubstitutions(isl_ast_expr *Expr, ScopStmt *Stmt, } } + // Add the current ValueMap to our per-statement value map. + // + // This is needed e.g. to rewrite array base addresses when moving code + // into a parallely executed subfunction. + VMap.insert(ValueMap.begin(), ValueMap.end()); + isl_ast_expr_free(Expr); } @@ -506,6 +800,8 @@ void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) { Id = isl_ast_expr_get_id(StmtExpr); isl_ast_expr_free(StmtExpr); + LTS.insert(OutsideLoopIterations.begin(), OutsideLoopIterations.end()); + Stmt = (ScopStmt *)isl_id_get_user(Id); createSubstitutions(Expr, Stmt, VMap, LTS); @@ -558,6 +854,27 @@ void IslNodeBuilder::addParameters(__isl_take isl_set *Context) { isl_id_free(Id); } + // Generate values for the current loop iteration for all surrounding loops. + // + // We may also reference loops outside of the scop which do not contain the + // scop itself, but as the number of such scops may be arbitrarily large we do + // not generate code for them here, but only at the point of code generation + // where these values are needed. + Region &R = S.getRegion(); + Loop *L = LI.getLoopFor(R.getEntry()); + + while (L != nullptr && R.contains(L)) + L = L->getParentLoop(); + + while (L != nullptr) { + const SCEV *OuterLIV = SE.getAddRecExpr(SE.getUnknown(Builder.getInt64(0)), + SE.getUnknown(Builder.getInt64(1)), + L, SCEV::FlagAnyWrap); + Value *V = generateSCEV(OuterLIV); + OutsideLoopIterations[L] = SE.getUnknown(V); + L = L->getParentLoop(); + } + isl_set_free(Context); } @@ -574,6 +891,9 @@ public: IslCodeGeneration() : ScopPass(ID) {} + /// @brief The datalayout used + const DataLayout *DL; + /// @name The analysis passes we need to generate code. /// ///{ @@ -605,6 +925,7 @@ public: AI = &getAnalysis<IslAstInfo>(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); SE = &getAnalysis<ScalarEvolution>(); + DL = &getAnalysis<DataLayoutPass>().getDataLayout(); assert(!S.getRegion().isTopLevelRegion() && "Top level regions are not supported"); @@ -616,7 +937,7 @@ public: BasicBlock *EnteringBB = simplifyRegion(&S, this); PollyIRBuilder Builder = createPollyIRBuilder(EnteringBB, Annotator); - IslNodeBuilder NodeBuilder(Builder, Annotator, this, *LI, *SE, *DT); + IslNodeBuilder NodeBuilder(Builder, Annotator, this, *DL, *LI, *SE, *DT, S); NodeBuilder.addParameters(S.getContext()); Value *RTC = buildRTC(Builder, NodeBuilder.getExprBuilder()); @@ -630,6 +951,7 @@ public: virtual void printScop(raw_ostream &OS) const {} virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<DataLayoutPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<IslAstInfo>(); AU.addRequired<RegionInfoPass>(); |