data_structures/union_find - minor refactors and tests added

Summary: Minor refactor of the union_find data structure. Testing added for union_find.

Test Plan: Manual

Reviewers: sale, buda

Reviewed By: sale, buda

Subscribers: pullbot, florijan, sale, buda

Differential Revision: https://phabricator.memgraph.io/D42
This commit is contained in:
Florijan Stamenkovic 2017-01-23 14:56:37 +01:00
parent 45c31b08e7
commit dee15acd1e
2 changed files with 138 additions and 33 deletions

View File

@ -1,28 +1,37 @@
#pragma once
#include <vector>
#include <memory>
template <class uintXX_t = uint32_t,
class allocator = std::allocator<uintXX_t>>
template <class uintXX_t = uint32_t>
/**
* UnionFind data structure. Provides means of connectivity
* setting and checking in logarithmic complexity. Memory
* complexity is linear.
*/
class UnionFind
{
public:
UnionFind(uintXX_t n) : N(n), n(n)
/**
* Constructor, creates a UnionFind structure of fixed size.
*
* @param n Number of elements in the data structure.
*/
UnionFind(uintXX_t n) : set_count(n), count(n), parent(n)
{
count = alloc.allocate(n);
parent = alloc.allocate(n);
for(auto i = 0; i < n; ++i)
count[i] = 1, parent[i] = i;
}
~UnionFind()
{
alloc.deallocate(count, N);
alloc.deallocate(parent, N);
}
// this is O(lg* n)
/**
* Connects two elements (and thereby the sets they belong
* to). If they are already connected the function has no effect.
*
* Has O(alpha(n)) time complexity.
*
* @param p First element.
* @param q Second element.
*/
void connect(uintXX_t p, uintXX_t q)
{
auto rp = root(p);
@ -39,16 +48,41 @@ public:
parent[rq] = rp, count[rp] += count[rq];
// update the number of groups
n--;
set_count--;
}
// O(lg* n)
/**
* Indicates if two elements are connected. Has O(alpha(n)) time
* complexity.
*
* @param p First element.
* @param q Second element.
* @return See above.
*/
bool find(uintXX_t p, uintXX_t q)
{
return root(p) == root(q);
}
// O(lg* n)
/**
* Returns the number of disjoint sets in this UnionFind.
*
* @return See above.
*/
uintXX_t size() const
{
return set_count;
}
private:
uintXX_t set_count;
// array of subtree counts
std::vector<uintXX_t> count;
// array of tree indices
std::vector<uintXX_t> parent;
uintXX_t root(uintXX_t p)
{
auto r = p;
@ -64,21 +98,4 @@ public:
return r;
}
uintXX_t size() const
{
return n;
}
private:
allocator alloc;
const uintXX_t N;
uintXX_t n;
// array of subtree counts
uintXX_t* count;
// array of tree indices
uintXX_t* parent;
};

88
tests/unit/union_find.cpp Normal file
View File

@ -0,0 +1,88 @@
#include <stdlib.h>
#include <iostream>
#include "gtest/gtest.h"
#include "data_structures/union_find/union_find.hpp"
void _expect_fully(UnionFind<> &uf, bool connected, int from=0, int to=-1) {
if (to == -1)
to = uf.size();
for (int i = from ; i < to ; i++)
for (int j = from ; j < to ; j++)
if (i != j)
EXPECT_EQ(uf.find(i, j), connected);
}
TEST(UnionFindTest, InitialSizeTest) {
for (int i = 0 ; i < 10 ; i ++) {
UnionFind<> uf(i);
EXPECT_EQ(i, uf.size());
}
}
TEST(UnionFindTest, ModifiedSizeTest) {
UnionFind<> uf(10);
EXPECT_EQ(10, uf.size());
uf.connect(0, 0);
EXPECT_EQ(10, uf.size());
uf.connect(0, 1);
EXPECT_EQ(9, uf.size());
uf.connect(2, 3);
EXPECT_EQ(8, uf.size());
uf.connect(0, 2);
EXPECT_EQ(7, uf.size());
uf.connect(1, 3);
EXPECT_EQ(7, uf.size());
}
TEST(UnionFindTest, Disconectivity) {
UnionFind<> uf(10);
_expect_fully(uf, false);
}
TEST(UnionFindTest, ConnectivityAlongChain) {
UnionFind<> uf(10);
for (unsigned int i = 1 ; i < uf.size() ; i++)
uf.connect(i - 1, i);
_expect_fully(uf, true);
}
TEST(UnionFindTest, ConnectivityOnTree) {
UnionFind<> uf(10);
_expect_fully(uf, false);
uf.connect(0, 1);
uf.connect(0, 2);
_expect_fully(uf, true, 0, 3);
_expect_fully(uf, false, 2);
uf.connect(2, 3);
_expect_fully(uf, true, 0, 4);
_expect_fully(uf, false, 3);
}
TEST(UnionFindTest, DisjointChains) {
UnionFind<> uf(30);
for (int i = 0 ; i < 30 ; i++)
uf.connect(i, i % 10 == 0 ? i : i - 1);
for (int i = 0 ; i < 30 ; i++)
for (int j = 0 ; j < 30 ; j++)
EXPECT_EQ(uf.find(i, j), (j - (j % 10)) == (i - (i % 10)));
}
int main(int argc, char **argv)
{
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}