From dee15acd1e446408d2b08f56c04964814fa51bb1 Mon Sep 17 00:00:00 2001
From: Florijan Stamenkovic <florijan.stamenkovic@memgraph.io>
Date: Mon, 23 Jan 2017 14:56:37 +0100
Subject: [PATCH] data_structures/union_find - minor refactors and tests added

Summary: Minor refactor of the union_find data structure. Testing added for union_find.

Test Plan: Manual

Reviewers: sale, buda

Reviewed By: sale, buda

Subscribers: pullbot, florijan, sale, buda

Differential Revision: https://phabricator.memgraph.io/D42
---
 .../data_structures/union_find/union_find.hpp | 83 ++++++++++-------
 tests/unit/union_find.cpp                     | 88 +++++++++++++++++++
 2 files changed, 138 insertions(+), 33 deletions(-)
 create mode 100644 tests/unit/union_find.cpp

diff --git a/include/data_structures/union_find/union_find.hpp b/include/data_structures/union_find/union_find.hpp
index 5d284b310..c43db20b9 100644
--- a/include/data_structures/union_find/union_find.hpp
+++ b/include/data_structures/union_find/union_find.hpp
@@ -1,28 +1,37 @@
 #pragma once
 
+#include <vector>
 #include <memory>
 
-template <class uintXX_t = uint32_t,
-          class allocator = std::allocator<uintXX_t>>
+template <class uintXX_t = uint32_t>
+/**
+ * UnionFind data structure. Provides means of connectivity
+ * setting and checking in logarithmic complexity. Memory
+ * complexity is linear.
+ */
 class UnionFind
 {
 public:
-    UnionFind(uintXX_t n) : N(n), n(n)
+    /**
+     * Constructor, creates a UnionFind structure of fixed size.
+     *
+     * @param n Number of elements in the data structure.
+     */
+    UnionFind(uintXX_t n) : set_count(n), count(n), parent(n)
     {
-        count = alloc.allocate(n);
-        parent = alloc.allocate(n);
-
         for(auto i = 0; i < n; ++i)
             count[i] = 1, parent[i] = i;
     }
 
-    ~UnionFind()
-    {
-        alloc.deallocate(count, N);
-        alloc.deallocate(parent, N);
-    }
-
-    // this is O(lg* n)
+    /**
+     * Connects two elements (and thereby the sets they belong
+     * to). If they are already connected the function has no effect.
+     *
+     * Has O(alpha(n)) time complexity.
+     *
+     * @param p First element.
+     * @param q Second element.
+     */
     void connect(uintXX_t p, uintXX_t q)
     {
         auto rp = root(p);
@@ -39,16 +48,41 @@ public:
             parent[rq] = rp, count[rp] += count[rq];
 
         // update the number of groups
-        n--;
+        set_count--;
     }
 
-    // O(lg* n)
+    /**
+     * Indicates if two elements are connected. Has O(alpha(n)) time
+     * complexity.
+     *
+     * @param p First element.
+     * @param q Second element.
+     * @return See above.
+     */
     bool find(uintXX_t p, uintXX_t q)
     {
         return root(p) == root(q);
     }
 
-    // O(lg* n)
+    /**
+     * Returns the number of disjoint sets in this UnionFind.
+     *
+     * @return See above.
+     */
+    uintXX_t size() const
+    {
+        return set_count;
+    }
+
+private:
+    uintXX_t set_count;
+
+    // array of subtree counts
+    std::vector<uintXX_t> count;
+
+    // array of tree indices
+    std::vector<uintXX_t> parent;
+    
     uintXX_t root(uintXX_t p)
     {
         auto r = p;
@@ -64,21 +98,4 @@ public:
 
         return r;
     }
-
-    uintXX_t size() const
-    {
-        return n;
-    }
-
-private:
-    allocator alloc;
-
-    const uintXX_t N;
-    uintXX_t n;
-
-    // array of subtree counts
-    uintXX_t* count;
-
-    // array of tree indices
-    uintXX_t* parent;
 };
diff --git a/tests/unit/union_find.cpp b/tests/unit/union_find.cpp
new file mode 100644
index 000000000..8f68d36ea
--- /dev/null
+++ b/tests/unit/union_find.cpp
@@ -0,0 +1,88 @@
+#include <stdlib.h>
+#include <iostream>
+
+#include "gtest/gtest.h"
+
+#include "data_structures/union_find/union_find.hpp"
+
+
+void _expect_fully(UnionFind<> &uf, bool connected, int from=0, int to=-1) {
+
+    if (to == -1)
+        to = uf.size();
+
+    for (int i = from ; i < to ; i++)
+        for (int j = from ; j < to ; j++)
+            if (i != j)
+                EXPECT_EQ(uf.find(i, j), connected);
+}
+
+TEST(UnionFindTest, InitialSizeTest) {
+    for (int i = 0 ; i < 10 ; i ++) {
+        UnionFind<> uf(i);
+        EXPECT_EQ(i, uf.size());
+    }
+}
+
+TEST(UnionFindTest, ModifiedSizeTest) {
+    UnionFind<> uf(10);
+    EXPECT_EQ(10, uf.size());
+
+    uf.connect(0, 0);
+    EXPECT_EQ(10, uf.size());
+
+    uf.connect(0, 1);
+    EXPECT_EQ(9, uf.size());
+
+    uf.connect(2, 3);
+    EXPECT_EQ(8, uf.size());
+
+    uf.connect(0, 2);
+    EXPECT_EQ(7, uf.size());
+
+    uf.connect(1, 3);
+    EXPECT_EQ(7, uf.size());
+}
+
+
+TEST(UnionFindTest, Disconectivity) {
+    UnionFind<> uf(10);
+    _expect_fully(uf, false);
+}
+
+TEST(UnionFindTest, ConnectivityAlongChain) {
+    UnionFind<> uf(10);
+    for (unsigned int i = 1 ; i < uf.size() ; i++)
+        uf.connect(i - 1, i);
+    _expect_fully(uf, true);
+}
+
+TEST(UnionFindTest, ConnectivityOnTree) {
+    UnionFind<> uf(10);
+    _expect_fully(uf, false);
+    
+    uf.connect(0, 1);
+    uf.connect(0, 2);
+    _expect_fully(uf, true, 0, 3);
+    _expect_fully(uf, false, 2);
+
+    uf.connect(2, 3);
+    _expect_fully(uf, true, 0, 4);
+    _expect_fully(uf, false, 3);
+}
+
+TEST(UnionFindTest, DisjointChains) {
+    UnionFind<> uf(30);
+    for (int i = 0 ; i < 30 ; i++)
+        uf.connect(i, i % 10 == 0 ? i : i - 1);
+
+    for (int i = 0 ; i < 30 ; i++)
+        for (int j = 0 ; j < 30 ; j++)
+            EXPECT_EQ(uf.find(i, j), (j - (j % 10)) == (i - (i % 10)));
+}
+
+int main(int argc, char **argv)
+{
+    ::testing::InitGoogleTest(&argc, argv);
+    return RUN_ALL_TESTS();
+}