diff --git a/src/data_structures/bitset/static_bitset.hpp b/src/data_structures/bitset/static_bitset.hpp new file mode 100644 index 000000000..8f869c06c --- /dev/null +++ b/src/data_structures/bitset/static_bitset.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include <algorithm> +#include <bitset> +#include <iostream> +#include <string> + +#include "utils/assert.hpp" + +/** + * Bitset data structure with a number of bits provided in constructor. + * @tparam TStore type of underlying bit storage: int32, int64, char, etc. + */ +template <typename TStore> +class Bitset { + public: + /** + * Create bitset. + * @param sz size of bitset + */ + Bitset(size_t sz) : block_size_(8 * sizeof(TStore)) { + if (sz % block_size_ != 0) sz += block_size_; + blocks_.resize(sz / block_size_); + } + /** + * Set bit to one. + * @param idx position of bit. + */ + void Set(int idx) { + debug_assert(idx >= 0, "Invalid bit location."); + debug_assert(idx < static_cast<int64_t>(blocks_.size()) * block_size_, + "Invalid bit location."); + int bucket = idx / block_size_; + blocks_[bucket] |= TStore(1) << idx % block_size_; + } + /** + * Return bit at position. + * @param idx position of bit. + * @return 1/0. + */ + bool At(int idx) const { + debug_assert(idx >= 0, "Invalid bit location."); + debug_assert(idx < static_cast<int64_t>(blocks_.size()) * block_size_, + "Invalid bit location."); + int bucket = idx / block_size_; + return (blocks_[bucket] >> (idx % block_size_)) & 1; + } + /** + * Intersect two bitsets + * @param other bitset. + * @return intersection. + */ + Bitset<TStore> Intersect(const Bitset<TStore> &other) const { + debug_assert(this->blocks_.size() == other.blocks_.size(), + "Bitsets are not of equal size."); + Bitset<TStore> ret(this->blocks_.size() * this->block_size_); + for (int i = 0; i < (int)this->blocks_.size(); ++i) { + ret.blocks_[i] = this->blocks_[i] & other.blocks_[i]; + continue; + } + return ret; + } + /** + * Positions of bits set to 1. + * @return positions of bits set to 1. + */ + std::vector<int> Ones() const { + std::vector<int> ret; + int ret_idx = 0; + for (auto x : blocks_) { + while (x) { + auto pos = CountTrailingZeroes(x & -x); + x -= x & -x; + ret.push_back(ret_idx + pos); + } + ret_idx += block_size_; + } + return ret; + } + + private: + /** + * Calculate number of trailing zeroes in a binary number, usually a power + * of two. + * @return number of trailing zeroes. + */ + size_t CountTrailingZeroes(TStore v) const { + size_t ret = 0; + while (v >> 32) ret += 32, v >>= 32; + if (v >> 16) ret += 16, v >>= 16; + if (v >> 8) ret += 8, v >>= 8; + if (v >> 4) ret += 4, v >>= 4; + if (v >> 2) ret += 2, v >>= 2; + if (v >> 1) ret += 1, v >>= 1; + return ret; + } + std::vector<TStore> blocks_; + const size_t block_size_; +}; diff --git a/src/query/plan_template_cpp b/src/query/plan_template_cpp index 1b39dc044..79b1e6e59 100644 --- a/src/query/plan_template_cpp +++ b/src/query/plan_template_cpp @@ -1,6 +1,7 @@ #include <iostream> #include <string> +#include "data_structures/bitset/static_bitset.hpp" #include "communication/bolt/v1/serialization/record_stream.hpp" #include "io/network/socket.hpp" #include "query/backend/cpp/typed_value.hpp" diff --git a/tests/integration/hardcoded_query/clique.hpp b/tests/integration/hardcoded_query/clique.hpp index 08e8c6107..295b2684d 100644 --- a/tests/integration/hardcoded_query/clique.hpp +++ b/tests/integration/hardcoded_query/clique.hpp @@ -3,6 +3,7 @@ #include <iostream> #include <string> +#include "data_structures/bitset/static_bitset.hpp" #include "query/backend/cpp/typed_value.hpp" #include "query/plan_interface.hpp" #include "storage/edge_accessor.hpp" @@ -25,95 +26,6 @@ using std::endl; // s1.score+s2.score+s3.score+s4.score ORDER BY // s1.score+s2.score+s3.score+s4.score DESC LIMIT 10 -template <typename TStore> -/** - * Bitset data structure with a number of bits provided in constructor. - * @tparam type of underlying bit storage: int32, int64, char, etc. - */ -class Bitset { - public: - /** - * Create bitset. - * @param sz size of bitset - */ - Bitset(size_t sz) : block_size_(8 * sizeof(TStore)) { - if (sz % block_size_ != 0) sz += block_size_; - blocks_.resize(sz / block_size_); - } - /** - * Set bit to one. - * @param idx position of bit. - */ - void Set(int idx) { - debug_assert(idx < static_cast<int64_t>(blocks_.size()) * block_size_, - "Invalid bit location."); - int bucket = idx / block_size_; - blocks_[bucket] |= TStore(1) << idx % block_size_; - } - /** - * Return bit at position. - * @param idx position of bit. - * @return 1/0. - */ - bool At(int idx) const { - debug_assert(idx < static_cast<int64_t>(blocks_.size()) * block_size_, - "Invalid bit location."); - int bucket = idx / block_size_; - return (blocks_[bucket] >> (idx % block_size_)) & 1; - } - /** - * Intersect two bitsets - * @param other bitset. - * @return intersection. - */ - Bitset<TStore> Intersect(const Bitset<TStore> &other) { - debug_assert(this->blocks_.size() == other.blocks_.size(), - "Bitsets are not of equal size."); - Bitset<TStore> ret(this->blocks_.size() * this->block_size_); - for (int i = 0; i < (int)this->blocks_.size(); ++i) { - ret.blocks_[i] = this->blocks_[i] & other.blocks_[i]; - continue; - } - return ret; - } - /** - * Positions of bits set to 1. - * @return positions of bits set to 1. - */ - std::vector<int> Ones() const { - std::vector<int> ret; - int ret_idx = 0; - for (auto x : blocks_) { - while (x) { - auto pos = CountTrailingZeroes(x & -x); - x -= x & -x; - ret.push_back(ret_idx + pos); - } - ret_idx += block_size_; - } - return ret; - } - - private: - /** - * Calculate number of trailing zeroes in a binary number, usually a power - * of two. - * @return number of trailing zeroes. - */ - size_t CountTrailingZeroes(TStore v) const { - size_t ret = 0; - while (v >> 32) ret += 32, v >>= 32; - if (v >> 16) ret += 16, v >>= 16; - if (v >> 8) ret += 8, v >>= 8; - if (v >> 4) ret += 4, v >>= 4; - if (v >> 2) ret += 2, v >>= 2; - if (v >> 1) ret += 1, v >>= 1; - return ret; - } - std::vector<TStore> blocks_; - const size_t block_size_; -}; - enum CliqueQuery { SCORE_AND_LIMIT, FIND_ALL }; bool run_general_query(GraphDbAccessor &db_accessor, diff --git a/tests/integration/hardcoded_query/using.hpp b/tests/integration/hardcoded_query/using.hpp index 832c4ff7d..a7335f4c4 100644 --- a/tests/integration/hardcoded_query/using.hpp +++ b/tests/integration/hardcoded_query/using.hpp @@ -10,3 +10,4 @@ using Stream = bolt::RecordStream<io::network::Socket>; #include "../stream/print_record_stream.hpp" using Stream = PrintRecordStream; #endif +#include "data_structures/bitset/static_bitset.hpp" diff --git a/tests/unit/static_bitset.cpp b/tests/unit/static_bitset.cpp new file mode 100644 index 000000000..830d6c8ee --- /dev/null +++ b/tests/unit/static_bitset.cpp @@ -0,0 +1,80 @@ +#include "data_structures/bitset/static_bitset.hpp" +#include <gmock/gmock.h> +#include <vector> +#include "gtest/gtest-spi.h" +#include "gtest/gtest.h" + +using testing::UnorderedElementsAreArray; + +TEST(StaticBitset, Intersection) { + const int n = 50; + Bitset<int64_t> bitset(n); + Bitset<int64_t> bitset2(n); + std::vector<int> V; + std::vector<int> V2; + for (int i = 0; i < n / 2; ++i) { + const int pos = rand() % n; + bitset.Set(pos); + V.push_back(pos); + } + for (int i = 0; i < n / 2; ++i) { + const int pos = rand() % n; + bitset2.Set(pos); + V2.push_back(pos); + } + Bitset<int64_t> intersected = bitset.Intersect(bitset); + sort(V.begin(), V.end()); + V.resize(unique(V.begin(), V.end()) - V.begin()); + EXPECT_THAT(V, UnorderedElementsAreArray(intersected.Ones())); + + sort(V2.begin(), V2.end()); + V2.resize(unique(V2.begin(), V2.end()) - V2.begin()); + + std::vector<int> V3; + set_intersection(V.begin(), V.end(), V2.begin(), V2.end(), back_inserter(V3)); + Bitset<int64_t> intersected_two = bitset.Intersect(bitset2); + EXPECT_THAT(V3, UnorderedElementsAreArray(intersected_two.Ones())); +} + +TEST(StaticBitset, BasicFunctionality) { + const int n = 50; + Bitset<int64_t> bitset(n); + std::vector<int> V; + for (int i = 0; i < n / 2; ++i) { + const int pos = rand() % n; + bitset.Set(pos); + V.push_back(pos); + } + sort(V.begin(), V.end()); + V.resize(unique(V.begin(), V.end()) - V.begin()); + EXPECT_THAT(V, UnorderedElementsAreArray(bitset.Ones())); +} + +TEST(StaticBitset, SetAndReadBit) { + const int n = 50; + Bitset<char> bitset(n); + bitset.Set(4); + EXPECT_EQ(bitset.At(4), true); + EXPECT_EQ(bitset.At(3), false); +} + +TEST(StaticBitset, SetOutOfRange) { + const int n = 50; + Bitset<char> bitset(n); + EXPECT_DEATH(bitset.Set(-1), "Invalid bit location."); + EXPECT_DEATH(bitset.Set(150), "Invalid bit location."); + bitset.Set(49); +} + +TEST(StaticBitset, AtOutOfRange) { + const int n = 50; + Bitset<char> bitset(n); + bitset.Set(33); + EXPECT_DEATH(bitset.At(150), "Invalid bit location."); + EXPECT_DEATH(bitset.At(-1), "Invalid bit location."); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}