summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/ADT/DirectedGraph.h270
-rw-r--r--llvm/unittests/ADT/CMakeLists.txt1
-rw-r--r--llvm/unittests/ADT/DirectedGraphTest.cpp295
3 files changed, 566 insertions, 0 deletions
diff --git a/llvm/include/llvm/ADT/DirectedGraph.h b/llvm/include/llvm/ADT/DirectedGraph.h
new file mode 100644
index 00000000000..f6a358d99cd
--- /dev/null
+++ b/llvm/include/llvm/ADT/DirectedGraph.h
@@ -0,0 +1,270 @@
+//===- llvm/ADT/DirectedGraph.h - Directed Graph ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the interface and a base class implementation for a
+// directed graph.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_DIRECTEDGRAPH_H
+#define LLVM_ADT_DIRECTEDGRAPH_H
+
+#include "llvm/ADT/GraphTraits.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm {
+
+/// Represent an edge in the directed graph.
+/// The edge contains the target node it connects to.
+template <class NodeType, class EdgeType> class DGEdge {
+public:
+ DGEdge() = delete;
+ /// Create an edge pointing to the given node \p N.
+ explicit DGEdge(NodeType &N) : TargetNode(N) {}
+ explicit DGEdge(const DGEdge<NodeType, EdgeType> &E)
+ : TargetNode(E.TargetNode) {}
+ DGEdge<NodeType, EdgeType> &operator=(const DGEdge<NodeType, EdgeType> &E) {
+ TargetNode = E.TargetNode;
+ return *this;
+ }
+
+ /// Static polymorphism: delegate implementation (via isEqualTo) to the
+ /// derived class.
+ bool operator==(const EdgeType &E) const { return getDerived().isEqualTo(E); }
+ bool operator!=(const EdgeType &E) const { return !operator==(E); }
+
+ /// Retrieve the target node this edge connects to.
+ const NodeType &getTargetNode() const { return TargetNode; }
+ NodeType &getTargetNode() {
+ return const_cast<NodeType &>(
+ static_cast<const DGEdge<NodeType, EdgeType> &>(*this).getTargetNode());
+ }
+
+protected:
+ // As the default implementation use address comparison for equality.
+ bool isEqualTo(const EdgeType &E) const { return this == &E; }
+
+ // Cast the 'this' pointer to the derived type and return a reference.
+ EdgeType &getDerived() { return *static_cast<EdgeType *>(this); }
+ const EdgeType &getDerived() const {
+ return *static_cast<const EdgeType *>(this);
+ }
+
+ // The target node this edge connects to.
+ NodeType &TargetNode;
+};
+
+/// Represent a node in the directed graph.
+/// The node has a (possibly empty) list of outgoing edges.
+template <class NodeType, class EdgeType> class DGNode {
+public:
+ using EdgeListTy = SetVector<EdgeType *>;
+ using iterator = typename EdgeListTy::iterator;
+ using const_iterator = typename EdgeListTy::const_iterator;
+
+ /// Create a node with a single outgoing edge \p E.
+ explicit DGNode(EdgeType &E) : Edges() { Edges.insert(&E); }
+ DGNode() = default;
+
+ explicit DGNode(const DGNode<NodeType, EdgeType> &N) : Edges(N.Edges) {}
+ DGNode(DGNode<NodeType, EdgeType> &&N) : Edges(std::move(N.Edges)) {}
+
+ DGNode<NodeType, EdgeType> &operator=(const DGNode<NodeType, EdgeType> &N) {
+ Edges = N.Edges;
+ return *this;
+ }
+ DGNode<NodeType, EdgeType> &operator=(const DGNode<NodeType, EdgeType> &&N) {
+ Edges = std::move(N.Edges);
+ return *this;
+ }
+
+ /// Static polymorphism: delegate implementation (via isEqualTo) to the
+ /// derived class.
+ bool operator==(const NodeType &N) const { return getDerived().isEqualTo(N); }
+ bool operator!=(const NodeType &N) const { return !operator==(N); }
+
+ const_iterator begin() const { return Edges.begin(); }
+ const_iterator end() const { return Edges.end(); }
+ iterator begin() { return Edges.begin(); }
+ iterator end() { return Edges.end(); }
+ const EdgeType &front() const { return *Edges.front(); }
+ EdgeType &front() { return *Edges.front(); }
+ const EdgeType &back() const { return *Edges.back(); }
+ EdgeType &back() { return *Edges.back(); }
+
+ /// Collect in \p EL, all the edges from this node to \p N.
+ /// Return true if at least one edge was found, and false otherwise.
+ /// Note that this implementation allows more than one edge to connect
+ /// a given pair of nodes.
+ bool findEdgesTo(const NodeType &N, SmallVectorImpl<EdgeType *> &EL) const {
+ assert(EL.empty() && "Expected the list of edges to be empty.");
+ for (auto *E : Edges)
+ if (E->getTargetNode() == N)
+ EL.push_back(E);
+ return !EL.empty();
+ }
+
+ /// Add the given edge \p E to this node, if it doesn't exist already. Returns
+ /// true if the edge is added and false otherwise.
+ bool addEdge(EdgeType &E) { return Edges.insert(&E); }
+
+ /// Remove the given edge \p E from this node, if it exists.
+ void removeEdge(EdgeType &E) { Edges.remove(&E); }
+
+ /// Test whether there is an edge that goes from this node to \p N.
+ bool hasEdgeTo(const NodeType &N) const {
+ return (findEdgeTo(N) != Edges.end());
+ }
+
+ /// Retrieve the outgoing edges for the node.
+ const EdgeListTy &getEdges() const { return Edges; }
+ EdgeListTy &getEdges() {
+ return const_cast<EdgeListTy &>(
+ static_cast<const DGNode<NodeType, EdgeType> &>(*this).Edges);
+ }
+
+ /// Clear the outgoing edges.
+ void clear() { Edges.clear(); }
+
+protected:
+ // As the default implementation use address comparison for equality.
+ bool isEqualTo(const NodeType &N) const { return this == &N; }
+
+ // Cast the 'this' pointer to the derived type and return a reference.
+ NodeType &getDerived() { return *static_cast<NodeType *>(this); }
+ const NodeType &getDerived() const {
+ return *static_cast<const NodeType *>(this);
+ }
+
+ /// Find an edge to \p N. If more than one edge exists, this will return
+ /// the first one in the list of edges.
+ const_iterator findEdgeTo(const NodeType &N) const {
+ return llvm::find_if(
+ Edges, [&N](const EdgeType *E) { return E->getTargetNode() == N; });
+ }
+
+ // The list of outgoing edges.
+ EdgeListTy Edges;
+};
+
+/// Directed graph
+///
+/// The graph is represented by a table of nodes.
+/// Each node contains a (possibly empty) list of outgoing edges.
+/// Each edge contains the target node it connects to.
+template <class NodeType, class EdgeType> class DirectedGraph {
+protected:
+ using NodeListTy = SmallVector<NodeType *, 10>;
+ using EdgeListTy = SmallVector<EdgeType *, 10>;
+public:
+ using iterator = typename NodeListTy::iterator;
+ using const_iterator = typename NodeListTy::const_iterator;
+ using DGraphType = DirectedGraph<NodeType, EdgeType>;
+
+ DirectedGraph() = default;
+ explicit DirectedGraph(NodeType &N) : Nodes() { addNode(N); }
+ DirectedGraph(const DGraphType &G) : Nodes(G.Nodes) {}
+ DirectedGraph(DGraphType &&RHS) : Nodes(std::move(RHS.Nodes)) {}
+ DGraphType &operator=(const DGraphType &G) {
+ Nodes = G.Nodes;
+ return *this;
+ }
+ DGraphType &operator=(const DGraphType &&G) {
+ Nodes = std::move(G.Nodes);
+ return *this;
+ }
+
+ const_iterator begin() const { return Nodes.begin(); }
+ const_iterator end() const { return Nodes.end(); }
+ iterator begin() { return Nodes.begin(); }
+ iterator end() { return Nodes.end(); }
+ const NodeType &front() const { return *Nodes.front(); }
+ NodeType &front() { return *Nodes.front(); }
+ const NodeType &back() const { return *Nodes.back(); }
+ NodeType &back() { return *Nodes.back(); }
+
+ size_t size() const { return Nodes.size(); }
+
+ /// Find the given node \p N in the table.
+ const_iterator findNode(const NodeType &N) const {
+ return llvm::find_if(Nodes,
+ [&N](const NodeType *Node) { return *Node == N; });
+ }
+ iterator findNode(const NodeType &N) {
+ return const_cast<iterator>(
+ static_cast<const DGraphType &>(*this).findNode(N));
+ }
+
+ /// Add the given node \p N to the graph if it is not already present.
+ bool addNode(NodeType &N) {
+ if (findNode(N) != Nodes.end())
+ return false;
+ Nodes.push_back(&N);
+ return true;
+ }
+
+ /// Collect in \p EL all edges that are coming into node \p N. Return true
+ /// if at least one edge was found, and false otherwise.
+ bool findIncomingEdgesToNode(const NodeType &N, SmallVectorImpl<EdgeType*> &EL) const {
+ assert(EL.empty() && "Expected the list of edges to be empty.");
+ EdgeListTy TempList;
+ for (auto *Node : Nodes) {
+ if (*Node == N)
+ continue;
+ Node->findEdgesTo(N, TempList);
+ EL.insert(EL.end(), TempList.begin(), TempList.end());
+ TempList.clear();
+ }
+ return !EL.empty();
+ }
+
+ /// Remove the given node \p N from the graph. If the node has incoming or
+ /// outgoing edges, they are also removed. Return true if the node was found
+ /// and then removed, and false if the node was not found in the graph to
+ /// begin with.
+ bool removeNode(NodeType &N) {
+ iterator IT = findNode(N);
+ if (IT == Nodes.end())
+ return false;
+ // Remove incoming edges.
+ EdgeListTy EL;
+ for (auto *Node : Nodes) {
+ if (*Node == N)
+ continue;
+ Node->findEdgesTo(N, EL);
+ for (auto *E : EL)
+ Node->removeEdge(*E);
+ EL.clear();
+ }
+ N.clear();
+ Nodes.erase(IT);
+ return true;
+ }
+
+ /// Assuming nodes \p Src and \p Dst are already in the graph, connect node \p
+ /// Src to node \p Dst using the provided edge \p E. Return true if \p Src is
+ /// not already connected to \p Dst via \p E, and false otherwise.
+ bool connect(NodeType &Src, NodeType &Dst, EdgeType &E) {
+ assert(findNode(Src) != Nodes.end() && "Src node should be present.");
+ assert(findNode(Dst) != Nodes.end() && "Dst node should be present.");
+ assert((E.getTargetNode() == Dst) &&
+ "Target of the given edge does not match Dst.");
+ return Src.addEdge(E);
+ }
+
+protected:
+ // The list of nodes in the graph.
+ NodeListTy Nodes;
+};
+
+} // namespace llvm
+
+#endif // LLVM_ADT_DIRECTEDGRAPH_H
diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt
index 676ce181871..3a7be5b5522 100644
--- a/llvm/unittests/ADT/CMakeLists.txt
+++ b/llvm/unittests/ADT/CMakeLists.txt
@@ -17,6 +17,7 @@ add_llvm_unittest(ADTTests
DenseMapTest.cpp
DenseSetTest.cpp
DepthFirstIteratorTest.cpp
+ DirectedGraphTest.cpp
EquivalenceClassesTest.cpp
FallibleIteratorTest.cpp
FoldingSet.cpp
diff --git a/llvm/unittests/ADT/DirectedGraphTest.cpp b/llvm/unittests/ADT/DirectedGraphTest.cpp
new file mode 100644
index 00000000000..ae1f6b01ef2
--- /dev/null
+++ b/llvm/unittests/ADT/DirectedGraphTest.cpp
@@ -0,0 +1,295 @@
+//===- llvm/unittest/ADT/DirectedGraphTest.cpp ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines concrete derivations of the directed-graph base classes
+// for testing purposes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/DirectedGraph.h"
+#include "llvm/ADT/GraphTraits.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "gtest/gtest.h"
+
+namespace llvm {
+
+//===--------------------------------------------------------------------===//
+// Derived nodes, edges and graph types based on DirectedGraph.
+//===--------------------------------------------------------------------===//
+
+class DGTestNode;
+class DGTestEdge;
+using DGTestNodeBase = DGNode<DGTestNode, DGTestEdge>;
+using DGTestEdgeBase = DGEdge<DGTestNode, DGTestEdge>;
+using DGTestBase = DirectedGraph<DGTestNode, DGTestEdge>;
+
+class DGTestNode : public DGTestNodeBase {
+public:
+ DGTestNode() = default;
+};
+
+class DGTestEdge : public DGTestEdgeBase {
+public:
+ DGTestEdge() = delete;
+ DGTestEdge(DGTestNode &N) : DGTestEdgeBase(N) {}
+};
+
+class DGTestGraph : public DGTestBase {
+public:
+ DGTestGraph() = default;
+ ~DGTestGraph(){};
+};
+
+using EdgeListTy = SmallVector<DGTestEdge *, 2>;
+
+//===--------------------------------------------------------------------===//
+// GraphTraits specializations for the DGTest
+//===--------------------------------------------------------------------===//
+
+template <> struct GraphTraits<DGTestNode *> {
+ using NodeRef = DGTestNode *;
+
+ static DGTestNode *DGTestGetTargetNode(DGEdge<DGTestNode, DGTestEdge> *P) {
+ return &P->getTargetNode();
+ }
+
+ // Provide a mapped iterator so that the GraphTrait-based implementations can
+ // find the target nodes without having to explicitly go through the edges.
+ using ChildIteratorType =
+ mapped_iterator<DGTestNode::iterator, decltype(&DGTestGetTargetNode)>;
+ using ChildEdgeIteratorType = DGTestNode::iterator;
+
+ static NodeRef getEntryNode(NodeRef N) { return N; }
+ static ChildIteratorType child_begin(NodeRef N) {
+ return ChildIteratorType(N->begin(), &DGTestGetTargetNode);
+ }
+ static ChildIteratorType child_end(NodeRef N) {
+ return ChildIteratorType(N->end(), &DGTestGetTargetNode);
+ }
+
+ static ChildEdgeIteratorType child_edge_begin(NodeRef N) {
+ return N->begin();
+ }
+ static ChildEdgeIteratorType child_edge_end(NodeRef N) { return N->end(); }
+};
+
+template <>
+struct GraphTraits<DGTestGraph *> : public GraphTraits<DGTestNode *> {
+ using nodes_iterator = DGTestGraph::iterator;
+ static NodeRef getEntryNode(DGTestGraph *DG) { return *DG->begin(); }
+ static nodes_iterator nodes_begin(DGTestGraph *DG) { return DG->begin(); }
+ static nodes_iterator nodes_end(DGTestGraph *DG) { return DG->end(); }
+};
+
+//===--------------------------------------------------------------------===//
+// Test various modification and query functions.
+//===--------------------------------------------------------------------===//
+
+TEST(DirectedGraphTest, AddAndConnectNodes) {
+ DGTestGraph DG;
+ DGTestNode N1, N2, N3;
+ DGTestEdge E1(N1), E2(N2), E3(N3);
+
+ // Check that new nodes can be added successfully.
+ EXPECT_TRUE(DG.addNode(N1));
+ EXPECT_TRUE(DG.addNode(N2));
+ EXPECT_TRUE(DG.addNode(N3));
+
+ // Check that duplicate nodes are not added to the graph.
+ EXPECT_FALSE(DG.addNode(N1));
+
+ // Check that nodes can be connected using valid edges with no errors.
+ EXPECT_TRUE(DG.connect(N1, N2, E2));
+ EXPECT_TRUE(DG.connect(N2, N3, E3));
+ EXPECT_TRUE(DG.connect(N3, N1, E1));
+
+ // The graph looks like this now:
+ //
+ // +---------------+
+ // v |
+ // N1 -> N2 -> N3 -+
+
+ // Check that already connected nodes with the given edge are not connected
+ // again (ie. edges are between nodes are not duplicated).
+ EXPECT_FALSE(DG.connect(N3, N1, E1));
+
+ // Check that there are 3 nodes in the graph.
+ EXPECT_TRUE(DG.size() == 3);
+
+ // Check that the added nodes can be found in the graph.
+ EXPECT_NE(DG.findNode(N3), DG.end());
+
+ // Check that nodes that are not part of the graph are not found.
+ DGTestNode N4;
+ EXPECT_EQ(DG.findNode(N4), DG.end());
+
+ // Check that findIncommingEdgesToNode works correctly.
+ EdgeListTy EL;
+ EXPECT_TRUE(DG.findIncomingEdgesToNode(N1, EL));
+ EXPECT_TRUE(EL.size() == 1);
+ EXPECT_EQ(*EL[0], E1);
+}
+
+TEST(DirectedGraphTest, AddRemoveEdge) {
+ DGTestGraph DG;
+ DGTestNode N1, N2, N3;
+ DGTestEdge E1(N1), E2(N2), E3(N3);
+ DG.addNode(N1);
+ DG.addNode(N2);
+ DG.addNode(N3);
+ DG.connect(N1, N2, E2);
+ DG.connect(N2, N3, E3);
+ DG.connect(N3, N1, E1);
+
+ // The graph looks like this now:
+ //
+ // +---------------+
+ // v |
+ // N1 -> N2 -> N3 -+
+
+ // Check that there are 3 nodes in the graph.
+ EXPECT_TRUE(DG.size() == 3);
+
+ // Check that the target nodes of the edges are correct.
+ EXPECT_EQ(E1.getTargetNode(), N1);
+ EXPECT_EQ(E2.getTargetNode(), N2);
+ EXPECT_EQ(E3.getTargetNode(), N3);
+
+ // Remove the edge from N1 to N2.
+ N1.removeEdge(E2);
+
+ // The graph looks like this now:
+ //
+ // N2 -> N3 -> N1
+
+ // Check that there are no incoming edges to N2.
+ EdgeListTy EL;
+ EXPECT_FALSE(DG.findIncomingEdgesToNode(N2, EL));
+ EXPECT_TRUE(EL.empty());
+
+ // Put the edge from N1 to N2 back in place.
+ N1.addEdge(E2);
+
+ // Check that E2 is the only incoming edge to N2.
+ EL.clear();
+ EXPECT_TRUE(DG.findIncomingEdgesToNode(N2, EL));
+ EXPECT_EQ(*EL[0], E2);
+}
+
+TEST(DirectedGraphTest, hasEdgeTo) {
+ DGTestGraph DG;
+ DGTestNode N1, N2, N3;
+ DGTestEdge E1(N1), E2(N2), E3(N3), E4(N1);
+ DG.addNode(N1);
+ DG.addNode(N2);
+ DG.addNode(N3);
+ DG.connect(N1, N2, E2);
+ DG.connect(N2, N3, E3);
+ DG.connect(N3, N1, E1);
+ DG.connect(N2, N1, E4);
+
+ // The graph looks like this now:
+ //
+ // +-----+
+ // v |
+ // N1 -> N2 -> N3
+ // ^ |
+ // +-----------+
+
+ EXPECT_TRUE(N2.hasEdgeTo(N1));
+ EXPECT_TRUE(N3.hasEdgeTo(N1));
+}
+
+TEST(DirectedGraphTest, AddRemoveNode) {
+ DGTestGraph DG;
+ DGTestNode N1, N2, N3;
+ DGTestEdge E1(N1), E2(N2), E3(N3);
+ DG.addNode(N1);
+ DG.addNode(N2);
+ DG.addNode(N3);
+ DG.connect(N1, N2, E2);
+ DG.connect(N2, N3, E3);
+ DG.connect(N3, N1, E1);
+
+ // The graph looks like this now:
+ //
+ // +---------------+
+ // v |
+ // N1 -> N2 -> N3 -+
+
+ // Check that there are 3 nodes in the graph.
+ EXPECT_TRUE(DG.size() == 3);
+
+ // Check that a node in the graph can be removed, but not more than once.
+ EXPECT_TRUE(DG.removeNode(N1));
+ EXPECT_EQ(DG.findNode(N1), DG.end());
+ EXPECT_FALSE(DG.removeNode(N1));
+
+ // The graph looks like this now:
+ //
+ // N2 -> N3
+
+ // Check that there are 2 nodes in the graph and only N2 is connected to N3.
+ EXPECT_TRUE(DG.size() == 2);
+ EXPECT_TRUE(N3.getEdges().empty());
+ EdgeListTy EL;
+ EXPECT_FALSE(DG.findIncomingEdgesToNode(N2, EL));
+ EXPECT_TRUE(EL.empty());
+}
+
+TEST(DirectedGraphTest, SCC) {
+
+ DGTestGraph DG;
+ DGTestNode N1, N2, N3, N4;
+ DGTestEdge E1(N1), E2(N2), E3(N3), E4(N4);
+ DG.addNode(N1);
+ DG.addNode(N2);
+ DG.addNode(N3);
+ DG.addNode(N4);
+ DG.connect(N1, N2, E2);
+ DG.connect(N2, N3, E3);
+ DG.connect(N3, N1, E1);
+ DG.connect(N3, N4, E4);
+
+ // The graph looks like this now:
+ //
+ // +---------------+
+ // v |
+ // N1 -> N2 -> N3 -+ N4
+ // | ^
+ // +--------+
+
+ // Test that there are two SCCs:
+ // 1. {N1, N2, N3}
+ // 2. {N4}
+ using NodeListTy = SmallPtrSet<DGTestNode *, 3>;
+ SmallVector<NodeListTy, 4> ListOfSCCs;
+ for (auto &SCC : make_range(scc_begin(&DG), scc_end(&DG)))
+ ListOfSCCs.push_back(NodeListTy(SCC.begin(), SCC.end()));
+
+ EXPECT_TRUE(ListOfSCCs.size() == 2);
+
+ for (auto &SCC : ListOfSCCs) {
+ if (SCC.size() > 1)
+ continue;
+ EXPECT_TRUE(SCC.size() == 1);
+ EXPECT_TRUE(SCC.count(&N4) == 1);
+ }
+ for (auto &SCC : ListOfSCCs) {
+ if (SCC.size() <= 1)
+ continue;
+ EXPECT_TRUE(SCC.size() == 3);
+ EXPECT_TRUE(SCC.count(&N1) == 1);
+ EXPECT_TRUE(SCC.count(&N2) == 1);
+ EXPECT_TRUE(SCC.count(&N3) == 1);
+ EXPECT_TRUE(SCC.count(&N4) == 0);
+ }
+}
+
+} // namespace llvm
OpenPOWER on IntegriCloud