diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 95b4ccc57..e6f0a718a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,6 +9,7 @@ set(memgraph_src_files communication/messaging/protocol.cpp communication/rpc/rpc.cpp data_structures/concurrent/skiplist_gc.cpp + database/counters.cpp database/graph_db.cpp database/graph_db_accessor.cpp database/graph_db_config.cpp diff --git a/src/database/counters.cpp b/src/database/counters.cpp new file mode 100644 index 000000000..799634568 --- /dev/null +++ b/src/database/counters.cpp @@ -0,0 +1,72 @@ +#include "database/counters.hpp" + +#include "boost/archive/binary_iarchive.hpp" +#include "boost/archive/binary_oarchive.hpp" +#include "boost/serialization/export.hpp" +#include "boost/serialization/utility.hpp" + +namespace database { + +const std::string kCountersRpc = "CountersRpc"; +const auto kCountersRpcTimeout = 300ms; + +RPC_SINGLE_MEMBER_MESSAGE(CountersGetReq, std::string); +RPC_SINGLE_MEMBER_MESSAGE(CountersGetRes, int64_t); +using CountersGetRpc = + communication::rpc::RequestResponse; + +using CountersSetReqData = std::pair; +RPC_SINGLE_MEMBER_MESSAGE(CountersSetReq, CountersSetReqData); +RPC_NO_MEMBER_MESSAGE(CountersSetRes); +using CountersSetRpc = + communication::rpc::RequestResponse; + +int64_t SingleNodeCounters::Get(const std::string &name) { + return counters_.access() + .emplace(name, std::make_tuple(name), std::make_tuple(0)) + .first->second.fetch_add(1); +} + +void SingleNodeCounters::Set(const std::string &name, int64_t value) { + auto name_counter_pair = counters_.access().emplace( + name, std::make_tuple(name), std::make_tuple(value)); + if (!name_counter_pair.second) name_counter_pair.first->second.store(value); +} + +MasterCounters::MasterCounters(communication::messaging::System &system) + : rpc_server_(system, kCountersRpc) { + rpc_server_.Register([this](const CountersGetReq &req) { + return std::make_unique(Get(req.member)); + }); + rpc_server_.Register([this](const CountersSetReq &req) { + Set(req.member.first, req.member.second); + return std::make_unique(); + }); +} + +void MasterCounters::Start() { rpc_server_.Start(); } +void MasterCounters::Shutdown() { rpc_server_.Shutdown(); } + +WorkerCounters::WorkerCounters( + communication::messaging::System &system, + const io::network::NetworkEndpoint &master_endpoint) + : rpc_client_(system, master_endpoint, kCountersRpc) {} + +int64_t WorkerCounters::Get(const std::string &name) { + auto response = rpc_client_.Call(kCountersRpcTimeout, name); + CHECK(response) << "CountersGetRpc - failed to get response from master"; + return response->member; +} + +void WorkerCounters::Set(const std::string &name, int64_t value) { + auto response = rpc_client_.Call( + kCountersRpcTimeout, CountersSetReqData{name, value}); + CHECK(response) << "CountersSetRpc - failed to get response from master"; +} + +} // namespace database + +BOOST_CLASS_EXPORT(database::CountersGetReq); +BOOST_CLASS_EXPORT(database::CountersGetRes); +BOOST_CLASS_EXPORT(database::CountersSetReq); +BOOST_CLASS_EXPORT(database::CountersSetRes); diff --git a/src/database/counters.hpp b/src/database/counters.hpp new file mode 100644 index 000000000..82e3533dc --- /dev/null +++ b/src/database/counters.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include + +#include "communication/messaging/distributed.hpp" +#include "communication/rpc/rpc.hpp" +#include "data_structures/concurrent/concurrent_map.hpp" +#include "utils/rpc_pimp.hpp" + +namespace database { + +/** A set of counter that are guaranteed to produce unique, consecutive values + * on each call. */ +class Counters { + public: + virtual ~Counters() {} + + /** + * Returns the current value of the counter with the given name, and + * increments that counter. If the counter with the given name does not exist, + * a new counter is created and this function returns 0. + */ + virtual int64_t Get(const std::string &name) = 0; + + /** + * Sets the counter with the given name to the given value. Returns nothing. + * If the counter with the given name does not exist, a new counter is created + * and set to the given value. + */ + virtual void Set(const std::string &name, int64_t values) = 0; +}; + +/** Implementation for the single-node memgraph */ +class SingleNodeCounters : public Counters { + public: + int64_t Get(const std::string &name) override; + void Set(const std::string &name, int64_t value) override; + + private: + ConcurrentMap> counters_; +}; + +/** Implementation for distributed master. */ +class MasterCounters : public SingleNodeCounters { + public: + MasterCounters(communication::messaging::System &system); + void Start(); + void Shutdown(); + + private: + communication::rpc::Server rpc_server_; +}; + +/** Implementation for distributed worker. */ +class WorkerCounters : public Counters { + public: + WorkerCounters(communication::messaging::System &system, + const io::network::NetworkEndpoint &master_endpoint); + + int64_t Get(const std::string &name) override; + void Set(const std::string &name, int64_t value) override; + + private: + communication::rpc::Client rpc_client_; +}; + +} // namespace database diff --git a/src/database/graph_db.cpp b/src/database/graph_db.cpp index a34f83af1..1490b05c1 100644 --- a/src/database/graph_db.cpp +++ b/src/database/graph_db.cpp @@ -23,6 +23,7 @@ namespace fs = std::experimental::filesystem; GraphDb::GraphDb(Config config) : GraphDb(config, 0) { tx_engine_ = std::make_unique(&wal_); + counters_ = std::make_unique(); INIT_MAPPERS(storage::SingleNodeConcurrentIdMapper); Start(); } @@ -33,6 +34,9 @@ GraphDb::GraphDb(communication::messaging::System &system, auto tx_engine = std::make_unique(&wal_); tx_engine->StartServer(system); tx_engine_ = std::move(tx_engine); + auto counters = std::make_unique(system); + counters->Start(); + counters_ = std::move(counters); INIT_MAPPERS(storage::MasterConcurrentIdMapper, system); get_endpoint_ = [&master](int worker_id) { return master.GetEndpoint(worker_id); @@ -45,6 +49,8 @@ GraphDb::GraphDb(communication::messaging::System &system, int worker_id, Endpoint master_endpoint, Config config) : GraphDb(config, worker_id) { tx_engine_ = std::make_unique(system, master_endpoint); + counters_ = + std::make_unique(system, master_endpoint); INIT_MAPPERS(storage::WorkerConcurrentIdMapper, system, master_endpoint); get_endpoint_ = [&worker](int worker_id) { return worker.GetEndpoint(worker_id); diff --git a/src/database/graph_db.hpp b/src/database/graph_db.hpp index 97eed334b..473017942 100644 --- a/src/database/graph_db.hpp +++ b/src/database/graph_db.hpp @@ -8,6 +8,7 @@ #include "data_structures/concurrent/concurrent_map.hpp" #include "data_structures/concurrent/concurrent_set.hpp" +#include "database/counters.hpp" #include "database/graph_db_datatypes.hpp" #include "database/indexes/key_index.hpp" #include "database/indexes/label_property_index.hpp" @@ -167,7 +168,7 @@ class GraphDb { Scheduler transaction_killer_; // DB level global counters, used in the "counter" function. - ConcurrentMap> counters_; + std::unique_ptr counters_; // Returns Endpoint info for worker ID. Different implementation in master vs. // worker. Unused in single-node version. diff --git a/src/database/graph_db_accessor.cpp b/src/database/graph_db_accessor.cpp index c11e3608d..3db3a8571 100644 --- a/src/database/graph_db_accessor.cpp +++ b/src/database/graph_db_accessor.cpp @@ -407,15 +407,11 @@ const std::string &GraphDbAccessor::PropertyName( } int64_t GraphDbAccessor::Counter(const std::string &name) { - return db_.counters_.access() - .emplace(name, std::make_tuple(name), std::make_tuple(0)) - .first->second.fetch_add(1); + return db_.counters_->Get(name); } void GraphDbAccessor::CounterSet(const std::string &name, int64_t value) { - auto name_counter_pair = db_.counters_.access().emplace( - name, std::make_tuple(name), std::make_tuple(value)); - if (!name_counter_pair.second) name_counter_pair.first->second.store(value); + db_.counters_->Set(name, value); } std::vector GraphDbAccessor::IndexInfo() const { diff --git a/src/storage/concurrent_id_mapper.hpp b/src/storage/concurrent_id_mapper.hpp index 9503ddc54..3b7b486a4 100644 --- a/src/storage/concurrent_id_mapper.hpp +++ b/src/storage/concurrent_id_mapper.hpp @@ -15,9 +15,10 @@ namespace storage { template class ConcurrentIdMapper { public: + virtual ~ConcurrentIdMapper() {} + virtual TId value_to_id(const std::string &value) = 0; virtual const std::string &id_to_value(const TId &id) = 0; - virtual ~ConcurrentIdMapper() {} }; } // namespace storage diff --git a/tests/unit/counters.cpp b/tests/unit/counters.cpp new file mode 100644 index 000000000..dbeee9168 --- /dev/null +++ b/tests/unit/counters.cpp @@ -0,0 +1,34 @@ +#include "gtest/gtest.h" + +#include "communication/messaging/distributed.hpp" +#include "database/counters.hpp" + +const std::string kLocal = "127.0.0.1"; + +TEST(CountersDistributed, All) { + communication::messaging::System master_sys(kLocal, 0); + database::MasterCounters master(master_sys); + master.Start(); + + communication::messaging::System w1_sys(kLocal, 0); + database::WorkerCounters w1(w1_sys, master_sys.endpoint()); + + communication::messaging::System w2_sys(kLocal, 0); + database::WorkerCounters w2(w2_sys, master_sys.endpoint()); + + EXPECT_EQ(w1.Get("a"), 0); + EXPECT_EQ(w1.Get("a"), 1); + EXPECT_EQ(w2.Get("a"), 2); + EXPECT_EQ(w1.Get("a"), 3); + EXPECT_EQ(master.Get("a"), 4); + + EXPECT_EQ(master.Get("b"), 0); + EXPECT_EQ(w2.Get("b"), 1); + w1.Set("b", 42); + EXPECT_EQ(w2.Get("b"), 42); + + w2_sys.Shutdown(); + w1_sys.Shutdown(); + master.Shutdown(); + master_sys.Shutdown(); +}