summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis')
-rw-r--r--llvm/lib/Analysis/LoopInfo.cpp72
-rw-r--r--llvm/lib/Analysis/VectorUtils.cpp95
2 files changed, 161 insertions, 6 deletions
diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp
index 6c779bf2cca..ef2b1257015 100644
--- a/llvm/lib/Analysis/LoopInfo.cpp
+++ b/llvm/lib/Analysis/LoopInfo.cpp
@@ -293,16 +293,50 @@ bool Loop::isAnnotatedParallel() const {
if (!DesiredLoopIdMetadata)
return false;
+ MDNode *ParallelAccesses =
+ findOptionMDForLoop(this, "llvm.loop.parallel_accesses");
+ SmallPtrSet<MDNode *, 4>
+ ParallelAccessGroups; // For scalable 'contains' check.
+ if (ParallelAccesses) {
+ for (const MDOperand &MD : drop_begin(ParallelAccesses->operands(), 1)) {
+ MDNode *AccGroup = cast<MDNode>(MD.get());
+ assert(isValidAsAccessGroup(AccGroup) &&
+ "List item must be an access group");
+ ParallelAccessGroups.insert(AccGroup);
+ }
+ }
+
// The loop branch contains the parallel loop metadata. In order to ensure
// that any parallel-loop-unaware optimization pass hasn't added loop-carried
// dependencies (thus converted the loop back to a sequential loop), check
- // that all the memory instructions in the loop contain parallelism metadata
- // that point to the same unique "loop id metadata" the loop branch does.
+ // that all the memory instructions in the loop belong to an access group that
+ // is parallel to this loop.
for (BasicBlock *BB : this->blocks()) {
for (Instruction &I : *BB) {
if (!I.mayReadOrWriteMemory())
continue;
+ if (MDNode *AccessGroup = I.getMetadata(LLVMContext::MD_access_group)) {
+ auto ContainsAccessGroup = [&ParallelAccessGroups](MDNode *AG) -> bool {
+ if (AG->getNumOperands() == 0) {
+ assert(isValidAsAccessGroup(AG) && "Item must be an access group");
+ return ParallelAccessGroups.count(AG);
+ }
+
+ for (const MDOperand &AccessListItem : AG->operands()) {
+ MDNode *AccGroup = cast<MDNode>(AccessListItem.get());
+ assert(isValidAsAccessGroup(AccGroup) &&
+ "List item must be an access group");
+ if (ParallelAccessGroups.count(AccGroup))
+ return true;
+ }
+ return false;
+ };
+
+ if (ContainsAccessGroup(AccessGroup))
+ continue;
+ }
+
// The memory instruction can refer to the loop identifier metadata
// directly or indirectly through another list metadata (in case of
// nested parallel loops). The loop identifier metadata refers to
@@ -693,6 +727,40 @@ void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) {
}
}
+MDNode *llvm::findOptionMDForLoopID(MDNode *LoopID, StringRef Name) {
+ // No loop metadata node, no loop properties.
+ if (!LoopID)
+ return nullptr;
+
+ // First operand should refer to the metadata node itself, for legacy reasons.
+ assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
+ assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
+
+ // Iterate over the metdata node operands and look for MDString metadata.
+ for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) {
+ MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i));
+ if (!MD || MD->getNumOperands() < 1)
+ continue;
+ MDString *S = dyn_cast<MDString>(MD->getOperand(0));
+ if (!S)
+ continue;
+ // Return the operand node if MDString holds expected metadata.
+ if (Name.equals(S->getString()))
+ return MD;
+ }
+
+ // Loop property not found.
+ return nullptr;
+}
+
+MDNode *llvm::findOptionMDForLoop(const Loop *TheLoop, StringRef Name) {
+ return findOptionMDForLoopID(TheLoop->getLoopID(), Name);
+}
+
+bool llvm::isValidAsAccessGroup(MDNode *Node) {
+ return Node->getNumOperands() == 0 && Node->isDistinct();
+}
+
//===----------------------------------------------------------------------===//
// LoopInfo implementation
//
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 9ebb01684c8..e7404be73dc 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -464,16 +464,100 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
return MinBWs;
}
+/// Add all access groups in @p AccGroups to @p List.
+template <typename ListT>
+static void addToAccessGroupList(ListT &List, MDNode *AccGroups) {
+ // Interpret an access group as a list containing itself.
+ if (AccGroups->getNumOperands() == 0) {
+ assert(isValidAsAccessGroup(AccGroups) && "Node must be an access group");
+ List.insert(AccGroups);
+ return;
+ }
+
+ for (auto &AccGroupListOp : AccGroups->operands()) {
+ auto *Item = cast<MDNode>(AccGroupListOp.get());
+ assert(isValidAsAccessGroup(Item) && "List item must be an access group");
+ List.insert(Item);
+ }
+};
+
+MDNode *llvm::uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2) {
+ if (!AccGroups1)
+ return AccGroups2;
+ if (!AccGroups2)
+ return AccGroups1;
+ if (AccGroups1 == AccGroups2)
+ return AccGroups1;
+
+ SmallSetVector<Metadata *, 4> Union;
+ addToAccessGroupList(Union, AccGroups1);
+ addToAccessGroupList(Union, AccGroups2);
+
+ if (Union.size() == 0)
+ return nullptr;
+ if (Union.size() == 1)
+ return cast<MDNode>(Union.front());
+
+ LLVMContext &Ctx = AccGroups1->getContext();
+ return MDNode::get(Ctx, Union.getArrayRef());
+}
+
+MDNode *llvm::intersectAccessGroups(const Instruction *Inst1,
+ const Instruction *Inst2) {
+ bool MayAccessMem1 = Inst1->mayReadOrWriteMemory();
+ bool MayAccessMem2 = Inst2->mayReadOrWriteMemory();
+
+ if (!MayAccessMem1 && !MayAccessMem2)
+ return nullptr;
+ if (!MayAccessMem1)
+ return Inst2->getMetadata(LLVMContext::MD_access_group);
+ if (!MayAccessMem2)
+ return Inst1->getMetadata(LLVMContext::MD_access_group);
+
+ MDNode *MD1 = Inst1->getMetadata(LLVMContext::MD_access_group);
+ MDNode *MD2 = Inst2->getMetadata(LLVMContext::MD_access_group);
+ if (!MD1 || !MD2)
+ return nullptr;
+ if (MD1 == MD2)
+ return MD1;
+
+ // Use set for scalable 'contains' check.
+ SmallPtrSet<Metadata *, 4> AccGroupSet2;
+ addToAccessGroupList(AccGroupSet2, MD2);
+
+ SmallVector<Metadata *, 4> Intersection;
+ if (MD1->getNumOperands() == 0) {
+ assert(isValidAsAccessGroup(MD1) && "Node must be an access group");
+ if (AccGroupSet2.count(MD1))
+ Intersection.push_back(MD1);
+ } else {
+ for (const MDOperand &Node : MD1->operands()) {
+ auto *Item = cast<MDNode>(Node.get());
+ assert(isValidAsAccessGroup(Item) && "List item must be an access group");
+ if (AccGroupSet2.count(Item))
+ Intersection.push_back(Item);
+ }
+ }
+
+ if (Intersection.size() == 0)
+ return nullptr;
+ if (Intersection.size() == 1)
+ return cast<MDNode>(Intersection.front());
+
+ LLVMContext &Ctx = Inst1->getContext();
+ return MDNode::get(Ctx, Intersection);
+}
+
/// \returns \p I after propagating metadata from \p VL.
Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) {
Instruction *I0 = cast<Instruction>(VL[0]);
SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata;
I0->getAllMetadataOtherThanDebugLoc(Metadata);
- for (auto Kind :
- {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
- LLVMContext::MD_noalias, LLVMContext::MD_fpmath,
- LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load}) {
+ for (auto Kind : {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
+ LLVMContext::MD_noalias, LLVMContext::MD_fpmath,
+ LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load,
+ LLVMContext::MD_access_group}) {
MDNode *MD = I0->getMetadata(Kind);
for (int J = 1, E = VL.size(); MD && J != E; ++J) {
@@ -494,6 +578,9 @@ Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) {
case LLVMContext::MD_invariant_load:
MD = MDNode::intersect(MD, IMD);
break;
+ case LLVMContext::MD_access_group:
+ MD = intersectAccessGroups(Inst, IJ);
+ break;
default:
llvm_unreachable("unhandled metadata");
}
OpenPOWER on IntegriCloud