From 3fffa17b0de3740b6ed571bc4976fe3662e43b4e Mon Sep 17 00:00:00 2001
From: Kruno Tomola Fabro <krunotf@memgraph.io>
Date: Wed, 10 Aug 2016 20:02:54 +0100
Subject: [PATCH] Started adding remove methods. Discovered bugs in
 HashMultiMap and after fixing it add method is now little slower but find
 method is medium faster. So it turned out good.

---
 include/mvcc/edge_record.hpp                |   2 +-
 include/utils/eq_wrapper.hpp                |  20 +
 poc/astar.cpp                               |  13 +-
 src/data_structures/map/rh_hashmap.hpp      | 105 ++++-
 src/data_structures/map/rh_hashmultimap.hpp | 417 ++++++++++++++++----
 src/storage/edge_accessor.hpp               |  11 -
 src/utils/option_ptr.hpp                    |   2 +-
 tests/unit/rh_hashmap.cpp                   |  41 +-
 tests/unit/rh_hashmultimap.cpp              | 137 ++++++-
 9 files changed, 629 insertions(+), 119 deletions(-)
 create mode 100644 include/utils/eq_wrapper.hpp

diff --git a/include/mvcc/edge_record.hpp b/include/mvcc/edge_record.hpp
index d2e98e1a4..e04cde31b 100644
--- a/include/mvcc/edge_record.hpp
+++ b/include/mvcc/edge_record.hpp
@@ -21,7 +21,7 @@ public:
     {
     }
 
-    VertexRecord *&get_key() { return from_v; }
+    VertexRecord *&get_key() { return this->from_v; }
 
     auto from() const { return this->from_v; }
 
diff --git a/include/utils/eq_wrapper.hpp b/include/utils/eq_wrapper.hpp
new file mode 100644
index 000000000..3b802e2a0
--- /dev/null
+++ b/include/utils/eq_wrapper.hpp
@@ -0,0 +1,20 @@
+#pragma once
+
+template <class T>
+class EqWrapper
+{
+public:
+    EqWrapper(T t) : t(t) {}
+
+    friend bool operator==(const EqWrapper &a, const EqWrapper &b)
+    {
+        return a.t == b.t;
+    }
+
+    friend bool operator!=(const EqWrapper &a, const EqWrapper &b)
+    {
+        return !(a == b);
+    }
+
+    T t;
+};
diff --git a/poc/astar.cpp b/poc/astar.cpp
index b53b03ad7..6a4258e1d 100644
--- a/poc/astar.cpp
+++ b/poc/astar.cpp
@@ -8,6 +8,7 @@
 #include <string>
 #include <vector>
 
+#include "data_structures/map/rh_hashmap.hpp"
 #include "database/db.hpp"
 
 using namespace std;
@@ -33,6 +34,8 @@ public:
           record(record)
     {
     }
+
+    VertexRecord *&get_key() { return record; }
 };
 
 // class Iterator : public Crtp<Iterator>
