Add support for async replication (#41)
* Add thread pool * Define async replication * Expose replication state * Rename TransactionHandler to ReplicaStream
This commit is contained in:
parent
bc0c944910
commit
03cc568e39
@ -7,12 +7,13 @@ ReplicationClient::ReplicationClient(std::string name,
|
||||
NameIdMapper *name_id_mapper,
|
||||
Config::Items items,
|
||||
const io::network::Endpoint &endpoint,
|
||||
bool use_ssl)
|
||||
bool use_ssl, const ReplicationMode mode)
|
||||
: name_(std::move(name)),
|
||||
name_id_mapper_(name_id_mapper),
|
||||
items_(items),
|
||||
rpc_context_(use_ssl),
|
||||
rpc_client_(endpoint, &rpc_context_) {}
|
||||
rpc_client_(endpoint, &rpc_context_),
|
||||
mode_(mode) {}
|
||||
|
||||
void ReplicationClient::TransferSnapshot(const std::filesystem::path &path) {
|
||||
auto stream{rpc_client_.Stream<SnapshotRpc>()};
|
||||
@ -34,32 +35,76 @@ void ReplicationClient::TransferWalFiles(
|
||||
stream.AwaitResponse();
|
||||
}
|
||||
|
||||
////// TransactionHandler //////
|
||||
ReplicationClient::TransactionHandler::TransactionHandler(
|
||||
ReplicationClient *self)
|
||||
bool ReplicationClient::StartTransactionReplication() {
|
||||
std::unique_lock guard(client_lock_);
|
||||
const auto status = replica_state_.load();
|
||||
switch (status) {
|
||||
case ReplicaState::RECOVERY:
|
||||
DLOG(INFO) << "Replica " << name_ << " is behind MAIN instance";
|
||||
return false;
|
||||
case ReplicaState::REPLICATING:
|
||||
replica_state_.store(ReplicaState::RECOVERY);
|
||||
return false;
|
||||
case ReplicaState::READY:
|
||||
CHECK(!replica_stream_);
|
||||
replica_stream_.emplace(ReplicaStream{this});
|
||||
replica_state_.store(ReplicaState::REPLICATING);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
void ReplicationClient::IfStreamingTransaction(
|
||||
const std::function<void(ReplicaStream &handler)> &callback) {
|
||||
if (replica_stream_) {
|
||||
callback(*replica_stream_);
|
||||
}
|
||||
}
|
||||
|
||||
void ReplicationClient::FinalizeTransactionReplication() {
|
||||
if (mode_ == ReplicationMode::ASYNC) {
|
||||
thread_pool_.AddTask(
|
||||
[this] { this->FinalizeTransactionReplicationInternal(); });
|
||||
} else {
|
||||
FinalizeTransactionReplicationInternal();
|
||||
}
|
||||
}
|
||||
|
||||
void ReplicationClient::FinalizeTransactionReplicationInternal() {
|
||||
if (replica_stream_) {
|
||||
replica_stream_->Finalize();
|
||||
replica_stream_.reset();
|
||||
}
|
||||
|
||||
std::unique_lock guard(client_lock_);
|
||||
if (replica_state_.load() == ReplicaState::REPLICATING) {
|
||||
replica_state_.store(ReplicaState::READY);
|
||||
}
|
||||
}
|
||||
////// ReplicaStream //////
|
||||
ReplicationClient::ReplicaStream::ReplicaStream(ReplicationClient *self)
|
||||
: self_(self), stream_(self_->rpc_client_.Stream<AppendDeltasRpc>()) {}
|
||||
|
||||
void ReplicationClient::TransactionHandler::AppendDelta(
|
||||
void ReplicationClient::ReplicaStream::AppendDelta(
|
||||
const Delta &delta, const Vertex &vertex, uint64_t final_commit_timestamp) {
|
||||
Encoder encoder(stream_.GetBuilder());
|
||||
EncodeDelta(&encoder, self_->name_id_mapper_, self_->items_, delta, vertex,
|
||||
final_commit_timestamp);
|
||||
}
|
||||
|
||||
void ReplicationClient::TransactionHandler::AppendDelta(
|
||||
void ReplicationClient::ReplicaStream::AppendDelta(
|
||||
const Delta &delta, const Edge &edge, uint64_t final_commit_timestamp) {
|
||||
Encoder encoder(stream_.GetBuilder());
|
||||
EncodeDelta(&encoder, self_->name_id_mapper_, delta, edge,
|
||||
final_commit_timestamp);
|
||||
}
|
||||
|
||||
void ReplicationClient::TransactionHandler::AppendTransactionEnd(
|
||||
void ReplicationClient::ReplicaStream::AppendTransactionEnd(
|
||||
uint64_t final_commit_timestamp) {
|
||||
Encoder encoder(stream_.GetBuilder());
|
||||
EncodeTransactionEnd(&encoder, final_commit_timestamp);
|
||||
}
|
||||
|
||||
void ReplicationClient::TransactionHandler::AppendOperation(
|
||||
void ReplicationClient::ReplicaStream::AppendOperation(
|
||||
durability::StorageGlobalOperation operation, LabelId label,
|
||||
const std::set<PropertyId> &properties, uint64_t timestamp) {
|
||||
Encoder encoder(stream_.GetBuilder());
|
||||
@ -67,9 +112,7 @@ void ReplicationClient::TransactionHandler::AppendOperation(
|
||||
properties, timestamp);
|
||||
}
|
||||
|
||||
void ReplicationClient::TransactionHandler::Finalize() {
|
||||
stream_.AwaitResponse();
|
||||
}
|
||||
void ReplicationClient::ReplicaStream::Finalize() { stream_.AwaitResponse(); }
|
||||
|
||||
////// CurrentWalHandler //////
|
||||
ReplicationClient::CurrentWalHandler::CurrentWalHandler(ReplicationClient *self)
|
||||
|
@ -1,5 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <thread>
|
||||
|
||||
#include "rpc/client.hpp"
|
||||
#include "storage/v2/config.hpp"
|
||||
#include "storage/v2/delta.hpp"
|
||||
@ -11,20 +15,27 @@
|
||||
#include "storage/v2/replication/rpc.hpp"
|
||||
#include "storage/v2/replication/serialization.hpp"
|
||||
#include "utils/file.hpp"
|
||||
#include "utils/spin_lock.hpp"
|
||||
#include "utils/synchronized.hpp"
|
||||
#include "utils/thread_pool.hpp"
|
||||
|
||||
namespace storage::replication {
|
||||
|
||||
enum class ReplicationMode : std::uint8_t { SYNC, ASYNC };
|
||||
|
||||
enum class ReplicaState : std::uint8_t { READY, REPLICATING, RECOVERY };
|
||||
|
||||
class ReplicationClient {
|
||||
public:
|
||||
ReplicationClient(std::string name, NameIdMapper *name_id_mapper,
|
||||
Config::Items items, const io::network::Endpoint &endpoint,
|
||||
bool use_ssl);
|
||||
bool use_ssl, ReplicationMode mode);
|
||||
|
||||
// Handler used for transfering the current transaction.
|
||||
class TransactionHandler {
|
||||
class ReplicaStream {
|
||||
private:
|
||||
friend class ReplicationClient;
|
||||
explicit TransactionHandler(ReplicationClient *self);
|
||||
explicit ReplicaStream(ReplicationClient *self);
|
||||
|
||||
public:
|
||||
/// @throw rpc::RpcFailedException
|
||||
@ -43,20 +54,14 @@ class ReplicationClient {
|
||||
LabelId label, const std::set<PropertyId> &properties,
|
||||
uint64_t timestamp);
|
||||
|
||||
private:
|
||||
/// @throw rpc::RpcFailedException
|
||||
void Finalize();
|
||||
|
||||
private:
|
||||
ReplicationClient *self_;
|
||||
rpc::Client::StreamHandler<AppendDeltasRpc> stream_;
|
||||
};
|
||||
|
||||
TransactionHandler ReplicateTransaction() { return TransactionHandler(this); }
|
||||
|
||||
// Transfer the snapshot file.
|
||||
// @param path Path of the snapshot file.
|
||||
void TransferSnapshot(const std::filesystem::path &path);
|
||||
|
||||
// Handler for transfering the current WAL file whose data is
|
||||
// contained in the internal buffer and the file.
|
||||
class CurrentWalHandler {
|
||||
@ -81,6 +86,22 @@ class ReplicationClient {
|
||||
rpc::Client::StreamHandler<WalFilesRpc> stream_;
|
||||
};
|
||||
|
||||
bool StartTransactionReplication();
|
||||
|
||||
// Replication clients can be removed at any point
|
||||
// so to avoid any complexity of checking if the client was removed whenever
|
||||
// we want to send part of transaction and to avoid adding some GC logic this
|
||||
// function will run a callback if, after previously callling
|
||||
// StartTransactionReplication, stream is created.
|
||||
void IfStreamingTransaction(
|
||||
const std::function<void(ReplicaStream &handler)> &callback);
|
||||
|
||||
void FinalizeTransactionReplication();
|
||||
|
||||
// Transfer the snapshot file.
|
||||
// @param path Path of the snapshot file.
|
||||
void TransferSnapshot(const std::filesystem::path &path);
|
||||
|
||||
CurrentWalHandler TransferCurrentWalFile() { return CurrentWalHandler{this}; }
|
||||
|
||||
// Transfer the WAL files
|
||||
@ -88,12 +109,23 @@ class ReplicationClient {
|
||||
|
||||
const auto &Name() const { return name_; }
|
||||
|
||||
auto State() const { return replica_state_.load(); }
|
||||
|
||||
private:
|
||||
void FinalizeTransactionReplicationInternal();
|
||||
|
||||
std::string name_;
|
||||
NameIdMapper *name_id_mapper_;
|
||||
Config::Items items_;
|
||||
communication::ClientContext rpc_context_;
|
||||
rpc::Client rpc_client_;
|
||||
|
||||
std::optional<ReplicaStream> replica_stream_;
|
||||
ReplicationMode mode_{ReplicationMode::SYNC};
|
||||
|
||||
utils::SpinLock client_lock_;
|
||||
utils::ThreadPool thread_pool_{1};
|
||||
std::atomic<ReplicaState> replica_state_{ReplicaState::READY};
|
||||
};
|
||||
|
||||
} // namespace storage::replication
|
||||
|
@ -410,9 +410,9 @@ Storage::Storage(Config config)
|
||||
// For testing purposes until we can define the instance type from
|
||||
// a query.
|
||||
if (FLAGS_main) {
|
||||
SetReplicationState<ReplicationState::MAIN>();
|
||||
SetReplicationRole<ReplicationRole::MAIN>();
|
||||
} else if (FLAGS_replica) {
|
||||
SetReplicationState<ReplicationState::REPLICA>(
|
||||
SetReplicationRole<ReplicationRole::REPLICA>(
|
||||
io::network::Endpoint{"127.0.0.1", 1000});
|
||||
}
|
||||
#endif
|
||||
@ -425,7 +425,6 @@ Storage::~Storage() {
|
||||
#ifdef MG_ENTERPRISE
|
||||
{
|
||||
// Clear replication data
|
||||
std::unique_lock<utils::RWLock> replication_guard(replication_lock_);
|
||||
replication_server_.reset();
|
||||
replication_clients_.WithLock([&](auto &clients) { clients.clear(); });
|
||||
}
|
||||
@ -1354,7 +1353,7 @@ Transaction Storage::CreateTransaction() {
|
||||
std::lock_guard<utils::SpinLock> guard(engine_lock_);
|
||||
transaction_id = transaction_id_++;
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (replication_state_.load() != ReplicationState::REPLICA) {
|
||||
if (replication_role_.load() != ReplicationRole::REPLICA) {
|
||||
start_timestamp = timestamp_++;
|
||||
} else {
|
||||
start_timestamp = timestamp_;
|
||||
@ -1639,15 +1638,14 @@ void Storage::AppendToWal(const Transaction &transaction,
|
||||
// We need to keep this lock because handler takes a pointer to the client
|
||||
// from which it was created
|
||||
std::shared_lock<utils::RWLock> replication_guard(replication_lock_);
|
||||
std::list<replication::ReplicationClient::TransactionHandler> streams;
|
||||
if (replication_state_.load() == ReplicationState::MAIN) {
|
||||
if (replication_role_.load() == ReplicationRole::MAIN) {
|
||||
replication_clients_.WithLock([&](auto &clients) {
|
||||
try {
|
||||
std::transform(
|
||||
clients.begin(), clients.end(), std::back_inserter(streams),
|
||||
[](auto &client) { return client.ReplicateTransaction(); });
|
||||
} catch (const rpc::RpcFailedException &) {
|
||||
LOG(FATAL) << "Couldn't replicate data!";
|
||||
for (auto &client : clients) {
|
||||
try {
|
||||
client.StartTransactionReplication();
|
||||
} catch (const rpc::RpcFailedException &) {
|
||||
LOG(FATAL) << "Couldn't replicate data!";
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -1669,13 +1667,17 @@ void Storage::AppendToWal(const Transaction &transaction,
|
||||
if (filter(delta->action)) {
|
||||
wal_file_->AppendDelta(*delta, parent, final_commit_timestamp);
|
||||
#ifdef MG_ENTERPRISE
|
||||
try {
|
||||
for (auto &stream : streams) {
|
||||
stream.AppendDelta(*delta, parent, final_commit_timestamp);
|
||||
replication_clients_.WithLock([&](auto &clients) {
|
||||
for (auto &client : clients) {
|
||||
try {
|
||||
client.IfStreamingTransaction([&](auto &stream) {
|
||||
stream.AppendDelta(*delta, parent, final_commit_timestamp);
|
||||
});
|
||||
} catch (const rpc::RpcFailedException &) {
|
||||
LOG(FATAL) << "Couldn't replicate data!";
|
||||
}
|
||||
}
|
||||
} catch (const rpc::RpcFailedException &) {
|
||||
LOG(FATAL) << "Couldn't replicate data!";
|
||||
}
|
||||
});
|
||||
#endif
|
||||
}
|
||||
auto prev = delta->prev.Get();
|
||||
@ -1811,14 +1813,18 @@ void Storage::AppendToWal(const Transaction &transaction,
|
||||
FinalizeWalFile();
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
try {
|
||||
for (auto &stream : streams) {
|
||||
stream.AppendTransactionEnd(final_commit_timestamp);
|
||||
stream.Finalize();
|
||||
replication_clients_.WithLock([&](auto &clients) {
|
||||
for (auto &client : clients) {
|
||||
try {
|
||||
client.IfStreamingTransaction([&](auto &stream) {
|
||||
stream.AppendTransactionEnd(final_commit_timestamp);
|
||||
});
|
||||
client.FinalizeTransactionReplication();
|
||||
} catch (const rpc::RpcFailedException &) {
|
||||
LOG(FATAL) << "Couldn't replicate data!";
|
||||
}
|
||||
}
|
||||
} catch (const rpc::RpcFailedException &) {
|
||||
LOG(FATAL) << "Couldn't replicate data!";
|
||||
}
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -1831,14 +1837,16 @@ void Storage::AppendToWal(durability::StorageGlobalOperation operation,
|
||||
#ifdef MG_ENTERPRISE
|
||||
{
|
||||
std::shared_lock<utils::RWLock> replication_guard(replication_lock_);
|
||||
if (replication_state_.load() == ReplicationState::MAIN) {
|
||||
if (replication_role_.load() == ReplicationRole::MAIN) {
|
||||
replication_clients_.WithLock([&](auto &clients) {
|
||||
for (auto &client : clients) {
|
||||
auto stream = client.ReplicateTransaction();
|
||||
try {
|
||||
stream.AppendOperation(operation, label, properties,
|
||||
final_commit_timestamp);
|
||||
stream.Finalize();
|
||||
client.StartTransactionReplication();
|
||||
client.IfStreamingTransaction([&](auto &stream) {
|
||||
stream.AppendOperation(operation, label, properties,
|
||||
final_commit_timestamp);
|
||||
});
|
||||
client.FinalizeTransactionReplication();
|
||||
} catch (const rpc::RpcFailedException &) {
|
||||
LOG(FATAL) << "Couldn't replicate data!";
|
||||
}
|
||||
@ -1852,7 +1860,7 @@ void Storage::AppendToWal(durability::StorageGlobalOperation operation,
|
||||
|
||||
void Storage::CreateSnapshot() {
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (replication_state_.load() != ReplicationState::MAIN) {
|
||||
if (replication_role_.load() != ReplicationRole::MAIN) {
|
||||
LOG(WARNING) << "Snapshots are disabled for replicas!";
|
||||
return;
|
||||
}
|
||||
@ -2363,10 +2371,12 @@ void Storage::ConfigureReplica(io::network::Endpoint endpoint) {
|
||||
replication_server_->rpc_server->Start();
|
||||
}
|
||||
|
||||
void Storage::RegisterReplica(std::string name,
|
||||
io::network::Endpoint endpoint) {
|
||||
void Storage::RegisterReplica(
|
||||
std::string name, io::network::Endpoint endpoint,
|
||||
const replication::ReplicationMode replication_mode) {
|
||||
std::shared_lock guard(replication_lock_);
|
||||
CHECK(replication_state_.load() == ReplicationState::MAIN)
|
||||
// TODO (antonio2368): This shouldn't stop the main instance
|
||||
CHECK(replication_role_.load() == ReplicationRole::MAIN)
|
||||
<< "Only main instance can register a replica!";
|
||||
|
||||
// We can safely add new elements to the list because it doesn't validate
|
||||
@ -2377,7 +2387,7 @@ void Storage::RegisterReplica(std::string name,
|
||||
throw utils::BasicException("Replica with a same name already exists!");
|
||||
}
|
||||
clients.emplace_back(std::move(name), &name_id_mapper_, config_.items,
|
||||
endpoint, false);
|
||||
endpoint, false, replication_mode);
|
||||
return clients.back();
|
||||
});
|
||||
|
||||
@ -2454,15 +2464,29 @@ void Storage::RegisterReplica(std::string name,
|
||||
}
|
||||
}
|
||||
|
||||
void Storage::UnregisterReplica(const std::string &name) {
|
||||
void Storage::UnregisterReplica(const std::string_view name) {
|
||||
std::unique_lock<utils::RWLock> replication_guard(replication_lock_);
|
||||
CHECK(replication_state_.load() == ReplicationState::MAIN)
|
||||
CHECK(replication_role_.load() == ReplicationRole::MAIN)
|
||||
<< "Only main instance can unregister a replica!";
|
||||
replication_clients_.WithLock([&](auto &clients) {
|
||||
clients.remove_if(
|
||||
[&](const auto &client) { return client.Name() == name; });
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<replication::ReplicaState> Storage::ReplicaState(
|
||||
const std::string_view name) {
|
||||
return replication_clients_.WithLock(
|
||||
[&](auto &clients) -> std::optional<replication::ReplicaState> {
|
||||
const auto client_it = std::find_if(
|
||||
clients.cbegin(), clients.cend(),
|
||||
[name](auto &client) { return client.Name() == name; });
|
||||
if (client_it == clients.cend()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return client_it->State();
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace storage
|
||||
|
@ -169,7 +169,7 @@ struct StorageInfo {
|
||||
};
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
enum class ReplicationState : uint8_t { MAIN, REPLICA };
|
||||
enum class ReplicationRole : uint8_t { MAIN, REPLICA };
|
||||
#endif
|
||||
|
||||
class Storage final {
|
||||
@ -406,26 +406,30 @@ class Storage final {
|
||||
StorageInfo GetInfo() const;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
template <ReplicationState state, typename... Args>
|
||||
void SetReplicationState(Args &&... args) {
|
||||
if (replication_state_.load() == state) {
|
||||
template <ReplicationRole role, typename... Args>
|
||||
void SetReplicationRole(Args &&... args) {
|
||||
if (replication_role_.load() == role) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock<utils::RWLock> replication_guard(replication_lock_);
|
||||
|
||||
if constexpr (state == ReplicationState::REPLICA) {
|
||||
if constexpr (role == ReplicationRole::REPLICA) {
|
||||
ConfigureReplica(std::forward<Args>(args)...);
|
||||
} else if (state == ReplicationState::MAIN) {
|
||||
} else if (role == ReplicationRole::MAIN) {
|
||||
// Main instance does not need replication server
|
||||
replication_server_.reset();
|
||||
}
|
||||
|
||||
replication_state_.store(state);
|
||||
replication_role_.store(role);
|
||||
}
|
||||
|
||||
void RegisterReplica(std::string name, io::network::Endpoint endpoint);
|
||||
void UnregisterReplica(const std::string &name);
|
||||
void RegisterReplica(std::string name, io::network::Endpoint endpoint,
|
||||
replication::ReplicationMode replication_mode =
|
||||
replication::ReplicationMode::SYNC);
|
||||
void UnregisterReplica(std::string_view name);
|
||||
|
||||
std::optional<replication::ReplicaState> ReplicaState(std::string_view name);
|
||||
#endif
|
||||
|
||||
private:
|
||||
@ -554,7 +558,7 @@ class Storage final {
|
||||
std::optional<ReplicationServer> replication_server_;
|
||||
ReplicationClientList replication_clients_;
|
||||
|
||||
std::atomic<ReplicationState> replication_state_{ReplicationState::MAIN};
|
||||
std::atomic<ReplicationRole> replication_role_{ReplicationRole::MAIN};
|
||||
#endif
|
||||
};
|
||||
|
||||
|
@ -4,6 +4,7 @@ set(utils_src_files
|
||||
memory.cpp
|
||||
signals.cpp
|
||||
thread.cpp
|
||||
thread_pool.cpp
|
||||
uuid.cpp)
|
||||
|
||||
add_library(mg-utils STATIC ${utils_src_files})
|
||||
|
76
src/utils/thread_pool.cpp
Normal file
76
src/utils/thread_pool.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
#include "utils/thread_pool.hpp"
|
||||
|
||||
namespace utils {
|
||||
|
||||
ThreadPool::ThreadPool(const size_t pool_size) {
|
||||
for (size_t i = 0; i < pool_size; ++i) {
|
||||
thread_pool_.emplace_back(([this] { this->ThreadLoop(); }));
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPool::AddTask(std::function<void()> new_task) {
|
||||
task_queue_.WithLock([&](auto &queue) {
|
||||
queue.emplace(std::make_unique<TaskSignature>(std::move(new_task)));
|
||||
});
|
||||
queue_cv_.notify_one();
|
||||
}
|
||||
|
||||
void ThreadPool::Shutdown() {
|
||||
terminate_pool_.store(true);
|
||||
queue_cv_.notify_all();
|
||||
|
||||
for (auto &thread : thread_pool_) {
|
||||
if (thread.joinable()) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
thread_pool_.clear();
|
||||
stopped_.store(true);
|
||||
}
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
if (!stopped_.load()) {
|
||||
Shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<ThreadPool::TaskSignature> ThreadPool::PopTask() {
|
||||
return task_queue_.WithLock(
|
||||
[](auto &queue) -> std::unique_ptr<TaskSignature> {
|
||||
if (queue.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto front = std::move(queue.front());
|
||||
queue.pop();
|
||||
return front;
|
||||
});
|
||||
}
|
||||
|
||||
void ThreadPool::ThreadLoop() {
|
||||
std::unique_ptr<TaskSignature> task = PopTask();
|
||||
while (true) {
|
||||
while (task) {
|
||||
if (terminate_pool_.load()) {
|
||||
return;
|
||||
}
|
||||
(*task)();
|
||||
task = PopTask();
|
||||
}
|
||||
|
||||
std::unique_lock guard(pool_lock_);
|
||||
idle_thread_num_.fetch_add(1);
|
||||
queue_cv_.wait(guard, [&] {
|
||||
task = PopTask();
|
||||
return task || terminate_pool_.load();
|
||||
});
|
||||
idle_thread_num_.fetch_sub(1);
|
||||
if (terminate_pool_.load()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t ThreadPool::IdleThreadNum() const { return idle_thread_num_.load(); }
|
||||
|
||||
} // namespace utils
|
51
src/utils/thread_pool.hpp
Normal file
51
src/utils/thread_pool.hpp
Normal file
@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
|
||||
#include "utils/spin_lock.hpp"
|
||||
#include "utils/synchronized.hpp"
|
||||
#include "utils/thread.hpp"
|
||||
|
||||
namespace utils {
|
||||
|
||||
class ThreadPool {
|
||||
using TaskSignature = std::function<void()>;
|
||||
|
||||
public:
|
||||
explicit ThreadPool(size_t pool_size);
|
||||
|
||||
void AddTask(std::function<void()> new_task);
|
||||
|
||||
void Shutdown();
|
||||
|
||||
~ThreadPool();
|
||||
|
||||
ThreadPool(const ThreadPool &) = delete;
|
||||
ThreadPool(ThreadPool &&) = delete;
|
||||
ThreadPool &operator=(const ThreadPool &) = delete;
|
||||
ThreadPool &operator=(ThreadPool &&) = delete;
|
||||
|
||||
size_t IdleThreadNum() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TaskSignature> PopTask();
|
||||
|
||||
void ThreadLoop();
|
||||
|
||||
std::vector<std::thread> thread_pool_;
|
||||
|
||||
std::atomic<size_t> idle_thread_num_{0};
|
||||
std::atomic<bool> terminate_pool_{false};
|
||||
std::atomic<bool> stopped_{false};
|
||||
utils::Synchronized<std::queue<std::unique_ptr<TaskSignature>>,
|
||||
utils::SpinLock>
|
||||
task_queue_;
|
||||
std::mutex pool_lock_;
|
||||
std::condition_variable queue_cv_;
|
||||
};
|
||||
|
||||
} // namespace utils
|
@ -198,6 +198,8 @@ target_link_libraries(${test_prefix}small_vector mg-utils)
|
||||
add_unit_test(utils_file_locker.cpp)
|
||||
target_link_libraries(${test_prefix}utils_file_locker mg-utils fmt)
|
||||
|
||||
add_unit_test(utils_thread_pool.cpp)
|
||||
target_link_libraries(${test_prefix}utils_thread_pool mg-utils fmt)
|
||||
|
||||
# Test mg-storage-v2
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <storage/v2/property_value.hpp>
|
||||
#include <storage/v2/replication/replication.hpp>
|
||||
#include <storage/v2/storage.hpp>
|
||||
|
||||
using testing::UnorderedElementsAre;
|
||||
@ -43,7 +44,7 @@ TEST_F(ReplicationTest, BasicSynchronousReplicationTest) {
|
||||
.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::
|
||||
PERIODIC_SNAPSHOT_WITH_WAL,
|
||||
}});
|
||||
replica_store.SetReplicationState<storage::ReplicationState::REPLICA>(
|
||||
replica_store.SetReplicationRole<storage::ReplicationRole::REPLICA>(
|
||||
io::network::Endpoint{"127.0.0.1", 10000});
|
||||
|
||||
main_store.RegisterReplica("REPLICA",
|
||||
@ -280,7 +281,7 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) {
|
||||
.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::
|
||||
PERIODIC_SNAPSHOT_WITH_WAL,
|
||||
}});
|
||||
replica_store1.SetReplicationState<storage::ReplicationState::REPLICA>(
|
||||
replica_store1.SetReplicationRole<storage::ReplicationRole::REPLICA>(
|
||||
io::network::Endpoint{"127.0.0.1", 10000});
|
||||
|
||||
storage::Storage replica_store2(
|
||||
@ -289,7 +290,7 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) {
|
||||
.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::
|
||||
PERIODIC_SNAPSHOT_WITH_WAL,
|
||||
}});
|
||||
replica_store2.SetReplicationState<storage::ReplicationState::REPLICA>(
|
||||
replica_store2.SetReplicationRole<storage::ReplicationRole::REPLICA>(
|
||||
io::network::Endpoint{"127.0.0.1", 20000});
|
||||
|
||||
main_store.RegisterReplica("REPLICA1",
|
||||
@ -426,7 +427,7 @@ TEST_F(ReplicationTest, RecoveryProcess) {
|
||||
storage::Storage replica_store(
|
||||
{.durability = {.storage_directory = replica_storage_directory}});
|
||||
|
||||
replica_store.SetReplicationState<storage::ReplicationState::REPLICA>(
|
||||
replica_store.SetReplicationRole<storage::ReplicationRole::REPLICA>(
|
||||
io::network::Endpoint{"127.0.0.1", 10000});
|
||||
|
||||
main_store.RegisterReplica("REPLICA1",
|
||||
@ -481,3 +482,69 @@ TEST_F(ReplicationTest, RecoveryProcess) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ReplicationTest, BasicAsynchronousReplicationTest) {
|
||||
storage::Storage main_store(
|
||||
{.items = {.properties_on_edges = true},
|
||||
.durability = {
|
||||
.storage_directory = storage_directory,
|
||||
.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::
|
||||
PERIODIC_SNAPSHOT_WITH_WAL,
|
||||
}});
|
||||
|
||||
storage::Storage replica_store_async(
|
||||
{.items = {.properties_on_edges = true},
|
||||
.durability = {
|
||||
.storage_directory = storage_directory,
|
||||
.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::
|
||||
PERIODIC_SNAPSHOT_WITH_WAL,
|
||||
}});
|
||||
|
||||
replica_store_async.SetReplicationRole<storage::ReplicationRole::REPLICA>(
|
||||
io::network::Endpoint{"127.0.0.1", 20000});
|
||||
|
||||
main_store.RegisterReplica("REPLICA_ASYNC",
|
||||
io::network::Endpoint{"127.0.0.1", 20000},
|
||||
storage::replication::ReplicationMode::ASYNC);
|
||||
|
||||
constexpr size_t vertices_create_num = 10;
|
||||
std::vector<storage::Gid> created_vertices;
|
||||
for (size_t i = 0; i < vertices_create_num; ++i) {
|
||||
auto acc = main_store.Access();
|
||||
auto v = acc.CreateVertex();
|
||||
created_vertices.push_back(v.Gid());
|
||||
ASSERT_FALSE(acc.Commit().HasError());
|
||||
|
||||
if (i == 0) {
|
||||
ASSERT_EQ(main_store.ReplicaState("REPLICA_ASYNC"),
|
||||
storage::replication::ReplicaState::REPLICATING);
|
||||
} else {
|
||||
ASSERT_EQ(main_store.ReplicaState("REPLICA_ASYNC"),
|
||||
storage::replication::ReplicaState::RECOVERY);
|
||||
}
|
||||
}
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(20));
|
||||
|
||||
ASSERT_EQ(main_store.ReplicaState("REPLICA_ASYNC"),
|
||||
storage::replication::ReplicaState::RECOVERY);
|
||||
// Replica should have at least the first vertex
|
||||
{
|
||||
auto acc = replica_store_async.Access();
|
||||
auto v = acc.FindVertex(created_vertices[0], storage::View::OLD);
|
||||
ASSERT_TRUE(v);
|
||||
ASSERT_FALSE(acc.Commit().HasError());
|
||||
}
|
||||
|
||||
// Most of the later vertices should be skipped because
|
||||
// asyn replica cannot keep up
|
||||
ASSERT_FALSE(std::all_of(created_vertices.begin() + 1, created_vertices.end(),
|
||||
[&](const auto vertex_gid) {
|
||||
auto acc = replica_store_async.Access();
|
||||
auto v =
|
||||
acc.FindVertex(vertex_gid, storage::View::OLD);
|
||||
const bool exists = v.has_value();
|
||||
EXPECT_FALSE(acc.Commit().HasError());
|
||||
return exists;
|
||||
}));
|
||||
}
|
||||
|
29
tests/unit/utils_thread_pool.cpp
Normal file
29
tests/unit/utils_thread_pool.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
#include <utils/thread_pool.hpp>
|
||||
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
TEST(ThreadPool, Basic) {
|
||||
constexpr size_t adder_count = 100'000;
|
||||
constexpr std::array<size_t, 5> pool_sizes{1, 2, 4, 8, 100};
|
||||
|
||||
for (const auto pool_size : pool_sizes) {
|
||||
utils::ThreadPool pool{pool_size};
|
||||
|
||||
std::atomic<int> count{0};
|
||||
for (size_t i = 0; i < adder_count; ++i) {
|
||||
pool.AddTask([&] { count.fetch_add(1); });
|
||||
}
|
||||
|
||||
while (pool.IdleThreadNum() != pool_size) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
|
||||
ASSERT_EQ(count.load(), adder_count);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user