diff options
Diffstat (limited to 'llvm/lib/Analysis')
| -rw-r--r-- | llvm/lib/Analysis/LoopInfo.cpp | 72 | ||||
| -rw-r--r-- | llvm/lib/Analysis/VectorUtils.cpp | 95 |
2 files changed, 161 insertions, 6 deletions
diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp index 6c779bf2cca..ef2b1257015 100644 --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -293,16 +293,50 @@ bool Loop::isAnnotatedParallel() const { if (!DesiredLoopIdMetadata) return false; + MDNode *ParallelAccesses = + findOptionMDForLoop(this, "llvm.loop.parallel_accesses"); + SmallPtrSet<MDNode *, 4> + ParallelAccessGroups; // For scalable 'contains' check. + if (ParallelAccesses) { + for (const MDOperand &MD : drop_begin(ParallelAccesses->operands(), 1)) { + MDNode *AccGroup = cast<MDNode>(MD.get()); + assert(isValidAsAccessGroup(AccGroup) && + "List item must be an access group"); + ParallelAccessGroups.insert(AccGroup); + } + } + // The loop branch contains the parallel loop metadata. In order to ensure // that any parallel-loop-unaware optimization pass hasn't added loop-carried // dependencies (thus converted the loop back to a sequential loop), check - // that all the memory instructions in the loop contain parallelism metadata - // that point to the same unique "loop id metadata" the loop branch does. + // that all the memory instructions in the loop belong to an access group that + // is parallel to this loop. for (BasicBlock *BB : this->blocks()) { for (Instruction &I : *BB) { if (!I.mayReadOrWriteMemory()) continue; + if (MDNode *AccessGroup = I.getMetadata(LLVMContext::MD_access_group)) { + auto ContainsAccessGroup = [&ParallelAccessGroups](MDNode *AG) -> bool { + if (AG->getNumOperands() == 0) { + assert(isValidAsAccessGroup(AG) && "Item must be an access group"); + return ParallelAccessGroups.count(AG); + } + + for (const MDOperand &AccessListItem : AG->operands()) { + MDNode *AccGroup = cast<MDNode>(AccessListItem.get()); + assert(isValidAsAccessGroup(AccGroup) && + "List item must be an access group"); + if (ParallelAccessGroups.count(AccGroup)) + return true; + } + return false; + }; + + if (ContainsAccessGroup(AccessGroup)) + continue; + } + // The memory instruction can refer to the loop identifier metadata // directly or indirectly through another list metadata (in case of // nested parallel loops). The loop identifier metadata refers to @@ -693,6 +727,40 @@ void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) { } } +MDNode *llvm::findOptionMDForLoopID(MDNode *LoopID, StringRef Name) { + // No loop metadata node, no loop properties. + if (!LoopID) + return nullptr; + + // First operand should refer to the metadata node itself, for legacy reasons. + assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); + assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + + // Iterate over the metdata node operands and look for MDString metadata. + for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (!MD || MD->getNumOperands() < 1) + continue; + MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + if (!S) + continue; + // Return the operand node if MDString holds expected metadata. + if (Name.equals(S->getString())) + return MD; + } + + // Loop property not found. + return nullptr; +} + +MDNode *llvm::findOptionMDForLoop(const Loop *TheLoop, StringRef Name) { + return findOptionMDForLoopID(TheLoop->getLoopID(), Name); +} + +bool llvm::isValidAsAccessGroup(MDNode *Node) { + return Node->getNumOperands() == 0 && Node->isDistinct(); +} + //===----------------------------------------------------------------------===// // LoopInfo implementation // diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 9ebb01684c8..e7404be73dc 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -464,16 +464,100 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, return MinBWs; } +/// Add all access groups in @p AccGroups to @p List. +template <typename ListT> +static void addToAccessGroupList(ListT &List, MDNode *AccGroups) { + // Interpret an access group as a list containing itself. + if (AccGroups->getNumOperands() == 0) { + assert(isValidAsAccessGroup(AccGroups) && "Node must be an access group"); + List.insert(AccGroups); + return; + } + + for (auto &AccGroupListOp : AccGroups->operands()) { + auto *Item = cast<MDNode>(AccGroupListOp.get()); + assert(isValidAsAccessGroup(Item) && "List item must be an access group"); + List.insert(Item); + } +}; + +MDNode *llvm::uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2) { + if (!AccGroups1) + return AccGroups2; + if (!AccGroups2) + return AccGroups1; + if (AccGroups1 == AccGroups2) + return AccGroups1; + + SmallSetVector<Metadata *, 4> Union; + addToAccessGroupList(Union, AccGroups1); + addToAccessGroupList(Union, AccGroups2); + + if (Union.size() == 0) + return nullptr; + if (Union.size() == 1) + return cast<MDNode>(Union.front()); + + LLVMContext &Ctx = AccGroups1->getContext(); + return MDNode::get(Ctx, Union.getArrayRef()); +} + +MDNode *llvm::intersectAccessGroups(const Instruction *Inst1, + const Instruction *Inst2) { + bool MayAccessMem1 = Inst1->mayReadOrWriteMemory(); + bool MayAccessMem2 = Inst2->mayReadOrWriteMemory(); + + if (!MayAccessMem1 && !MayAccessMem2) + return nullptr; + if (!MayAccessMem1) + return Inst2->getMetadata(LLVMContext::MD_access_group); + if (!MayAccessMem2) + return Inst1->getMetadata(LLVMContext::MD_access_group); + + MDNode *MD1 = Inst1->getMetadata(LLVMContext::MD_access_group); + MDNode *MD2 = Inst2->getMetadata(LLVMContext::MD_access_group); + if (!MD1 || !MD2) + return nullptr; + if (MD1 == MD2) + return MD1; + + // Use set for scalable 'contains' check. + SmallPtrSet<Metadata *, 4> AccGroupSet2; + addToAccessGroupList(AccGroupSet2, MD2); + + SmallVector<Metadata *, 4> Intersection; + if (MD1->getNumOperands() == 0) { + assert(isValidAsAccessGroup(MD1) && "Node must be an access group"); + if (AccGroupSet2.count(MD1)) + Intersection.push_back(MD1); + } else { + for (const MDOperand &Node : MD1->operands()) { + auto *Item = cast<MDNode>(Node.get()); + assert(isValidAsAccessGroup(Item) && "List item must be an access group"); + if (AccGroupSet2.count(Item)) + Intersection.push_back(Item); + } + } + + if (Intersection.size() == 0) + return nullptr; + if (Intersection.size() == 1) + return cast<MDNode>(Intersection.front()); + + LLVMContext &Ctx = Inst1->getContext(); + return MDNode::get(Ctx, Intersection); +} + /// \returns \p I after propagating metadata from \p VL. Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { Instruction *I0 = cast<Instruction>(VL[0]); SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; I0->getAllMetadataOtherThanDebugLoc(Metadata); - for (auto Kind : - {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_fpmath, - LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load}) { + for (auto Kind : {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_fpmath, + LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load, + LLVMContext::MD_access_group}) { MDNode *MD = I0->getMetadata(Kind); for (int J = 1, E = VL.size(); MD && J != E; ++J) { @@ -494,6 +578,9 @@ Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { case LLVMContext::MD_invariant_load: MD = MDNode::intersect(MD, IMD); break; + case LLVMContext::MD_access_group: + MD = intersectAccessGroups(Inst, IJ); + break; default: llvm_unreachable("unhandled metadata"); } |