@@ -144,6 +147,7 @@ void a_star(Db &db, int64_t sys_id_start, uint max_depth, EdgeFilter e_filter[],
             int limit)
 {
     auto &t = db.tx_engine.begin();
+    RhHashMap<VertexRecord *, Node> visited;
 
     auto cmp = [](Node *left, Node *right) { return left->cost > right->cost; };
     std::priority_queue<Node *, std::vector<Node *>, decltype(cmp)> queue(cmp);
@@ -155,6 +159,11 @@ void a_star(Db &db, int64_t sys_id_start, uint max_depth, EdgeFilter e_filter[],
     do {
         auto now = queue.top();
         queue.pop();
+
+        // if(!visited.insert(now)){
+        //     continue;
+        // }
+
         if (max_depth <= now->depth) {
             found_result(now);
             count++;
@@ -176,8 +185,8 @@ void a_star(Db &db, int64_t sys_id_start, uint max_depth, EdgeFilter e_filter[],
             }
         }
     } while (!queue.empty());
-
-    // GUBI SE MEMORIJA JER SE NODOVI NEBRISU
+    std::cout << "Found: " << count << " resoults\n";
+    // TODO: GUBI SE MEMORIJA JER SE NODOVI NEBRISU
 
     t.commit();
 }
diff --git a/src/data_structures/map/rh_hashmap.hpp b/src/data_structures/map/rh_hashmap.hpp
index e902395de..75b45cfa1 100644
--- a/src/data_structures/map/rh_hashmap.hpp
+++ b/src/data_structures/map/rh_hashmap.hpp
@@ -1,11 +1,12 @@
 #include "utils/crtp.hpp"
 #include "utils/option_ptr.hpp"
+#include <functional>
 
 // HashMap with RobinHood collision resolution policy.
 // Single threaded.
 // Entrys are saved as pointers alligned to 8B.
 // Entrys must know thers key.
-// D must have method K& get_key()
+// D must have method const K & get_key()
 // K must be comparable with ==.
 // HashMap behaves as if it isn't owner of entrys.
 template <class K, class D, size_t init_size_pow2 = 2>
@@ -20,7 +21,7 @@ private:
 
         Combined(D *data, size_t off)
         {
-            assert((data & 0x7) == 0 && off < 8);
+            // assert(((((size_t)(data)) & 0x7) == 0) && off < 8);
             this->data = ((size_t)data) | off;
         }
 
@@ -28,6 +29,17 @@ private:
 
         size_t off() { return data & 0x7; }
 
+        void decrement_off() { data--; }
+
+        bool increment_off()
+        {
+            if (off() < 7) {
+                data++;
+                return true;
+            }
+            return false;
+        }
+
         D *ptr() { return (D *)(data & (~(0x7))); }
 
     private:
@@ -169,7 +181,7 @@ public:
     void increase_size()
     {
         if (capacity == 0) {
-            assert(array == nullptr && count == 0);
+            // assert(array == nullptr && count == 0);
             size_t new_size = 1 << init_size_pow2;
             init_array(new_size);
             return;
@@ -189,7 +201,9 @@ public:
         free(a);
     }
 
-    OptionPtr<D> find(const K &key)
+    bool contains(const K &key) { return find(key).is_present(); }
+
+    OptionPtr<D> find(const K key)
     {
         size_t mask = this->mask();
         size_t now = index(key, mask);
@@ -217,9 +231,55 @@ public:
 
     // Inserts element. Returns true if element wasn't in the map.
     bool insert(D *data)
+    {
+        if (count < capacity) {
+            size_t mask = this->mask();
+            auto key = std::ref(data->get_key());
+            size_t now = index(key, mask);
+            size_t off = 0;
+            size_t border = 8 <= capacity ? 8 : capacity;
+            while (off < border) {
+                Combined other = array[now];
+                if (other.valid()) {
+                    auto other_off = other.off();
+                    if (other_off == off && key == other.ptr()->get_key()) {
+                        return false;
+
+                    } else if (other_off < off) { // Other is rich
+                        array[now] = Combined(data, off);
+
+                        while (other.increment_off()) {
+                            now = (now + 1) & mask;
+                            auto tmp = array[now];
+                            array[now] = other;
+                            other = tmp;
+                            if (!other.valid()) {
+                                count++;
+                                return true;
+                            }
+                        }
+                        data = other.ptr();
+                        break; // Cant insert removed element
+                    } // Else other has equal or greater offset, so he is poor.
+                } else {
+                    array[now] = Combined(data, off);
+                    count++;
+                    return true;
+                }
+
+                off++;
+                now = (now + 1) & mask;
+            }
+        }
+
+        increase_size();
+        return insert(data);
+    }
+
+    // Removes element. Returns removed element if it existed.
+    OptionPtr<D> remove(const K &key)
     {
         size_t mask = this->mask();
-        auto key = data->get_key();
         size_t now = index(key, mask);
         size_t off = 0;
         size_t border = 8 <= capacity ? 8 : capacity;
@@ -227,29 +287,36 @@ public:
             Combined other = array[now];
             if (other.valid()) {
                 auto other_off = other.off();
-                if (other_off == off && key == other.ptr()->get_key()) {
-                    return false;
+                auto other_ptr = other.ptr();
+                if (other_off == off &&
+                    key == other_ptr->get_key()) { // Found it
+
+                    auto before = now;
+                    do {
+                        other.decrement_off(); // This is alright even for off=0
+                                               // on found element because it
+                                               // wont be seen.
+                        array[before] = other;
+                        before = now;
+                        now = (now + 1) & mask;
+                        other = array[now];
+                    } while (other.valid() && other.off() > 0);
+
+                    array[before] = Combined();
+                    count--;
+                    return OptionPtr<D>(other_ptr);
 
                 } else if (other_off < off) { // Other is rich
-                    array[now] = Combined(data, off);
-
-                    // Hacked reusing of function
-                    data = other.ptr();
-                    key = data->get_key();
-                    off = other_off;
+                    break;
                 } // Else other has equal or greater offset, so he is poor.
             } else {
-                array[now] = Combined(data, off);
-                count++;
-                return true;
+                break;
             }
 
             off++;
             now = (now + 1) & mask;
         }
-
-        increase_size();
-        return insert(data);
+        return OptionPtr<D>();
     }
 
     void clear()
diff --git a/src/data_structures/map/rh_hashmultimap.hpp b/src/data_structures/map/rh_hashmultimap.hpp
index 3d1103171..111c21a3b 100644
--- a/src/data_structures/map/rh_hashmultimap.hpp
+++ b/src/data_structures/map/rh_hashmultimap.hpp
@@ -1,10 +1,11 @@
 #include "utils/crtp.hpp"
 #include "utils/option_ptr.hpp"
 #include <cstring>
+#include <functional>
 
 // HashMultiMap with RobinHood collision resolution policy.
 // Single threaded.
-// Entrys are saved as pointers alligned to 8B.
+// Entrys are POINTERS alligned to 8B.
 // Entrys must know thers key.
 // D must have method K& get_key()
 // K must be comparable with ==.
@@ -29,6 +30,17 @@ private:
 
         size_t off() { return data & 0x7; }
 
+        void decrement_off() { data--; }
+
+        bool increment_off()
+        {
+            if (off() < 7) {
+                data++;
+                return true;
+            }
+            return false;
+        }
+
         D *ptr() { return (D *)(data & (~(0x7))); }
 
     private:
@@ -89,7 +101,7 @@ private:
                     advanced = index = ~((size_t)0);
                     break;
                 }
-                index = advanced & mask;
+                index = (index + 1) & mask;
             } while (!map->array[index].valid());
 
             return this->derived();
@@ -224,122 +236,360 @@ public:
 
     bool contains(const K &key) { return find(key) != end(); }
 
-    Iterator find(const K &key)
+    Iterator find(const K &key_in)
     {
-        size_t mask = this->mask();
-        size_t now = index(key, mask);
-        size_t off = 0;
-
-        bool bef_init = false;
-        size_t before_off;
-        auto before_key = key;
-
-        size_t border = 8 <= capacity ? 8 : capacity;
-        while (off < border) {
+        if (count > 0) {
+            auto key = std::ref(key_in);
+            size_t mask = this->mask();
+            size_t now = index(key, mask);
+            size_t off = 0;
+            size_t checked = 0;
+            size_t border = 8 <= capacity ? 8 : capacity;
             Combined other = array[now];
-            if (other.valid()) {
+            while (other.valid() && off < border) {
                 auto other_off = other.off();
-                auto other_key = other.ptr()->get_key();
-                if (other_off == off && key == other_key) {
+                if (other_off == off && key == other.ptr()->get_key()) {
                     return Iterator(this, now);
 
                 } else if (other_off < off) { // Other is rich
                     break;
 
-                } else if (bef_init) { // Else other has equal or greater
-                                       // offset, so he is poor.
-                    if (before_off == other_off && before_key == other_key) {
-                        if (count == capacity) {
+                } else { // Else other has equal or greater
+                         // offset, so he is poor.
+                    auto other_key = other.ptr()->get_key();
+                    do {
+                        now = (now + 1) & mask;
+                        other = array[now];
+                        checked++;
+                        if (checked >= count) { // Reason is possibility of map
+                                                // full of same values.
                             break;
                         }
-                        // Proceed
-                    } else {
-                        before_off = other_off;
-                        before_key = other_key;
-                        off++;
-                    }
-                } else {
-                    bef_init = true;
-                    before_off = other_off;
-                    before_key = other_key;
+                    } while (other.valid() && other.off() == other_off &&
+                             other.ptr()->get_key() == other_key);
                     off++;
                 }
-
-            } else {
-                break;
             }
-
-            now = (now + 1) & mask;
         }
+
         return end();
     }
 
     // Inserts element with the given key.
-    void add(K &key, D *data)
+    void add(const K &key_in, D *data)
     {
-        assert(key == data->get_key());
+        assert(key_in == data->get_key());
 
-        size_t mask = this->mask();
-        size_t now = index(key, mask);
-        size_t off = 0;
+        if (count < capacity) {
+            auto key = std::ref(key_in);
+            size_t mask = this->mask();
+            size_t now = index(key, mask);
+            size_t start = now;
+            size_t off = 0;
+            size_t border = 8 <= capacity ? 8 : capacity;
+            bool multi = false;
 
-        bool bef_init = false;
-        size_t before_off;
-        auto before_key = key;
-
-        size_t border = 8 <= capacity ? 8 : capacity;
-        while (off < border) {
             Combined other = array[now];
-            if (other.valid()) {
-                auto other_off = other.off();
-                auto other_key = other.ptr()->get_key();
-                if (other_off == off && key == other_key) {
-                    // Proceed
+            while (off < border) {
+                if (other.valid()) {
+                    auto other_off = other.off();
+                    if (other_off == off &&
+                        other.ptr()->get_key() == key) { // Found the
+                        do {                             // same
+                            now = (now + 1) & mask;
+                            other = array[now];
+                            if (!other.valid()) {
+                                set(now, data, off);
+                                return;
+                            }
+                            other_off = other.off();
+                        } while (other_off == off &&
+                                 other.ptr()->get_key() == key);
+                        multi = true;
+                    } else if (other_off > off ||
+                               other_poor(other, mask, start,
+                                          now)) { // Other is poor or the same
+                        auto other_key = other.ptr()->get_key();
 
-                } else if (other_off < off) { // Other is rich
-                    array[now] = Combined(data, off);
-
-                    // Hacked reusing of function
-                    data = other.ptr();
-                    key = other_key;
-                    off = other_off;
-
-                    off++;
-                } else if (bef_init) { // Else other has equal or greater
-                                       // offset, so he is poor.
-                    if (before_off == other_off && before_key == other_key) {
-                        if (count == capacity) {
-                            break;
-                        }
-                        // Proceed
-                    } else {
-                        before_off = other_off;
-                        before_key = other_key;
+                        do {
+                            now = (now + 1) & mask;
+                            other = array[now];
+                        } while (other.valid() && other.off() == other_off &&
+                                 other.ptr()->get_key() == other_key);
                         off++;
+                        continue;
                     }
+
+                    array[now] = Combined(data, off);
+                    auto start = now;
+                    while (adjust_off(other, mask, start, now, multi)) {
+                        now = (now + 1) & mask;
+                        auto tmp = array[now];
+                        array[now] = other;
+                        other = tmp;
+                        if (!other.valid()) {
+                            count++;
+                            return;
+                        }
+                    }
+                    data = other.ptr();
+                    break; // Cant insert removed element
                 } else {
-                    bef_init = true;
-                    before_off = other_off;
-                    before_key = other_key;
-                    off++;
+                    set(now, data, off);
+                    return;
                 }
-
-            } else {
-                array[now] = Combined(data, off);
-                count++;
-                return;
             }
-
-            now = (now + 1) & mask;
         }
 
         increase_size();
         add(data);
     }
 
+private:
+    void set(size_t now, D *data, size_t off)
+    {
+        array[now] = Combined(data, off);
+        count++;
+    }
+
+    bool adjust_off(Combined &com, size_t mask, size_t start, size_t now,
+                    bool multi)
+    {
+        if (com.off() == 0) {
+            com.increment_off();
+            return true;
+        }
+        size_t cin = index(com.ptr()->get_key(), mask);
+        if ((start <= now && (cin < start || cin > now)) ||
+            (now < start && cin < start && cin > now)) {
+            return multi || com.increment_off();
+        }
+        auto a = array[cin];
+        auto b = array[(cin + 1) & mask];
+        return (a.off() == b.off() &&
+                a.ptr()->get_key() == b.ptr()->get_key()) ||
+               com.increment_off();
+    }
+
+    bool other_poor(Combined other, size_t mask, size_t start, size_t now)
+    {
+        auto cin = index(other.ptr()->get_key(), mask);
+        return (start <= now && (cin <= start || cin > now)) ||
+               (now < start && cin <= start && cin > now);
+    }
+    //     void skip(size_t &now, size_t mask)
+    //     {
+    //         Combined start = array[now];
+    //         size_t end = now;
+    //         auto off = start.off();
+    //         auto key = start.ptr()->get_key();
+    //         do {
+    //             now = (now + 1) & mask;
+    //             start = array[now];
+    //         } while (start.valid() && start.off() == off &&
+    //                  start.ptr()->get_key() == key && now != end);
+    //     }
+    //
+    //     void _insert(size_t now, Combined com)
+    //     {
+    //         Combined other = array[now];
+    //         array[now] = com;
+    //         if (other.valid()) {
+    //             _add(now, other.off(), other.ptr());
+    //         } else {
+    //             count++;
+    //         }
+    //     }
+    //
+    //     void _add(size_t now, size_t off, Data *data) {
+    //         size_t mask = this->mask();
+    //         auto key = std::ref(data->get_key());
+    //         size_t border = 8 <= capacity ? 8 : capacity;
+    //
+    //         skip()
+    //
+    //         while(off<border){
+    //             Combined other = array[now];
+    //             if (other.valid()) {
+    //                 _add(now, other.off(), other.ptr());
+    //             } else {
+    //                 _insert(now, RhHashMultiMap::Combined com)
+    //             }
+    //         }
+    //
+    //         increase_size();
+    //         add(data);
+    //     }
+
+public:
+    // // Inserts element with the given key.
+    // void add(const K &key_in, D *data)
+    // {
+    //     assert(key_in == data->get_key());
+    //
+    //     if (count < capacity) {
+    //         auto key = std::ref(key_in);
+    //         size_t mask = this->mask();
+    //         size_t now = index(key, mask);
+    //         size_t off = 0;
+    //
+    //         bool bef_init = false;
+    //         size_t before_off;
+    //         auto before_key = std::ref(key);
+    //         bool found_it = false;
+    //
+    //         size_t border = 8 <= capacity ? 8 : capacity;
+    //         while (off < border) {
+    //             Combined other = array[now];
+    //             if (other.valid()) {
+    //                 auto other_off = other.off();
+    //                 auto other_key = std::ref<const
+    //                 K>(other.ptr()->get_key());
+    //                 if (other_off == off && key == other_key) {
+    //                     found_it = true;
+    //                     // Proceed
+    //
+    //                 } else if (other_off < off || found_it) { // Other is
+    //                 rich
+    //                                                           // or after
+    //                                                           list
+    //                                                           // of my keys
+    //                     assert(other_off <= off);
+    //
+    //                     array[now] = Combined(data, off);
+    //                     // add(other.ptr()->get_key(), other.ptr());
+    //                     // return;
+    //
+    //                     // Hacked reusing of function
+    //                     before_off = off;
+    //                     before_key = key;
+    //                     data = other.ptr();
+    //                     key = other_key;
+    //                     off = other_off;
+    //
+    //                     if (found_it) { // Offset isn't increased
+    //                         found_it = false;
+    //                     } else {
+    //                         off++;
+    //                     }
+    //                 } else if (bef_init) { // Else other has equal or greater
+    //                                        // offset, so he is poor.
+    //                     if (before_off == other_off &&
+    //                         before_key == other_key) {
+    //                         if (count == capacity) {
+    //                             break;
+    //                         }
+    //                         // Proceed
+    //                     } else {
+    //                         before_off = other_off;
+    //                         before_key = other_key;
+    //                         off++;
+    //                     }
+    //                 } else {
+    //                     bef_init = true;
+    //                     before_off = other_off;
+    //                     before_key = other_key;
+    //                     off++;
+    //                 }
+    //
+    //             } else {
+    //                 array[now] = Combined(data, off);
+    //                 count++;
+    //                 return;
+    //             }
+    //
+    //             now = (now + 1) & mask;
+    //         }
+    //     }
+    //
+    //     increase_size();
+    //     add(data);
+    // }
+
     // Inserts element.
     void add(D *data) { add(data->get_key(), data); }
 
+    // Removes element. Returns removed element if it existed. It doesn't
+    // specify which element from same key group will be removed.
+    OptionPtr<D> remove(const K &key_in)
+    {
+        // auto key = std::ref(key_in);
+        // size_t mask = this->mask();
+        // size_t now = index(key, mask);
+        // size_t off = 0;
+        //
+        // bool bef_init = false;
+        // size_t before_off;
+        // auto before_key = key;
+        // bool found_it = false;
+        //
+        // size_t border = 8 <= capacity ? 8 : capacity;
+        // while (off < border) {
+        //     Combined other = array[now];
+        //     if (other.valid()) {
+        //         auto other_off = other.off();
+        //         auto other_ptr = other.ptr();
+        //         auto other_key = std::ref<const K>(other_ptr->get_key());
+        //         if (other_off == off && key == other_key) { // Found it
+        //             found_it = true;
+        //
+        //         } else if (found_it) { // Found first element after last
+        //         element
+        //             // for remove.
+        //             auto before = before_index(now, mask);
+        //             auto ret = OptionPtr<D>(array[before].ptr());
+        //             std::cout << "<-" << ret.get()->get_key() << "\n";
+        //             while (other.valid() && other.off() > 0) {
+        //                 std::cout << "<>" << other.ptr()->get_key() << "\n";
+        //                 other.decrement_off();
+        //                 array[before] = other;
+        //                 before = now;
+        //                 now = (now + 1) & mask;
+        //                 other = array[now];
+        //             }
+        //
+        //             array[before] = Combined();
+        //             count--;
+        //             return ret;
+        //         } else if (other_off < off) { // Other is rich
+        //             break;
+        //
+        //         } else if (bef_init) { // Else other has equal or greater
+        //                                // offset, so he is poor.
+        //             if (before_off == other_off && before_key == other_key) {
+        //                 if (count == capacity) { // I am stuck.
+        //                     break;
+        //                 }
+        //                 // Proceed
+        //             } else {
+        //                 before_off = other_off;
+        //                 before_key = other_key;
+        //                 off++;
+        //             }
+        //         } else {
+        //             bef_init = true;
+        //             before_off = other_off;
+        //             before_key = other_key;
+        //             off++;
+        //         }
+        //
+        //     } else if (found_it) { // Found empty space after last element
+        //     for
+        //                            // remove.
+        //         auto before = before_index(now, mask);
+        //         auto ret = OptionPtr<D>(array[before].ptr());
+        //         array[before] = Combined();
+        //
+        //         return ret;
+        //     } else {
+        //         break;
+        //     }
+        //
+        //     now = (now + 1) & mask;
+        // }
+        return OptionPtr<D>();
+    }
+
     void clear()
     {
         free(array);
@@ -351,6 +601,11 @@ public:
     size_t size() const { return count; }
 
 private:
+    size_t before_index(size_t now, size_t mask)
+    {
+        return (now - 1) & mask; // THIS IS VALID
+    }
+
     size_t index(const K &key, size_t mask) const
     {
         return hash(std::hash<K>()(key)) & mask;
diff --git a/src/storage/edge_accessor.hpp b/src/storage/edge_accessor.hpp
index bef5fdaee..88b977c91 100644
--- a/src/storage/edge_accessor.hpp
+++ b/src/storage/edge_accessor.hpp
@@ -27,17 +27,6 @@ public:
         return edge_type_ref_t(*this->record->data.edge_type);
     }
 
-    // TODO: VertexAccessor
-    // void from(VertexRecord *vertex_record)
-    // {
-    //     this->record->data.from = vertex_record;
-    // }
-    //
-    // void to(VertexRecord *vertex_record)
-    // {
-    //     this->record->data.to = vertex_record;
-    // }
-
     auto from() const { return this->vlist->from(); }
 
     auto to() const { return this->vlist->to(); }
diff --git a/src/utils/option_ptr.hpp b/src/utils/option_ptr.hpp
index 1c312216a..b3fb0f550 100644
--- a/src/utils/option_ptr.hpp
+++ b/src/utils/option_ptr.hpp
@@ -1,4 +1,4 @@
-
+#pragma once
 
 template <class T>
 class OptionPtr
diff --git a/tests/unit/rh_hashmap.cpp b/tests/unit/rh_hashmap.cpp
index 7a40f1857..a0d66baf1 100644
--- a/tests/unit/rh_hashmap.cpp
+++ b/tests/unit/rh_hashmap.cpp
@@ -13,9 +13,11 @@ private:
 public:
     Data(int key) : key(key) {}
 
-    int &get_key() { return key; }
+    const int &get_key() const { return key; }
 };
 
+void cross_validate(RhHashMap<int, Data> &map, std::map<int, Data *> &s_map);
+
 TEST_CASE("Robin hood hashmap basic functionality")
 {
     RhHashMap<int, Data> map;
@@ -25,6 +27,16 @@ TEST_CASE("Robin hood hashmap basic functionality")
     REQUIRE(map.size() == 1);
 }
 
+TEST_CASE("Robin hood hashmap remove functionality")
+{
+    RhHashMap<int, Data> map;
+
+    REQUIRE(map.insert(new Data(0)));
+    REQUIRE(map.remove(0).is_present());
+    REQUIRE(map.size() == 0);
+    REQUIRE(!map.find(0).is_present());
+}
+
 TEST_CASE("Robin hood hashmap insert/get check")
 {
     RhHashMap<int, Data> map;
@@ -97,6 +109,33 @@ TEST_CASE("Robin hood hashmap checked")
         }
     }
 
+    cross_validate(map, s_map);
+}
+
+TEST_CASE("Robin hood hashmap checked with remove")
+{
+    RhHashMap<int, Data> map;
+    std::map<int, Data *> s_map;
+
+    for (int i = 0; i < 1280; i++) {
+        int key = std::rand() % 100;
+        auto data = new Data(key);
+        if (map.insert(data)) {
+            REQUIRE(s_map.find(key) == s_map.end());
+            s_map[key] = data;
+            cross_validate(map, s_map);
+        } else {
+            REQUIRE(map.remove(key).is_present());
+            REQUIRE(s_map.erase(key) == 1);
+            cross_validate(map, s_map);
+        }
+    }
+
+    cross_validate(map, s_map);
+}
+
+void cross_validate(RhHashMap<int, Data> &map, std::map<int, Data *> &s_map)
+{
     for (auto e : map) {
         REQUIRE(s_map.find(e->get_key()) != s_map.end());
     }
diff --git a/tests/unit/rh_hashmultimap.cpp b/tests/unit/rh_hashmultimap.cpp
index 6bb916d49..3c19deb7e 100644
--- a/tests/unit/rh_hashmultimap.cpp
+++ b/tests/unit/rh_hashmultimap.cpp
@@ -13,9 +13,15 @@ private:
 public:
     Data(int key) : key(key) {}
 
-    int &get_key() { return key; }
+    const int &get_key() { return key; }
 };
 
+void cross_validate(RhHashMultiMap<int, Data> &map,
+                    std::multimap<int, Data *> &s_map);
+
+void cross_validate_weak(RhHashMultiMap<int, Data> &map,
+                         std::multimap<int, Data *> &s_map);
+
 TEST_CASE("Robin hood hashmultimap basic functionality")
 {
     RhHashMultiMap<int, Data> map;
@@ -36,6 +42,19 @@ TEST_CASE("Robin hood hashmultimap insert/get check")
     REQUIRE(*map.find(0) == ptr0);
 }
 
+// TEST_CASE("Robin hood hasmultihmap remove functionality")
+// {
+//     RhHashMultiMap<int, Data> map;
+//
+//     REQUIRE(map.find(0) == map.end());
+//     auto ptr0 = new Data(0);
+//     map.add(ptr0);
+//     REQUIRE(map.find(0) != map.end());
+//     REQUIRE(*map.find(0) == ptr0);
+//     REQUIRE(map.remove(0).get() == ptr0);
+//     REQUIRE(map.find(0) == map.end());
+// }
+
 TEST_CASE("Robin hood hashmultimap double insert")
 {
     RhHashMultiMap<int, Data> map;
@@ -107,6 +126,61 @@ TEST_CASE("Robin hood hashmultimap checked")
         map.add(data);
         s_map.insert(std::pair<int, Data *>(key, data));
     }
+    cross_validate(map, s_map);
+}
+
+TEST_CASE("Robin hood hashmultimap checked rand")
+{
+    RhHashMultiMap<int, Data> map;
+    std::multimap<int, Data *> s_map;
+    std::srand(std::time(0));
+    for (int i = 0; i < 164308; i++) {
+        int key = (std::rand() % 10000) << 3;
+
+        auto data = new Data(key);
+        map.add(data);
+        s_map.insert(std::pair<int, Data *>(key, data));
+    }
+    cross_validate(map, s_map);
+}
+
+// TEST_CASE("Robin hood hashmultimap with remove checked")
+// {
+//     RhHashMultiMap<int, Data> map;
+//     std::multimap<int, Data *> s_map;
+//
+//     for (int i = 0; i < 2638; i++) {
+//         int key = (std::rand() % 100) << 3;
+//         if ((std::rand() % 3) == 0) {
+//             std::cout << "Remove: " << key << "\n";
+//             auto removed = map.remove(key);
+//             // auto it = s_map.find(key);
+//             if (removed.is_present()) {
+//                 // while (it != s_map.end() && it->second != removed.get()) {
+//                 //     it++;
+//                 // }
+//                 // REQUIRE(it != s_map.end());
+//                 // s_map.erase(it);
+//                 // cross_validate(map, s_map);
+//
+//             } else {
+//                 // REQUIRE(it == s_map.end());
+//             }
+//         } else {
+//             std::cout << "Insert: " << key << "\n";
+//             auto data = new Data(key);
+//             map.add(data);
+//             s_map.insert(std::pair<int, Data *>(key, data));
+//             cross_validate(map, s_map);
+//         }
+//     }
+//
+//     cross_validate_weak(map, s_map);
+// }
+
+void cross_validate(RhHashMultiMap<int, Data> &map,
+                    std::multimap<int, Data *> &s_map)
+{
 
     for (auto e : map) {
         auto it = s_map.find(e->get_key());
@@ -114,15 +188,72 @@ TEST_CASE("Robin hood hashmultimap checked")
         while (it != s_map.end() && it->second != e) {
             it++;
         }
-        REQUIRE(it->second == e);
+        REQUIRE(it != s_map.end());
     }
 
     for (auto e : s_map) {
         auto it = map.find(e.first);
+        // std::cout << "s_map: " << e.first << "\n";
 
         while (it != map.end() && *it != e.second) {
             it++;
         }
-        REQUIRE(e.second == *it);
+        REQUIRE(it != map.end());
+    }
+}
+
+void cross_validate_weak(RhHashMultiMap<int, Data> &map,
+                         std::multimap<int, Data *> &s_map)
+{
+    int count = 0;
+    int key = 0;
+    for (auto e : map) {
+        if (e->get_key() == key) {
+            count++;
+        } else {
+            auto it = s_map.find(key);
+
+            while (it != s_map.end() && it->first == key) {
+                it++;
+                count--;
+            }
+            REQUIRE(count == 0);
+            key = e->get_key();
+            count = 1;
+        }
+    }
+    {
+        auto it = s_map.find(key);
+
+        while (it != s_map.end() && it->first == key) {
+            it++;
+            count--;
+        }
+        REQUIRE(count == 0);
+    }
+
+    for (auto e : s_map) {
+        if (e.first == key) {
+            count++;
+        } else {
+            auto it = map.find(key);
+
+            while (it != map.end() && it->get_key() == key) {
+                it++;
+                count--;
+            }
+            REQUIRE(count == 0);
+            key = e.first;
+            count = 1;
+        }
+    }
+    {
+        auto it = map.find(key);
+
+        while (it != map.end() && it->get_key() == key) {
+            it++;
+            count--;
+        }
+        REQUIRE(count == 0);
     }
 }