From 4ef6a1f9c3c94afad405a358065ac3108491b0fe Mon Sep 17 00:00:00 2001 From: Gareth Andrew Lloyd <gareth.lloyd@memgraph.io> Date: Tue, 6 Feb 2024 17:07:38 +0000 Subject: [PATCH 1/4] Improve memory handling of Deltas (#1688) - Reduce delta from 104B to 80B - Hold and pass them around as in a deque - Detect and deleted deltas within commit if safe to do so --- libs/setup.sh | 2 + src/storage/v2/delta.hpp | 77 +-- src/storage/v2/disk/storage.cpp | 12 +- src/storage/v2/durability/wal.cpp | 4 +- src/storage/v2/edge_accessor.cpp | 10 +- src/storage/v2/id_types.hpp | 26 +- src/storage/v2/indices/indices_utils.hpp | 10 +- src/storage/v2/inmemory/storage.cpp | 571 +++++++++++------- src/storage/v2/inmemory/storage.hpp | 8 +- .../v2/inmemory/unique_constraints.cpp | 12 +- src/storage/v2/mvcc.hpp | 16 +- src/storage/v2/property_value.hpp | 168 +++--- src/storage/v2/transaction.hpp | 7 +- src/storage/v2/vertex_info_helpers.hpp | 24 +- src/utils/disk_utils.hpp | 4 +- src/utils/string.hpp | 7 + tests/unit/storage_v2_property_store.cpp | 8 +- tests/unit/storage_v2_wal_file.cpp | 8 +- 18 files changed, 547 insertions(+), 427 deletions(-) diff --git a/libs/setup.sh b/libs/setup.sh index 4b5b81dfc..76fb4fcfa 100755 --- a/libs/setup.sh +++ b/libs/setup.sh @@ -270,6 +270,8 @@ pushd jemalloc MALLOC_CONF="retain:false,percpu_arena:percpu,oversize_threshold:0,muzzy_decay_ms:5000,dirty_decay_ms:5000" \ ./configure \ --disable-cxx \ + --with-lg-page=12 \ + --with-lg-hugepage=21 \ --enable-shared=no --prefix=$working_dir \ --with-malloc-conf="retain:false,percpu_arena:percpu,oversize_threshold:0,muzzy_decay_ms:5000,dirty_decay_ms:5000" diff --git a/src/storage/v2/delta.hpp b/src/storage/v2/delta.hpp index bcb2930eb..9c70bdc4c 100644 --- a/src/storage/v2/delta.hpp +++ b/src/storage/v2/delta.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -57,9 +57,11 @@ class PreviousPtr { explicit Pointer(Edge *edge) : type(Type::EDGE), edge(edge) {} Type type{Type::NULLPTR}; - Delta *delta{nullptr}; - Vertex *vertex{nullptr}; - Edge *edge{nullptr}; + union { + Delta *delta = nullptr; + Vertex *vertex; + Edge *edge; + }; }; PreviousPtr() : storage_(0) {} @@ -157,59 +159,51 @@ struct Delta { // DELETE_DESERIALIZED_OBJECT is used to load data from disk committed by past txs. // Because of this object was created in past txs, we create timestamp by ourselves inside instead of having it from // current tx. This timestamp we got from RocksDB timestamp stored in key. - Delta(DeleteDeserializedObjectTag /*tag*/, uint64_t ts, const std::optional<std::string> &old_disk_key) - : action(Action::DELETE_DESERIALIZED_OBJECT), - timestamp(new std::atomic<uint64_t>(ts)), - command_id(0), - old_disk_key(old_disk_key) {} + Delta(DeleteDeserializedObjectTag /*tag*/, uint64_t ts, std::optional<std::string> old_disk_key) + : timestamp(new std::atomic<uint64_t>(ts)), command_id(0), old_disk_key{.value = std::move(old_disk_key)} {} Delta(DeleteObjectTag /*tag*/, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::DELETE_OBJECT), timestamp(timestamp), command_id(command_id) {} + : timestamp(timestamp), command_id(command_id), action(Action::DELETE_OBJECT) {} Delta(RecreateObjectTag /*tag*/, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::RECREATE_OBJECT), timestamp(timestamp), command_id(command_id) {} + : timestamp(timestamp), command_id(command_id), action(Action::RECREATE_OBJECT) {} Delta(AddLabelTag /*tag*/, LabelId label, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::ADD_LABEL), timestamp(timestamp), command_id(command_id), label(label) {} + : timestamp(timestamp), command_id(command_id), label{.action = Action::ADD_LABEL, .value = label} {} Delta(RemoveLabelTag /*tag*/, LabelId label, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::REMOVE_LABEL), timestamp(timestamp), command_id(command_id), label(label) {} + : timestamp(timestamp), command_id(command_id), label{.action = Action::REMOVE_LABEL, .value = label} {} - Delta(SetPropertyTag /*tag*/, PropertyId key, const PropertyValue &value, std::atomic<uint64_t> *timestamp, + Delta(SetPropertyTag /*tag*/, PropertyId key, PropertyValue value, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::SET_PROPERTY), timestamp(timestamp), command_id(command_id), property({key, value}) {} - - Delta(SetPropertyTag /*tag*/, PropertyId key, PropertyValue &&value, std::atomic<uint64_t> *timestamp, - uint64_t command_id) - : action(Action::SET_PROPERTY), timestamp(timestamp), command_id(command_id), property({key, std::move(value)}) {} + : timestamp(timestamp), + command_id(command_id), + property{ + .action = Action::SET_PROPERTY, .key = key, .value = std::make_unique<PropertyValue>(std::move(value))} {} Delta(AddInEdgeTag /*tag*/, EdgeTypeId edge_type, Vertex *vertex, EdgeRef edge, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::ADD_IN_EDGE), - timestamp(timestamp), + : timestamp(timestamp), command_id(command_id), - vertex_edge({edge_type, vertex, edge}) {} + vertex_edge{.action = Action::ADD_IN_EDGE, .edge_type = edge_type, vertex, edge} {} Delta(AddOutEdgeTag /*tag*/, EdgeTypeId edge_type, Vertex *vertex, EdgeRef edge, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::ADD_OUT_EDGE), - timestamp(timestamp), + : timestamp(timestamp), command_id(command_id), - vertex_edge({edge_type, vertex, edge}) {} + vertex_edge{.action = Action::ADD_OUT_EDGE, .edge_type = edge_type, vertex, edge} {} Delta(RemoveInEdgeTag /*tag*/, EdgeTypeId edge_type, Vertex *vertex, EdgeRef edge, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::REMOVE_IN_EDGE), - timestamp(timestamp), + : timestamp(timestamp), command_id(command_id), - vertex_edge({edge_type, vertex, edge}) {} + vertex_edge{.action = Action::REMOVE_IN_EDGE, .edge_type = edge_type, vertex, edge} {} Delta(RemoveOutEdgeTag /*tag*/, EdgeTypeId edge_type, Vertex *vertex, EdgeRef edge, std::atomic<uint64_t> *timestamp, uint64_t command_id) - : action(Action::REMOVE_OUT_EDGE), - timestamp(timestamp), + : timestamp(timestamp), command_id(command_id), - vertex_edge({edge_type, vertex, edge}) {} + vertex_edge{.action = Action::REMOVE_OUT_EDGE, .edge_type = edge_type, vertex, edge} {} Delta(const Delta &) = delete; Delta(Delta &&) = delete; @@ -228,18 +222,16 @@ struct Delta { case Action::REMOVE_OUT_EDGE: break; case Action::DELETE_DESERIALIZED_OBJECT: - old_disk_key.reset(); + old_disk_key.value.reset(); delete timestamp; timestamp = nullptr; break; case Action::SET_PROPERTY: - property.value.~PropertyValue(); + property.value.reset(); break; } } - Action action; - // TODO: optimize with in-place copy std::atomic<uint64_t> *timestamp; uint64_t command_id; @@ -247,13 +239,22 @@ struct Delta { std::atomic<Delta *> next{nullptr}; union { - std::optional<std::string> old_disk_key; - LabelId label; + Action action; struct { + Action action = Action::DELETE_DESERIALIZED_OBJECT; + std::optional<std::string> value; + } old_disk_key; + struct { + Action action; + LabelId value; + } label; + struct { + Action action; PropertyId key; - storage::PropertyValue value; + std::unique_ptr<storage::PropertyValue> value; } property; struct { + Action action; EdgeTypeId edge_type; Vertex *vertex; EdgeRef edge; diff --git a/src/storage/v2/disk/storage.cpp b/src/storage/v2/disk/storage.cpp index f3c3aa0f4..adc0e92f4 100644 --- a/src/storage/v2/disk/storage.cpp +++ b/src/storage/v2/disk/storage.cpp @@ -137,14 +137,14 @@ bool VertexHasLabel(const Vertex &vertex, LabelId label, Transaction *transactio ApplyDeltasForRead(transaction, delta, view, [&deleted, &has_label, label](const Delta &delta) { switch (delta.action) { case Delta::Action::REMOVE_LABEL: { - if (delta.label == label) { + if (delta.label.value == label) { MG_ASSERT(has_label, "Invalid database state!"); has_label = false; } break; } case Delta::Action::ADD_LABEL: { - if (delta.label == label) { + if (delta.label.value == label) { MG_ASSERT(!has_label, "Invalid database state!"); has_label = true; } @@ -177,7 +177,7 @@ PropertyValue GetVertexProperty(const Vertex &vertex, PropertyId property, Trans switch (delta.action) { case Delta::Action::SET_PROPERTY: { if (delta.property.key == property) { - value = delta.property.value; + value = *delta.property.value; } break; } @@ -1682,9 +1682,9 @@ utils::BasicResult<StorageManipulationError, void> DiskStorage::DiskAccessor::Co } break; } } - } else if (transaction_.deltas.use().empty() || + } else if (transaction_.deltas.empty() || (!edge_import_mode_active && - std::all_of(transaction_.deltas.use().begin(), transaction_.deltas.use().end(), [](const Delta &delta) { + std::all_of(transaction_.deltas.begin(), transaction_.deltas.end(), [](const Delta &delta) { return delta.action == Delta::Action::DELETE_DESERIALIZED_OBJECT; }))) { } else { @@ -1812,7 +1812,7 @@ void DiskStorage::DiskAccessor::UpdateObjectsCountOnAbort() { auto *disk_storage = static_cast<DiskStorage *>(storage_); uint64_t transaction_id = transaction_.transaction_id; - for (const auto &delta : transaction_.deltas.use()) { + for (const auto &delta : transaction_.deltas) { auto prev = delta.prev.Get(); switch (prev.type) { case PreviousPtr::Type::VERTEX: { diff --git a/src/storage/v2/durability/wal.cpp b/src/storage/v2/durability/wal.cpp index e808f01a3..52e916052 100644 --- a/src/storage/v2/durability/wal.cpp +++ b/src/storage/v2/durability/wal.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -580,7 +580,7 @@ void EncodeDelta(BaseEncoder *encoder, NameIdMapper *name_id_mapper, SalientConf case Delta::Action::REMOVE_LABEL: { encoder->WriteMarker(VertexActionToMarker(delta.action)); encoder->WriteUint(vertex.gid.AsUint()); - encoder->WriteString(name_id_mapper->IdToName(delta.label.AsUint())); + encoder->WriteString(name_id_mapper->IdToName(delta.label.value.AsUint())); break; } case Delta::Action::ADD_OUT_EDGE: diff --git a/src/storage/v2/edge_accessor.cpp b/src/storage/v2/edge_accessor.cpp index 3ab2e3d79..03522ba16 100644 --- a/src/storage/v2/edge_accessor.cpp +++ b/src/storage/v2/edge_accessor.cpp @@ -237,7 +237,7 @@ Result<PropertyValue> EdgeAccessor::GetProperty(PropertyId property, View view) switch (delta.action) { case Delta::Action::SET_PROPERTY: { if (delta.property.key == property) { - *value = delta.property.value; + *value = *delta.property.value; } break; } @@ -281,15 +281,15 @@ Result<std::map<PropertyId, PropertyValue>> EdgeAccessor::Properties(View view) case Delta::Action::SET_PROPERTY: { auto it = properties.find(delta.property.key); if (it != properties.end()) { - if (delta.property.value.IsNull()) { + if (delta.property.value->IsNull()) { // remove the property properties.erase(it); } else { // set the value - it->second = delta.property.value; + it->second = *delta.property.value; } - } else if (!delta.property.value.IsNull()) { - properties.emplace(delta.property.key, delta.property.value); + } else if (!delta.property.value->IsNull()) { + properties.emplace(delta.property.key, *delta.property.value); } break; } diff --git a/src/storage/v2/id_types.hpp b/src/storage/v2/id_types.hpp index 3f2c8aa40..5e1809c67 100644 --- a/src/storage/v2/id_types.hpp +++ b/src/storage/v2/id_types.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -23,24 +23,24 @@ namespace memgraph::storage { -#define STORAGE_DEFINE_ID_TYPE(name) \ +#define STORAGE_DEFINE_ID_TYPE(name, type_store, type_conv, parse) \ class name final { \ private: \ - explicit name(uint64_t id) : id_(id) {} \ + explicit name(type_store id) : id_(id) {} \ \ public: \ /* Default constructor to allow serialization or preallocation. */ \ name() = default; \ \ - static name FromUint(uint64_t id) { return name{id}; } \ - static name FromInt(int64_t id) { return name{utils::MemcpyCast<uint64_t>(id)}; } \ - uint64_t AsUint() const { return id_; } \ - int64_t AsInt() const { return utils::MemcpyCast<int64_t>(id_); } \ - static name FromString(std::string_view id) { return name{utils::ParseStringToUint64(id)}; } \ + static name FromUint(type_store id) { return name{id}; } \ + static name FromInt(type_conv id) { return name{utils::MemcpyCast<type_store>(id)}; } \ + type_store AsUint() const { return id_; } \ + type_conv AsInt() const { return utils::MemcpyCast<type_conv>(id_); } \ + static name FromString(std::string_view id) { return name{parse(id)}; } \ std::string ToString() const { return std::to_string(id_); } \ \ private: \ - uint64_t id_; \ + type_store id_; \ }; \ static_assert(std::is_trivially_copyable_v<name>, "storage::" #name " must be trivially copyable!"); \ inline bool operator==(const name &first, const name &second) { return first.AsUint() == second.AsUint(); } \ @@ -50,10 +50,10 @@ namespace memgraph::storage { inline bool operator<=(const name &first, const name &second) { return first.AsUint() <= second.AsUint(); } \ inline bool operator>=(const name &first, const name &second) { return first.AsUint() >= second.AsUint(); } -STORAGE_DEFINE_ID_TYPE(Gid); -STORAGE_DEFINE_ID_TYPE(LabelId); -STORAGE_DEFINE_ID_TYPE(PropertyId); -STORAGE_DEFINE_ID_TYPE(EdgeTypeId); +STORAGE_DEFINE_ID_TYPE(Gid, uint64_t, int64_t, utils::ParseStringToUint64); +STORAGE_DEFINE_ID_TYPE(LabelId, uint32_t, int32_t, utils::ParseStringToUint32); +STORAGE_DEFINE_ID_TYPE(PropertyId, uint32_t, int32_t, utils::ParseStringToUint32); +STORAGE_DEFINE_ID_TYPE(EdgeTypeId, uint32_t, int32_t, utils::ParseStringToUint32); #undef STORAGE_DEFINE_ID_TYPE diff --git a/src/storage/v2/indices/indices_utils.hpp b/src/storage/v2/indices/indices_utils.hpp index 054609188..52938a1db 100644 --- a/src/storage/v2/indices/indices_utils.hpp +++ b/src/storage/v2/indices/indices_utils.hpp @@ -72,13 +72,13 @@ inline bool AnyVersionHasLabel(const Vertex &vertex, LabelId label, uint64_t tim return AnyVersionSatisfiesPredicate<interesting>(timestamp, delta, [&has_label, &deleted, label](const Delta &delta) { switch (delta.action) { case Delta::Action::ADD_LABEL: - if (delta.label == label) { + if (delta.label.value == label) { MG_ASSERT(!has_label, "Invalid database state!"); has_label = true; } break; case Delta::Action::REMOVE_LABEL: - if (delta.label == label) { + if (delta.label.value == label) { MG_ASSERT(has_label, "Invalid database state!"); has_label = false; } @@ -135,20 +135,20 @@ inline bool AnyVersionHasLabelProperty(const Vertex &vertex, LabelId label, Prop timestamp, delta, [&has_label, ¤t_value_equal_to_value, &deleted, label, key, &value](const Delta &delta) { switch (delta.action) { case Delta::Action::ADD_LABEL: - if (delta.label == label) { + if (delta.label.value == label) { MG_ASSERT(!has_label, "Invalid database state!"); has_label = true; } break; case Delta::Action::REMOVE_LABEL: - if (delta.label == label) { + if (delta.label.value == label) { MG_ASSERT(has_label, "Invalid database state!"); has_label = false; } break; case Delta::Action::SET_PROPERTY: if (delta.property.key == key) { - current_value_equal_to_value = delta.property.value == value; + current_value_equal_to_value = *delta.property.value == value; } break; case Delta::Action::RECREATE_OBJECT: { diff --git a/src/storage/v2/inmemory/storage.cpp b/src/storage/v2/inmemory/storage.cpp index 381a67d3f..c97d12072 100644 --- a/src/storage/v2/inmemory/storage.cpp +++ b/src/storage/v2/inmemory/storage.cpp @@ -176,9 +176,9 @@ InMemoryStorage::~InMemoryStorage() { committed_transactions_.WithLock([](auto &transactions) { transactions.clear(); }); } -InMemoryStorage::InMemoryAccessor::InMemoryAccessor(auto tag, InMemoryStorage *storage, IsolationLevel isolation_level, - StorageMode storage_mode, - memgraph::replication_coordination_glue::ReplicationRole replication_role) +InMemoryStorage::InMemoryAccessor::InMemoryAccessor( + auto tag, InMemoryStorage *storage, IsolationLevel isolation_level, StorageMode storage_mode, + memgraph::replication_coordination_glue::ReplicationRole replication_role) : Accessor(tag, storage, isolation_level, storage_mode, replication_role), config_(storage->config_.salient.items) {} InMemoryStorage::InMemoryAccessor::InMemoryAccessor(InMemoryAccessor &&other) noexcept @@ -757,7 +757,7 @@ utils::BasicResult<StorageManipulationError, void> InMemoryStorage::InMemoryAcce auto *mem_storage = static_cast<InMemoryStorage *>(storage_); // TODO: duplicated transaction finalisation in md_deltas and deltas processing cases - if (transaction_.deltas.use().empty() && transaction_.md_deltas.empty()) { + if (transaction_.deltas.empty() && transaction_.md_deltas.empty()) { // We don't have to update the commit timestamp here because no one reads // it. mem_storage->commit_log_->MarkFinished(transaction_.start_timestamp); @@ -836,25 +836,37 @@ utils::BasicResult<StorageManipulationError, void> InMemoryStorage::InMemoryAcce // Replica can log only the write transaction received from Main // so the Wal files are consistent if (is_main_or_replica_write) { - could_replicate_all_sync_replicas = mem_storage->AppendToWal(transaction_, *commit_timestamp_, - std::move(db_acc)); // protected by engine_guard + could_replicate_all_sync_replicas = + mem_storage->AppendToWal(transaction_, *commit_timestamp_, std::move(db_acc)); // TODO: release lock, and update all deltas to have a local copy of the commit timestamp MG_ASSERT(transaction_.commit_timestamp != nullptr, "Invalid database state!"); - transaction_.commit_timestamp->store(*commit_timestamp_, - std::memory_order_release); // protected by engine_guard + transaction_.commit_timestamp->store(*commit_timestamp_, std::memory_order_release); // Replica can only update the last commit timestamp with // the commits received from main. // Update the last commit timestamp - mem_storage->repl_storage_state_.last_commit_timestamp_.store( - *commit_timestamp_); // protected by engine_guard + mem_storage->repl_storage_state_.last_commit_timestamp_.store(*commit_timestamp_); } - // Release engine lock because we don't have to hold it anymore - engine_guard.unlock(); + // TODO: can and should this be moved earlier? mem_storage->commit_log_->MarkFinished(start_timestamp); + + // while still holding engine lock + // and after durability + replication + // check if we can fast discard deltas (ie. do not hand over to GC) + bool no_older_transactions = mem_storage->commit_log_->OldestActive() == *commit_timestamp_; + bool no_newer_transactions = mem_storage->transaction_id_ == transaction_.transaction_id + 1; + if (no_older_transactions && no_newer_transactions) [[unlikely]] { + // STEP 0) Can only do fast discard if GC is not running + // We can't unlink our transcations deltas until all of the older deltas in GC have been unlinked + // must do a try here, to avoid deadlock between transactions `engine_lock_` and the GC `gc_lock_` + auto gc_guard = std::unique_lock{mem_storage->gc_lock_, std::defer_lock}; + if (gc_guard.try_lock()) { + FastDiscardOfDeltas(*commit_timestamp_, std::move(gc_guard)); + } + } } - } + } // Release engine lock because we don't have to hold it anymore if (unique_constraint_violation) { Abort(); @@ -873,241 +885,332 @@ utils::BasicResult<StorageManipulationError, void> InMemoryStorage::InMemoryAcce return {}; } +void InMemoryStorage::InMemoryAccessor::FastDiscardOfDeltas(uint64_t oldest_active_timestamp, + std::unique_lock<std::mutex> /*gc_guard*/) { + auto *mem_storage = static_cast<InMemoryStorage *>(storage_); + std::list<Gid> current_deleted_edges; + std::list<Gid> current_deleted_vertices; + + auto const unlink_remove_clear = [&](std::deque<Delta> &deltas) { + for (auto &delta : deltas) { + auto prev = delta.prev.Get(); + switch (prev.type) { + case PreviousPtr::Type::NULLPTR: + case PreviousPtr::Type::DELTA: + break; + case PreviousPtr::Type::VERTEX: { + // safe because no other txn can be reading this while we have engine lock + auto &vertex = *prev.vertex; + vertex.delta = nullptr; + if (vertex.deleted) { + DMG_ASSERT(delta.action == Delta::Action::RECREATE_OBJECT); + current_deleted_vertices.push_back(vertex.gid); + } + break; + } + case PreviousPtr::Type::EDGE: { + // safe because no other txn can be reading this while we have engine lock + auto &edge = *prev.edge; + edge.delta = nullptr; + if (edge.deleted) { + DMG_ASSERT(delta.action == Delta::Action::RECREATE_OBJECT); + current_deleted_edges.push_back(edge.gid); + } + break; + } + } + } + // delete deltas + deltas.clear(); + }; + + // STEP 1) ensure everything in GC is gone + + // 1.a) old garbage_undo_buffers are safe to remove + // we are the only transaction, no one is reading those unlinked deltas + mem_storage->garbage_undo_buffers_.WithLock([&](auto &garbage_undo_buffers) { garbage_undo_buffers.clear(); }); + + // 1.b.0) old committed_transactions_ need mininal unlinking + remove + clear + // must be done before this transactions delta unlinking + auto linked_undo_buffers = std::list<GCDeltas>{}; + mem_storage->committed_transactions_.WithLock( + [&](auto &committed_transactions) { committed_transactions.swap(linked_undo_buffers); }); + + // 1.b.1) unlink, gathering the removals + for (auto &gc_deltas : linked_undo_buffers) { + unlink_remove_clear(gc_deltas.deltas_); + } + // 1.b.2) clear the list of deltas deques + linked_undo_buffers.clear(); + + // STEP 2) this transactions deltas also mininal unlinking + remove + clear + unlink_remove_clear(transaction_.deltas); + + // STEP 3) skip_list removals + if (!current_deleted_vertices.empty()) { + // 3.a) clear from indexes first + std::stop_source dummy; + mem_storage->indices_.RemoveObsoleteEntries(oldest_active_timestamp, dummy.get_token()); + auto *mem_unique_constraints = + static_cast<InMemoryUniqueConstraints *>(mem_storage->constraints_.unique_constraints_.get()); + mem_unique_constraints->RemoveObsoleteEntries(oldest_active_timestamp, dummy.get_token()); + + // 3.b) remove from veretex skip_list + auto vertex_acc = mem_storage->vertices_.access(); + for (auto gid : current_deleted_vertices) { + vertex_acc.remove(gid); + } + } + + if (!current_deleted_edges.empty()) { + // 3.c) remove from edge skip_list + auto edge_acc = mem_storage->edges_.access(); + for (auto gid : current_deleted_edges) { + edge_acc.remove(gid); + } + } +} + void InMemoryStorage::InMemoryAccessor::Abort() { MG_ASSERT(is_transaction_active_, "The transaction is already terminated!"); - // We collect vertices and edges we've created here and then splice them into - // `deleted_vertices_` and `deleted_edges_` lists, instead of adding them one - // by one and acquiring lock every time. - std::list<Gid> my_deleted_vertices; - std::list<Gid> my_deleted_edges; + auto *mem_storage = static_cast<InMemoryStorage *>(storage_); - std::map<LabelId, std::vector<Vertex *>> label_cleanup; - std::map<LabelId, std::vector<std::pair<PropertyValue, Vertex *>>> label_property_cleanup; - std::map<PropertyId, std::vector<std::pair<PropertyValue, Vertex *>>> property_cleanup; + // if we have no deltas then no need to do any undo work during Abort + // note: this check also saves on unnecessary contention on `engine_lock_` + if (!transaction_.deltas.empty()) { + // CONSTRAINTS + if (transaction_.constraint_verification_info.NeedsUniqueConstraintVerification()) { + // Need to remove elements from constraints before handling of the deltas, so the elements match the correct + // values + auto vertices_to_check = transaction_.constraint_verification_info.GetVerticesForUniqueConstraintChecking(); + auto vertices_to_check_v = std::vector<Vertex const *>{vertices_to_check.begin(), vertices_to_check.end()}; + storage_->constraints_.AbortEntries(vertices_to_check_v, transaction_.start_timestamp); + } - // CONSTRAINTS - if (transaction_.constraint_verification_info.NeedsUniqueConstraintVerification()) { - // Need to remove elements from constraints before handling of the deltas, so the elements match the correct - // values - auto vertices_to_check = transaction_.constraint_verification_info.GetVerticesForUniqueConstraintChecking(); - auto vertices_to_check_v = std::vector<Vertex const *>{vertices_to_check.begin(), vertices_to_check.end()}; - storage_->constraints_.AbortEntries(vertices_to_check_v, transaction_.start_timestamp); - } + const auto index_stats = storage_->indices_.Analysis(); - const auto index_stats = storage_->indices_.Analysis(); + // We collect vertices and edges we've created here and then splice them into + // `deleted_vertices_` and `deleted_edges_` lists, instead of adding them one + // by one and acquiring lock every time. + std::list<Gid> my_deleted_vertices; + std::list<Gid> my_deleted_edges; - for (const auto &delta : transaction_.deltas.use()) { - auto prev = delta.prev.Get(); - switch (prev.type) { - case PreviousPtr::Type::VERTEX: { - auto *vertex = prev.vertex; - auto guard = std::unique_lock{vertex->lock}; - Delta *current = vertex->delta; - while (current != nullptr && - current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) { - switch (current->action) { - case Delta::Action::REMOVE_LABEL: { - auto it = std::find(vertex->labels.begin(), vertex->labels.end(), current->label); - MG_ASSERT(it != vertex->labels.end(), "Invalid database state!"); - std::swap(*it, *vertex->labels.rbegin()); - vertex->labels.pop_back(); + std::map<LabelId, std::vector<Vertex *>> label_cleanup; + std::map<LabelId, std::vector<std::pair<PropertyValue, Vertex *>>> label_property_cleanup; + std::map<PropertyId, std::vector<std::pair<PropertyValue, Vertex *>>> property_cleanup; - // For label index - // check if there is a label index for the label and add entry if so - // For property label index - // check if we care about the label; this will return all the propertyIds we care about and then get - // the current property value - if (std::binary_search(index_stats.label.begin(), index_stats.label.end(), current->label)) { - label_cleanup[current->label].emplace_back(vertex); - } - const auto &properties = index_stats.property_label.l2p.find(current->label); - if (properties != index_stats.property_label.l2p.end()) { - for (const auto &property : properties->second) { - auto current_value = vertex->properties.GetProperty(property); - if (!current_value.IsNull()) { - label_property_cleanup[current->label].emplace_back(std::move(current_value), vertex); + for (const auto &delta : transaction_.deltas) { + auto prev = delta.prev.Get(); + switch (prev.type) { + case PreviousPtr::Type::VERTEX: { + auto *vertex = prev.vertex; + auto guard = std::unique_lock{vertex->lock}; + Delta *current = vertex->delta; + while (current != nullptr && + current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) { + switch (current->action) { + case Delta::Action::REMOVE_LABEL: { + auto it = std::find(vertex->labels.begin(), vertex->labels.end(), current->label.value); + MG_ASSERT(it != vertex->labels.end(), "Invalid database state!"); + std::swap(*it, *vertex->labels.rbegin()); + vertex->labels.pop_back(); + + // For label index + // check if there is a label index for the label and add entry if so + // For property label index + // check if we care about the label; this will return all the propertyIds we care about and then get + // the current property value + if (std::binary_search(index_stats.label.begin(), index_stats.label.end(), current->label.value)) { + label_cleanup[current->label.value].emplace_back(vertex); + } + const auto &properties = index_stats.property_label.l2p.find(current->label.value); + if (properties != index_stats.property_label.l2p.end()) { + for (const auto &property : properties->second) { + auto current_value = vertex->properties.GetProperty(property); + if (!current_value.IsNull()) { + label_property_cleanup[current->label.value].emplace_back(std::move(current_value), vertex); + } } } + break; } - break; - } - case Delta::Action::ADD_LABEL: { - auto it = std::find(vertex->labels.begin(), vertex->labels.end(), current->label); - MG_ASSERT(it == vertex->labels.end(), "Invalid database state!"); - vertex->labels.push_back(current->label); - break; - } - case Delta::Action::SET_PROPERTY: { - // For label index nothing - // For property label index - // check if we care about the property, this will return all the labels and then get current property - // value - const auto &labels = index_stats.property_label.p2l.find(current->property.key); - if (labels != index_stats.property_label.p2l.end()) { - auto current_value = vertex->properties.GetProperty(current->property.key); - if (!current_value.IsNull()) { - property_cleanup[current->property.key].emplace_back(std::move(current_value), vertex); + case Delta::Action::ADD_LABEL: { + auto it = std::find(vertex->labels.begin(), vertex->labels.end(), current->label.value); + MG_ASSERT(it == vertex->labels.end(), "Invalid database state!"); + vertex->labels.push_back(current->label.value); + break; + } + case Delta::Action::SET_PROPERTY: { + // For label index nothing + // For property label index + // check if we care about the property, this will return all the labels and then get current property + // value + const auto &labels = index_stats.property_label.p2l.find(current->property.key); + if (labels != index_stats.property_label.p2l.end()) { + auto current_value = vertex->properties.GetProperty(current->property.key); + if (!current_value.IsNull()) { + property_cleanup[current->property.key].emplace_back(std::move(current_value), vertex); + } } + // Setting the correct value + vertex->properties.SetProperty(current->property.key, *current->property.value); + break; + } + case Delta::Action::ADD_IN_EDGE: { + std::tuple<EdgeTypeId, Vertex *, EdgeRef> link{current->vertex_edge.edge_type, + current->vertex_edge.vertex, current->vertex_edge.edge}; + auto it = std::find(vertex->in_edges.begin(), vertex->in_edges.end(), link); + MG_ASSERT(it == vertex->in_edges.end(), "Invalid database state!"); + vertex->in_edges.push_back(link); + break; + } + case Delta::Action::ADD_OUT_EDGE: { + std::tuple<EdgeTypeId, Vertex *, EdgeRef> link{current->vertex_edge.edge_type, + current->vertex_edge.vertex, current->vertex_edge.edge}; + auto it = std::find(vertex->out_edges.begin(), vertex->out_edges.end(), link); + MG_ASSERT(it == vertex->out_edges.end(), "Invalid database state!"); + vertex->out_edges.push_back(link); + // Increment edge count. We only increment the count here because + // the information in `ADD_IN_EDGE` and `Edge/RECREATE_OBJECT` is + // redundant. Also, `Edge/RECREATE_OBJECT` isn't available when + // edge properties are disabled. + storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); + break; + } + case Delta::Action::REMOVE_IN_EDGE: { + std::tuple<EdgeTypeId, Vertex *, EdgeRef> link{current->vertex_edge.edge_type, + current->vertex_edge.vertex, current->vertex_edge.edge}; + auto it = std::find(vertex->in_edges.begin(), vertex->in_edges.end(), link); + MG_ASSERT(it != vertex->in_edges.end(), "Invalid database state!"); + std::swap(*it, *vertex->in_edges.rbegin()); + vertex->in_edges.pop_back(); + break; + } + case Delta::Action::REMOVE_OUT_EDGE: { + std::tuple<EdgeTypeId, Vertex *, EdgeRef> link{current->vertex_edge.edge_type, + current->vertex_edge.vertex, current->vertex_edge.edge}; + auto it = std::find(vertex->out_edges.begin(), vertex->out_edges.end(), link); + MG_ASSERT(it != vertex->out_edges.end(), "Invalid database state!"); + std::swap(*it, *vertex->out_edges.rbegin()); + vertex->out_edges.pop_back(); + // Decrement edge count. We only decrement the count here because + // the information in `REMOVE_IN_EDGE` and `Edge/DELETE_OBJECT` is + // redundant. Also, `Edge/DELETE_OBJECT` isn't available when edge + // properties are disabled. + storage_->edge_count_.fetch_add(-1, std::memory_order_acq_rel); + break; + } + case Delta::Action::DELETE_DESERIALIZED_OBJECT: + case Delta::Action::DELETE_OBJECT: { + vertex->deleted = true; + my_deleted_vertices.push_back(vertex->gid); + break; + } + case Delta::Action::RECREATE_OBJECT: { + vertex->deleted = false; + break; } - // Setting the correct value - vertex->properties.SetProperty(current->property.key, current->property.value); - break; - } - case Delta::Action::ADD_IN_EDGE: { - std::tuple<EdgeTypeId, Vertex *, EdgeRef> link{current->vertex_edge.edge_type, - current->vertex_edge.vertex, current->vertex_edge.edge}; - auto it = std::find(vertex->in_edges.begin(), vertex->in_edges.end(), link); - MG_ASSERT(it == vertex->in_edges.end(), "Invalid database state!"); - vertex->in_edges.push_back(link); - break; - } - case Delta::Action::ADD_OUT_EDGE: { - std::tuple<EdgeTypeId, Vertex *, EdgeRef> link{current->vertex_edge.edge_type, - current->vertex_edge.vertex, current->vertex_edge.edge}; - auto it = std::find(vertex->out_edges.begin(), vertex->out_edges.end(), link); - MG_ASSERT(it == vertex->out_edges.end(), "Invalid database state!"); - vertex->out_edges.push_back(link); - // Increment edge count. We only increment the count here because - // the information in `ADD_IN_EDGE` and `Edge/RECREATE_OBJECT` is - // redundant. Also, `Edge/RECREATE_OBJECT` isn't available when - // edge properties are disabled. - storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); - break; - } - case Delta::Action::REMOVE_IN_EDGE: { - std::tuple<EdgeTypeId, Vertex *, EdgeRef> link{current->vertex_edge.edge_type, - current->vertex_edge.vertex, current->vertex_edge.edge}; - auto it = std::find(vertex->in_edges.begin(), vertex->in_edges.end(), link); - MG_ASSERT(it != vertex->in_edges.end(), "Invalid database state!"); - std::swap(*it, *vertex->in_edges.rbegin()); - vertex->in_edges.pop_back(); - break; - } - case Delta::Action::REMOVE_OUT_EDGE: { - std::tuple<EdgeTypeId, Vertex *, EdgeRef> link{current->vertex_edge.edge_type, - current->vertex_edge.vertex, current->vertex_edge.edge}; - auto it = std::find(vertex->out_edges.begin(), vertex->out_edges.end(), link); - MG_ASSERT(it != vertex->out_edges.end(), "Invalid database state!"); - std::swap(*it, *vertex->out_edges.rbegin()); - vertex->out_edges.pop_back(); - // Decrement edge count. We only decrement the count here because - // the information in `REMOVE_IN_EDGE` and `Edge/DELETE_OBJECT` is - // redundant. Also, `Edge/DELETE_OBJECT` isn't available when edge - // properties are disabled. - storage_->edge_count_.fetch_add(-1, std::memory_order_acq_rel); - break; - } - case Delta::Action::DELETE_DESERIALIZED_OBJECT: - case Delta::Action::DELETE_OBJECT: { - vertex->deleted = true; - my_deleted_vertices.push_back(vertex->gid); - break; - } - case Delta::Action::RECREATE_OBJECT: { - vertex->deleted = false; - break; } + current = current->next.load(std::memory_order_acquire); } - current = current->next.load(std::memory_order_acquire); - } - vertex->delta = current; - if (current != nullptr) { - current->prev.Set(vertex); - } - - break; - } - case PreviousPtr::Type::EDGE: { - auto *edge = prev.edge; - auto guard = std::lock_guard{edge->lock}; - Delta *current = edge->delta; - while (current != nullptr && - current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) { - switch (current->action) { - case Delta::Action::SET_PROPERTY: { - edge->properties.SetProperty(current->property.key, current->property.value); - break; - } - case Delta::Action::DELETE_DESERIALIZED_OBJECT: - case Delta::Action::DELETE_OBJECT: { - edge->deleted = true; - my_deleted_edges.push_back(edge->gid); - break; - } - case Delta::Action::RECREATE_OBJECT: { - edge->deleted = false; - break; - } - case Delta::Action::REMOVE_LABEL: - case Delta::Action::ADD_LABEL: - case Delta::Action::ADD_IN_EDGE: - case Delta::Action::ADD_OUT_EDGE: - case Delta::Action::REMOVE_IN_EDGE: - case Delta::Action::REMOVE_OUT_EDGE: { - LOG_FATAL("Invalid database state!"); - break; - } + vertex->delta = current; + if (current != nullptr) { + current->prev.Set(vertex); } - current = current->next.load(std::memory_order_acquire); + + break; } - edge->delta = current; - if (current != nullptr) { - current->prev.Set(edge); + case PreviousPtr::Type::EDGE: { + auto *edge = prev.edge; + auto guard = std::lock_guard{edge->lock}; + Delta *current = edge->delta; + while (current != nullptr && + current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) { + switch (current->action) { + case Delta::Action::SET_PROPERTY: { + edge->properties.SetProperty(current->property.key, *current->property.value); + break; + } + case Delta::Action::DELETE_DESERIALIZED_OBJECT: + case Delta::Action::DELETE_OBJECT: { + edge->deleted = true; + my_deleted_edges.push_back(edge->gid); + break; + } + case Delta::Action::RECREATE_OBJECT: { + edge->deleted = false; + break; + } + case Delta::Action::REMOVE_LABEL: + case Delta::Action::ADD_LABEL: + case Delta::Action::ADD_IN_EDGE: + case Delta::Action::ADD_OUT_EDGE: + case Delta::Action::REMOVE_IN_EDGE: + case Delta::Action::REMOVE_OUT_EDGE: { + LOG_FATAL("Invalid database state!"); + break; + } + } + current = current->next.load(std::memory_order_acquire); + } + edge->delta = current; + if (current != nullptr) { + current->prev.Set(edge); + } + + break; } - - break; - } - case PreviousPtr::Type::DELTA: - // pointer probably couldn't be set because allocation failed - case PreviousPtr::Type::NULLPTR: - break; - } - } - - auto *mem_storage = static_cast<InMemoryStorage *>(storage_); - { - auto engine_guard = std::unique_lock(storage_->engine_lock_); - uint64_t mark_timestamp = storage_->timestamp_; - // Take garbage_undo_buffers lock while holding the engine lock to make - // sure that entries are sorted by mark timestamp in the list. - mem_storage->garbage_undo_buffers_.WithLock([&](auto &garbage_undo_buffers) { - // Release engine lock because we don't have to hold it anymore and - // emplace back could take a long time. - engine_guard.unlock(); - - garbage_undo_buffers.emplace_back(mark_timestamp, std::move(transaction_.deltas), - std::move(transaction_.commit_timestamp)); - }); - - /// We MUST unlink (aka. remove) entries in indexes and constraints - /// before we unlink (aka. remove) vertices from storage - /// this is because they point into vertices skip_list - - // INDICES - for (auto const &[label, vertices] : label_cleanup) { - storage_->indices_.AbortEntries(label, vertices, transaction_.start_timestamp); - } - for (auto const &[label, prop_vertices] : label_property_cleanup) { - storage_->indices_.AbortEntries(label, prop_vertices, transaction_.start_timestamp); - } - for (auto const &[property, prop_vertices] : property_cleanup) { - storage_->indices_.AbortEntries(property, prop_vertices, transaction_.start_timestamp); - } - - // VERTICES - { - auto vertices_acc = mem_storage->vertices_.access(); - for (auto gid : my_deleted_vertices) { - vertices_acc.remove(gid); + case PreviousPtr::Type::DELTA: + // pointer probably couldn't be set because allocation failed + case PreviousPtr::Type::NULLPTR: + break; } } - // EDGES { - auto edges_acc = mem_storage->edges_.access(); - for (auto gid : my_deleted_edges) { - edges_acc.remove(gid); + auto engine_guard = std::unique_lock(storage_->engine_lock_); + uint64_t mark_timestamp = storage_->timestamp_; + // Take garbage_undo_buffers lock while holding the engine lock to make + // sure that entries are sorted by mark timestamp in the list. + mem_storage->garbage_undo_buffers_.WithLock([&](auto &garbage_undo_buffers) { + // Release engine lock because we don't have to hold it anymore and + // emplace back could take a long time. + engine_guard.unlock(); + + garbage_undo_buffers.emplace_back(mark_timestamp, std::move(transaction_.deltas), + std::move(transaction_.commit_timestamp)); + }); + + /// We MUST unlink (aka. remove) entries in indexes and constraints + /// before we unlink (aka. remove) vertices from storage + /// this is because they point into vertices skip_list + + // INDICES + for (auto const &[label, vertices] : label_cleanup) { + storage_->indices_.AbortEntries(label, vertices, transaction_.start_timestamp); + } + for (auto const &[label, prop_vertices] : label_property_cleanup) { + storage_->indices_.AbortEntries(label, prop_vertices, transaction_.start_timestamp); + } + for (auto const &[property, prop_vertices] : property_cleanup) { + storage_->indices_.AbortEntries(property, prop_vertices, transaction_.start_timestamp); + } + + // VERTICES + { + auto vertices_acc = mem_storage->vertices_.access(); + for (auto gid : my_deleted_vertices) { + vertices_acc.remove(gid); + } + } + + // EDGES + { + auto edges_acc = mem_storage->edges_.access(); + for (auto gid : my_deleted_edges) { + edges_acc.remove(gid); + } } } } @@ -1121,7 +1224,7 @@ void InMemoryStorage::InMemoryAccessor::FinalizeTransaction() { auto *mem_storage = static_cast<InMemoryStorage *>(storage_); mem_storage->commit_log_->MarkFinished(*commit_timestamp_); - if (!transaction_.deltas.use().empty()) { + if (!transaction_.deltas.empty()) { // Only hand over delta to be GC'ed if there was any deltas mem_storage->committed_transactions_.WithLock([&](auto &committed_transactions) { // using mark of 0 as GC will assign a mark_timestamp after unlinking @@ -1462,7 +1565,7 @@ void InMemoryStorage::CollectGarbage(std::unique_lock<utils::ResourceLock> main_ // chain in a broken state. // The chain can be only read without taking any locks. - for (Delta &delta : linked_entry->deltas_.use()) { + for (Delta &delta : linked_entry->deltas_) { while (true) { auto prev = delta.prev.Get(); switch (prev.type) { @@ -1781,7 +1884,7 @@ bool InMemoryStorage::AppendToWal(const Transaction &transaction, uint64_t final // 1. Process all Vertex deltas and store all operations that create vertices // and modify vertex data. - for (const auto &delta : transaction.deltas.use()) { + for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) continue; @@ -1804,7 +1907,7 @@ bool InMemoryStorage::AppendToWal(const Transaction &transaction, uint64_t final }); } // 2. Process all Vertex deltas and store all operations that create edges. - for (const auto &delta : transaction.deltas.use()) { + for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) continue; @@ -1826,7 +1929,7 @@ bool InMemoryStorage::AppendToWal(const Transaction &transaction, uint64_t final }); } // 3. Process all Edge deltas and store all operations that modify edge data. - for (const auto &delta : transaction.deltas.use()) { + for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::EDGE) continue; @@ -1848,7 +1951,7 @@ bool InMemoryStorage::AppendToWal(const Transaction &transaction, uint64_t final }); } // 4. Process all Vertex deltas and store all operations that delete edges. - for (const auto &delta : transaction.deltas.use()) { + for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) continue; @@ -1870,7 +1973,7 @@ bool InMemoryStorage::AppendToWal(const Transaction &transaction, uint64_t final }); } // 5. Process all Vertex deltas and store all operations that delete vertices. - for (const auto &delta : transaction.deltas.use()) { + for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) continue; @@ -1894,7 +1997,7 @@ bool InMemoryStorage::AppendToWal(const Transaction &transaction, uint64_t final }; // Handle MVCC deltas - if (!transaction.deltas.use().empty()) { + if (!transaction.deltas.empty()) { append_deltas([&](const Delta &delta, const auto &parent, uint64_t timestamp) { wal_file_->AppendDelta(delta, parent, timestamp); repl_storage_state_.AppendDelta(delta, parent, timestamp); diff --git a/src/storage/v2/inmemory/storage.hpp b/src/storage/v2/inmemory/storage.hpp index 6f8806c26..26abe4faf 100644 --- a/src/storage/v2/inmemory/storage.hpp +++ b/src/storage/v2/inmemory/storage.hpp @@ -302,6 +302,9 @@ class InMemoryStorage final : public Storage { /// @throw std::bad_alloc Result<EdgeAccessor> CreateEdgeEx(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, storage::Gid gid); + /// Duiring commit, in some cases you do not need to hand over deltas to GC + /// in those cases this method is a light weight way to unlink and discard our deltas + void FastDiscardOfDeltas(uint64_t oldest_active_timestamp, std::unique_lock<std::mutex> gc_guard); SalientConfig::Items config_; }; @@ -429,16 +432,15 @@ class InMemoryStorage final : public Storage { utils::Scheduler gc_runner_; std::mutex gc_lock_; - using BondPmrLd = Bond<utils::pmr::list<Delta>>; struct GCDeltas { - GCDeltas(uint64_t mark_timestamp, BondPmrLd deltas, std::unique_ptr<std::atomic<uint64_t>> commit_timestamp) + GCDeltas(uint64_t mark_timestamp, std::deque<Delta> deltas, std::unique_ptr<std::atomic<uint64_t>> commit_timestamp) : mark_timestamp_{mark_timestamp}, deltas_{std::move(deltas)}, commit_timestamp_{std::move(commit_timestamp)} {} GCDeltas(GCDeltas &&) = default; GCDeltas &operator=(GCDeltas &&) = default; uint64_t mark_timestamp_{}; //!< a timestamp no active transaction currently has - BondPmrLd deltas_; //!< the deltas that need cleaning + std::deque<Delta> deltas_; //!< the deltas that need cleaning std::unique_ptr<std::atomic<uint64_t>> commit_timestamp_{}; //!< the timestamp the deltas are pointing at }; diff --git a/src/storage/v2/inmemory/unique_constraints.cpp b/src/storage/v2/inmemory/unique_constraints.cpp index 667d0229f..e08965eab 100644 --- a/src/storage/v2/inmemory/unique_constraints.cpp +++ b/src/storage/v2/inmemory/unique_constraints.cpp @@ -80,7 +80,7 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c case Delta::Action::SET_PROPERTY: { auto pos = FindPropertyPosition(property_array, delta->property.key); if (pos) { - current_value_equal_to_value[*pos] = delta->property.value == value_array[*pos]; + current_value_equal_to_value[*pos] = *delta->property.value == value_array[*pos]; } break; } @@ -96,14 +96,14 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c break; } case Delta::Action::ADD_LABEL: { - if (delta->label == label) { + if (delta->label.value == label) { MG_ASSERT(!has_label, "Invalid database state!"); has_label = true; break; } } case Delta::Action::REMOVE_LABEL: { - if (delta->label == label) { + if (delta->label.value == label) { MG_ASSERT(has_label, "Invalid database state!"); has_label = false; break; @@ -190,13 +190,13 @@ bool AnyVersionHasLabelProperty(const Vertex &vertex, LabelId label, const std:: } switch (delta->action) { case Delta::Action::ADD_LABEL: - if (delta->label == label) { + if (delta->label.value == label) { MG_ASSERT(!has_label, "Invalid database state!"); has_label = true; } break; case Delta::Action::REMOVE_LABEL: - if (delta->label == label) { + if (delta->label.value == label) { MG_ASSERT(has_label, "Invalid database state!"); has_label = false; } @@ -204,7 +204,7 @@ bool AnyVersionHasLabelProperty(const Vertex &vertex, LabelId label, const std:: case Delta::Action::SET_PROPERTY: { auto pos = FindPropertyPosition(property_array, delta->property.key); if (pos) { - current_value_equal_to_value[*pos] = delta->property.value == values[*pos]; + current_value_equal_to_value[*pos] = *delta->property.value == values[*pos]; } break; } diff --git a/src/storage/v2/mvcc.hpp b/src/storage/v2/mvcc.hpp index f046a9b01..1cf057d9d 100644 --- a/src/storage/v2/mvcc.hpp +++ b/src/storage/v2/mvcc.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -114,8 +114,8 @@ inline Delta *CreateDeleteObjectDelta(Transaction *transaction) { return nullptr; } transaction->EnsureCommitTimestampExists(); - return &transaction->deltas.use().emplace_back(Delta::DeleteObjectTag(), transaction->commit_timestamp.get(), - transaction->command_id); + return &transaction->deltas.emplace_back(Delta::DeleteObjectTag(), transaction->commit_timestamp.get(), + transaction->command_id); } inline Delta *CreateDeleteObjectDelta(Transaction *transaction, std::list<Delta> *deltas) { @@ -133,19 +133,19 @@ inline Delta *CreateDeleteDeserializedObjectDelta(Transaction *transaction, std: transaction->EnsureCommitTimestampExists(); // Should use utils::DecodeFixed64(ts.c_str()) once we will move to RocksDB real timestamps uint64_t ts_id = utils::ParseStringToUint64(ts); - return &transaction->deltas.use().emplace_back(Delta::DeleteDeserializedObjectTag(), ts_id, old_disk_key); + return &transaction->deltas.emplace_back(Delta::DeleteDeserializedObjectTag(), ts_id, std::move(old_disk_key)); } inline Delta *CreateDeleteDeserializedObjectDelta(std::list<Delta> *deltas, std::optional<std::string> old_disk_key, std::string &&ts) { // Should use utils::DecodeFixed64(ts.c_str()) once we will move to RocksDB real timestamps uint64_t ts_id = utils::ParseStringToUint64(ts); - return &deltas->emplace_back(Delta::DeleteDeserializedObjectTag(), ts_id, old_disk_key); + return &deltas->emplace_back(Delta::DeleteDeserializedObjectTag(), ts_id, std::move(old_disk_key)); } inline Delta *CreateDeleteDeserializedIndexObjectDelta(std::list<Delta> &deltas, std::optional<std::string> old_disk_key, const uint64_t ts) { - return &deltas.emplace_back(Delta::DeleteDeserializedObjectTag(), ts, old_disk_key); + return &deltas.emplace_back(Delta::DeleteDeserializedObjectTag(), ts, std::move(old_disk_key)); } /// TODO: what if in-memory analytical @@ -165,8 +165,8 @@ inline void CreateAndLinkDelta(Transaction *transaction, TObj *object, Args &&.. return; } transaction->EnsureCommitTimestampExists(); - auto delta = &transaction->deltas.use().emplace_back(std::forward<Args>(args)..., transaction->commit_timestamp.get(), - transaction->command_id); + auto delta = &transaction->deltas.emplace_back(std::forward<Args>(args)..., transaction->commit_timestamp.get(), + transaction->command_id); // The operations are written in such order so that both `next` and `prev` // chains are valid at all times. The chains must be valid at all times diff --git a/src/storage/v2/property_value.hpp b/src/storage/v2/property_value.hpp index 727c75377..e48be008a 100644 --- a/src/storage/v2/property_value.hpp +++ b/src/storage/v2/property_value.hpp @@ -57,38 +57,24 @@ class PropertyValue { PropertyValue() : type_(Type::Null) {} // constructors for primitive types - explicit PropertyValue(const bool value) : type_(Type::Bool) { bool_v = value; } - explicit PropertyValue(const int value) : type_(Type::Int) { int_v = value; } - explicit PropertyValue(const int64_t value) : type_(Type::Int) { int_v = value; } - explicit PropertyValue(const double value) : type_(Type::Double) { double_v = value; } - explicit PropertyValue(const TemporalData value) : type_{Type::TemporalData} { temporal_data_v = value; } + explicit PropertyValue(const bool value) : bool_v{.val_ = value} {} + explicit PropertyValue(const int value) : int_v{.val_ = value} {} + explicit PropertyValue(const int64_t value) : int_v{.val_ = value} {} + explicit PropertyValue(const double value) : double_v{.val_ = value} {} + explicit PropertyValue(const TemporalData value) : temporal_data_v{.val_ = value} {} // copy constructors for non-primitive types /// @throw std::bad_alloc - explicit PropertyValue(const std::string &value) : type_(Type::String) { new (&string_v) std::string(value); } + explicit PropertyValue(std::string value) : string_v{.val_ = std::move(value)} {} /// @throw std::bad_alloc /// @throw std::length_error if length of value exceeds /// std::string::max_length(). - explicit PropertyValue(const char *value) : type_(Type::String) { new (&string_v) std::string(value); } + explicit PropertyValue(std::string_view value) : string_v{.val_ = std::string(value)} {} + explicit PropertyValue(char const *value) : string_v{.val_ = std::string(value)} {} /// @throw std::bad_alloc - explicit PropertyValue(const std::vector<PropertyValue> &value) : type_(Type::List) { - new (&list_v) std::vector<PropertyValue>(value); - } + explicit PropertyValue(std::vector<PropertyValue> value) : list_v{.val_ = std::move(value)} {} /// @throw std::bad_alloc - explicit PropertyValue(const std::map<std::string, PropertyValue> &value) : type_(Type::Map) { - new (&map_v) std::map<std::string, PropertyValue>(value); - } - - // move constructors for non-primitive types - explicit PropertyValue(std::string &&value) noexcept : type_(Type::String) { - new (&string_v) std::string(std::move(value)); - } - explicit PropertyValue(std::vector<PropertyValue> &&value) noexcept : type_(Type::List) { - new (&list_v) std::vector<PropertyValue>(std::move(value)); - } - explicit PropertyValue(std::map<std::string, PropertyValue> &&value) noexcept : type_(Type::Map) { - new (&map_v) std::map<std::string, PropertyValue>(std::move(value)); - } + explicit PropertyValue(std::map<std::string, PropertyValue> value) : map_v{.val_ = std::move(value)} {} // copy constructor /// @throw std::bad_alloc @@ -126,21 +112,21 @@ class PropertyValue { if (type_ != Type::Bool) [[unlikely]] { throw PropertyValueException("The value isn't a bool!"); } - return bool_v; + return bool_v.val_; } /// @throw PropertyValueException if value isn't of correct type. int64_t ValueInt() const { if (type_ != Type::Int) [[unlikely]] { throw PropertyValueException("The value isn't an int!"); } - return int_v; + return int_v.val_; } /// @throw PropertyValueException if value isn't of correct type. double ValueDouble() const { if (type_ != Type::Double) [[unlikely]] { throw PropertyValueException("The value isn't a double!"); } - return double_v; + return double_v.val_; } /// @throw PropertyValueException if value isn't of correct type. @@ -149,7 +135,7 @@ class PropertyValue { throw PropertyValueException("The value isn't a temporal data!"); } - return temporal_data_v; + return temporal_data_v.val_; } // const value getters for non-primitive types @@ -158,7 +144,7 @@ class PropertyValue { if (type_ != Type::String) [[unlikely]] { throw PropertyValueException("The value isn't a string!"); } - return string_v; + return string_v.val_; } /// @throw PropertyValueException if value isn't of correct type. @@ -166,7 +152,7 @@ class PropertyValue { if (type_ != Type::List) [[unlikely]] { throw PropertyValueException("The value isn't a list!"); } - return list_v; + return list_v.val_; } /// @throw PropertyValueException if value isn't of correct type. @@ -174,7 +160,7 @@ class PropertyValue { if (type_ != Type::Map) [[unlikely]] { throw PropertyValueException("The value isn't a map!"); } - return map_v; + return map_v.val_; } // reference value getters for non-primitive types @@ -183,7 +169,7 @@ class PropertyValue { if (type_ != Type::String) [[unlikely]] { throw PropertyValueException("The value isn't a string!"); } - return string_v; + return string_v.val_; } /// @throw PropertyValueException if value isn't of correct type. @@ -191,7 +177,7 @@ class PropertyValue { if (type_ != Type::List) [[unlikely]] { throw PropertyValueException("The value isn't a list!"); } - return list_v; + return list_v.val_; } /// @throw PropertyValueException if value isn't of correct type. @@ -199,23 +185,45 @@ class PropertyValue { if (type_ != Type::Map) [[unlikely]] { throw PropertyValueException("The value isn't a map!"); } - return map_v; + return map_v.val_; } private: void DestroyValue() noexcept; + // NOTE: this may look strange but it is for better data layout + // https://eel.is/c++draft/class.union#general-note-1 union { - bool bool_v; - int64_t int_v; - double double_v; - std::string string_v; - std::vector<PropertyValue> list_v; - std::map<std::string, PropertyValue> map_v; - TemporalData temporal_data_v; + Type type_; + struct { + Type type_ = Type::Bool; + bool val_; + } bool_v; + struct { + Type type_ = Type::Int; + int64_t val_; + } int_v; + struct { + Type type_ = Type::Double; + double val_; + } double_v; + struct { + Type type_ = Type::String; + std::string val_; + } string_v; + struct { + Type type_ = Type::List; + std::vector<PropertyValue> val_; + } list_v; + struct { + Type type_ = Type::Map; + std::map<std::string, PropertyValue> val_; + } map_v; + struct { + Type type_ = Type::TemporalData; + TemporalData val_; + } temporal_data_v; }; - - Type type_; }; // stream output @@ -340,25 +348,25 @@ inline PropertyValue::PropertyValue(const PropertyValue &other) : type_(other.ty case Type::Null: return; case Type::Bool: - this->bool_v = other.bool_v; + this->bool_v.val_ = other.bool_v.val_; return; case Type::Int: - this->int_v = other.int_v; + this->int_v.val_ = other.int_v.val_; return; case Type::Double: - this->double_v = other.double_v; + this->double_v.val_ = other.double_v.val_; return; case Type::String: - new (&string_v) std::string(other.string_v); + new (&string_v.val_) std::string(other.string_v.val_); return; case Type::List: - new (&list_v) std::vector<PropertyValue>(other.list_v); + new (&list_v.val_) std::vector<PropertyValue>(other.list_v.val_); return; case Type::Map: - new (&map_v) std::map<std::string, PropertyValue>(other.map_v); + new (&map_v.val_) std::map<std::string, PropertyValue>(other.map_v.val_); return; case Type::TemporalData: - this->temporal_data_v = other.temporal_data_v; + this->temporal_data_v.val_ = other.temporal_data_v.val_; return; } } @@ -368,28 +376,28 @@ inline PropertyValue::PropertyValue(PropertyValue &&other) noexcept : type_(std: case Type::Null: break; case Type::Bool: - bool_v = other.bool_v; + bool_v.val_ = other.bool_v.val_; break; case Type::Int: - int_v = other.int_v; + int_v.val_ = other.int_v.val_; break; case Type::Double: - double_v = other.double_v; + double_v.val_ = other.double_v.val_; break; case Type::String: - std::construct_at(&string_v, std::move(other.string_v)); - std::destroy_at(&other.string_v); + std::construct_at(&string_v.val_, std::move(other.string_v.val_)); + std::destroy_at(&other.string_v.val_); break; case Type::List: - std::construct_at(&list_v, std::move(other.list_v)); - std::destroy_at(&other.list_v); + std::construct_at(&list_v.val_, std::move(other.list_v.val_)); + std::destroy_at(&other.list_v.val_); break; case Type::Map: - std::construct_at(&map_v, std::move(other.map_v)); - std::destroy_at(&other.map_v); + std::construct_at(&map_v.val_, std::move(other.map_v.val_)); + std::destroy_at(&other.map_v.val_); break; case Type::TemporalData: - temporal_data_v = other.temporal_data_v; + temporal_data_v.val_ = other.temporal_data_v.val_; break; } } @@ -404,25 +412,25 @@ inline PropertyValue &PropertyValue::operator=(const PropertyValue &other) { case Type::Null: break; case Type::Bool: - this->bool_v = other.bool_v; + this->bool_v.val_ = other.bool_v.val_; break; case Type::Int: - this->int_v = other.int_v; + this->int_v.val_ = other.int_v.val_; break; case Type::Double: - this->double_v = other.double_v; + this->double_v.val_ = other.double_v.val_; break; case Type::String: - new (&string_v) std::string(other.string_v); + new (&string_v.val_) std::string(other.string_v.val_); break; case Type::List: - new (&list_v) std::vector<PropertyValue>(other.list_v); + new (&list_v.val_) std::vector<PropertyValue>(other.list_v.val_); break; case Type::Map: - new (&map_v) std::map<std::string, PropertyValue>(other.map_v); + new (&map_v.val_) std::map<std::string, PropertyValue>(other.map_v.val_); break; case Type::TemporalData: - this->temporal_data_v = other.temporal_data_v; + this->temporal_data_v.val_ = other.temporal_data_v.val_; break; } @@ -438,28 +446,28 @@ inline PropertyValue &PropertyValue::operator=(PropertyValue &&other) noexcept { case Type::Null: break; case Type::Bool: - bool_v = other.bool_v; + bool_v.val_ = other.bool_v.val_; break; case Type::Int: - int_v = other.int_v; + int_v.val_ = other.int_v.val_; break; case Type::Double: - double_v = other.double_v; + double_v.val_ = other.double_v.val_; break; case Type::String: - string_v = std::move(other.string_v); - std::destroy_at(&other.string_v); + string_v.val_ = std::move(other.string_v.val_); + std::destroy_at(&other.string_v.val_); break; case Type::List: - list_v = std::move(other.list_v); - std::destroy_at(&other.list_v); + list_v.val_ = std::move(other.list_v.val_); + std::destroy_at(&other.list_v.val_); break; case Type::Map: - map_v = std::move(other.map_v); - std::destroy_at(&other.map_v); + map_v.val_ = std::move(other.map_v.val_); + std::destroy_at(&other.map_v.val_); break; case Type::TemporalData: - temporal_data_v = other.temporal_data_v; + temporal_data_v.val_ = other.temporal_data_v.val_; break; } other.type_ = Type::Null; @@ -482,13 +490,13 @@ inline void PropertyValue::DestroyValue() noexcept { // destructor for non primitive types since we used placement new case Type::String: - std::destroy_at(&string_v); + std::destroy_at(&string_v.val_); return; case Type::List: - std::destroy_at(&list_v); + std::destroy_at(&list_v.val_); return; case Type::Map: - std::destroy_at(&map_v); + std::destroy_at(&map_v.val_); return; } } diff --git a/src/storage/v2/transaction.hpp b/src/storage/v2/transaction.hpp index 2bdd68a94..9f973cbf0 100644 --- a/src/storage/v2/transaction.hpp +++ b/src/storage/v2/transaction.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -13,7 +13,6 @@ #include <atomic> #include <limits> -#include <list> #include <memory> #include "utils/memory.hpp" @@ -39,7 +38,6 @@ namespace memgraph::storage { const uint64_t kTimestampInitialId = 0; const uint64_t kTransactionInitialId = 1ULL << 63U; -using PmrListDelta = utils::pmr::list<Delta>; struct Transaction { Transaction(uint64_t transaction_id, uint64_t start_timestamp, IsolationLevel isolation_level, @@ -47,7 +45,6 @@ struct Transaction { : transaction_id(transaction_id), start_timestamp(start_timestamp), command_id(0), - deltas(0), md_deltas(utils::NewDeleteResource()), must_abort(false), isolation_level(isolation_level), @@ -91,7 +88,7 @@ struct Transaction { std::unique_ptr<std::atomic<uint64_t>> commit_timestamp{}; uint64_t command_id{}; - Bond<PmrListDelta> deltas; + std::deque<Delta> deltas; utils::pmr::list<MetadataDelta> md_deltas; bool must_abort{}; IsolationLevel isolation_level{}; diff --git a/src/storage/v2/vertex_info_helpers.hpp b/src/storage/v2/vertex_info_helpers.hpp index 27ecb8398..7c8e2a652 100644 --- a/src/storage/v2/vertex_info_helpers.hpp +++ b/src/storage/v2/vertex_info_helpers.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -76,13 +76,13 @@ inline auto HasLabel_ActionMethod(bool &has_label, LabelId label) { // clang-format off return utils::Overloaded{ ActionMethod<REMOVE_LABEL>([&, label](Delta const &delta) { - if (delta.label == label) { + if (delta.label.value == label) { MG_ASSERT(has_label, "Invalid database state!"); has_label = false; } }), ActionMethod<ADD_LABEL>([&, label](Delta const &delta) { - if (delta.label == label) { + if (delta.label.value == label) { MG_ASSERT(!has_label, "Invalid database state!"); has_label = true; } @@ -96,14 +96,14 @@ inline auto Labels_ActionMethod(std::vector<LabelId> &labels) { // clang-format off return utils::Overloaded{ ActionMethod<REMOVE_LABEL>([&](Delta const &delta) { - auto it = std::find(labels.begin(), labels.end(), delta.label); + auto it = std::find(labels.begin(), labels.end(), delta.label.value); DMG_ASSERT(it != labels.end(), "Invalid database state!"); *it = labels.back(); labels.pop_back(); }), ActionMethod<ADD_LABEL>([&](Delta const &delta) { - DMG_ASSERT(std::find(labels.begin(), labels.end(), delta.label) == labels.end(), "Invalid database state!"); - labels.emplace_back(delta.label); + DMG_ASSERT(std::find(labels.begin(), labels.end(), delta.label.value) == labels.end(), "Invalid database state!"); + labels.emplace_back(delta.label.value); }) }; // clang-format on @@ -113,7 +113,7 @@ inline auto PropertyValue_ActionMethod(PropertyValue &value, PropertyId property using enum Delta::Action; return ActionMethod<SET_PROPERTY>([&, property](Delta const &delta) { if (delta.property.key == property) { - value = delta.property.value; + value = *delta.property.value; } }); } @@ -121,7 +121,7 @@ inline auto PropertyValue_ActionMethod(PropertyValue &value, PropertyId property inline auto PropertyValueMatch_ActionMethod(bool &match, PropertyId property, PropertyValue const &value) { using enum Delta::Action; return ActionMethod<SET_PROPERTY>([&, property](Delta const &delta) { - if (delta.property.key == property) match = (value == delta.property.value); + if (delta.property.key == property) match = (value == *delta.property.value); }); } @@ -130,15 +130,15 @@ inline auto Properties_ActionMethod(std::map<PropertyId, PropertyValue> &propert return ActionMethod<SET_PROPERTY>([&](Delta const &delta) { auto it = properties.find(delta.property.key); if (it != properties.end()) { - if (delta.property.value.IsNull()) { + if (delta.property.value->IsNull()) { // remove the property properties.erase(it); } else { // set the value - it->second = delta.property.value; + it->second = *delta.property.value; } - } else if (!delta.property.value.IsNull()) { - properties.emplace(delta.property.key, delta.property.value); + } else if (!delta.property.value->IsNull()) { + properties.emplace(delta.property.key, *delta.property.value); } }); } diff --git a/src/utils/disk_utils.hpp b/src/utils/disk_utils.hpp index 34bc704f6..c4b9accd6 100644 --- a/src/utils/disk_utils.hpp +++ b/src/utils/disk_utils.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -21,7 +21,7 @@ inline std::optional<std::string> GetOldDiskKeyOrNull(storage::Delta *head) { head = head->next; } if (head->action == storage::Delta::Action::DELETE_DESERIALIZED_OBJECT) { - return head->old_disk_key; + return head->old_disk_key.value; } return std::nullopt; } diff --git a/src/utils/string.hpp b/src/utils/string.hpp index 8593fc57f..e5c4c4f3c 100644 --- a/src/utils/string.hpp +++ b/src/utils/string.hpp @@ -338,6 +338,13 @@ inline uint64_t ParseStringToUint64(const std::string_view s) { throw utils::ParseException(s); } +inline uint32_t ParseStringToUint32(const std::string_view s) { + if (uint32_t value = 0; std::from_chars(s.data(), s.data() + s.size(), value).ec == std::errc{}) { + return value; + } + throw utils::ParseException(s); +} + /** * Parse a double floating point value from a string using classic locale. * Note, the current implementation copies the given string which may perform a diff --git a/tests/unit/storage_v2_property_store.cpp b/tests/unit/storage_v2_property_store.cpp index 59b38c632..683146f2d 100644 --- a/tests/unit/storage_v2_property_store.cpp +++ b/tests/unit/storage_v2_property_store.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -419,9 +419,9 @@ TEST(PropertyStore, IntEncoding) { {memgraph::storage::PropertyId::FromUint(1048576UL), memgraph::storage::PropertyValue(1048576L)}, {memgraph::storage::PropertyId::FromUint(std::numeric_limits<uint32_t>::max()), memgraph::storage::PropertyValue(std::numeric_limits<int32_t>::max())}, - {memgraph::storage::PropertyId::FromUint(4294967296UL), memgraph::storage::PropertyValue(4294967296L)}, - {memgraph::storage::PropertyId::FromUint(137438953472UL), memgraph::storage::PropertyValue(137438953472L)}, - {memgraph::storage::PropertyId::FromUint(std::numeric_limits<uint64_t>::max()), + {memgraph::storage::PropertyId::FromUint(1048577UL), memgraph::storage::PropertyValue(4294967296L)}, + {memgraph::storage::PropertyId::FromUint(1048578UL), memgraph::storage::PropertyValue(137438953472L)}, + {memgraph::storage::PropertyId::FromUint(std::numeric_limits<uint32_t>::max()), memgraph::storage::PropertyValue(std::numeric_limits<int64_t>::max())}}; memgraph::storage::PropertyStore props; diff --git a/tests/unit/storage_v2_wal_file.cpp b/tests/unit/storage_v2_wal_file.cpp index a67b09305..07a35d754 100644 --- a/tests/unit/storage_v2_wal_file.cpp +++ b/tests/unit/storage_v2_wal_file.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -154,8 +154,8 @@ class DeltaGenerator final { void Finalize(bool append_transaction_end = true) { auto commit_timestamp = gen_->timestamp_++; - if (transaction_.deltas.use().empty()) return; - for (const auto &delta : transaction_.deltas.use()) { + if (transaction_.deltas.empty()) return; + for (const auto &delta : transaction_.deltas) { auto owner = delta.prev.Get(); while (owner.type == memgraph::storage::PreviousPtr::Type::DELTA) { owner = owner.delta->prev.Get(); @@ -171,7 +171,7 @@ class DeltaGenerator final { if (append_transaction_end) { gen_->wal_file_.AppendTransactionEnd(commit_timestamp); if (gen_->valid_) { - gen_->UpdateStats(commit_timestamp, transaction_.deltas.use().size() + 1); + gen_->UpdateStats(commit_timestamp, transaction_.deltas.size() + 1); for (auto &data : data_) { if (data.type == memgraph::storage::durability::WalDeltaData::Type::VERTEX_SET_PROPERTY) { // We need to put the final property value into the SET_PROPERTY From c15b62a88d2983122cd19531d3a09b376bbbaea3 Mon Sep 17 00:00:00 2001 From: Antonio Filipovic <61245998+antoniofilipovic@users.noreply.github.com> Date: Wed, 7 Feb 2024 11:20:47 +0100 Subject: [PATCH 2/4] HA: Disable replication from old main (#1674) --- src/auth/auth.cpp | 15 +- src/auth/replication_handlers.cpp | 32 +- src/auth/replication_handlers.hpp | 10 +- src/auth/rpc.cpp | 4 + src/auth/rpc.hpp | 20 +- src/coordination/coordinator_client.cpp | 24 +- src/coordination/coordinator_data.cpp | 113 ++++- src/coordination/coordinator_handlers.cpp | 41 +- src/coordination/coordinator_instance.cpp | 18 +- src/coordination/coordinator_rpc.cpp | 2 + src/coordination/coordinator_server.cpp | 2 +- .../coordination/coordinator_client.hpp | 10 +- .../include/coordination/coordinator_data.hpp | 8 +- .../coordination/coordinator_handlers.hpp | 2 + .../coordination/coordinator_instance.hpp | 17 +- .../include/coordination/coordinator_rpc.hpp | 17 +- ...gister_main_replica_coordinator_status.hpp | 1 + src/dbms/dbms_handler.cpp | 16 +- src/dbms/dbms_handler.hpp | 11 + src/dbms/inmemory/replication_handlers.cpp | 155 ++++-- src/dbms/inmemory/replication_handlers.hpp | 24 +- src/dbms/replication_client.cpp | 0 src/dbms/replication_handler.cpp | 0 src/dbms/replication_handler.hpp | 0 src/dbms/replication_handlers.cpp | 26 +- src/dbms/replication_handlers.hpp | 14 +- src/dbms/rpc.hpp | 16 +- src/dbms/utils.hpp | 0 src/query/interpreter.cpp | 8 +- src/query/replication_query_handler.hpp | 16 +- src/replication/CMakeLists.txt | 2 +- .../include/replication/messages.hpp | 0 .../replication/replication_client.hpp | 5 +- src/replication/include/replication/state.hpp | 22 +- .../include/replication/status.hpp | 7 +- src/replication/messages.cpp | 0 src/replication/replication_server.cpp | 2 +- src/replication/state.cpp | 74 ++- src/replication/status.cpp | 42 +- .../CMakeLists.txt | 1 + src/replication_coordination_glue/handler.hpp | 41 ++ .../messages.cpp | 39 +- .../messages.hpp | 36 +- src/replication_handler/CMakeLists.txt | 2 +- .../replication_handler.hpp | 48 +- .../system_replication.hpp | 18 +- .../replication_handler/system_rpc.hpp | 11 +- .../replication_handler.cpp | 81 +-- .../system_replication.cpp | 37 +- src/replication_handler/system_rpc.cpp | 10 +- src/rpc/client.hpp | 1 - src/rpc/version.hpp | 7 +- .../v2/inmemory/replication/recovery.cpp | 22 +- .../v2/inmemory/replication/recovery.hpp | 7 +- src/storage/v2/inmemory/storage.cpp | 1 + .../v2/replication/replication_client.cpp | 86 ++-- .../v2/replication/replication_client.hpp | 8 +- src/storage/v2/replication/rpc.cpp | 12 + src/storage/v2/replication/rpc.hpp | 25 +- src/system/action.cpp | 2 +- src/system/include/system/action.hpp | 2 +- src/system/include/system/transaction.hpp | 2 +- src/utils/typeinfo.hpp | 4 + .../CMakeLists.txt | 1 + .../automatic_failover.py | 85 ++-- .../not_replicate_from_old_main.py | 117 +++++ .../workloads.yaml | 4 + tests/e2e/interactive_mg_runner.py | 5 + tests/unit/replication_persistence_helper.cpp | 39 ++ tests/unit/storage_v2_replication.cpp | 462 ++++++++++-------- 70 files changed, 1419 insertions(+), 573 deletions(-) create mode 100644 src/dbms/replication_client.cpp create mode 100644 src/dbms/replication_handler.cpp create mode 100644 src/dbms/replication_handler.hpp create mode 100644 src/dbms/utils.hpp create mode 100644 src/replication/include/replication/messages.hpp create mode 100644 src/replication/messages.cpp create mode 100644 src/replication_coordination_glue/handler.hpp create mode 100644 tests/e2e/high_availability_experimental/not_replicate_from_old_main.py diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index 405c04c45..3bb7648db 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -57,16 +57,19 @@ struct UpdateAuthData : memgraph::system::ISystemAction { void DoDurability() override { /* Done during Auth execution */ } - bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch, + bool DoReplication(replication::ReplicationClient &client, const utils::UUID &main_uuid, + replication::ReplicationEpoch const &epoch, memgraph::system::Transaction const &txn) const override { auto check_response = [](const replication::UpdateAuthDataRes &response) { return response.success; }; if (user_) { return client.SteamAndFinalizeDelta<replication::UpdateAuthDataRpc>( - check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), *user_); + check_response, main_uuid, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), + *user_); } if (role_) { return client.SteamAndFinalizeDelta<replication::UpdateAuthDataRpc>( - check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), *role_); + check_response, main_uuid, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), + *role_); } // Should never get here MG_ASSERT(false, "Trying to update auth data that is not a user nor a role"); @@ -88,7 +91,8 @@ struct DropAuthData : memgraph::system::ISystemAction { void DoDurability() override { /* Done during Auth execution */ } - bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch, + bool DoReplication(replication::ReplicationClient &client, const utils::UUID &main_uuid, + replication::ReplicationEpoch const &epoch, memgraph::system::Transaction const &txn) const override { auto check_response = [](const replication::DropAuthDataRes &response) { return response.success; }; @@ -102,7 +106,8 @@ struct DropAuthData : memgraph::system::ISystemAction { break; } return client.SteamAndFinalizeDelta<replication::DropAuthDataRpc>( - check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), type, name_); + check_response, main_uuid, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), + type, name_); } void PostReplication(replication::RoleMainData &mainData) const override {} diff --git a/src/auth/replication_handlers.cpp b/src/auth/replication_handlers.cpp index 8ee0cd7f3..9cd0c6a24 100644 --- a/src/auth/replication_handlers.cpp +++ b/src/auth/replication_handlers.cpp @@ -17,8 +17,15 @@ namespace memgraph::auth { +void LogWrongMain(const std::optional<utils::UUID> ¤t_main_uuid, const utils::UUID &main_req_id, + std::string_view rpc_req) { + spdlog::error(fmt::format("Received {} with main_id: {} != current_main_uuid: {}", rpc_req, std::string(main_req_id), + current_main_uuid.has_value() ? std::string(current_main_uuid.value()) : "")); +} + #ifdef MG_ENTERPRISE -void UpdateAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth, +void UpdateAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, + const std::optional<utils::UUID> ¤t_main_uuid, auth::SynchedAuth &auth, slk::Reader *req_reader, slk::Builder *res_builder) { replication::UpdateAuthDataReq req; memgraph::slk::Load(&req, req_reader); @@ -26,6 +33,12 @@ void UpdateAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system using memgraph::replication::UpdateAuthDataRes; UpdateAuthDataRes res(false); + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, replication::UpdateAuthDataReq::kType.name); + memgraph::slk::Save(res, res_builder); + return; + } + // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot // of the set of databases. Hence no history exists to maintain regarding epoch change. // If MAIN has changed we need to check this new group_timestamp is consistent with @@ -53,7 +66,8 @@ void UpdateAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system memgraph::slk::Save(res, res_builder); } -void DropAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth, +void DropAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, + const std::optional<utils::UUID> ¤t_main_uuid, auth::SynchedAuth &auth, slk::Reader *req_reader, slk::Builder *res_builder) { replication::DropAuthDataReq req; memgraph::slk::Load(&req, req_reader); @@ -61,6 +75,12 @@ void DropAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_s using memgraph::replication::DropAuthDataRes; DropAuthDataRes res(false); + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, replication::DropAuthDataRes::kType.name); + memgraph::slk::Save(res, res_builder); + return; + } + // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot // of the set of databases. Hence no history exists to maintain regarding epoch change. // If MAIN has changed we need to check this new group_timestamp is consistent with @@ -155,14 +175,14 @@ void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAc auth::SynchedAuth &auth) { // NOTE: Register even without license as the user could add a license at run-time data.server->rpc_server_.Register<replication::UpdateAuthDataRpc>( - [system_state_access, &auth](auto *req_reader, auto *res_builder) mutable { + [&data, system_state_access, &auth](auto *req_reader, auto *res_builder) mutable { spdlog::debug("Received UpdateAuthDataRpc"); - UpdateAuthDataHandler(system_state_access, auth, req_reader, res_builder); + UpdateAuthDataHandler(system_state_access, data.uuid_, auth, req_reader, res_builder); }); data.server->rpc_server_.Register<replication::DropAuthDataRpc>( - [system_state_access, &auth](auto *req_reader, auto *res_builder) mutable { + [&data, system_state_access, &auth](auto *req_reader, auto *res_builder) mutable { spdlog::debug("Received DropAuthDataRpc"); - DropAuthDataHandler(system_state_access, auth, req_reader, res_builder); + DropAuthDataHandler(system_state_access, data.uuid_, auth, req_reader, res_builder); }); } #endif diff --git a/src/auth/replication_handlers.hpp b/src/auth/replication_handlers.hpp index 0d46e957f..38ec5277f 100644 --- a/src/auth/replication_handlers.hpp +++ b/src/auth/replication_handlers.hpp @@ -17,10 +17,16 @@ #include "system/state.hpp" namespace memgraph::auth { + +void LogWrongMain(const std::optional<utils::UUID> ¤t_main_uuid, const utils::UUID &main_req_id, + std::string_view rpc_req); + #ifdef MG_ENTERPRISE -void UpdateAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth, +void UpdateAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, + const std::optional<utils::UUID> ¤t_main_uuid, auth::SynchedAuth &auth, slk::Reader *req_reader, slk::Builder *res_builder); -void DropAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth, +void DropAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, + const std::optional<utils::UUID> ¤t_main_uuid, auth::SynchedAuth &auth, slk::Reader *req_reader, slk::Builder *res_builder); bool SystemRecoveryHandler(auth::SynchedAuth &auth, auth::Auth::Config auth_config, diff --git a/src/auth/rpc.cpp b/src/auth/rpc.cpp index f1d09eb01..b658c9491 100644 --- a/src/auth/rpc.cpp +++ b/src/auth/rpc.cpp @@ -89,6 +89,7 @@ void Load(auth::Auth::Config *self, memgraph::slk::Reader *reader) { // Serialize code for UpdateAuthDataReq void Save(const memgraph::replication::UpdateAuthDataReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.epoch_id, builder); memgraph::slk::Save(self.expected_group_timestamp, builder); memgraph::slk::Save(self.new_group_timestamp, builder); @@ -96,6 +97,7 @@ void Save(const memgraph::replication::UpdateAuthDataReq &self, memgraph::slk::B memgraph::slk::Save(self.role, builder); } void Load(memgraph::replication::UpdateAuthDataReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->epoch_id, reader); memgraph::slk::Load(&self->expected_group_timestamp, reader); memgraph::slk::Load(&self->new_group_timestamp, reader); @@ -113,6 +115,7 @@ void Load(memgraph::replication::UpdateAuthDataRes *self, memgraph::slk::Reader // Serialize code for DropAuthDataReq void Save(const memgraph::replication::DropAuthDataReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.epoch_id, builder); memgraph::slk::Save(self.expected_group_timestamp, builder); memgraph::slk::Save(self.new_group_timestamp, builder); @@ -120,6 +123,7 @@ void Save(const memgraph::replication::DropAuthDataReq &self, memgraph::slk::Bui memgraph::slk::Save(self.name, builder); } void Load(memgraph::replication::DropAuthDataReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->epoch_id, reader); memgraph::slk::Load(&self->expected_group_timestamp, reader); memgraph::slk::Load(&self->new_group_timestamp, reader); diff --git a/src/auth/rpc.hpp b/src/auth/rpc.hpp index 55bd403c7..a4ce653b8 100644 --- a/src/auth/rpc.hpp +++ b/src/auth/rpc.hpp @@ -27,17 +27,22 @@ struct UpdateAuthDataReq { static void Load(UpdateAuthDataReq *self, memgraph::slk::Reader *reader); static void Save(const UpdateAuthDataReq &self, memgraph::slk::Builder *builder); UpdateAuthDataReq() = default; - UpdateAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, auth::User user) - : epoch_id{std::move(epoch_id)}, + UpdateAuthDataReq(const utils::UUID &main_uuid, std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, + auth::User user) + : main_uuid(main_uuid), + epoch_id{std::move(epoch_id)}, expected_group_timestamp{expected_ts}, new_group_timestamp{new_ts}, user{std::move(user)} {} - UpdateAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, auth::Role role) - : epoch_id{std::move(epoch_id)}, + UpdateAuthDataReq(const utils::UUID &main_uuid, std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, + auth::Role role) + : main_uuid(main_uuid), + epoch_id{std::move(epoch_id)}, expected_group_timestamp{expected_ts}, new_group_timestamp{new_ts}, role{std::move(role)} {} + utils::UUID main_uuid; std::string epoch_id; uint64_t expected_group_timestamp; uint64_t new_group_timestamp; @@ -69,13 +74,16 @@ struct DropAuthDataReq { enum class DataType { USER, ROLE }; - DropAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, DataType type, std::string_view name) - : epoch_id{std::move(epoch_id)}, + DropAuthDataReq(const utils::UUID &main_uuid, std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, + DataType type, std::string_view name) + : main_uuid(main_uuid), + epoch_id{std::move(epoch_id)}, expected_group_timestamp{expected_ts}, new_group_timestamp{new_ts}, type{type}, name{name} {} + utils::UUID main_uuid; std::string epoch_id; uint64_t expected_group_timestamp; uint64_t new_group_timestamp; diff --git a/src/coordination/coordinator_client.cpp b/src/coordination/coordinator_client.cpp index db8d692f6..ce2cb0bda 100644 --- a/src/coordination/coordinator_client.cpp +++ b/src/coordination/coordinator_client.cpp @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include "utils/uuid.hpp" #ifdef MG_ENTERPRISE #include "coordination/coordinator_client.hpp" @@ -71,16 +72,17 @@ auto CoordinatorClient::SetCallbacks(HealthCheckCallback succ_cb, HealthCheckCal auto CoordinatorClient::ReplicationClientInfo() const -> ReplClientInfo { return config_.replication_client_info; } -auto CoordinatorClient::SendPromoteReplicaToMainRpc(ReplicationClientsInfo replication_clients_info) const -> bool { +auto CoordinatorClient::SendPromoteReplicaToMainRpc(const utils::UUID &uuid, + ReplicationClientsInfo replication_clients_info) const -> bool { try { - auto stream{rpc_client_.Stream<PromoteReplicaToMainRpc>(std::move(replication_clients_info))}; + auto stream{rpc_client_.Stream<PromoteReplicaToMainRpc>(uuid, std::move(replication_clients_info))}; if (!stream.AwaitResponse().success) { - spdlog::error("Failed to receive successful RPC failover response!"); + spdlog::error("Failed to receive successful PromoteReplicaToMainRpc response!"); return false; } return true; } catch (rpc::RpcFailedException const &) { - spdlog::error("RPC error occurred while sending failover RPC!"); + spdlog::error("RPC error occurred while sending PromoteReplicaToMainRpc!"); } return false; } @@ -101,5 +103,19 @@ auto CoordinatorClient::DemoteToReplica() const -> bool { return false; } +auto CoordinatorClient::SendSwapMainUUIDRpc(const utils::UUID &uuid) const -> bool { + try { + auto stream{rpc_client_.Stream<replication_coordination_glue::SwapMainUUIDRpc>(uuid)}; + if (!stream.AwaitResponse().success) { + spdlog::error("Failed to receive successful RPC swapping of uuid response!"); + return false; + } + return true; + } catch (const rpc::RpcFailedException &) { + spdlog::error("RPC error occurred while sending swapping uuid RPC!"); + } + return false; +} + } // namespace memgraph::coordination #endif diff --git a/src/coordination/coordinator_data.cpp b/src/coordination/coordinator_data.cpp index 856c3e84d..3eb251003 100644 --- a/src/coordination/coordinator_data.cpp +++ b/src/coordination/coordinator_data.cpp @@ -11,6 +11,7 @@ #include "coordination/coordinator_instance.hpp" #include "coordination/register_main_replica_coordinator_status.hpp" +#include "utils/uuid.hpp" #ifdef MG_ENTERPRISE #include "coordination/coordinator_data.hpp" @@ -32,60 +33,94 @@ CoordinatorData::CoordinatorData() { return *instance; }; - replica_succ_cb_ = [find_instance](CoordinatorData *coord_data, std::string_view instance_name) -> void { + replica_succ_cb_ = [this, find_instance](CoordinatorData *coord_data, std::string_view instance_name) -> void { auto lock = std::lock_guard{coord_data->coord_data_lock_}; spdlog::trace("Instance {} performing replica successful callback", instance_name); - find_instance(coord_data, instance_name).OnSuccessPing(); + auto &instance = find_instance(coord_data, instance_name); + + if (!instance.GetMainUUID().has_value() || main_uuid_ != instance.GetMainUUID().value()) { + if (!instance.SendSwapAndUpdateUUID(main_uuid_)) { + spdlog::error( + fmt::format("Failed to swap uuid for replica instance {} which is alive", instance.InstanceName())); + return; + } + } + + instance.OnSuccessPing(); }; replica_fail_cb_ = [find_instance](CoordinatorData *coord_data, std::string_view instance_name) -> void { auto lock = std::lock_guard{coord_data->coord_data_lock_}; spdlog::trace("Instance {} performing replica failure callback", instance_name); - find_instance(coord_data, instance_name).OnFailPing(); + auto &instance = find_instance(coord_data, instance_name); + instance.OnFailPing(); + // We need to restart main uuid from instance since it was "down" at least a second + // There is slight delay, if we choose to use isAlive, instance can be down and back up in less than + // our isAlive time difference, which would lead to instance setting UUID to nullopt and stopping accepting any + // incoming RPCs from valid main + // TODO(antoniofilipovic) this needs here more complex logic + // We need to get id of main replica is listening to on successful ping + // and swap it to correct uuid if it failed + instance.SetNewMainUUID(); }; - main_succ_cb_ = [find_instance](CoordinatorData *coord_data, std::string_view instance_name) -> void { + main_succ_cb_ = [this, find_instance](CoordinatorData *coord_data, std::string_view instance_name) -> void { auto lock = std::lock_guard{coord_data->coord_data_lock_}; spdlog::trace("Instance {} performing main successful callback", instance_name); auto &instance = find_instance(coord_data, instance_name); - if (instance.IsAlive() || !coord_data->ClusterHasAliveMain_()) { + const auto &instance_uuid = instance.GetMainUUID(); + MG_ASSERT(instance_uuid.has_value(), "Instance must have uuid set"); + if (main_uuid_ == instance_uuid.value()) { instance.OnSuccessPing(); return; } + // TODO(antoniof) make demoteToReplica idempotent since main can be demoted to replica but + // swapUUID can fail bool const demoted = instance.DemoteToReplica(coord_data->replica_succ_cb_, coord_data->replica_fail_cb_); if (demoted) { instance.OnSuccessPing(); spdlog::info("Instance {} demoted to replica", instance_name); } else { spdlog::error("Instance {} failed to become replica", instance_name); + return; + } + + if (!instance.SendSwapAndUpdateUUID(main_uuid_)) { + spdlog::error(fmt::format("Failed to swap uuid for demoted main instance {}", instance.InstanceName())); + return; } }; - main_fail_cb_ = [find_instance](CoordinatorData *coord_data, std::string_view instance_name) -> void { + main_fail_cb_ = [this, find_instance](CoordinatorData *coord_data, std::string_view instance_name) -> void { auto lock = std::lock_guard{coord_data->coord_data_lock_}; spdlog::trace("Instance {} performing main failure callback", instance_name); - find_instance(coord_data, instance_name).OnFailPing(); + auto &instance = find_instance(coord_data, instance_name); + instance.OnFailPing(); + const auto &instance_uuid = instance.GetMainUUID(); + MG_ASSERT(instance_uuid.has_value(), "Instance must have uuid set"); - if (!coord_data->ClusterHasAliveMain_()) { + if (!instance.IsAlive() && main_uuid_ == instance_uuid.value()) { spdlog::info("Cluster without main instance, trying automatic failover"); coord_data->TryFailover(); } }; } -auto CoordinatorData::ClusterHasAliveMain_() const -> bool { - auto const alive_main = [](CoordinatorInstance const &instance) { return instance.IsMain() && instance.IsAlive(); }; - return std::ranges::any_of(registered_instances_, alive_main); -} - auto CoordinatorData::TryFailover() -> void { - auto replica_instances = registered_instances_ | ranges::views::filter(&CoordinatorInstance::IsReplica); + std::vector<CoordinatorInstance *> alive_registered_replica_instances{}; + std::ranges::transform(registered_instances_ | ranges::views::filter(&CoordinatorInstance::IsReplica) | + ranges::views::filter(&CoordinatorInstance::IsAlive), + std::back_inserter(alive_registered_replica_instances), + [](CoordinatorInstance &instance) { return &instance; }); - auto chosen_replica_instance = std::ranges::find_if(replica_instances, &CoordinatorInstance::IsAlive); - if (chosen_replica_instance == replica_instances.end()) { + // TODO(antoniof) more complex logic of choosing replica instance + CoordinatorInstance *chosen_replica_instance = + !alive_registered_replica_instances.empty() ? alive_registered_replica_instances[0] : nullptr; + + if (nullptr == chosen_replica_instance) { spdlog::warn("Failover failed since all replicas are down!"); return; } @@ -93,21 +128,39 @@ auto CoordinatorData::TryFailover() -> void { chosen_replica_instance->PauseFrequentCheck(); utils::OnScopeExit scope_exit{[&chosen_replica_instance] { chosen_replica_instance->ResumeFrequentCheck(); }}; - std::vector<ReplClientInfo> repl_clients_info; - repl_clients_info.reserve(std::ranges::distance(replica_instances)); + utils::UUID potential_new_main_uuid = utils::UUID{}; + spdlog::trace("Generated potential new main uuid"); - auto const not_chosen_replica_instance = [&chosen_replica_instance](CoordinatorInstance const &instance) { - return instance != *chosen_replica_instance; + auto not_chosen_instance = [chosen_replica_instance](auto *instance) { + return *instance != *chosen_replica_instance; }; + // If for some replicas swap fails, for others on successful ping we will revert back on next change + // or we will do failover first again and then it will be consistent again + for (auto *other_replica_instance : alive_registered_replica_instances | ranges::views::filter(not_chosen_instance)) { + if (!other_replica_instance->SendSwapAndUpdateUUID(potential_new_main_uuid)) { + spdlog::error(fmt::format("Failed to swap uuid for instance {} which is alive, aborting failover", + other_replica_instance->InstanceName())); + return; + } + } - std::ranges::transform(registered_instances_ | ranges::views::filter(not_chosen_replica_instance), + std::vector<ReplClientInfo> repl_clients_info; + repl_clients_info.reserve(registered_instances_.size() - 1); + + std::ranges::transform(registered_instances_ | ranges::views::filter([chosen_replica_instance](const auto &instance) { + return *chosen_replica_instance != instance; + }), std::back_inserter(repl_clients_info), [](const CoordinatorInstance &instance) { return instance.ReplicationClientInfo(); }); - if (!chosen_replica_instance->PromoteToMain(std::move(repl_clients_info), main_succ_cb_, main_fail_cb_)) { + if (!chosen_replica_instance->PromoteToMain(potential_new_main_uuid, std::move(repl_clients_info), main_succ_cb_, + main_fail_cb_)) { spdlog::warn("Failover failed since promoting replica to main failed!"); return; } + chosen_replica_instance->SetNewMainUUID(potential_new_main_uuid); + main_uuid_ = potential_new_main_uuid; + spdlog::info("Failover successful! Instance {} promoted to main.", chosen_replica_instance->InstanceName()); } @@ -160,14 +213,28 @@ auto CoordinatorData::SetInstanceToMain(std::string instance_name) -> SetInstanc auto const is_not_new_main = [&instance_name](CoordinatorInstance const &instance) { return instance.InstanceName() != instance_name; }; + + auto potential_new_main_uuid = utils::UUID{}; + spdlog::trace("Generated potential new main uuid"); + + for (auto &other_instance : registered_instances_ | ranges::views::filter(is_not_new_main)) { + if (!other_instance.SendSwapAndUpdateUUID(potential_new_main_uuid)) { + spdlog::error( + fmt::format("Failed to swap uuid for instance {}, aborting failover", other_instance.InstanceName())); + return SetInstanceToMainCoordinatorStatus::SWAP_UUID_FAILED; + } + } + std::ranges::transform(registered_instances_ | ranges::views::filter(is_not_new_main), std::back_inserter(repl_clients_info), [](const CoordinatorInstance &instance) { return instance.ReplicationClientInfo(); }); - if (!new_main->PromoteToMain(std::move(repl_clients_info), main_succ_cb_, main_fail_cb_)) { + if (!new_main->PromoteToMain(potential_new_main_uuid, std::move(repl_clients_info), main_succ_cb_, main_fail_cb_)) { return SetInstanceToMainCoordinatorStatus::COULD_NOT_PROMOTE_TO_MAIN; } + new_main->SetNewMainUUID(potential_new_main_uuid); + main_uuid_ = potential_new_main_uuid; spdlog::info("Instance {} promoted to main", instance_name); return SetInstanceToMainCoordinatorStatus::SUCCESS; } diff --git a/src/coordination/coordinator_handlers.cpp b/src/coordination/coordinator_handlers.cpp index 63e1e4f8f..fb0750935 100644 --- a/src/coordination/coordinator_handlers.cpp +++ b/src/coordination/coordinator_handlers.cpp @@ -16,6 +16,7 @@ #include "coordination/coordinator_rpc.hpp" #include "coordination/include/coordination/coordinator_server.hpp" +#include "replication/state.hpp" namespace memgraph::dbms { @@ -32,6 +33,29 @@ void CoordinatorHandlers::Register(memgraph::coordination::CoordinatorServer &se spdlog::info("Received DemoteMainToReplicaRpc from coordinator server"); CoordinatorHandlers::DemoteMainToReplicaHandler(replication_handler, req_reader, res_builder); }); + + server.Register<replication_coordination_glue::SwapMainUUIDRpc>( + [&replication_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { + spdlog::info("Received SwapMainUUIDRPC on coordinator server"); + CoordinatorHandlers::SwapMainUUIDHandler(replication_handler, req_reader, res_builder); + }); +} + +void CoordinatorHandlers::SwapMainUUIDHandler(replication::ReplicationHandler &replication_handler, + slk::Reader *req_reader, slk::Builder *res_builder) { + if (!replication_handler.IsReplica()) { + spdlog::error("Setting main uuid must be performed on replica."); + slk::Save(replication_coordination_glue::SwapMainUUIDRes{false}, res_builder); + return; + } + + replication_coordination_glue::SwapMainUUIDReq req; + slk::Load(&req, req_reader); + spdlog::info(fmt::format("Set replica data UUID to main uuid {}", std::string(req.uuid))); + std::get<memgraph::replication::RoleReplicaData>(replication_handler.GetReplState().ReplicationData()).uuid_ = + req.uuid; + + slk::Save(replication_coordination_glue::SwapMainUUIDRes{true}, res_builder); } void CoordinatorHandlers::DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler, @@ -51,7 +75,7 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(replication::ReplicationHan .ip_address = req.replication_client_info.replication_ip_address, .port = req.replication_client_info.replication_port}; - if (!replication_handler.SetReplicationRoleReplica(clients_config)) { + if (!replication_handler.SetReplicationRoleReplica(clients_config, std::nullopt)) { spdlog::error("Demoting main to replica failed!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; @@ -67,18 +91,17 @@ void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHa slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; } + coordination::PromoteReplicaToMainReq req; + slk::Load(&req, req_reader); // This can fail because of disk. If it does, the cluster state could get inconsistent. // We don't handle disk issues. - if (!replication_handler.DoReplicaToMainPromotion()) { + if (const bool success = replication_handler.DoReplicaToMainPromotion(req.main_uuid_); !success) { spdlog::error("Promoting replica to main failed!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; } - coordination::PromoteReplicaToMainReq req; - slk::Load(&req, req_reader); - auto const converter = [](const auto &repl_info_config) { return replication::ReplicationClientConfig{ .name = repl_info_config.instance_name, @@ -90,7 +113,7 @@ void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHa // registering replicas for (auto const &config : req.replication_clients_info | ranges::views::transform(converter)) { - auto instance_client = replication_handler.RegisterReplica(config); + auto instance_client = replication_handler.RegisterReplica(config, false); if (instance_client.HasError()) { using enum memgraph::replication::RegisterReplicaError; switch (instance_client.GetError()) { @@ -109,13 +132,17 @@ void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHa spdlog::error("Registered replica could not be persisted!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; + case memgraph::query::RegisterReplicaError::ERROR_ACCEPTING_MAIN: + spdlog::error("Replica didn't accept change of main!"); + slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); + return; case memgraph::query::RegisterReplicaError::CONNECTION_FAILED: // Connection failure is not a fatal error break; } } } - + spdlog::error(fmt::format("FICO : Promote replica to main was success {}", std::string(req.main_uuid_))); slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder); } diff --git a/src/coordination/coordinator_instance.cpp b/src/coordination/coordinator_instance.cpp index 0180b99c1..a759a2505 100644 --- a/src/coordination/coordinator_instance.cpp +++ b/src/coordination/coordinator_instance.cpp @@ -49,9 +49,9 @@ auto CoordinatorInstance::IsMain() const -> bool { return replication_role_ == replication_coordination_glue::ReplicationRole::MAIN; } -auto CoordinatorInstance::PromoteToMain(ReplicationClientsInfo repl_clients_info, HealthCheckCallback main_succ_cb, - HealthCheckCallback main_fail_cb) -> bool { - if (!client_.SendPromoteReplicaToMainRpc(std::move(repl_clients_info))) { +auto CoordinatorInstance::PromoteToMain(utils::UUID uuid, ReplicationClientsInfo repl_clients_info, + HealthCheckCallback main_succ_cb, HealthCheckCallback main_fail_cb) -> bool { + if (!client_.SendPromoteReplicaToMainRpc(uuid, std::move(repl_clients_info))) { return false; } @@ -80,5 +80,17 @@ auto CoordinatorInstance::ReplicationClientInfo() const -> CoordinatorClientConf return client_.ReplicationClientInfo(); } +auto CoordinatorInstance::GetClient() -> CoordinatorClient & { return client_; } +void CoordinatorInstance::SetNewMainUUID(const std::optional<utils::UUID> &main_uuid) { main_uuid_ = main_uuid; } +auto CoordinatorInstance::GetMainUUID() -> const std::optional<utils::UUID> & { return main_uuid_; } + +auto CoordinatorInstance::SendSwapAndUpdateUUID(const utils::UUID &main_uuid) -> bool { + if (!replication_coordination_glue::SendSwapMainUUIDRpc(client_.RpcClient(), main_uuid)) { + return false; + } + SetNewMainUUID(main_uuid_); + return true; +} + } // namespace memgraph::coordination #endif diff --git a/src/coordination/coordinator_rpc.cpp b/src/coordination/coordinator_rpc.cpp index 053e46f13..2b5752a07 100644 --- a/src/coordination/coordinator_rpc.cpp +++ b/src/coordination/coordinator_rpc.cpp @@ -77,10 +77,12 @@ void Load(memgraph::coordination::PromoteReplicaToMainRes *self, memgraph::slk:: } void Save(const memgraph::coordination::PromoteReplicaToMainReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid_, builder); memgraph::slk::Save(self.replication_clients_info, builder); } void Load(memgraph::coordination::PromoteReplicaToMainReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid_, reader); memgraph::slk::Load(&self->replication_clients_info, reader); } diff --git a/src/coordination/coordinator_server.cpp b/src/coordination/coordinator_server.cpp index a8253cf25..60dc5e348 100644 --- a/src/coordination/coordinator_server.cpp +++ b/src/coordination/coordinator_server.cpp @@ -12,7 +12,7 @@ #ifdef MG_ENTERPRISE #include "coordination/coordinator_server.hpp" -#include "replication_coordination_glue/messages.hpp" +#include "replication_coordination_glue/handler.hpp" namespace memgraph::coordination { diff --git a/src/coordination/include/coordination/coordinator_client.hpp b/src/coordination/include/coordination/coordinator_client.hpp index a8d49e00e..00695acd7 100644 --- a/src/coordination/include/coordination/coordinator_client.hpp +++ b/src/coordination/include/coordination/coordinator_client.hpp @@ -11,6 +11,7 @@ #pragma once +#include "utils/uuid.hpp" #ifdef MG_ENTERPRISE #include "coordination/coordinator_config.hpp" @@ -44,13 +45,20 @@ class CoordinatorClient { auto InstanceName() const -> std::string; auto SocketAddress() const -> std::string; - [[nodiscard]] auto SendPromoteReplicaToMainRpc(ReplicationClientsInfo replication_clients_info) const -> bool; [[nodiscard]] auto DemoteToReplica() const -> bool; + auto SendPromoteReplicaToMainRpc(const utils::UUID &uuid, ReplicationClientsInfo replication_clients_info) const + -> bool; + + + auto SendSwapMainUUIDRpc(const utils::UUID &uuid) const -> bool; auto ReplicationClientInfo() const -> ReplClientInfo; + auto SetCallbacks(HealthCheckCallback succ_cb, HealthCheckCallback fail_cb) -> void; + auto RpcClient() -> rpc::Client & { return rpc_client_; } + friend bool operator==(CoordinatorClient const &first, CoordinatorClient const &second) { return first.config_ == second.config_; } diff --git a/src/coordination/include/coordination/coordinator_data.hpp b/src/coordination/include/coordination/coordinator_data.hpp index 4dff209f0..73bebdf7e 100644 --- a/src/coordination/include/coordination/coordinator_data.hpp +++ b/src/coordination/include/coordination/coordinator_data.hpp @@ -11,17 +11,18 @@ #pragma once +#include "utils/uuid.hpp" #ifdef MG_ENTERPRISE +#include <list> #include "coordination/coordinator_instance.hpp" #include "coordination/coordinator_instance_status.hpp" #include "coordination/coordinator_server.hpp" #include "coordination/register_main_replica_coordinator_status.hpp" +#include "replication_coordination_glue/handler.hpp" #include "utils/rw_lock.hpp" #include "utils/thread_pool.hpp" -#include <list> - namespace memgraph::coordination { class CoordinatorData { public: @@ -36,12 +37,11 @@ class CoordinatorData { auto ShowInstances() const -> std::vector<CoordinatorInstanceStatus>; private: - auto ClusterHasAliveMain_() const -> bool; - mutable utils::RWLock coord_data_lock_{utils::RWLock::Priority::READ}; HealthCheckCallback main_succ_cb_, main_fail_cb_, replica_succ_cb_, replica_fail_cb_; // NOTE: Must be std::list because we rely on pointer stability std::list<CoordinatorInstance> registered_instances_; + utils::UUID main_uuid_; }; struct CoordinatorMainReplicaData { diff --git a/src/coordination/include/coordination/coordinator_handlers.hpp b/src/coordination/include/coordination/coordinator_handlers.hpp index a5cd4929e..1f170bd61 100644 --- a/src/coordination/include/coordination/coordinator_handlers.hpp +++ b/src/coordination/include/coordination/coordinator_handlers.hpp @@ -31,6 +31,8 @@ class CoordinatorHandlers { slk::Builder *res_builder); static void DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, slk::Builder *res_builder); + static void SwapMainUUIDHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, slk::Builder *res_builder); + }; } // namespace memgraph::dbms diff --git a/src/coordination/include/coordination/coordinator_instance.hpp b/src/coordination/include/coordination/coordinator_instance.hpp index e03fa30c7..f3fd3deca 100644 --- a/src/coordination/include/coordination/coordinator_instance.hpp +++ b/src/coordination/include/coordination/coordinator_instance.hpp @@ -16,6 +16,7 @@ #include "coordination/coordinator_client.hpp" #include "coordination/coordinator_cluster_config.hpp" #include "coordination/coordinator_exceptions.hpp" +#include "replication_coordination_glue/handler.hpp" #include "replication_coordination_glue/role.hpp" namespace memgraph::coordination { @@ -44,7 +45,7 @@ class CoordinatorInstance { auto IsReplica() const -> bool; auto IsMain() const -> bool; - auto PromoteToMain(ReplicationClientsInfo repl_clients_info, HealthCheckCallback main_succ_cb, + auto PromoteToMain(utils::UUID main_uuid, ReplicationClientsInfo repl_clients_info, HealthCheckCallback main_succ_cb, HealthCheckCallback main_fail_cb) -> bool; auto DemoteToReplica(HealthCheckCallback replica_succ_cb, HealthCheckCallback replica_fail_cb) -> bool; @@ -53,11 +54,25 @@ class CoordinatorInstance { auto ReplicationClientInfo() const -> ReplClientInfo; + auto GetClient() -> CoordinatorClient &; + + void SetNewMainUUID(const std::optional<utils::UUID> &main_uuid = std::nullopt); + auto GetMainUUID() -> const std::optional<utils::UUID> &; + + auto SendSwapAndUpdateUUID(const utils::UUID &main_uuid) -> bool; + private: CoordinatorClient client_; replication_coordination_glue::ReplicationRole replication_role_; std::chrono::system_clock::time_point last_response_time_{}; + // TODO this needs to be atomic? What if instance is alive and then we read it and it has changed bool is_alive_{false}; + // for replica this is main uuid of current main + // for "main" main this same as in CoordinatorData + // it is set to nullopt when replica is down + // TLDR; when replica is down and comes back up we reset uuid of main replica is listening to + // so we need to send swap uuid again + std::optional<utils::UUID> main_uuid_; friend bool operator==(CoordinatorInstance const &first, CoordinatorInstance const &second) { return first.client_ == second.client_ && first.replication_role_ == second.replication_role_; diff --git a/src/coordination/include/coordination/coordinator_rpc.hpp b/src/coordination/include/coordination/coordinator_rpc.hpp index d2786ef62..56cfdb403 100644 --- a/src/coordination/include/coordination/coordinator_rpc.hpp +++ b/src/coordination/include/coordination/coordinator_rpc.hpp @@ -11,6 +11,7 @@ #pragma once +#include "utils/uuid.hpp" #ifdef MG_ENTERPRISE #include "coordination/coordinator_config.hpp" @@ -26,10 +27,13 @@ struct PromoteReplicaToMainReq { static void Load(PromoteReplicaToMainReq *self, memgraph::slk::Reader *reader); static void Save(const PromoteReplicaToMainReq &self, memgraph::slk::Builder *builder); - explicit PromoteReplicaToMainReq(std::vector<CoordinatorClientConfig::ReplicationClientInfo> replication_clients_info) - : replication_clients_info(std::move(replication_clients_info)) {} + explicit PromoteReplicaToMainReq(const utils::UUID &uuid, + std::vector<CoordinatorClientConfig::ReplicationClientInfo> replication_clients_info) + : main_uuid_(uuid), replication_clients_info(std::move(replication_clients_info)) {} PromoteReplicaToMainReq() = default; + // get uuid here + utils::UUID main_uuid_; std::vector<CoordinatorClientConfig::ReplicationClientInfo> replication_clients_info; }; @@ -83,22 +87,19 @@ using DemoteMainToReplicaRpc = rpc::RequestResponse<DemoteMainToReplicaReq, Demo // SLK serialization declarations namespace memgraph::slk { +// PromoteReplicaToMainRpc void Save(const memgraph::coordination::PromoteReplicaToMainRes &self, memgraph::slk::Builder *builder); - void Load(memgraph::coordination::PromoteReplicaToMainRes *self, memgraph::slk::Reader *reader); - void Save(const memgraph::coordination::PromoteReplicaToMainReq &self, memgraph::slk::Builder *builder); - void Load(memgraph::coordination::PromoteReplicaToMainReq *self, memgraph::slk::Reader *reader); +// DemoteMainToReplicaRpc void Save(const memgraph::coordination::DemoteMainToReplicaRes &self, memgraph::slk::Builder *builder); - void Load(memgraph::coordination::DemoteMainToReplicaRes *self, memgraph::slk::Reader *reader); - void Save(const memgraph::coordination::DemoteMainToReplicaReq &self, memgraph::slk::Builder *builder); - void Load(memgraph::coordination::DemoteMainToReplicaReq *self, memgraph::slk::Reader *reader); + } // namespace memgraph::slk #endif diff --git a/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp b/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp index 3e742fb3b..bf35e9156 100644 --- a/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp +++ b/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp @@ -30,6 +30,7 @@ enum class SetInstanceToMainCoordinatorStatus : uint8_t { NOT_COORDINATOR, SUCCESS, COULD_NOT_PROMOTE_TO_MAIN, + SWAP_UUID_FAILED }; } // namespace memgraph::coordination diff --git a/src/dbms/dbms_handler.cpp b/src/dbms/dbms_handler.cpp index 861bcf701..a68fbc72c 100644 --- a/src/dbms/dbms_handler.cpp +++ b/src/dbms/dbms_handler.cpp @@ -38,6 +38,8 @@ std::string RegisterReplicaErrorToString(query::RegisterReplicaError error) { return "CONNECTION_FAILED"; case COULD_NOT_BE_PERSISTED: return "COULD_NOT_BE_PERSISTED"; + case ERROR_ACCEPTING_MAIN: + return "ERROR_ACCEPTING_MAIN"; } } @@ -52,7 +54,7 @@ void RestoreReplication(replication::RoleMainData &mainData, DatabaseAccess db_a spdlog::info("Replica {} restoration started for {}.", instance_client.name_, db_acc->name()); const auto &ret = db_acc->storage()->repl_storage_state_.replication_clients_.WithLock( [&, db_acc](auto &storage_clients) mutable -> utils::BasicResult<query::RegisterReplicaError> { - auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client); + auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client, mainData.uuid_); auto *storage = db_acc->storage(); client->Start(storage, std::move(db_acc)); // After start the storage <-> replica state should be READY or RECOVERING (if correctly started) @@ -239,14 +241,16 @@ struct DropDatabase : memgraph::system::ISystemAction { void DoDurability() override { /* Done during DBMS execution */ } - bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch, + bool DoReplication(replication::ReplicationClient &client, const utils::UUID &main_uuid, + replication::ReplicationEpoch const &epoch, memgraph::system::Transaction const &txn) const override { auto check_response = [](const storage::replication::DropDatabaseRes &response) { return response.result != storage::replication::DropDatabaseRes::Result::FAILURE; }; return client.SteamAndFinalizeDelta<storage::replication::DropDatabaseRpc>( - check_response, epoch.id(), txn.last_committed_system_timestamp(), txn.timestamp(), uuid_); + check_response, main_uuid, std::string(epoch.id()), txn.last_committed_system_timestamp(), txn.timestamp(), + uuid_); } void PostReplication(replication::RoleMainData &mainData) const override {} @@ -323,14 +327,16 @@ struct CreateDatabase : memgraph::system::ISystemAction { // Done during dbms execution } - bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch, + bool DoReplication(replication::ReplicationClient &client, const utils::UUID &main_uuid, + replication::ReplicationEpoch const &epoch, memgraph::system::Transaction const &txn) const override { auto check_response = [](const storage::replication::CreateDatabaseRes &response) { return response.result != storage::replication::CreateDatabaseRes::Result::FAILURE; }; return client.SteamAndFinalizeDelta<storage::replication::CreateDatabaseRpc>( - check_response, epoch.id(), txn.last_committed_system_timestamp(), txn.timestamp(), config_); + check_response, main_uuid, std::string(epoch.id()), txn.last_committed_system_timestamp(), txn.timestamp(), + config_); } void PostReplication(replication::RoleMainData &mainData) const override { diff --git a/src/dbms/dbms_handler.hpp b/src/dbms/dbms_handler.hpp index 24a7599a2..a373f751b 100644 --- a/src/dbms/dbms_handler.hpp +++ b/src/dbms/dbms_handler.hpp @@ -29,6 +29,7 @@ #include "kvstore/kvstore.hpp" #include "license/license.hpp" #include "replication/replication_client.hpp" +#include "replication_coordination_glue/handler.hpp" #include "storage/v2/config.hpp" #include "storage/v2/transaction.hpp" #include "system/system.hpp" @@ -261,6 +262,16 @@ class DbmsHandler { #endif } + replication::ReplicationState &ReplicationState() { return repl_state_; } + replication::ReplicationState const &ReplicationState() const { return repl_state_; } + + bool IsMain() const { return repl_state_.IsMain(); } + bool IsReplica() const { return repl_state_.IsReplica(); } + +#ifdef MG_ENTERPRISE + // coordination::CoordinatorState &CoordinatorState() { return coordinator_state_; } +#endif + /** * @brief Return the statistics all databases. * diff --git a/src/dbms/inmemory/replication_handlers.cpp b/src/dbms/inmemory/replication_handlers.cpp index cef2bf8c6..61d8d3bbd 100644 --- a/src/dbms/inmemory/replication_handlers.cpp +++ b/src/dbms/inmemory/replication_handlers.cpp @@ -76,47 +76,84 @@ std::optional<DatabaseAccess> GetDatabaseAccessor(dbms::DbmsHandler *dbms_handle return std::nullopt; } } + +void LogWrongMain(const std::optional<utils::UUID> ¤t_main_uuid, const utils::UUID &main_req_id, + std::string_view rpc_req) { + spdlog::error("Received {} with main_id: {} != current_main_uuid: {}", rpc_req, std::string(main_req_id), + current_main_uuid.has_value() ? std::string(current_main_uuid.value()) : ""); +} } // namespace -void InMemoryReplicationHandlers::Register(dbms::DbmsHandler *dbms_handler, replication::ReplicationServer &server) { - server.rpc_server_.Register<storage::replication::HeartbeatRpc>([dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received HeartbeatRpc"); - InMemoryReplicationHandlers::HeartbeatHandler(dbms_handler, req_reader, res_builder); - }); - server.rpc_server_.Register<storage::replication::AppendDeltasRpc>( - [dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received AppendDeltasRpc"); - InMemoryReplicationHandlers::AppendDeltasHandler(dbms_handler, req_reader, res_builder); +void InMemoryReplicationHandlers::Register(dbms::DbmsHandler *dbms_handler, replication::RoleReplicaData &data) { + auto &server = *data.server; + server.rpc_server_.Register<storage::replication::HeartbeatRpc>( + [&data, dbms_handler](auto *req_reader, auto *res_builder) { + spdlog::debug("Received HeartbeatRpc"); + InMemoryReplicationHandlers::HeartbeatHandler(dbms_handler, data.uuid_, req_reader, res_builder); + }); + server.rpc_server_.Register<storage::replication::AppendDeltasRpc>( + [&data, dbms_handler](auto *req_reader, auto *res_builder) { + spdlog::debug("Received AppendDeltasRpc"); + InMemoryReplicationHandlers::AppendDeltasHandler(dbms_handler, data.uuid_, req_reader, res_builder); + }); + server.rpc_server_.Register<storage::replication::SnapshotRpc>( + [&data, dbms_handler](auto *req_reader, auto *res_builder) { + spdlog::debug("Received SnapshotRpc"); + InMemoryReplicationHandlers::SnapshotHandler(dbms_handler, data.uuid_, req_reader, res_builder); + }); + server.rpc_server_.Register<storage::replication::WalFilesRpc>( + [&data, dbms_handler](auto *req_reader, auto *res_builder) { + spdlog::debug("Received WalFilesRpc"); + InMemoryReplicationHandlers::WalFilesHandler(dbms_handler, data.uuid_, req_reader, res_builder); + }); + server.rpc_server_.Register<storage::replication::CurrentWalRpc>( + [&data, dbms_handler](auto *req_reader, auto *res_builder) { + spdlog::debug("Received CurrentWalRpc"); + InMemoryReplicationHandlers::CurrentWalHandler(dbms_handler, data.uuid_, req_reader, res_builder); + }); + server.rpc_server_.Register<storage::replication::TimestampRpc>( + [&data, dbms_handler](auto *req_reader, auto *res_builder) { + spdlog::debug("Received TimestampRpc"); + InMemoryReplicationHandlers::TimestampHandler(dbms_handler, data.uuid_, req_reader, res_builder); + }); + server.rpc_server_.Register<replication_coordination_glue::SwapMainUUIDRpc>( + [&data, dbms_handler](auto *req_reader, auto *res_builder) { + spdlog::debug("Received SwapMainUUIDHandler"); + InMemoryReplicationHandlers::SwapMainUUIDHandler(dbms_handler, data, req_reader, res_builder); }); - server.rpc_server_.Register<storage::replication::SnapshotRpc>([dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received SnapshotRpc"); - InMemoryReplicationHandlers::SnapshotHandler(dbms_handler, req_reader, res_builder); - }); - server.rpc_server_.Register<storage::replication::WalFilesRpc>([dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received WalFilesRpc"); - InMemoryReplicationHandlers::WalFilesHandler(dbms_handler, req_reader, res_builder); - }); - server.rpc_server_.Register<storage::replication::CurrentWalRpc>([dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received CurrentWalRpc"); - InMemoryReplicationHandlers::CurrentWalHandler(dbms_handler, req_reader, res_builder); - }); - server.rpc_server_.Register<storage::replication::TimestampRpc>([dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received TimestampRpc"); - InMemoryReplicationHandlers::TimestampHandler(dbms_handler, req_reader, res_builder); - }); } -void InMemoryReplicationHandlers::HeartbeatHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, - slk::Builder *res_builder) { +void InMemoryReplicationHandlers::SwapMainUUIDHandler(dbms::DbmsHandler *dbms_handler, + replication::RoleReplicaData &role_replica_data, + slk::Reader *req_reader, slk::Builder *res_builder) { + if (!dbms_handler->IsReplica()) { + spdlog::error("Setting main uuid must be performed on replica."); + slk::Save(replication_coordination_glue::SwapMainUUIDRes{false}, res_builder); + return; + } + + replication_coordination_glue::SwapMainUUIDReq req; + slk::Load(&req, req_reader); + spdlog::info(fmt::format("Set replica data UUID to main uuid {}", std::string(req.uuid))); + dbms_handler->ReplicationState().TryPersistRoleReplica(role_replica_data.config, req.uuid); + role_replica_data.uuid_ = req.uuid; + + slk::Save(replication_coordination_glue::SwapMainUUIDRes{true}, res_builder); +} + +void InMemoryReplicationHandlers::HeartbeatHandler(dbms::DbmsHandler *dbms_handler, + const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder) { storage::replication::HeartbeatReq req; slk::Load(&req, req_reader); auto const db_acc = GetDatabaseAccessor(dbms_handler, req.uuid); - if (!db_acc) { + + if (!current_main_uuid.has_value() || req.main_uuid != *current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, storage::replication::HeartbeatReq::kType.name); storage::replication::HeartbeatRes res{false, 0, ""}; slk::Save(res, res_builder); return; } - // TODO: this handler is agnostic of InMemory, move to be reused by on-disk auto const *storage = db_acc->get()->storage(); storage::replication::HeartbeatRes res{true, storage->repl_storage_state_.last_commit_timestamp_.load(), @@ -124,10 +161,19 @@ void InMemoryReplicationHandlers::HeartbeatHandler(dbms::DbmsHandler *dbms_handl slk::Save(res, res_builder); } -void InMemoryReplicationHandlers::AppendDeltasHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, - slk::Builder *res_builder) { +void InMemoryReplicationHandlers::AppendDeltasHandler(dbms::DbmsHandler *dbms_handler, + const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder) { storage::replication::AppendDeltasReq req; slk::Load(&req, req_reader); + + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, storage::replication::AppendDeltasReq::kType.name); + storage::replication::AppendDeltasRes res{false, 0}; + slk::Save(res, res_builder); + return; + } + auto db_acc = GetDatabaseAccessor(dbms_handler, req.uuid); if (!db_acc) { storage::replication::AppendDeltasRes res{false, 0}; @@ -187,8 +233,9 @@ void InMemoryReplicationHandlers::AppendDeltasHandler(dbms::DbmsHandler *dbms_ha spdlog::debug("Replication recovery from append deltas finished, replica is now up to date!"); } -void InMemoryReplicationHandlers::SnapshotHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, - slk::Builder *res_builder) { +void InMemoryReplicationHandlers::SnapshotHandler(dbms::DbmsHandler *dbms_handler, + const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder) { storage::replication::SnapshotReq req; slk::Load(&req, req_reader); auto db_acc = GetDatabaseAccessor(dbms_handler, req.uuid); @@ -197,6 +244,12 @@ void InMemoryReplicationHandlers::SnapshotHandler(dbms::DbmsHandler *dbms_handle slk::Save(res, res_builder); return; } + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, storage::replication::SnapshotReq::kType.name); + storage::replication::SnapshotRes res{false, 0}; + slk::Save(res, res_builder); + return; + } storage::replication::Decoder decoder(req_reader); @@ -270,8 +323,9 @@ void InMemoryReplicationHandlers::SnapshotHandler(dbms::DbmsHandler *dbms_handle spdlog::debug("Replication recovery from snapshot finished!"); } -void InMemoryReplicationHandlers::WalFilesHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, - slk::Builder *res_builder) { +void InMemoryReplicationHandlers::WalFilesHandler(dbms::DbmsHandler *dbms_handler, + const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder) { storage::replication::WalFilesReq req; slk::Load(&req, req_reader); auto db_acc = GetDatabaseAccessor(dbms_handler, req.uuid); @@ -280,6 +334,12 @@ void InMemoryReplicationHandlers::WalFilesHandler(dbms::DbmsHandler *dbms_handle slk::Save(res, res_builder); return; } + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, storage::replication::WalFilesReq::kType.name); + storage::replication::WalFilesRes res{false, 0}; + slk::Save(res, res_builder); + return; + } const auto wal_file_number = req.file_number; spdlog::debug("Received WAL files: {}", wal_file_number); @@ -298,8 +358,9 @@ void InMemoryReplicationHandlers::WalFilesHandler(dbms::DbmsHandler *dbms_handle spdlog::debug("Replication recovery from WAL files ended successfully, replica is now up to date!"); } -void InMemoryReplicationHandlers::CurrentWalHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, - slk::Builder *res_builder) { +void InMemoryReplicationHandlers::CurrentWalHandler(dbms::DbmsHandler *dbms_handler, + const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder) { storage::replication::CurrentWalReq req; slk::Load(&req, req_reader); auto db_acc = GetDatabaseAccessor(dbms_handler, req.uuid); @@ -309,6 +370,13 @@ void InMemoryReplicationHandlers::CurrentWalHandler(dbms::DbmsHandler *dbms_hand return; } + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, storage::replication::CurrentWalReq::kType.name); + storage::replication::CurrentWalRes res{false, 0}; + slk::Save(res, res_builder); + return; + } + storage::replication::Decoder decoder(req_reader); auto *storage = static_cast<storage::InMemoryStorage *>(db_acc->get()->storage()); @@ -370,8 +438,9 @@ void InMemoryReplicationHandlers::LoadWal(storage::InMemoryStorage *storage, sto } } -void InMemoryReplicationHandlers::TimestampHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, - slk::Builder *res_builder) { +void InMemoryReplicationHandlers::TimestampHandler(dbms::DbmsHandler *dbms_handler, + const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder) { storage::replication::TimestampReq req; slk::Load(&req, req_reader); auto const db_acc = GetDatabaseAccessor(dbms_handler, req.uuid); @@ -381,12 +450,20 @@ void InMemoryReplicationHandlers::TimestampHandler(dbms::DbmsHandler *dbms_handl return; } + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, storage::replication::TimestampReq::kType.name); + storage::replication::CurrentWalRes res{false, 0}; + slk::Save(res, res_builder); + return; + } + // TODO: this handler is agnostic of InMemory, move to be reused by on-disk auto const *storage = db_acc->get()->storage(); storage::replication::TimestampRes res{true, storage->repl_storage_state_.last_commit_timestamp_.load()}; slk::Save(res, res_builder); } +/////// AF how does this work, does it get all deltas at once or what? uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage *storage, storage::durability::BaseDecoder *decoder, const uint64_t version) { diff --git a/src/dbms/inmemory/replication_handlers.hpp b/src/dbms/inmemory/replication_handlers.hpp index 4f6523747..4406b8338 100644 --- a/src/dbms/inmemory/replication_handlers.hpp +++ b/src/dbms/inmemory/replication_handlers.hpp @@ -12,6 +12,7 @@ #pragma once #include "replication/replication_server.hpp" +#include "replication/state.hpp" #include "storage/v2/replication/serialization.hpp" namespace memgraph::storage { @@ -23,21 +24,30 @@ class DbmsHandler; class InMemoryReplicationHandlers { public: - static void Register(dbms::DbmsHandler *dbms_handler, replication::ReplicationServer &server); + static void Register(dbms::DbmsHandler *dbms_handler, replication::RoleReplicaData &data); private: // RPC handlers - static void HeartbeatHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); + static void HeartbeatHandler(dbms::DbmsHandler *dbms_handler, const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder); - static void AppendDeltasHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); + static void AppendDeltasHandler(dbms::DbmsHandler *dbms_handler, const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder); - static void SnapshotHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); + static void SnapshotHandler(dbms::DbmsHandler *dbms_handler, const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder); - static void WalFilesHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); + static void WalFilesHandler(dbms::DbmsHandler *dbms_handler, const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder); - static void CurrentWalHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); + static void CurrentWalHandler(dbms::DbmsHandler *dbms_handler, const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder); - static void TimestampHandler(dbms::DbmsHandler *dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); + static void TimestampHandler(dbms::DbmsHandler *dbms_handler, const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder); + + static void SwapMainUUIDHandler(dbms::DbmsHandler *dbms_handler, replication::RoleReplicaData &role_replica_data, + slk::Reader *req_reader, slk::Builder *res_builder); static void LoadWal(storage::InMemoryStorage *storage, storage::replication::Decoder *decoder); diff --git a/src/dbms/replication_client.cpp b/src/dbms/replication_client.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/src/dbms/replication_handler.cpp b/src/dbms/replication_handler.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/src/dbms/replication_handler.hpp b/src/dbms/replication_handler.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/src/dbms/replication_handlers.cpp b/src/dbms/replication_handlers.cpp index 2c77262fa..d2ad025b2 100644 --- a/src/dbms/replication_handlers.cpp +++ b/src/dbms/replication_handlers.cpp @@ -21,7 +21,8 @@ namespace memgraph::dbms { #ifdef MG_ENTERPRISE void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, - DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) { + const std::optional<utils::UUID> ¤t_main_uuid, DbmsHandler &dbms_handler, + slk::Reader *req_reader, slk::Builder *res_builder) { using memgraph::storage::replication::CreateDatabaseRes; CreateDatabaseRes res(CreateDatabaseRes::Result::FAILURE); @@ -35,6 +36,12 @@ void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system memgraph::storage::replication::CreateDatabaseReq req; memgraph::slk::Load(&req, req_reader); + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, memgraph::storage::replication::CreateDatabaseReq::kType.name); + memgraph::slk::Save(res, res_builder); + return; + } + // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot // of the set of databases. Hence no history exists to maintain regarding epoch change. // If MAIN has changed we need to check this new group_timestamp is consistent with @@ -63,7 +70,8 @@ void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system memgraph::slk::Save(res, res_builder); } -void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, DbmsHandler &dbms_handler, +void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, + const std::optional<utils::UUID> ¤t_main_uuid, DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) { using memgraph::storage::replication::DropDatabaseRes; DropDatabaseRes res(DropDatabaseRes::Result::FAILURE); @@ -78,6 +86,12 @@ void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_s memgraph::storage::replication::DropDatabaseReq req; memgraph::slk::Load(&req, req_reader); + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, memgraph::storage::replication::DropDatabaseReq::kType.name); + memgraph::slk::Save(res, res_builder); + return; + } + // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot // of the set of databases. Hence no history exists to maintain regarding epoch change. // If MAIN has changed we need to check this new group_timestamp is consistent with @@ -177,14 +191,14 @@ void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAc dbms::DbmsHandler &dbms_handler) { // NOTE: Register even without license as the user could add a license at run-time data.server->rpc_server_.Register<storage::replication::CreateDatabaseRpc>( - [system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable { + [&data, system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable { spdlog::debug("Received CreateDatabaseRpc"); - CreateDatabaseHandler(system_state_access, dbms_handler, req_reader, res_builder); + CreateDatabaseHandler(system_state_access, data.uuid_, dbms_handler, req_reader, res_builder); }); data.server->rpc_server_.Register<storage::replication::DropDatabaseRpc>( - [system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable { + [&data, system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable { spdlog::debug("Received DropDatabaseRpc"); - DropDatabaseHandler(system_state_access, dbms_handler, req_reader, res_builder); + DropDatabaseHandler(system_state_access, data.uuid_, dbms_handler, req_reader, res_builder); }); } #endif diff --git a/src/dbms/replication_handlers.hpp b/src/dbms/replication_handlers.hpp index 48e91e384..cf45882af 100644 --- a/src/dbms/replication_handlers.hpp +++ b/src/dbms/replication_handlers.hpp @@ -17,11 +17,21 @@ #include "system/state.hpp" namespace memgraph::dbms { + #ifdef MG_ENTERPRISE + +inline void LogWrongMain(const std::optional<utils::UUID> ¤t_main_uuid, const utils::UUID &main_req_id, + std::string_view rpc_req) { + spdlog::error("Received {} with main_id: {} != current_main_uuid: {}", rpc_req, std::string(main_req_id), + current_main_uuid.has_value() ? std::string(current_main_uuid.value()) : ""); +} + // RPC handlers void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, - DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); -void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, DbmsHandler &dbms_handler, + const std::optional<utils::UUID> ¤t_main_uuid, DbmsHandler &dbms_handler, + slk::Reader *req_reader, slk::Builder *res_builder); +void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, + const std::optional<utils::UUID> ¤t_main_uuid, DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); bool SystemRecoveryHandler(DbmsHandler &dbms_handler, const std::vector<storage::SalientConfig> &database_configs); diff --git a/src/dbms/rpc.hpp b/src/dbms/rpc.hpp index b08928e80..1edcc74e0 100644 --- a/src/dbms/rpc.hpp +++ b/src/dbms/rpc.hpp @@ -29,13 +29,15 @@ struct CreateDatabaseReq { static void Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader); static void Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder); CreateDatabaseReq() = default; - CreateDatabaseReq(std::string_view epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp, - storage::SalientConfig config) - : epoch_id(std::string(epoch_id)), + CreateDatabaseReq(const utils::UUID &main_uuid, std::string epoch_id, uint64_t expected_group_timestamp, + uint64_t new_group_timestamp, storage::SalientConfig config) + : main_uuid(main_uuid), + epoch_id(std::move(epoch_id)), expected_group_timestamp{expected_group_timestamp}, new_group_timestamp(new_group_timestamp), config(std::move(config)) {} + utils::UUID main_uuid; std::string epoch_id; uint64_t expected_group_timestamp; uint64_t new_group_timestamp; @@ -65,13 +67,15 @@ struct DropDatabaseReq { static void Load(DropDatabaseReq *self, memgraph::slk::Reader *reader); static void Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder); DropDatabaseReq() = default; - DropDatabaseReq(std::string_view epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp, - const utils::UUID &uuid) - : epoch_id(std::string(epoch_id)), + DropDatabaseReq(const utils::UUID &main_uuid, std::string epoch_id, uint64_t expected_group_timestamp, + uint64_t new_group_timestamp, const utils::UUID &uuid) + : main_uuid(main_uuid), + epoch_id(std::move(epoch_id)), expected_group_timestamp{expected_group_timestamp}, new_group_timestamp(new_group_timestamp), uuid(uuid) {} + utils::UUID main_uuid; std::string epoch_id; uint64_t expected_group_timestamp; uint64_t new_group_timestamp; diff --git a/src/dbms/utils.hpp b/src/dbms/utils.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 783ff6ae9..ea175a18e 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -327,7 +327,7 @@ class ReplQueryHandler { .port = static_cast<uint16_t>(*port), }; - if (!handler_->SetReplicationRoleReplica(config)) { + if (!handler_->SetReplicationRoleReplica(config, std::nullopt)) { throw QueryRuntimeException("Couldn't set role to replica!"); } } @@ -368,7 +368,7 @@ class ReplQueryHandler { .replica_check_frequency = replica_check_frequency, .ssl = std::nullopt}; - const auto error = handler_->TryRegisterReplica(replication_config).HasError(); + const auto error = handler_->TryRegisterReplica(replication_config, true).HasError(); if (error) { throw QueryRuntimeException(fmt::format("Couldn't register replica '{}'!", name)); @@ -518,7 +518,9 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { throw QueryRuntimeException("SET INSTANCE TO MAIN query can only be run on a coordinator!"); case COULD_NOT_PROMOTE_TO_MAIN: throw QueryRuntimeException( - "Couldn't set replica instance to main!. Check coordinator and replica for more logs"); + "Couldn't set replica instance to main! Check coordinator and replica for more logs"); + case SWAP_UUID_FAILED: + throw QueryRuntimeException("Couldn't set replica instance to main. Replicas didn't swap uuid of new main."); case SUCCESS: break; } diff --git a/src/query/replication_query_handler.hpp b/src/query/replication_query_handler.hpp index f391b4867..aa0611a43 100644 --- a/src/query/replication_query_handler.hpp +++ b/src/query/replication_query_handler.hpp @@ -13,6 +13,7 @@ #include "replication_coordination_glue/role.hpp" #include "utils/result.hpp" +#include "utils/uuid.hpp" // BEGIN fwd declares namespace memgraph::replication { @@ -23,7 +24,13 @@ struct ReplicationClientConfig; namespace memgraph::query { -enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, CONNECTION_FAILED, COULD_NOT_BE_PERSISTED }; +enum class RegisterReplicaError : uint8_t { + NAME_EXISTS, + ENDPOINT_EXISTS, + CONNECTION_FAILED, + COULD_NOT_BE_PERSISTED, + ERROR_ACCEPTING_MAIN +}; enum class UnregisterReplicaResult : uint8_t { NOT_MAIN, COULD_NOT_BE_PERSISTED, @@ -39,13 +46,14 @@ struct ReplicationQueryHandler { virtual bool SetReplicationRoleMain() = 0; // as MAIN, become REPLICA - virtual bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) = 0; + virtual bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, + const std::optional<utils::UUID> &main_uuid) = 0; // as MAIN, define and connect to REPLICAs - virtual auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + virtual auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) -> utils::BasicResult<RegisterReplicaError> = 0; - virtual auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + virtual auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) -> utils::BasicResult<RegisterReplicaError> = 0; // as MAIN, remove a REPLICA connection diff --git a/src/replication/CMakeLists.txt b/src/replication/CMakeLists.txt index 676dce744..3a6613ab9 100644 --- a/src/replication/CMakeLists.txt +++ b/src/replication/CMakeLists.txt @@ -21,6 +21,6 @@ target_include_directories(mg-replication PUBLIC include) find_package(fmt REQUIRED) target_link_libraries(mg-replication - PUBLIC mg::utils mg::kvstore lib::json mg::rpc mg::slk mg::io mg::repl_coord_glue + PUBLIC mg::utils mg::kvstore lib::json mg::rpc mg::slk mg::io mg::repl_coord_glue mg-flags PRIVATE fmt::fmt ) diff --git a/src/replication/include/replication/messages.hpp b/src/replication/include/replication/messages.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/src/replication/include/replication/replication_client.hpp b/src/replication/include/replication/replication_client.hpp index 0bf0e424f..92c321ac6 100644 --- a/src/replication/include/replication/replication_client.hpp +++ b/src/replication/include/replication/replication_client.hpp @@ -54,7 +54,7 @@ struct ReplicationClient { } catch (const rpc::RpcFailedException &) { // Nothing to do...wait for a reconnect // NOTE: Here we are communicating with the instance connection. - // We don't have access to the undelying client; so the only thing we can do it + // We don't have access to the underlying client; so the only thing we can do it // tell the callback that this is a reconnection and to check the state reconnect = true; } @@ -106,6 +106,9 @@ struct ReplicationClient { communication::ClientContext rpc_context_; rpc::Client rpc_client_; std::chrono::seconds replica_check_frequency_; + // True only when we are migrating from V1 or V2 to V3 in replication durability + // and we want to set replica to listen to main + bool try_set_uuid{false}; // TODO: Better, this was the easiest place to put this enum class State { diff --git a/src/replication/include/replication/state.hpp b/src/replication/include/replication/state.hpp index a53885aff..18f9efd4e 100644 --- a/src/replication/include/replication/state.hpp +++ b/src/replication/include/replication/state.hpp @@ -21,10 +21,12 @@ #include "status.hpp" #include "utils/result.hpp" #include "utils/synchronized.hpp" +#include "utils/uuid.hpp" #include <atomic> #include <cstdint> #include <list> +#include <optional> #include <variant> #include <vector> @@ -37,7 +39,11 @@ enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, COULD_ struct RoleMainData { RoleMainData() = default; - explicit RoleMainData(ReplicationEpoch e) : epoch_(std::move(e)) {} + explicit RoleMainData(ReplicationEpoch e, std::optional<utils::UUID> uuid = std::nullopt) : epoch_(std::move(e)) { + if (uuid) { + uuid_ = *uuid; + } + } ~RoleMainData() = default; RoleMainData(RoleMainData const &) = delete; @@ -47,11 +53,14 @@ struct RoleMainData { ReplicationEpoch epoch_; std::list<ReplicationClient> registered_replicas_{}; // TODO: data race issues + utils::UUID uuid_; }; struct RoleReplicaData { ReplicationServerConfig config; std::unique_ptr<ReplicationServer> server; + // uuid of main replica is listening to + std::optional<utils::UUID> uuid_; }; // Global (instance) level object @@ -83,18 +92,19 @@ struct ReplicationState { bool HasDurability() const { return nullptr != durability_; } - bool TryPersistRoleMain(std::string new_epoch); - bool TryPersistRoleReplica(const ReplicationServerConfig &config); + bool TryPersistRoleMain(std::string new_epoch, utils::UUID main_uuid); + bool TryPersistRoleReplica(const ReplicationServerConfig &config, const std::optional<utils::UUID> &main_uuid); bool TryPersistUnregisterReplica(std::string_view name); - bool TryPersistRegisteredReplica(const ReplicationClientConfig &config); + bool TryPersistRegisteredReplica(const ReplicationClientConfig &config, utils::UUID main_uuid); // TODO: locked access auto ReplicationData() -> ReplicationData_t & { return replication_data_; } auto ReplicationData() const -> ReplicationData_t const & { return replication_data_; } utils::BasicResult<RegisterReplicaError, ReplicationClient *> RegisterReplica(const ReplicationClientConfig &config); - bool SetReplicationRoleMain(); - bool SetReplicationRoleReplica(const ReplicationServerConfig &config); + bool SetReplicationRoleMain(const utils::UUID &main_uuid); + bool SetReplicationRoleReplica(const ReplicationServerConfig &config, + const std::optional<utils::UUID> &main_uuid = std::nullopt); private: bool HandleVersionMigration(durability::ReplicationRoleEntry &data) const; diff --git a/src/replication/include/replication/status.hpp b/src/replication/include/replication/status.hpp index 4dfba6aaa..484cba848 100644 --- a/src/replication/include/replication/status.hpp +++ b/src/replication/include/replication/status.hpp @@ -31,25 +31,28 @@ constexpr auto *kReplicationReplicaPrefix{"__replication_replica:"}; // introdu enum class DurabilityVersion : uint8_t { V1, // no distinct key for replicas - V2, // this version, epoch, replica prefix introduced + V2, // epoch, replica prefix introduced + V3, // this version, main uuid introduced }; // fragment of key: "__replication_role" struct MainRole { ReplicationEpoch epoch{}; + std::optional<utils::UUID> main_uuid{}; friend bool operator==(MainRole const &, MainRole const &) = default; }; // fragment of key: "__replication_role" struct ReplicaRole { ReplicationServerConfig config{}; + std::optional<utils::UUID> main_uuid{}; friend bool operator==(ReplicaRole const &, ReplicaRole const &) = default; }; // from key: "__replication_role" struct ReplicationRoleEntry { DurabilityVersion version = - DurabilityVersion::V2; // if not latest then migration required for kReplicationReplicaPrefix + DurabilityVersion::V3; // if not latest then migration required for kReplicationReplicaPrefix std::variant<MainRole, ReplicaRole> role; friend bool operator==(ReplicationRoleEntry const &, ReplicationRoleEntry const &) = default; diff --git a/src/replication/messages.cpp b/src/replication/messages.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/src/replication/replication_server.cpp b/src/replication/replication_server.cpp index 03c48d298..5b7baa17f 100644 --- a/src/replication/replication_server.cpp +++ b/src/replication/replication_server.cpp @@ -10,7 +10,7 @@ // licenses/APL.txt. #include "replication/replication_server.hpp" -#include "replication_coordination_glue/messages.hpp" +#include "replication_coordination_glue/handler.hpp" namespace memgraph::replication { namespace { diff --git a/src/replication/state.cpp b/src/replication/state.cpp index d04a3d245..6cc4ff951 100644 --- a/src/replication/state.cpp +++ b/src/replication/state.cpp @@ -10,12 +10,15 @@ // licenses/APL.txt. #include "replication/state.hpp" +#include <optional> +#include "flags/replication.hpp" #include "replication/replication_client.hpp" #include "replication/replication_server.hpp" #include "replication/status.hpp" #include "utils/file.hpp" #include "utils/result.hpp" +#include "utils/uuid.hpp" #include "utils/variant_helpers.hpp" constexpr auto kReplicationDirectory = std::string_view{"replication"}; @@ -36,9 +39,9 @@ ReplicationState::ReplicationState(std::optional<std::filesystem::path> durabili durability_ = std::make_unique<kvstore::KVStore>(std::move(repl_dir)); spdlog::info("Replication configuration will be stored and will be automatically restored in case of a crash."); - auto replicationData = FetchReplicationData(); - if (replicationData.HasError()) { - switch (replicationData.GetError()) { + auto fetched_replication_data = FetchReplicationData(); + if (fetched_replication_data.HasError()) { + switch (fetched_replication_data.GetError()) { using enum ReplicationState::FetchReplicationError; case NOTHING_FETCHED: { spdlog::debug("Cannot find data needed for restore replication role in persisted metadata."); @@ -51,15 +54,21 @@ ReplicationState::ReplicationState(std::optional<std::filesystem::path> durabili } } } - replication_data_ = std::move(replicationData).GetValue(); + auto replication_data = std::move(fetched_replication_data).GetValue(); +#ifdef MG_ENTERPRISE + if (FLAGS_coordinator_server_port && std::holds_alternative<RoleReplicaData>(replication_data)) { + std::get<RoleReplicaData>(replication_data).uuid_.reset(); + } +#endif + replication_data_ = std::move(replication_data); } -bool ReplicationState::TryPersistRoleReplica(const ReplicationServerConfig &config) { +bool ReplicationState::TryPersistRoleReplica(const ReplicationServerConfig &config, + const std::optional<utils::UUID> &main_uuid) { if (!HasDurability()) return true; - auto data = durability::ReplicationRoleEntry{.role = durability::ReplicaRole{ - .config = config, - }}; + auto data = + durability::ReplicationRoleEntry{.role = durability::ReplicaRole{.config = config, .main_uuid = main_uuid}}; if (!durability_->Put(durability::kReplicationRoleName, nlohmann::json(data).dump())) { spdlog::error("Error when saving REPLICA replication role in settings."); @@ -78,11 +87,11 @@ bool ReplicationState::TryPersistRoleReplica(const ReplicationServerConfig &conf return true; } -bool ReplicationState::TryPersistRoleMain(std::string new_epoch) { +bool ReplicationState::TryPersistRoleMain(std::string new_epoch, utils::UUID main_uuid) { if (!HasDurability()) return true; - auto data = - durability::ReplicationRoleEntry{.role = durability::MainRole{.epoch = ReplicationEpoch{std::move(new_epoch)}}}; + auto data = durability::ReplicationRoleEntry{ + .role = durability::MainRole{.epoch = ReplicationEpoch{std::move(new_epoch)}, .main_uuid = main_uuid}}; if (durability_->Put(durability::kReplicationRoleName, nlohmann::json(data).dump())) { role_persisted = RolePersisted::YES; @@ -128,7 +137,8 @@ auto ReplicationState::FetchReplicationData() -> FetchReplicationResult_t { return std::visit( utils::Overloaded{ [&](durability::MainRole &&r) -> FetchReplicationResult_t { - auto res = RoleMainData{std::move(r.epoch)}; + auto res = + RoleMainData{std::move(r.epoch), r.main_uuid.has_value() ? r.main_uuid.value() : utils::UUID{}}; auto b = durability_->begin(durability::kReplicationReplicaPrefix); auto e = durability_->end(durability::kReplicationReplicaPrefix); for (; b != e; ++b) { @@ -143,6 +153,8 @@ auto ReplicationState::FetchReplicationData() -> FetchReplicationResult_t { } // Instance clients res.registered_replicas_.emplace_back(data.config); + // Bump for each replica uuid + res.registered_replicas_.back().try_set_uuid = !r.main_uuid.has_value(); } catch (...) { return FetchReplicationError::PARSE_ERROR; } @@ -150,7 +162,9 @@ auto ReplicationState::FetchReplicationData() -> FetchReplicationResult_t { return {std::move(res)}; }, [&](durability::ReplicaRole &&r) -> FetchReplicationResult_t { - return {RoleReplicaData{r.config, std::make_unique<ReplicationServer>(r.config)}}; + // False positive report for the std::make_unique + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + return {RoleReplicaData{r.config, std::make_unique<ReplicationServer>(r.config), r.main_uuid}}; }, }, std::move(data.role)); @@ -192,21 +206,29 @@ bool ReplicationState::HandleVersionMigration(durability::ReplicationRoleEntry & [[fallthrough]]; } case durability::DurabilityVersion::V2: { - // do nothing - add code if V3 ever happens + if (std::holds_alternative<durability::MainRole>(data.role)) { + auto &main = std::get<durability::MainRole>(data.role); + main.main_uuid = utils::UUID{}; + } + data.version = durability::DurabilityVersion::V3; + break; + } + case durability::DurabilityVersion::V3: { + // do nothing - add code if V4 ever happens break; } } return true; } -bool ReplicationState::TryPersistRegisteredReplica(const ReplicationClientConfig &config) { +bool ReplicationState::TryPersistRegisteredReplica(const ReplicationClientConfig &config, utils::UUID main_uuid) { if (!HasDurability()) return true; // If any replicas are persisted then Role must be persisted if (role_persisted != RolePersisted::YES) { DMG_ASSERT(IsMain(), "MAIN is expected"); auto epoch_str = std::string(std::get<RoleMainData>(replication_data_).epoch_.id()); - if (!TryPersistRoleMain(std::move(epoch_str))) return false; + if (!TryPersistRoleMain(std::move(epoch_str), main_uuid)) return false; } auto data = durability::ReplicationReplicaEntry{.config = config}; @@ -217,22 +239,28 @@ bool ReplicationState::TryPersistRegisteredReplica(const ReplicationClientConfig return false; } -bool ReplicationState::SetReplicationRoleMain() { +bool ReplicationState::SetReplicationRoleMain(const utils::UUID &main_uuid) { auto new_epoch = utils::GenerateUUID(); - if (!TryPersistRoleMain(new_epoch)) { + if (!TryPersistRoleMain(new_epoch, main_uuid)) { return false; } - replication_data_ = RoleMainData{ReplicationEpoch{new_epoch}}; + + replication_data_ = RoleMainData{ReplicationEpoch{new_epoch}, main_uuid}; return true; } -bool ReplicationState::SetReplicationRoleReplica(const ReplicationServerConfig &config) { - if (!TryPersistRoleReplica(config)) { +bool ReplicationState::SetReplicationRoleReplica(const ReplicationServerConfig &config, + const std::optional<utils::UUID> &main_uuid) { + // False positive report for the std::make_unique + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + if (!TryPersistRoleReplica(config, main_uuid)) { return false; } - replication_data_ = RoleReplicaData{config, std::make_unique<ReplicationServer>(config)}; + // False positive report for the std::make_unique + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + replication_data_ = RoleReplicaData{config, std::make_unique<ReplicationServer>(config), std::nullopt}; return true; } @@ -264,7 +292,7 @@ utils::BasicResult<RegisterReplicaError, ReplicationClient *> ReplicationState:: } // Durability - if (!TryPersistRegisteredReplica(config)) { + if (!TryPersistRegisteredReplica(config, mainData.uuid_)) { return RegisterReplicaError::COULD_NOT_BE_PERSISTED; } diff --git a/src/replication/status.cpp b/src/replication/status.cpp index de1af9589..f578a85c5 100644 --- a/src/replication/status.cpp +++ b/src/replication/status.cpp @@ -26,21 +26,28 @@ constexpr auto *kSSLCertFile = "replica_ssl_cert_file"; constexpr auto *kReplicationRole = "replication_role"; constexpr auto *kEpoch = "epoch"; constexpr auto *kVersion = "durability_version"; +constexpr auto *kMainUUID = "main_uuid"; void to_json(nlohmann::json &j, const ReplicationRoleEntry &p) { auto processMAIN = [&](MainRole const &main) { - j = nlohmann::json{{kVersion, p.version}, - {kReplicationRole, replication_coordination_glue::ReplicationRole::MAIN}, - {kEpoch, main.epoch.id()}}; + auto common = nlohmann::json{{kVersion, p.version}, + {kReplicationRole, replication_coordination_glue::ReplicationRole::MAIN}, + {kEpoch, main.epoch.id()}}; + if (p.version != DurabilityVersion::V1 && p.version != DurabilityVersion::V2) { + MG_ASSERT(main.main_uuid.has_value(), "Main should have id ready on version >= V3"); + common[kMainUUID] = main.main_uuid.value(); + } + j = std::move(common); }; auto processREPLICA = [&](ReplicaRole const &replica) { - j = nlohmann::json{ - {kVersion, p.version}, - {kReplicationRole, replication_coordination_glue::ReplicationRole::REPLICA}, - {kIpAddress, replica.config.ip_address}, - {kPort, replica.config.port} - // TODO: SSL - }; + auto common = nlohmann::json{{kVersion, p.version}, + {kReplicationRole, replication_coordination_glue::ReplicationRole::REPLICA}, + {kIpAddress, replica.config.ip_address}, + {kPort, replica.config.port}}; + if (replica.main_uuid.has_value()) { + common[kMainUUID] = replica.main_uuid.value(); + } + j = std::move(common); }; std::visit(utils::Overloaded{processMAIN, processREPLICA}, p.role); } @@ -56,7 +63,12 @@ void from_json(const nlohmann::json &j, ReplicationRoleEntry &p) { auto json_epoch = j.value(kEpoch, std::string{}); auto epoch = ReplicationEpoch{}; if (!json_epoch.empty()) epoch.SetEpoch(json_epoch); - p = ReplicationRoleEntry{.version = version, .role = MainRole{.epoch = std::move(epoch)}}; + auto main_role = MainRole{.epoch = std::move(epoch)}; + + if (j.contains(kMainUUID)) { + main_role.main_uuid = j.at(kMainUUID); + } + p = ReplicationRoleEntry{.version = version, .role = std::move(main_role)}; break; } case memgraph::replication_coordination_glue::ReplicationRole::REPLICA: { @@ -66,7 +78,13 @@ void from_json(const nlohmann::json &j, ReplicationRoleEntry &p) { j.at(kIpAddress).get_to(ip_address); j.at(kPort).get_to(port); auto config = ReplicationServerConfig{.ip_address = std::move(ip_address), .port = port}; - p = ReplicationRoleEntry{.version = version, .role = ReplicaRole{.config = std::move(config)}}; + auto replica_role = ReplicaRole{.config = std::move(config)}; + if (j.contains(kMainUUID)) { + replica_role.main_uuid = j.at(kMainUUID); + } + + p = ReplicationRoleEntry{.version = version, .role = std::move(replica_role)}; + break; } } diff --git a/src/replication_coordination_glue/CMakeLists.txt b/src/replication_coordination_glue/CMakeLists.txt index 010a7b596..f81aed4ba 100644 --- a/src/replication_coordination_glue/CMakeLists.txt +++ b/src/replication_coordination_glue/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources(mg-repl_coord_glue messages.hpp mode.hpp role.hpp + handler.hpp PRIVATE messages.cpp diff --git a/src/replication_coordination_glue/handler.hpp b/src/replication_coordination_glue/handler.hpp new file mode 100644 index 000000000..2076b8fa2 --- /dev/null +++ b/src/replication_coordination_glue/handler.hpp @@ -0,0 +1,41 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +#pragma once + +#include "rpc/client.hpp" +#include "utils/uuid.hpp" + +#include "messages.hpp" +#include "rpc/messages.hpp" + +namespace memgraph::replication_coordination_glue { +inline bool SendSwapMainUUIDRpc(memgraph::rpc::Client &rpc_client_, const memgraph::utils::UUID &uuid) { + try { + auto stream{rpc_client_.Stream<SwapMainUUIDRpc>(uuid)}; + if (!stream.AwaitResponse().success) { + spdlog::error("Failed to receive successful RPC swapping of uuid response!"); + return false; + } + return true; + } catch (const memgraph::rpc::RpcFailedException &) { + spdlog::error("RPC error occurred while sending swapping uuid RPC!"); + } + return false; +} + +inline void FrequentHeartbeatHandler(slk::Reader *req_reader, slk::Builder *res_builder) { + FrequentHeartbeatReq req; + FrequentHeartbeatReq::Load(&req, req_reader); + memgraph::slk::Load(&req, req_reader); + FrequentHeartbeatRes res{}; + memgraph::slk::Save(res, res_builder); +} +} // namespace memgraph::replication_coordination_glue diff --git a/src/replication_coordination_glue/messages.cpp b/src/replication_coordination_glue/messages.cpp index c7cf0b15c..ad9d21a37 100644 --- a/src/replication_coordination_glue/messages.cpp +++ b/src/replication_coordination_glue/messages.cpp @@ -29,6 +29,25 @@ void Load(memgraph::replication_coordination_glue::FrequentHeartbeatReq * /*self /* Nothing to serialize */ } +// Serialize code for SwapMainUUIDRes + +void Save(const memgraph::replication_coordination_glue::SwapMainUUIDRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); +} + +void Load(memgraph::replication_coordination_glue::SwapMainUUIDRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); +} + +// Serialize code for SwapMainUUIDReq +void Save(const memgraph::replication_coordination_glue::SwapMainUUIDReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.uuid, builder); +} + +void Load(memgraph::replication_coordination_glue::SwapMainUUIDReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->uuid, reader); +} + } // namespace memgraph::slk namespace memgraph::replication_coordination_glue { @@ -39,6 +58,10 @@ constexpr utils::TypeInfo FrequentHeartbeatReq::kType{utils::TypeId::REP_FREQUEN constexpr utils::TypeInfo FrequentHeartbeatRes::kType{utils::TypeId::REP_FREQUENT_HEARTBEAT_RES, "FrequentHeartbeatRes", nullptr}; +constexpr utils::TypeInfo SwapMainUUIDReq::kType{utils::TypeId::COORD_SWAP_UUID_REQ, "SwapUUIDReq", nullptr}; + +constexpr utils::TypeInfo SwapMainUUIDRes::kType{utils::TypeId::COORD_SWAP_UUID_RES, "SwapUUIDRes", nullptr}; + void FrequentHeartbeatReq::Save(const FrequentHeartbeatReq &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); } @@ -52,12 +75,16 @@ void FrequentHeartbeatRes::Load(FrequentHeartbeatRes *self, memgraph::slk::Reade memgraph::slk::Load(self, reader); } -void FrequentHeartbeatHandler(slk::Reader *req_reader, slk::Builder *res_builder) { - FrequentHeartbeatReq req; - FrequentHeartbeatReq::Load(&req, req_reader); - memgraph::slk::Load(&req, req_reader); - FrequentHeartbeatRes res{}; - memgraph::slk::Save(res, res_builder); +void SwapMainUUIDReq::Save(const SwapMainUUIDReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); } +void SwapMainUUIDReq::Load(SwapMainUUIDReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } + +void SwapMainUUIDRes::Save(const SwapMainUUIDRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} + +void SwapMainUUIDRes::Load(SwapMainUUIDRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } + } // namespace memgraph::replication_coordination_glue diff --git a/src/replication_coordination_glue/messages.hpp b/src/replication_coordination_glue/messages.hpp index 5e2ef0fdf..81ce59a12 100644 --- a/src/replication_coordination_glue/messages.hpp +++ b/src/replication_coordination_glue/messages.hpp @@ -13,6 +13,7 @@ #include "rpc/messages.hpp" #include "slk/serialization.hpp" +#include "utils/uuid.hpp" namespace memgraph::replication_coordination_glue { @@ -36,7 +37,34 @@ struct FrequentHeartbeatRes { using FrequentHeartbeatRpc = rpc::RequestResponse<FrequentHeartbeatReq, FrequentHeartbeatRes>; -void FrequentHeartbeatHandler(slk::Reader *req_reader, slk::Builder *res_builder); +struct SwapMainUUIDReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(SwapMainUUIDReq *self, memgraph::slk::Reader *reader); + static void Save(const SwapMainUUIDReq &self, memgraph::slk::Builder *builder); + + explicit SwapMainUUIDReq(const utils::UUID &uuid) : uuid(uuid) {} + + SwapMainUUIDReq() = default; + + utils::UUID uuid; +}; + +struct SwapMainUUIDRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(SwapMainUUIDRes *self, memgraph::slk::Reader *reader); + static void Save(const SwapMainUUIDRes &self, memgraph::slk::Builder *builder); + + explicit SwapMainUUIDRes(bool success) : success(success) {} + SwapMainUUIDRes() = default; + + bool success; +}; + +using SwapMainUUIDRpc = rpc::RequestResponse<SwapMainUUIDReq, SwapMainUUIDRes>; } // namespace memgraph::replication_coordination_glue @@ -46,4 +74,10 @@ void Load(memgraph::replication_coordination_glue::FrequentHeartbeatRes *self, m void Save(const memgraph::replication_coordination_glue::FrequentHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/); void Load(memgraph::replication_coordination_glue::FrequentHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/); + +// SwapMainUUIDRpc +void Save(const memgraph::replication_coordination_glue::SwapMainUUIDReq &self, memgraph::slk::Builder *builder); +void Load(memgraph::replication_coordination_glue::SwapMainUUIDReq *self, memgraph::slk::Reader *reader); +void Save(const memgraph::replication_coordination_glue::SwapMainUUIDRes &self, memgraph::slk::Builder *builder); +void Load(memgraph::replication_coordination_glue::SwapMainUUIDRes *self, memgraph::slk::Reader *reader); } // namespace memgraph::slk diff --git a/src/replication_handler/CMakeLists.txt b/src/replication_handler/CMakeLists.txt index a0cd3734c..2ba563ce2 100644 --- a/src/replication_handler/CMakeLists.txt +++ b/src/replication_handler/CMakeLists.txt @@ -7,8 +7,8 @@ target_sources(mg-replication_handler include/replication_handler/system_rpc.hpp PRIVATE - replication_handler.cpp system_replication.cpp + replication_handler.cpp system_rpc.cpp ) target_include_directories(mg-replication_handler PUBLIC include) diff --git a/src/replication_handler/include/replication_handler/replication_handler.hpp b/src/replication_handler/include/replication_handler/replication_handler.hpp index 1ae9ceb6d..663b30f54 100644 --- a/src/replication_handler/include/replication_handler/replication_handler.hpp +++ b/src/replication_handler/include/replication_handler/replication_handler.hpp @@ -22,10 +22,10 @@ inline std::optional<query::RegisterReplicaError> HandleRegisterReplicaStatus( utils::BasicResult<replication::RegisterReplicaError, replication::ReplicationClient *> &instance_client); #ifdef MG_ENTERPRISE -void StartReplicaClient(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler, - auth::SynchedAuth &auth); +void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler, utils::UUID main_uuid, + system::System *system, auth::SynchedAuth &auth); #else -void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler); +void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler, utils::UUID main_uuid); #endif #ifdef MG_ENTERPRISE @@ -33,8 +33,8 @@ void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandle // When being called by interpreter no need to gain lock, it should already be under a system transaction // But concurrently the FrequentCheck is running and will need to lock before reading last_committed_system_timestamp_ template <bool REQUIRE_LOCK = false> -void SystemRestore(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler, - auth::SynchedAuth &auth) { +void SystemRestore(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler, + const utils::UUID &main_uuid, system::System *system, auth::SynchedAuth &auth) { // Check if system is up to date if (client.state_.WithLock( [](auto &state) { return state == memgraph::replication::ReplicationClient::State::READY; })) @@ -69,12 +69,12 @@ void SystemRestore(replication::ReplicationClient &client, system::System *syste // Handle only default database is no license if (!license::global_license_checker.IsEnterpriseValidFast()) { return client.rpc_client_.Stream<replication::SystemRecoveryRpc>( - db_info.last_committed_timestamp, std::move(db_info.configs), auth::Auth::Config{}, + main_uuid, db_info.last_committed_timestamp, std::move(db_info.configs), auth::Auth::Config{}, std::vector<auth::User>{}, std::vector<auth::Role>{}); } return auth.WithLock([&](auto &locked_auth) { return client.rpc_client_.Stream<replication::SystemRecoveryRpc>( - db_info.last_committed_timestamp, std::move(db_info.configs), locked_auth.GetConfig(), + main_uuid, db_info.last_committed_timestamp, std::move(db_info.configs), locked_auth.GetConfig(), locked_auth.AllUsers(), locked_auth.AllRoles()); }); }); @@ -109,28 +109,32 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { bool SetReplicationRoleMain() override; // as MAIN, become REPLICA - bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) override; + bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, + const std::optional<utils::UUID> &main_uuid) override; // as MAIN, define and connect to REPLICAs - auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) -> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> override; - auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) -> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> override; // as MAIN, remove a REPLICA connection auto UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult override; - bool DoReplicaToMainPromotion(); + bool DoReplicaToMainPromotion(const utils::UUID &main_uuid); // Helper pass-through (TODO: remove) auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole override; bool IsMain() const override; bool IsReplica() const override; + auto GetReplState() const -> const memgraph::replication::ReplicationState &; + auto GetReplState() -> memgraph::replication::ReplicationState &; + private: template <bool HandleFailure> - auto RegisterReplica_(const memgraph::replication::ReplicationClientConfig &config) + auto RegisterReplica_(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) -> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> { MG_ASSERT(repl_state_.IsMain(), "Only main instance can register a replica!"); @@ -154,10 +158,19 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { if (!memgraph::dbms::allow_mt_repl && dbms_handler_.All().size() > 1) { spdlog::warn("Multi-tenant replication is currently not supported!"); } + const auto main_uuid = + std::get<memgraph::replication::RoleMainData>(dbms_handler_.ReplicationState().ReplicationData()).uuid_; + + if (send_swap_uuid) { + if (!memgraph::replication_coordination_glue::SendSwapMainUUIDRpc(maybe_client.GetValue()->rpc_client_, + main_uuid)) { + return memgraph::query::RegisterReplicaError::ERROR_ACCEPTING_MAIN; + } + } #ifdef MG_ENTERPRISE // Update system before enabling individual storage <-> replica clients - SystemRestore(*maybe_client.GetValue(), system_, dbms_handler_, auth_); + SystemRestore(*maybe_client.GetValue(), dbms_handler_, main_uuid, system_, auth_); #endif const auto dbms_error = HandleRegisterReplicaStatus(maybe_client); @@ -177,8 +190,9 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return; all_clients_good &= storage->repl_storage_state_.replication_clients_.WithLock( - [storage, &instance_client_ptr, db_acc = std::move(db_acc)](auto &storage_clients) mutable { // NOLINT - auto client = std::make_unique<storage::ReplicationStorageClient>(*instance_client_ptr); + [storage, &instance_client_ptr, db_acc = std::move(db_acc), + main_uuid](auto &storage_clients) mutable { // NOLINT + auto client = std::make_unique<storage::ReplicationStorageClient>(*instance_client_ptr, main_uuid); // All good, start replica client client->Start(storage, std::move(db_acc)); // After start the storage <-> replica state should be READY or RECOVERING (if correctly started) @@ -201,9 +215,9 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { // No client error, start instance level client #ifdef MG_ENTERPRISE - StartReplicaClient(*instance_client_ptr, system_, dbms_handler_, auth_); + StartReplicaClient(*instance_client_ptr, dbms_handler_, main_uuid, system_, auth_); #else - StartReplicaClient(*instance_client_ptr, dbms_handler_); + StartReplicaClient(*instance_client_ptr, dbms_handler_, main_uuid); #endif return {}; } diff --git a/src/replication_handler/include/replication_handler/system_replication.hpp b/src/replication_handler/include/replication_handler/system_replication.hpp index e1d177fc6..27039d0ff 100644 --- a/src/replication_handler/include/replication_handler/system_replication.hpp +++ b/src/replication_handler/include/replication_handler/system_replication.hpp @@ -17,15 +17,23 @@ #include "system/state.hpp" namespace memgraph::replication { + +inline void LogWrongMain(const std::optional<utils::UUID> ¤t_main_uuid, const utils::UUID &main_req_id, + std::string_view rpc_req) { + spdlog::error("Received {} with main_id: {} != current_main_uuid: {}", rpc_req, std::string(main_req_id), + current_main_uuid.has_value() ? std::string(current_main_uuid.value()) : ""); +} + #ifdef MG_ENTERPRISE -void SystemHeartbeatHandler(uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder); +void SystemHeartbeatHandler(uint64_t ts, const std::optional<utils::UUID> ¤t_main_uuid, slk::Reader *req_reader, + slk::Builder *res_builder); void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, - dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth, slk::Reader *req_reader, - slk::Builder *res_builder); + std::optional<utils::UUID> ¤t_main_uuid, dbms::DbmsHandler &dbms_handler, + auth::SynchedAuth &auth, slk::Reader *req_reader, slk::Builder *res_builder); void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth); -bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data, auth::SynchedAuth &auth); +bool StartRpcServer(dbms::DbmsHandler &dbms_handler, replication::RoleReplicaData &data, auth::SynchedAuth &auth); #else -bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data); +bool StartRpcServer(dbms::DbmsHandler &dbms_handler, replication::RoleReplicaData &data); #endif } // namespace memgraph::replication diff --git a/src/replication_handler/include/replication_handler/system_rpc.hpp b/src/replication_handler/include/replication_handler/system_rpc.hpp index a2469fc5d..661994a24 100644 --- a/src/replication_handler/include/replication_handler/system_rpc.hpp +++ b/src/replication_handler/include/replication_handler/system_rpc.hpp @@ -27,6 +27,8 @@ struct SystemHeartbeatReq { static void Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader); static void Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder); SystemHeartbeatReq() = default; + explicit SystemHeartbeatReq(const utils::UUID &main_uuid) : main_uuid(main_uuid) {} + utils::UUID main_uuid; }; struct SystemHeartbeatRes { @@ -50,14 +52,17 @@ struct SystemRecoveryReq { static void Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader); static void Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder); SystemRecoveryReq() = default; - SystemRecoveryReq(uint64_t forced_group_timestamp, std::vector<storage::SalientConfig> database_configs, - auth::Auth::Config auth_config, std::vector<auth::User> users, std::vector<auth::Role> roles) - : forced_group_timestamp{forced_group_timestamp}, + SystemRecoveryReq(const utils::UUID &main_uuid, uint64_t forced_group_timestamp, + std::vector<storage::SalientConfig> database_configs, auth::Auth::Config auth_config, + std::vector<auth::User> users, std::vector<auth::Role> roles) + : main_uuid(main_uuid), + forced_group_timestamp{forced_group_timestamp}, database_configs(std::move(database_configs)), auth_config(std::move(auth_config)), users{std::move(users)}, roles{std::move(roles)} {} + utils::UUID main_uuid; uint64_t forced_group_timestamp; std::vector<storage::SalientConfig> database_configs; auth::Auth::Config auth_config; diff --git a/src/replication_handler/replication_handler.cpp b/src/replication_handler/replication_handler.cpp index cf1800168..ed0a095c8 100644 --- a/src/replication_handler/replication_handler.cpp +++ b/src/replication_handler/replication_handler.cpp @@ -24,14 +24,18 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, mem */ // Startup replication state (if recovered at startup) - auto replica = [&dbms_handler, &auth](memgraph::replication::RoleReplicaData const &data) { - return memgraph::replication::StartRpcServer(dbms_handler, data, auth); + auto replica = [&dbms_handler, &auth](memgraph::replication::RoleReplicaData &data) { + return StartRpcServer(dbms_handler, data, auth); }; // Replication recovery and frequent check start auto main = [system, &dbms_handler, &auth](memgraph::replication::RoleMainData &mainData) { for (auto &client : mainData.registered_replicas_) { - memgraph::replication::SystemRestore(client, system, dbms_handler, auth); + if (client.try_set_uuid && + replication_coordination_glue::SendSwapMainUUIDRpc(client.rpc_client_, mainData.uuid_)) { + client.try_set_uuid = false; + } + SystemRestore(client, dbms_handler, mainData.uuid_, system, auth); } // DBMS here dbms_handler.ForEach([&mainData](memgraph::dbms::DatabaseAccess db_acc) { @@ -39,7 +43,7 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, mem }); for (auto &client : mainData.registered_replicas_) { - memgraph::replication::StartReplicaClient(client, system, dbms_handler, auth); + StartReplicaClient(client, dbms_handler, mainData.uuid_, system, auth); } // Warning @@ -62,7 +66,7 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, mem void RecoverReplication(memgraph::replication::ReplicationState &repl_state, memgraph::dbms::DbmsHandler &dbms_handler) { // Startup replication state (if recovered at startup) - auto replica = [&dbms_handler](memgraph::replication::RoleReplicaData const &data) { + auto replica = [&dbms_handler](memgraph::replication::RoleReplicaData &data) { return memgraph::replication::StartRpcServer(dbms_handler, data); }; @@ -71,7 +75,11 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, dbms::DbmsHandler::RecoverStorageReplication(dbms_handler.Get(), mainData); for (auto &client : mainData.registered_replicas_) { - memgraph::replication::StartReplicaClient(client, dbms_handler); + if (client.try_set_uuid && + replication_coordination_glue::SendSwapMainUUIDRpc(client.rpc_client_, mainData.uuid_)) { + client.try_set_uuid = false; + } + memgraph::replication::StartReplicaClient(client, dbms_handler, mainData.uuid_); } // Warning @@ -112,10 +120,11 @@ inline std::optional<query::RegisterReplicaError> HandleRegisterReplicaStatus( } #ifdef MG_ENTERPRISE -void StartReplicaClient(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler, - auth::SynchedAuth &auth) { +void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler, utils::UUID main_uuid, + system::System *system, auth::SynchedAuth &auth) { #else -void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler) { +void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler, + utils::UUID main_uuid) { #endif // No client error, start instance level client auto const &endpoint = client.rpc_client_.Endpoint(); @@ -124,8 +133,12 @@ void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandle #ifdef MG_ENTERPRISE system = system, #endif - license = license::global_license_checker.IsEnterpriseValidFast()]( - bool reconnect, replication::ReplicationClient &client) mutable { + license = license::global_license_checker.IsEnterpriseValidFast(), + main_uuid](bool reconnect, replication::ReplicationClient &client) mutable { + if (client.try_set_uuid && + memgraph::replication_coordination_glue::SendSwapMainUUIDRpc(client.rpc_client_, main_uuid)) { + client.try_set_uuid = false; + } // Working connection // Check if system needs restoration if (reconnect) { @@ -138,7 +151,7 @@ void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandle client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); } #ifdef MG_ENTERPRISE - SystemRestore<true>(client, system, dbms_handler, auth); + SystemRestore<true>(client, dbms_handler, main_uuid, system, auth); #endif // Check if any database has been left behind dbms_handler.ForEach([&name = client.name_, reconnect](dbms::DatabaseAccess db_acc) { @@ -174,14 +187,15 @@ bool ReplicationHandler::SetReplicationRoleMain() { }; auto const replica_handler = [this](memgraph::replication::RoleReplicaData const &) { - return DoReplicaToMainPromotion(); + return DoReplicaToMainPromotion(utils::UUID{}); }; // TODO: under lock return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); } -bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) { +bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, + const std::optional<utils::UUID> &main_uuid) { // We don't want to restart the server if we're already a REPLICA if (repl_state_.IsReplica()) { return false; @@ -198,27 +212,26 @@ bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication:: std::get<memgraph::replication::RoleMainData>(repl_state_.ReplicationData()).registered_replicas_.clear(); // Creates the server - repl_state_.SetReplicationRoleReplica(config); + repl_state_.SetReplicationRoleReplica(config, main_uuid); // Start - const auto success = - std::visit(memgraph::utils::Overloaded{[](memgraph::replication::RoleMainData const &) { - // ASSERT - return false; - }, - [this](memgraph::replication::RoleReplicaData const &data) { + const auto success = std::visit(memgraph::utils::Overloaded{[](memgraph::replication::RoleMainData &) { + // ASSERT + return false; + }, + [this](memgraph::replication::RoleReplicaData &data) { #ifdef MG_ENTERPRISE - return StartRpcServer(dbms_handler_, data, auth_); + return StartRpcServer(dbms_handler_, data, auth_); #else - return StartRpcServer(dbms_handler_, data); + return StartRpcServer(dbms_handler_, data); #endif - }}, - repl_state_.ReplicationData()); + }}, + repl_state_.ReplicationData()); // TODO Handle error (restore to main?) return success; } -bool ReplicationHandler::DoReplicaToMainPromotion() { +bool ReplicationHandler::DoReplicaToMainPromotion(const utils::UUID &main_uuid) { // STEP 1) bring down all REPLICA servers dbms_handler_.ForEach([](dbms::DatabaseAccess db_acc) { auto *storage = db_acc->storage(); @@ -228,7 +241,7 @@ bool ReplicationHandler::DoReplicaToMainPromotion() { // STEP 2) Change to MAIN // TODO: restore replication servers if false? - if (!repl_state_.SetReplicationRoleMain()) { + if (!repl_state_.SetReplicationRoleMain(main_uuid)) { // TODO: Handle recovery on failure??? return false; } @@ -244,14 +257,16 @@ bool ReplicationHandler::DoReplicaToMainPromotion() { }; // as MAIN, define and connect to REPLICAs -auto ReplicationHandler::TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) +auto ReplicationHandler::TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config, + bool send_swap_uuid) -> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> { - return RegisterReplica_<false>(config); + return RegisterReplica_<false>(config, send_swap_uuid); } -auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) +auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config, + bool send_swap_uuid) -> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> { - return RegisterReplica_<true>(config); + return RegisterReplica_<true>(config, send_swap_uuid); } auto ReplicationHandler::UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult { @@ -284,6 +299,10 @@ auto ReplicationHandler::GetRole() const -> memgraph::replication_coordination_g return repl_state_.GetRole(); } +auto ReplicationHandler::GetReplState() const -> const memgraph::replication::ReplicationState & { return repl_state_; } + +auto ReplicationHandler::GetReplState() -> memgraph::replication::ReplicationState & { return repl_state_; } + bool ReplicationHandler::IsMain() const { return repl_state_.IsMain(); } bool ReplicationHandler::IsReplica() const { return repl_state_.IsReplica(); } diff --git a/src/replication_handler/system_replication.cpp b/src/replication_handler/system_replication.cpp index 4f818a567..dc9dd6f0c 100644 --- a/src/replication_handler/system_replication.cpp +++ b/src/replication_handler/system_replication.cpp @@ -21,7 +21,8 @@ namespace memgraph::replication { #ifdef MG_ENTERPRISE -void SystemHeartbeatHandler(const uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder) { +void SystemHeartbeatHandler(const uint64_t ts, const std::optional<utils::UUID> ¤t_main_uuid, + slk::Reader *req_reader, slk::Builder *res_builder) { replication::SystemHeartbeatRes res{0}; // Ignore if no license @@ -30,17 +31,23 @@ void SystemHeartbeatHandler(const uint64_t ts, slk::Reader *req_reader, slk::Bui memgraph::slk::Save(res, res_builder); return; } - replication::SystemHeartbeatReq req; replication::SystemHeartbeatReq::Load(&req, req_reader); + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, replication::SystemHeartbeatRes::kType.name); + replication::SystemHeartbeatRes res(-1); + memgraph::slk::Save(res, res_builder); + return; + } + res = replication::SystemHeartbeatRes{ts}; memgraph::slk::Save(res, res_builder); } void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, - dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth, slk::Reader *req_reader, - slk::Builder *res_builder) { + const std::optional<utils::UUID> ¤t_main_uuid, dbms::DbmsHandler &dbms_handler, + auth::SynchedAuth &auth, slk::Reader *req_reader, slk::Builder *res_builder) { using memgraph::replication::SystemRecoveryRes; SystemRecoveryRes res(SystemRecoveryRes::Result::FAILURE); @@ -49,6 +56,11 @@ void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system memgraph::replication::SystemRecoveryReq req; memgraph::slk::Load(&req, req_reader); + if (!current_main_uuid.has_value() || req.main_uuid != current_main_uuid) [[unlikely]] { + LogWrongMain(current_main_uuid, req.main_uuid, SystemRecoveryReq::kType.name); + return; + } + /* * DBMS */ @@ -74,15 +86,16 @@ void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_ auto system_state_access = dbms_handler.system_->CreateSystemStateAccess(); // System + // TODO: remove, as this is not used data.server->rpc_server_.Register<replication::SystemHeartbeatRpc>( - [system_state_access](auto *req_reader, auto *res_builder) { + [&data, system_state_access](auto *req_reader, auto *res_builder) { spdlog::debug("Received SystemHeartbeatRpc"); - SystemHeartbeatHandler(system_state_access.LastCommitedTS(), req_reader, res_builder); + SystemHeartbeatHandler(system_state_access.LastCommitedTS(), data.uuid_, req_reader, res_builder); }); data.server->rpc_server_.Register<replication::SystemRecoveryRpc>( - [system_state_access, &dbms_handler, &auth](auto *req_reader, auto *res_builder) mutable { + [&data, system_state_access, &dbms_handler, &auth](auto *req_reader, auto *res_builder) mutable { spdlog::debug("Received SystemRecoveryRpc"); - SystemRecoveryHandler(system_state_access, dbms_handler, auth, req_reader, res_builder); + SystemRecoveryHandler(system_state_access, data.uuid_, dbms_handler, auth, req_reader, res_builder); }); // DBMS @@ -94,13 +107,12 @@ void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_ #endif #ifdef MG_ENTERPRISE -bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data, - auth::SynchedAuth &auth) { +bool StartRpcServer(dbms::DbmsHandler &dbms_handler, replication::RoleReplicaData &data, auth::SynchedAuth &auth) { #else -bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data) { +bool StartRpcServer(dbms::DbmsHandler &dbms_handler, replication::RoleReplicaData &data) { #endif // Register storage handlers - dbms::InMemoryReplicationHandlers::Register(&dbms_handler, *data.server); + dbms::InMemoryReplicationHandlers::Register(&dbms_handler, data); #ifdef MG_ENTERPRISE // Register system handlers Register(data, dbms_handler, auth); @@ -112,4 +124,5 @@ bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleRepl } return true; } + } // namespace memgraph::replication diff --git a/src/replication_handler/system_rpc.cpp b/src/replication_handler/system_rpc.cpp index 0a0bd5e05..0dd8767b3 100644 --- a/src/replication_handler/system_rpc.cpp +++ b/src/replication_handler/system_rpc.cpp @@ -29,15 +29,16 @@ void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader } // Serialize code for SystemHeartbeatReq -void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/) { - /* Nothing to serialize */ +void Save(const memgraph::replication::SystemHeartbeatReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); } -void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/) { - /* Nothing to serialize */ +void Load(memgraph::replication::SystemHeartbeatReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); } // Serialize code for SystemRecoveryReq void Save(const memgraph::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.forced_group_timestamp, builder); memgraph::slk::Save(self.database_configs, builder); memgraph::slk::Save(self.auth_config, builder); @@ -46,6 +47,7 @@ void Save(const memgraph::replication::SystemRecoveryReq &self, memgraph::slk::B } void Load(memgraph::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->forced_group_timestamp, reader); memgraph::slk::Load(&self->database_configs, reader); memgraph::slk::Load(&self->auth_config, reader); diff --git a/src/rpc/client.hpp b/src/rpc/client.hpp index a9ae7202d..3a2fefd57 100644 --- a/src/rpc/client.hpp +++ b/src/rpc/client.hpp @@ -214,7 +214,6 @@ class Client { // Build and send the request. slk::Save(req_type.id, handler.GetBuilder()); slk::Save(rpc::current_version, handler.GetBuilder()); - TRequestResponse::Request::Save(request, handler.GetBuilder()); // Return the handler to the user. diff --git a/src/rpc/version.hpp b/src/rpc/version.hpp index b234a3ccc..e859cafd5 100644 --- a/src/rpc/version.hpp +++ b/src/rpc/version.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -28,6 +28,9 @@ constexpr auto v1 = Version{2023'10'30'0'2'13}; // for any TypeIds that get added. constexpr auto v2 = Version{2023'12'07'0'2'14}; -constexpr auto current_version = v2; +// To each RPC main uuid was added +constexpr auto v3 = Version{2024'02'02'0'2'14}; + +constexpr auto current_version = v3; } // namespace memgraph::rpc diff --git a/src/storage/v2/inmemory/replication/recovery.cpp b/src/storage/v2/inmemory/replication/recovery.cpp index d6f2b464c..921c1f5c0 100644 --- a/src/storage/v2/inmemory/replication/recovery.cpp +++ b/src/storage/v2/inmemory/replication/recovery.cpp @@ -18,6 +18,7 @@ #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/replication/recovery.hpp" #include "utils/on_scope_exit.hpp" +#include "utils/uuid.hpp" #include "utils/variant_helpers.hpp" namespace memgraph::storage { @@ -26,7 +27,8 @@ namespace memgraph::storage { // contained in the internal buffer and the file. class InMemoryCurrentWalHandler { public: - explicit InMemoryCurrentWalHandler(InMemoryStorage const *storage, rpc::Client &rpc_client); + explicit InMemoryCurrentWalHandler(const utils::UUID &main_uuid, InMemoryStorage const *storage, + rpc::Client &rpc_client); void AppendFilename(const std::string &filename); void AppendSize(size_t size); @@ -43,8 +45,9 @@ class InMemoryCurrentWalHandler { }; ////// CurrentWalHandler ////// -InMemoryCurrentWalHandler::InMemoryCurrentWalHandler(InMemoryStorage const *storage, rpc::Client &rpc_client) - : stream_(rpc_client.Stream<replication::CurrentWalRpc>(storage->uuid())) {} +InMemoryCurrentWalHandler::InMemoryCurrentWalHandler(const utils::UUID &main_uuid, InMemoryStorage const *storage, + rpc::Client &rpc_client) + : stream_(rpc_client.Stream<replication::CurrentWalRpc>(main_uuid, storage->uuid())) {} void InMemoryCurrentWalHandler::AppendFilename(const std::string &filename) { replication::Encoder encoder(stream_.GetBuilder()); @@ -69,10 +72,10 @@ void InMemoryCurrentWalHandler::AppendBufferData(const uint8_t *buffer, const si replication::CurrentWalRes InMemoryCurrentWalHandler::Finalize() { return stream_.AwaitResponse(); } ////// ReplicationClient Helpers ////// -replication::WalFilesRes TransferWalFiles(const utils::UUID &uuid, rpc::Client &client, +replication::WalFilesRes TransferWalFiles(const utils::UUID &main_uuid, const utils::UUID &uuid, rpc::Client &client, const std::vector<std::filesystem::path> &wal_files) { MG_ASSERT(!wal_files.empty(), "Wal files list is empty!"); - auto stream = client.Stream<replication::WalFilesRpc>(uuid, wal_files.size()); + auto stream = client.Stream<replication::WalFilesRpc>(main_uuid, uuid, wal_files.size()); replication::Encoder encoder(stream.GetBuilder()); for (const auto &wal : wal_files) { spdlog::debug("Sending wal file: {}", wal); @@ -81,16 +84,17 @@ replication::WalFilesRes TransferWalFiles(const utils::UUID &uuid, rpc::Client & return stream.AwaitResponse(); } -replication::SnapshotRes TransferSnapshot(const utils::UUID &uuid, rpc::Client &client, +replication::SnapshotRes TransferSnapshot(const utils::UUID &main_uuid, const utils::UUID &uuid, rpc::Client &client, const std::filesystem::path &path) { - auto stream = client.Stream<replication::SnapshotRpc>(uuid); + auto stream = client.Stream<replication::SnapshotRpc>(main_uuid, uuid); replication::Encoder encoder(stream.GetBuilder()); encoder.WriteFile(path); return stream.AwaitResponse(); } -uint64_t ReplicateCurrentWal(const InMemoryStorage *storage, rpc::Client &client, durability::WalFile const &wal_file) { - InMemoryCurrentWalHandler stream{storage, client}; +uint64_t ReplicateCurrentWal(const utils::UUID &main_uuid, const InMemoryStorage *storage, rpc::Client &client, + durability::WalFile const &wal_file) { + InMemoryCurrentWalHandler stream{main_uuid, storage, client}; stream.AppendFilename(wal_file.Path().filename()); utils::InputFile file; MG_ASSERT(file.Open(wal_file.Path()), "Failed to open current WAL file at {}!", wal_file.Path()); diff --git a/src/storage/v2/inmemory/replication/recovery.hpp b/src/storage/v2/inmemory/replication/recovery.hpp index 730822a62..07ebdd590 100644 --- a/src/storage/v2/inmemory/replication/recovery.hpp +++ b/src/storage/v2/inmemory/replication/recovery.hpp @@ -19,13 +19,14 @@ class InMemoryStorage; ////// ReplicationClient Helpers ////// -replication::WalFilesRes TransferWalFiles(const utils::UUID &uuid, rpc::Client &client, +replication::WalFilesRes TransferWalFiles(const utils::UUID &main_uuid, const utils::UUID &uuid, rpc::Client &client, const std::vector<std::filesystem::path> &wal_files); -replication::SnapshotRes TransferSnapshot(const utils::UUID &uuid, rpc::Client &client, +replication::SnapshotRes TransferSnapshot(const utils::UUID &main_uuid, const utils::UUID &uuid, rpc::Client &client, const std::filesystem::path &path); -uint64_t ReplicateCurrentWal(const InMemoryStorage *storage, rpc::Client &client, durability::WalFile const &wal_file); +uint64_t ReplicateCurrentWal(const utils::UUID &main_uuid, const InMemoryStorage *storage, rpc::Client &client, + durability::WalFile const &wal_file); auto GetRecoverySteps(uint64_t replica_commit, utils::FileRetainer::FileLocker *file_locker, const InMemoryStorage *storage) -> std::vector<RecoveryStep>; diff --git a/src/storage/v2/inmemory/storage.cpp b/src/storage/v2/inmemory/storage.cpp index c97d12072..0c7bde1a0 100644 --- a/src/storage/v2/inmemory/storage.cpp +++ b/src/storage/v2/inmemory/storage.cpp @@ -1847,6 +1847,7 @@ bool InMemoryStorage::AppendToWal(const Transaction &transaction, uint64_t final // A single transaction will always be contained in a single WAL file. auto current_commit_timestamp = transaction.commit_timestamp->load(std::memory_order_acquire); + //////// AF only this calls initialize transaction repl_storage_state_.InitializeTransaction(wal_file_->SequenceNumber(), this, db_acc); auto append_deltas = [&](auto callback) { diff --git a/src/storage/v2/replication/replication_client.cpp b/src/storage/v2/replication/replication_client.cpp index bbe9a9bb1..0c5ef8125 100644 --- a/src/storage/v2/replication/replication_client.cpp +++ b/src/storage/v2/replication/replication_client.cpp @@ -14,6 +14,7 @@ #include "storage/v2/storage.hpp" #include "utils/exceptions.hpp" #include "utils/on_scope_exit.hpp" +#include "utils/uuid.hpp" #include "utils/variant_helpers.hpp" #include <algorithm> @@ -25,8 +26,9 @@ template <typename> namespace memgraph::storage { -ReplicationStorageClient::ReplicationStorageClient(::memgraph::replication::ReplicationClient &client) - : client_{client} {} +ReplicationStorageClient::ReplicationStorageClient(::memgraph::replication::ReplicationClient &client, + utils::UUID main_uuid) + : client_{client}, main_uuid_(main_uuid) {} void ReplicationStorageClient::UpdateReplicaState(Storage *storage, DatabaseAccessProtector db_acc) { uint64_t current_commit_timestamp{kTimestampInitialId}; @@ -34,14 +36,13 @@ void ReplicationStorageClient::UpdateReplicaState(Storage *storage, DatabaseAcce auto &replStorageState = storage->repl_storage_state_; auto hb_stream{client_.rpc_client_.Stream<replication::HeartbeatRpc>( - storage->uuid(), replStorageState.last_commit_timestamp_, std::string{replStorageState.epoch_.id()})}; - + main_uuid_, storage->uuid(), replStorageState.last_commit_timestamp_, std::string{replStorageState.epoch_.id()})}; const auto replica = hb_stream.AwaitResponse(); #ifdef MG_ENTERPRISE // Multi-tenancy is only supported in enterprise if (!replica.success) { // Replica is missing the current database client_.state_.WithLock([&](auto &state) { - spdlog::debug("Replica '{}' missing database '{}' - '{}'", client_.name_, storage->name(), + spdlog::debug("Replica '{}' can't respond or missing database '{}' - '{}'", client_.name_, storage->name(), std::string{storage->uuid()}); state = memgraph::replication::ReplicationClient::State::BEHIND; }); @@ -95,7 +96,7 @@ TimestampInfo ReplicationStorageClient::GetTimestampInfo(Storage const *storage) info.current_number_of_timestamp_behind_master = 0; try { - auto stream{client_.rpc_client_.Stream<replication::TimestampRpc>(storage->uuid())}; + auto stream{client_.rpc_client_.Stream<replication::TimestampRpc>(main_uuid_, storage->uuid())}; const auto response = stream.AwaitResponse(); const auto is_success = response.success; @@ -173,7 +174,7 @@ void ReplicationStorageClient::StartTransactionReplication(const uint64_t curren case READY: MG_ASSERT(!replica_stream_); try { - replica_stream_.emplace(storage, client_.rpc_client_, current_wal_seq_num); + replica_stream_.emplace(storage, client_.rpc_client_, current_wal_seq_num, main_uuid_); *locked_state = REPLICATING; } catch (const rpc::RpcFailedException &) { *locked_state = MAYBE_BEHIND; @@ -183,6 +184,9 @@ void ReplicationStorageClient::StartTransactionReplication(const uint64_t curren } } +//////// AF: you can't finialize transaction replication if you are not replicating +/////// AF: if there is no stream or it is Defunct than we need to set replica in MAYBE_BEHIND -> is that even used +/////// AF: bool ReplicationStorageClient::FinalizeTransactionReplication(Storage *storage, DatabaseAccessProtector db_acc) { // We can only check the state because it guarantees to be only // valid during a single transaction replication (if the assumption @@ -256,36 +260,38 @@ void ReplicationStorageClient::RecoverReplica(uint64_t replica_commit, memgraph: spdlog::trace("Recovering in step: {}", i++); try { rpc::Client &rpcClient = client_.rpc_client_; - std::visit(utils::Overloaded{ - [&replica_commit, mem_storage, &rpcClient](RecoverySnapshot const &snapshot) { - spdlog::debug("Sending the latest snapshot file: {}", snapshot); - auto response = TransferSnapshot(mem_storage->uuid(), rpcClient, snapshot); - replica_commit = response.current_commit_timestamp; - }, - [&replica_commit, mem_storage, &rpcClient](RecoveryWals const &wals) { - spdlog::debug("Sending the latest wal files"); - auto response = TransferWalFiles(mem_storage->uuid(), rpcClient, wals); - replica_commit = response.current_commit_timestamp; - spdlog::debug("Wal files successfully transferred."); - }, - [&replica_commit, mem_storage, &rpcClient](RecoveryCurrentWal const ¤t_wal) { - std::unique_lock transaction_guard(mem_storage->engine_lock_); - if (mem_storage->wal_file_ && - mem_storage->wal_file_->SequenceNumber() == current_wal.current_wal_seq_num) { - utils::OnScopeExit on_exit([mem_storage]() { mem_storage->wal_file_->EnableFlushing(); }); - mem_storage->wal_file_->DisableFlushing(); - transaction_guard.unlock(); - spdlog::debug("Sending current wal file"); - replica_commit = ReplicateCurrentWal(mem_storage, rpcClient, *mem_storage->wal_file_); - } else { - spdlog::debug("Cannot recover using current wal file"); - } - }, - [](auto const &in) { - static_assert(always_false_v<decltype(in)>, "Missing type from variant visitor"); - }, - }, - recovery_step); + std::visit( + utils::Overloaded{ + [&replica_commit, mem_storage, &rpcClient, main_uuid = main_uuid_](RecoverySnapshot const &snapshot) { + spdlog::debug("Sending the latest snapshot file: {}", snapshot); + auto response = TransferSnapshot(main_uuid, mem_storage->uuid(), rpcClient, snapshot); + replica_commit = response.current_commit_timestamp; + }, + [&replica_commit, mem_storage, &rpcClient, main_uuid = main_uuid_](RecoveryWals const &wals) { + spdlog::debug("Sending the latest wal files"); + auto response = TransferWalFiles(main_uuid, mem_storage->uuid(), rpcClient, wals); + replica_commit = response.current_commit_timestamp; + spdlog::debug("Wal files successfully transferred."); + }, + [&replica_commit, mem_storage, &rpcClient, + main_uuid = main_uuid_](RecoveryCurrentWal const ¤t_wal) { + std::unique_lock transaction_guard(mem_storage->engine_lock_); + if (mem_storage->wal_file_ && + mem_storage->wal_file_->SequenceNumber() == current_wal.current_wal_seq_num) { + utils::OnScopeExit on_exit([mem_storage]() { mem_storage->wal_file_->EnableFlushing(); }); + mem_storage->wal_file_->DisableFlushing(); + transaction_guard.unlock(); + spdlog::debug("Sending current wal file"); + replica_commit = ReplicateCurrentWal(main_uuid, mem_storage, rpcClient, *mem_storage->wal_file_); + } else { + spdlog::debug("Cannot recover using current wal file"); + } + }, + [](auto const &in) { + static_assert(always_false_v<decltype(in)>, "Missing type from variant visitor"); + }, + }, + recovery_step); } catch (const rpc::RpcFailedException &) { replica_state_.WithLock([](auto &val) { val = replication::ReplicaState::MAYBE_BEHIND; }); LogRpcFailure(); @@ -314,10 +320,12 @@ void ReplicationStorageClient::RecoverReplica(uint64_t replica_commit, memgraph: } ////// ReplicaStream ////// -ReplicaStream::ReplicaStream(Storage *storage, rpc::Client &rpc_client, const uint64_t current_seq_num) +ReplicaStream::ReplicaStream(Storage *storage, rpc::Client &rpc_client, const uint64_t current_seq_num, + utils::UUID main_uuid) : storage_{storage}, stream_(rpc_client.Stream<replication::AppendDeltasRpc>( - storage->uuid(), storage->repl_storage_state_.last_commit_timestamp_.load(), current_seq_num)) { + main_uuid, storage->uuid(), storage->repl_storage_state_.last_commit_timestamp_.load(), current_seq_num)), + main_uuid_(main_uuid) { replication::Encoder encoder{stream_.GetBuilder()}; encoder.WriteString(storage->repl_storage_state_.epoch_.id()); } diff --git a/src/storage/v2/replication/replication_client.hpp b/src/storage/v2/replication/replication_client.hpp index fbcffe422..3352bab65 100644 --- a/src/storage/v2/replication/replication_client.hpp +++ b/src/storage/v2/replication/replication_client.hpp @@ -28,6 +28,7 @@ #include "utils/scheduler.hpp" #include "utils/synchronized.hpp" #include "utils/thread_pool.hpp" +#include "utils/uuid.hpp" #include <atomic> #include <concepts> @@ -48,7 +49,7 @@ class ReplicationStorageClient; // Handler used for transferring the current transaction. class ReplicaStream { public: - explicit ReplicaStream(Storage *storage, rpc::Client &rpc_client, uint64_t current_seq_num); + explicit ReplicaStream(Storage *storage, rpc::Client &rpc_client, uint64_t current_seq_num, utils::UUID main_uuid); /// @throw rpc::RpcFailedException void AppendDelta(const Delta &delta, const Vertex &vertex, uint64_t final_commit_timestamp); @@ -72,6 +73,7 @@ class ReplicaStream { private: Storage *storage_; rpc::Client::StreamHandler<replication::AppendDeltasRpc> stream_; + utils::UUID main_uuid_; }; template <typename F> @@ -84,7 +86,7 @@ class ReplicationStorageClient { friend struct ::memgraph::replication::ReplicationClient; public: - explicit ReplicationStorageClient(::memgraph::replication::ReplicationClient &client); + explicit ReplicationStorageClient(::memgraph::replication::ReplicationClient &client, utils::UUID main_uuid); ReplicationStorageClient(ReplicationStorageClient const &) = delete; ReplicationStorageClient &operator=(ReplicationStorageClient const &) = delete; @@ -202,6 +204,8 @@ class ReplicationStorageClient { replica_stream_; // Currently active stream (nullopt if not in use), note: a single stream per rpc client mutable utils::Synchronized<replication::ReplicaState, utils::SpinLock> replica_state_{ replication::ReplicaState::MAYBE_BEHIND}; + + const utils::UUID main_uuid_; }; } // namespace memgraph::storage diff --git a/src/storage/v2/replication/rpc.cpp b/src/storage/v2/replication/rpc.cpp index 59d1a02b9..f523bb5d7 100644 --- a/src/storage/v2/replication/rpc.cpp +++ b/src/storage/v2/replication/rpc.cpp @@ -114,10 +114,12 @@ void Load(memgraph::storage::replication::TimestampRes *self, memgraph::slk::Rea // Serialize code for TimestampReq void Save(const memgraph::storage::replication::TimestampReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.uuid, builder); } void Load(memgraph::storage::replication::TimestampReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->uuid, reader); } @@ -136,10 +138,12 @@ void Load(memgraph::storage::replication::CurrentWalRes *self, memgraph::slk::Re // Serialize code for CurrentWalReq void Save(const memgraph::storage::replication::CurrentWalReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.uuid, builder); } void Load(memgraph::storage::replication::CurrentWalReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->uuid, reader); } @@ -158,11 +162,13 @@ void Load(memgraph::storage::replication::WalFilesRes *self, memgraph::slk::Read // Serialize code for WalFilesReq void Save(const memgraph::storage::replication::WalFilesReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.uuid, builder); memgraph::slk::Save(self.file_number, builder); } void Load(memgraph::storage::replication::WalFilesReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->uuid, reader); memgraph::slk::Load(&self->file_number, reader); } @@ -182,10 +188,12 @@ void Load(memgraph::storage::replication::SnapshotRes *self, memgraph::slk::Read // Serialize code for SnapshotReq void Save(const memgraph::storage::replication::SnapshotReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.uuid, builder); } void Load(memgraph::storage::replication::SnapshotReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->uuid, reader); } @@ -206,12 +214,14 @@ void Load(memgraph::storage::replication::HeartbeatRes *self, memgraph::slk::Rea // Serialize code for HeartbeatReq void Save(const memgraph::storage::replication::HeartbeatReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.uuid, builder); memgraph::slk::Save(self.main_commit_timestamp, builder); memgraph::slk::Save(self.epoch_id, builder); } void Load(memgraph::storage::replication::HeartbeatReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->uuid, reader); memgraph::slk::Load(&self->main_commit_timestamp, reader); memgraph::slk::Load(&self->epoch_id, reader); @@ -232,12 +242,14 @@ void Load(memgraph::storage::replication::AppendDeltasRes *self, memgraph::slk:: // Serialize code for AppendDeltasReq void Save(const memgraph::storage::replication::AppendDeltasReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_uuid, builder); memgraph::slk::Save(self.uuid, builder); memgraph::slk::Save(self.previous_commit_timestamp, builder); memgraph::slk::Save(self.seq_num, builder); } void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_uuid, reader); memgraph::slk::Load(&self->uuid, reader); memgraph::slk::Load(&self->previous_commit_timestamp, reader); memgraph::slk::Load(&self->seq_num, reader); diff --git a/src/storage/v2/replication/rpc.hpp b/src/storage/v2/replication/rpc.hpp index 9c9f5c285..67f98d0ae 100644 --- a/src/storage/v2/replication/rpc.hpp +++ b/src/storage/v2/replication/rpc.hpp @@ -32,9 +32,11 @@ struct AppendDeltasReq { static void Load(AppendDeltasReq *self, memgraph::slk::Reader *reader); static void Save(const AppendDeltasReq &self, memgraph::slk::Builder *builder); AppendDeltasReq() = default; - AppendDeltasReq(const utils::UUID &uuid, uint64_t previous_commit_timestamp, uint64_t seq_num) - : uuid{uuid}, previous_commit_timestamp(previous_commit_timestamp), seq_num(seq_num) {} + AppendDeltasReq(const utils::UUID &main_uuid, const utils::UUID &uuid, uint64_t previous_commit_timestamp, + uint64_t seq_num) + : main_uuid{main_uuid}, uuid{uuid}, previous_commit_timestamp(previous_commit_timestamp), seq_num(seq_num) {} + utils::UUID main_uuid; utils::UUID uuid; uint64_t previous_commit_timestamp; uint64_t seq_num; @@ -63,9 +65,11 @@ struct HeartbeatReq { static void Load(HeartbeatReq *self, memgraph::slk::Reader *reader); static void Save(const HeartbeatReq &self, memgraph::slk::Builder *builder); HeartbeatReq() = default; - HeartbeatReq(const utils::UUID &uuid, uint64_t main_commit_timestamp, std::string epoch_id) - : uuid{uuid}, main_commit_timestamp(main_commit_timestamp), epoch_id(std::move(epoch_id)) {} + HeartbeatReq(const utils::UUID &main_uuid, const utils::UUID &uuid, uint64_t main_commit_timestamp, + std::string epoch_id) + : main_uuid(main_uuid), uuid{uuid}, main_commit_timestamp(main_commit_timestamp), epoch_id(std::move(epoch_id)) {} + utils::UUID main_uuid; utils::UUID uuid; uint64_t main_commit_timestamp; std::string epoch_id; @@ -95,8 +99,9 @@ struct SnapshotReq { static void Load(SnapshotReq *self, memgraph::slk::Reader *reader); static void Save(const SnapshotReq &self, memgraph::slk::Builder *builder); SnapshotReq() = default; - explicit SnapshotReq(const utils::UUID &uuid) : uuid{uuid} {} + explicit SnapshotReq(const utils::UUID &main_uuid, const utils::UUID &uuid) : main_uuid{main_uuid}, uuid{uuid} {} + utils::UUID main_uuid; utils::UUID uuid; }; @@ -123,8 +128,10 @@ struct WalFilesReq { static void Load(WalFilesReq *self, memgraph::slk::Reader *reader); static void Save(const WalFilesReq &self, memgraph::slk::Builder *builder); WalFilesReq() = default; - explicit WalFilesReq(const utils::UUID &uuid, uint64_t file_number) : uuid{uuid}, file_number(file_number) {} + explicit WalFilesReq(const utils::UUID &main_uuid, const utils::UUID &uuid, uint64_t file_number) + : main_uuid{main_uuid}, uuid{uuid}, file_number(file_number) {} + utils::UUID main_uuid; utils::UUID uuid; uint64_t file_number; }; @@ -152,8 +159,9 @@ struct CurrentWalReq { static void Load(CurrentWalReq *self, memgraph::slk::Reader *reader); static void Save(const CurrentWalReq &self, memgraph::slk::Builder *builder); CurrentWalReq() = default; - explicit CurrentWalReq(const utils::UUID &uuid) : uuid{uuid} {} + explicit CurrentWalReq(const utils::UUID &main_uuid, const utils::UUID &uuid) : main_uuid(main_uuid), uuid{uuid} {} + utils::UUID main_uuid; utils::UUID uuid; }; @@ -180,8 +188,9 @@ struct TimestampReq { static void Load(TimestampReq *self, memgraph::slk::Reader *reader); static void Save(const TimestampReq &self, memgraph::slk::Builder *builder); TimestampReq() = default; - explicit TimestampReq(const utils::UUID &uuid) : uuid{uuid} {} + explicit TimestampReq(const utils::UUID &main_uuid, const utils::UUID &uuid) : main_uuid(main_uuid), uuid{uuid} {} + utils::UUID main_uuid; utils::UUID uuid; }; diff --git a/src/system/action.cpp b/src/system/action.cpp index 00e1c1be9..d5ea29ea6 100644 --- a/src/system/action.cpp +++ b/src/system/action.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/system/include/system/action.hpp b/src/system/include/system/action.hpp index 77f4cb3e8..3cfe58ffd 100644 --- a/src/system/include/system/action.hpp +++ b/src/system/include/system/action.hpp @@ -27,7 +27,7 @@ struct ISystemAction { virtual void DoDurability() = 0; /// Prepare the RPC payload that will be sent to all replicas clients - virtual bool DoReplication(memgraph::replication::ReplicationClient &client, + virtual bool DoReplication(memgraph::replication::ReplicationClient &client, const utils::UUID &main_uuid, memgraph::replication::ReplicationEpoch const &epoch, Transaction const &system_tx) const = 0; diff --git a/src/system/include/system/transaction.hpp b/src/system/include/system/transaction.hpp index af03fe434..e30752eaa 100644 --- a/src/system/include/system/transaction.hpp +++ b/src/system/include/system/transaction.hpp @@ -99,7 +99,7 @@ struct DoReplication { auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed; for (auto &client : main_data_.registered_replicas_) { - bool completed = action.DoReplication(client, main_data_.epoch_, system_tx); + bool completed = action.DoReplication(client, main_data_.uuid_, main_data_.epoch_, system_tx); if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) { sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed; } diff --git a/src/utils/typeinfo.hpp b/src/utils/typeinfo.hpp index 6919e8e5c..1640a70f7 100644 --- a/src/utils/typeinfo.hpp +++ b/src/utils/typeinfo.hpp @@ -97,12 +97,16 @@ enum class TypeId : uint64_t { REP_UPDATE_AUTH_DATA_RES, REP_DROP_AUTH_DATA_REQ, REP_DROP_AUTH_DATA_RES, + REP_TRY_SET_MAIN_UUID_REQ, + REP_TRY_SET_MAIN_UUID_RES, // Coordinator COORD_FAILOVER_REQ, COORD_FAILOVER_RES, COORD_SET_REPL_MAIN_REQ, COORD_SET_REPL_MAIN_RES, + COORD_SWAP_UUID_REQ, + COORD_SWAP_UUID_RES, // AST AST_LABELIX = 3000, diff --git a/tests/e2e/high_availability_experimental/CMakeLists.txt b/tests/e2e/high_availability_experimental/CMakeLists.txt index e587d6fef..f22e24f43 100644 --- a/tests/e2e/high_availability_experimental/CMakeLists.txt +++ b/tests/e2e/high_availability_experimental/CMakeLists.txt @@ -3,6 +3,7 @@ find_package(gflags REQUIRED) copy_e2e_python_files(ha_experimental coordinator.py) copy_e2e_python_files(ha_experimental automatic_failover.py) copy_e2e_python_files(ha_experimental manual_setting_replicas.py) +copy_e2e_python_files(ha_experimental not_replicate_from_old_main.py) copy_e2e_python_files(ha_experimental common.py) copy_e2e_python_files(ha_experimental workloads.yaml) diff --git a/tests/e2e/high_availability_experimental/automatic_failover.py b/tests/e2e/high_availability_experimental/automatic_failover.py index 5c43ac62d..23b462f45 100644 --- a/tests/e2e/high_availability_experimental/automatic_failover.py +++ b/tests/e2e/high_availability_experimental/automatic_failover.py @@ -13,6 +13,7 @@ import os import shutil import sys import tempfile +import time import interactive_mg_runner import pytest @@ -131,6 +132,7 @@ def test_replication_works_on_failover(): mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") + expected_data_on_new_main = [ ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), ("instance_3", "127.0.0.1:10003", "sync", 0, 0, "ready"), @@ -141,8 +143,8 @@ def test_replication_works_on_failover(): execute_and_fetch_all(new_main_cursor, "CREATE ();") # 6 - alive_replica_cursror = connect(host="localhost", port=7689).cursor() - res = execute_and_fetch_all(alive_replica_cursror, "MATCH (n) RETURN count(n) as count;")[0][0] + alive_replica_cursor = connect(host="localhost", port=7689).cursor() + res = execute_and_fetch_all(alive_replica_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] assert res == 1, "Vertex should be replicated" interactive_mg_runner.stop_all(MEMGRAPH_INSTANCES_DESCRIPTION) @@ -344,65 +346,60 @@ def test_automatic_failover_main_back_as_replica(): mg_sleep_and_assert([("replica",)], retrieve_data_show_repl_role_instance3) -def test_automatic_failover_main_back_as_main(): +def test_replica_instance_restarts_replication_works(): safe_execute(shutil.rmtree, TEMP_DIR) interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) - interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") - interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_2") - interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") + cursor = connect(host="localhost", port=7690).cursor() - coord_cursor = connect(host="localhost", port=7690).cursor() + def show_repl_cluster(): + return sorted(list(execute_and_fetch_all(cursor, "SHOW REPLICATION CLUSTER;"))) - def retrieve_data_show_repl_cluster(): - return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW REPLICATION CLUSTER;"))) - - expected_data_all_down = [ - ("instance_1", "127.0.0.1:10011", False, "unknown"), - ("instance_2", "127.0.0.1:10012", False, "unknown"), - ("instance_3", "127.0.0.1:10013", False, "unknown"), - ] - - mg_sleep_and_assert(expected_data_all_down, retrieve_data_show_repl_cluster) - - interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") - expected_data_main_back = [ - ("instance_1", "127.0.0.1:10011", False, "unknown"), - ("instance_2", "127.0.0.1:10012", False, "unknown"), - ("instance_3", "127.0.0.1:10013", True, "main"), - ] - mg_sleep_and_assert(expected_data_main_back, retrieve_data_show_repl_cluster) - - instance3_cursor = connect(host="localhost", port=7687).cursor() - - def retrieve_data_show_repl_role_instance3(): - return sorted(list(execute_and_fetch_all(instance3_cursor, "SHOW REPLICATION ROLE;"))) - - mg_sleep_and_assert([("main",)], retrieve_data_show_repl_role_instance3) - - interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") - interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_2") - - expected_data_replicas_back = [ + expected_data_up = [ ("instance_1", "127.0.0.1:10011", True, "replica"), ("instance_2", "127.0.0.1:10012", True, "replica"), ("instance_3", "127.0.0.1:10013", True, "main"), ] + mg_sleep_and_assert(expected_data_up, show_repl_cluster) - mg_sleep_and_assert(expected_data_replicas_back, retrieve_data_show_repl_cluster) + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") + expected_data_down = [ + ("instance_1", "127.0.0.1:10011", False, "unknown"), + ("instance_2", "127.0.0.1:10012", True, "replica"), + ("instance_3", "127.0.0.1:10013", True, "main"), + ] + mg_sleep_and_assert(expected_data_down, show_repl_cluster) + + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") + + mg_sleep_and_assert(expected_data_up, show_repl_cluster) + + expected_data_on_main_show_replicas = [ + ("instance_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), + ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), + ] + instance3_cursor = connect(host="localhost", port=7687).cursor() instance1_cursor = connect(host="localhost", port=7688).cursor() - instance2_cursor = connect(host="localhost", port=7689).cursor() + + def retrieve_data_show_repl_role_instance1(): + return sorted(list(execute_and_fetch_all(instance3_cursor, "SHOW REPLICAS;"))) + + mg_sleep_and_assert(expected_data_on_main_show_replicas, retrieve_data_show_repl_role_instance1) def retrieve_data_show_repl_role_instance1(): return sorted(list(execute_and_fetch_all(instance1_cursor, "SHOW REPLICATION ROLE;"))) - def retrieve_data_show_repl_role_instance2(): - return sorted(list(execute_and_fetch_all(instance2_cursor, "SHOW REPLICATION ROLE;"))) + expected_data_replica = [("replica",)] + mg_sleep_and_assert(expected_data_replica, retrieve_data_show_repl_role_instance1) - mg_sleep_and_assert([("replica",)], retrieve_data_show_repl_role_instance1) - mg_sleep_and_assert([("replica",)], retrieve_data_show_repl_role_instance2) - mg_sleep_and_assert([("main",)], retrieve_data_show_repl_role_instance3) + execute_and_fetch_all(instance3_cursor, "CREATE ();") + + def retrieve_data_replica(): + return execute_and_fetch_all(instance1_cursor, "MATCH (n) RETURN count(n);")[0][0] + + expected_data_replica = 1 + mg_sleep_and_assert(expected_data_replica, retrieve_data_replica) if __name__ == "__main__": diff --git a/tests/e2e/high_availability_experimental/not_replicate_from_old_main.py b/tests/e2e/high_availability_experimental/not_replicate_from_old_main.py new file mode 100644 index 000000000..d6f6f7da4 --- /dev/null +++ b/tests/e2e/high_availability_experimental/not_replicate_from_old_main.py @@ -0,0 +1,117 @@ +# Copyright 2024 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import os +import sys + +import interactive_mg_runner +import pytest +from common import execute_and_fetch_all +from mg_utils import mg_sleep_and_assert + +interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +interactive_mg_runner.PROJECT_DIR = os.path.normpath( + os.path.join(interactive_mg_runner.SCRIPT_DIR, "..", "..", "..", "..") +) +interactive_mg_runner.BUILD_DIR = os.path.normpath(os.path.join(interactive_mg_runner.PROJECT_DIR, "build")) +interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactive_mg_runner.BUILD_DIR, "memgraph")) + +MEMGRAPH_FIRST_CLUSTER_DESCRIPTION = { + "shared_replica": { + "args": ["--bolt-port", "7688", "--log-level", "TRACE"], + "log_file": "replica2.log", + "setup_queries": ["SET REPLICATION ROLE TO REPLICA WITH PORT 10001;"], + }, + "main1": { + "args": ["--bolt-port", "7687", "--log-level", "TRACE"], + "log_file": "main.log", + "setup_queries": ["REGISTER REPLICA shared_replica SYNC TO '127.0.0.1:10001' ;"], + }, +} + + +MEMGRAPH_INSTANCES_DESCRIPTION = { + "replica": { + "args": ["--bolt-port", "7689", "--log-level", "TRACE"], + "log_file": "replica.log", + "setup_queries": ["SET REPLICATION ROLE TO REPLICA WITH PORT 10002;"], + }, + "main_2": { + "args": ["--bolt-port", "7690", "--log-level", "TRACE"], + "log_file": "main_2.log", + "setup_queries": [ + "REGISTER REPLICA shared_replica SYNC TO '127.0.0.1:10001' ;", + "REGISTER REPLICA replica SYNC TO '127.0.0.1:10002' ; ", + ], + }, +} + + +def test_replication_works_on_failover(connection): + # Goal of this test is to check that after changing `shared_replica` + # to be part of new cluster, `main` (old cluster) can't write any more to it + + # 1 + interactive_mg_runner.start_all_keep_others(MEMGRAPH_FIRST_CLUSTER_DESCRIPTION) + + # 2 + main_cursor = connection(7687, "main1").cursor() + expected_data_on_main = [ + ("shared_replica", "127.0.0.1:10001", "sync", 0, 0, "ready"), + ] + actual_data_on_main = sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) + assert actual_data_on_main == expected_data_on_main + + # 3 + interactive_mg_runner.start_all_keep_others(MEMGRAPH_INSTANCES_DESCRIPTION) + + # 4 + new_main_cursor = connection(7690, "main_2").cursor() + + def retrieve_data_show_replicas(): + return sorted(list(execute_and_fetch_all(new_main_cursor, "SHOW REPLICAS;"))) + + expected_data_on_new_main = [ + ("replica", "127.0.0.1:10002", "sync", 0, 0, "ready"), + ("shared_replica", "127.0.0.1:10001", "sync", 0, 0, "ready"), + ] + mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) + + # 5 + shared_replica_cursor = connection(7688, "shared_replica").cursor() + + with pytest.raises(Exception) as e: + execute_and_fetch_all(main_cursor, "CREATE ();") + assert ( + str(e.value) + == "Replication Exception: At least one SYNC replica has not confirmed committing last transaction. Check the status of the replicas using 'SHOW REPLICAS' query." + ) + + res = execute_and_fetch_all(main_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + assert res == 1, "Vertex should be created" + + res = execute_and_fetch_all(shared_replica_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + assert res == 0, "Vertex shouldn't be replicated" + + # 7 + execute_and_fetch_all(new_main_cursor, "CREATE ();") + + res = execute_and_fetch_all(new_main_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + assert res == 1, "Vertex should be created" + + res = execute_and_fetch_all(shared_replica_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + assert res == 1, "Vertex should be replicated" + + interactive_mg_runner.stop_all() + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/high_availability_experimental/workloads.yaml b/tests/e2e/high_availability_experimental/workloads.yaml index 1d692084a..8b617dfb5 100644 --- a/tests/e2e/high_availability_experimental/workloads.yaml +++ b/tests/e2e/high_availability_experimental/workloads.yaml @@ -35,3 +35,7 @@ workloads: - name: "Disabled manual setting of replication cluster" binary: "tests/e2e/pytest_runner.sh" args: ["high_availability_experimental/manual_setting_replicas.py"] + + - name: "Not replicate from old main" + binary: "tests/e2e/pytest_runner.sh" + args: ["high_availability_experimental/not_replicate_from_old_main.py"] diff --git a/tests/e2e/interactive_mg_runner.py b/tests/e2e/interactive_mg_runner.py index 93bfc5fe6..06908747e 100755 --- a/tests/e2e/interactive_mg_runner.py +++ b/tests/e2e/interactive_mg_runner.py @@ -208,6 +208,11 @@ def start_all(context, procdir="", keep_directories=True): start_instance(context, key, procdir) +def start_all_keep_others(context, procdir="", keep_directories=True): + for key, _ in context.items(): + start_instance(context, key, procdir) + + def start(context, name, procdir=""): if name != "all": start_instance(context, name, procdir) diff --git a/tests/unit/replication_persistence_helper.cpp b/tests/unit/replication_persistence_helper.cpp index ef3ba254d..aeb58441e 100644 --- a/tests/unit/replication_persistence_helper.cpp +++ b/tests/unit/replication_persistence_helper.cpp @@ -13,6 +13,7 @@ #include "replication/state.hpp" #include "replication/status.hpp" #include "utils/logging.hpp" +#include "utils/uuid.hpp" #include <gtest/gtest.h> #include <fstream> @@ -48,6 +49,17 @@ TEST(ReplicationDurability, V2Main) { ASSERT_EQ(role_entry, deser); } +TEST(ReplicationDurability, V3Main) { + auto const role_entry = ReplicationRoleEntry{ + .version = DurabilityVersion::V3, + .role = MainRole{.epoch = ReplicationEpoch{"TEST_STRING"}, .main_uuid = memgraph::utils::UUID{}}}; + nlohmann::json j; + to_json(j, role_entry); + ReplicationRoleEntry deser; + from_json(j, deser); + ASSERT_EQ(role_entry, deser); +} + TEST(ReplicationDurability, V1Replica) { auto const role_entry = ReplicationRoleEntry{.version = DurabilityVersion::V1, @@ -74,6 +86,33 @@ TEST(ReplicationDurability, V2Replica) { ASSERT_EQ(role_entry, deser); } +TEST(ReplicationDurability, V3ReplicaNoMain) { + auto const role_entry = + ReplicationRoleEntry{.version = DurabilityVersion::V3, + .role = ReplicaRole{ + .config = ReplicationServerConfig{.ip_address = "000.123.456.789", .port = 2023}, + }}; + nlohmann::json j; + to_json(j, role_entry); + ReplicationRoleEntry deser; + from_json(j, deser); + ASSERT_EQ(role_entry, deser); +} + +TEST(ReplicationDurability, V3ReplicaMain) { + auto const role_entry = + ReplicationRoleEntry{.version = DurabilityVersion::V2, + .role = ReplicaRole{ + .config = ReplicationServerConfig{.ip_address = "000.123.456.789", .port = 2023}, + .main_uuid = memgraph::utils::UUID{}, + }}; + nlohmann::json j; + to_json(j, role_entry); + ReplicationRoleEntry deser; + from_json(j, deser); + ASSERT_EQ(role_entry, deser); +} + TEST(ReplicationDurability, ReplicaEntrySync) { using namespace std::chrono_literals; using namespace std::string_literals; diff --git a/tests/unit/storage_v2_replication.cpp b/tests/unit/storage_v2_replication.cpp index b2adf3588..8c8536a84 100644 --- a/tests/unit/storage_v2_replication.cpp +++ b/tests/unit/storage_v2_replication.cpp @@ -142,17 +142,21 @@ TEST_F(ReplicationTest, BasicSynchronousReplicationTest) { MinMemgraph replica(repl_conf); auto replica_store_handler = replica.repl_handler; - replica_store_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[0], - }); + replica_store_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[0], + }, + std::nullopt); - const auto ® = main.repl_handler.TryRegisterReplica(ReplicationClientConfig{ - .name = "REPLICA", - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }); + const auto ® = main.repl_handler.TryRegisterReplica( + ReplicationClientConfig{ + .name = "REPLICA", + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }, + true); ASSERT_FALSE(reg.HasError()) << (int)reg.GetError(); // vertex create @@ -435,30 +439,38 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) { MinMemgraph replica1(repl_conf); MinMemgraph replica2(repl2_conf); - replica1.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[0], - }); - replica2.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[1], - }); + replica1.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[0], + }, + std::nullopt); + replica2.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[1], + }, + std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }, + true) .HasError()); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[1], - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[1], + }, + true) .HasError()); const auto *vertex_label = "label"; @@ -585,17 +597,21 @@ TEST_F(ReplicationTest, RecoveryProcess) { MinMemgraph replica(repl_conf); auto replica_store_handler = replica.repl_handler; - replica_store_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[0], - }); + replica_store_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[0], + }, + std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }, + true) .HasError()); ASSERT_EQ(main.db.storage()->GetReplicaState(replicas[0]), ReplicaState::RECOVERY); @@ -660,18 +676,22 @@ TEST_F(ReplicationTest, BasicAsynchronousReplicationTest) { MinMemgraph replica_async(repl_conf); auto replica_store_handler = replica_async.repl_handler; - replica_store_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[1], - }); + replica_store_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[1], + }, + std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = "REPLICA_ASYNC", - .mode = ReplicationMode::ASYNC, - .ip_address = local_host, - .port = ports[1], - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = "REPLICA_ASYNC", + .mode = ReplicationMode::ASYNC, + .ip_address = local_host, + .port = ports[1], + }, + true) .HasError()); static constexpr size_t vertices_create_num = 10; @@ -706,33 +726,41 @@ TEST_F(ReplicationTest, EpochTest) { MinMemgraph main(main_conf); MinMemgraph replica1(repl_conf); - replica1.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[0], - }); + replica1.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[0], + }, + std::nullopt); MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = 10001, - }); + replica2.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = 10001, + }, + std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }, + true) .HasError()); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = 10001, - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = 10001, + }, + true) .HasError()); std::optional<Gid> vertex_gid; @@ -761,12 +789,14 @@ TEST_F(ReplicationTest, EpochTest) { ASSERT_TRUE(replica1.repl_handler.SetReplicationRoleMain()); ASSERT_FALSE(replica1.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = 10001, - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = 10001, + }, + true) .HasError()); @@ -789,17 +819,21 @@ TEST_F(ReplicationTest, EpochTest) { ASSERT_FALSE(acc->Commit().HasError()); } - replica1.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[0], - }); + replica1.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[0], + }, + std::nullopt); ASSERT_TRUE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }, + true) .HasError()); @@ -824,35 +858,43 @@ TEST_F(ReplicationTest, ReplicationInformation) { MinMemgraph replica1(repl_conf); uint16_t replica1_port = 10001; - replica1.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = replica1_port, - }); + replica1.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = replica1_port, + }, + std::nullopt); uint16_t replica2_port = 10002; MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = replica2_port, - }); + replica2.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = replica2_port, + }, + std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = replica1_port, - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = replica1_port, + }, + true) .HasError()); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::ASYNC, - .ip_address = local_host, - .port = replica2_port, - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::ASYNC, + .ip_address = local_host, + .port = replica2_port, + }, + true) .HasError()); @@ -881,33 +923,41 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) { MinMemgraph replica1(repl_conf); uint16_t replica1_port = 10001; - replica1.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = replica1_port, - }); + replica1.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = replica1_port, + }, + std::nullopt); uint16_t replica2_port = 10002; MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = replica2_port, - }); + replica2.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = replica2_port, + }, + std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = replica1_port, - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = replica1_port, + }, + true) .HasError()); ASSERT_TRUE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::ASYNC, - .ip_address = local_host, - .port = replica2_port, - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::ASYNC, + .ip_address = local_host, + .port = replica2_port, + }, + true) .GetError() == RegisterReplicaError::NAME_EXISTS); } @@ -916,33 +966,41 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) { MinMemgraph main(main_conf); MinMemgraph replica1(repl_conf); - replica1.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = common_port, - }); + replica1.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = common_port, + }, + std::nullopt); MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = common_port, - }); + replica2.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = common_port, + }, + std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = common_port, - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = common_port, + }, + true) .HasError()); ASSERT_TRUE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::ASYNC, - .ip_address = local_host, - .port = common_port, - }) + .TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::ASYNC, + .ip_address = local_host, + .port = common_port, + }, + true) .GetError() == RegisterReplicaError::ENDPOINT_EXISTS); } @@ -965,30 +1023,38 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartupAfterDroppingReplica) { std::optional<MinMemgraph> main(main_config); MinMemgraph replica1(replica1_config); - replica1.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[0], - }); + replica1.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[0], + }, + std::nullopt); MinMemgraph replica2(replica2_config); - replica2.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[1], - }); + replica2.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[1], + }, + std::nullopt); - auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }); + auto res = main->repl_handler.TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }, + true); ASSERT_FALSE(res.HasError()) << (int)res.GetError(); - res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[1], - }); + res = main->repl_handler.TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[1], + }, + true); ASSERT_FALSE(res.HasError()) << (int)res.GetError(); auto replica_infos = main->db.storage()->ReplicasInfo(); @@ -1022,30 +1088,38 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartup) { std::optional<MinMemgraph> main(main_config); MinMemgraph replica1(repl_conf); - replica1.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[0], - }); + replica1.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[0], + }, + std::nullopt); MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica(ReplicationServerConfig{ - .ip_address = local_host, - .port = ports[1], - }); - auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }); + replica2.repl_handler.SetReplicationRoleReplica( + ReplicationServerConfig{ + .ip_address = local_host, + .port = ports[1], + }, + std::nullopt); + auto res = main->repl_handler.TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }, + true); ASSERT_FALSE(res.HasError()); - res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[1], - }); + res = main->repl_handler.TryRegisterReplica( + ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[1], + }, + true); ASSERT_FALSE(res.HasError()); auto replica_infos = main->db.storage()->ReplicasInfo(); @@ -1083,11 +1157,13 @@ TEST_F(ReplicationTest, AddingInvalidReplica) { MinMemgraph main(main_conf); ASSERT_TRUE(main.repl_handler - .TryRegisterReplica(ReplicationClientConfig{ - .name = "REPLICA", - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }) - .GetError() == RegisterReplicaError::CONNECTION_FAILED); + .TryRegisterReplica( + ReplicationClientConfig{ + .name = "REPLICA", + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }, + true) + .GetError() == RegisterReplicaError::ERROR_ACCEPTING_MAIN); } From 2fa8e001246c434f60ac24190e39c2e4171c73bd Mon Sep 17 00:00:00 2001 From: Aidar Samerkhanov <aidar.samerkhanov@memgraph.io> Date: Thu, 8 Feb 2024 09:48:54 +0300 Subject: [PATCH 3/4] Fix accumulated path evaluation in builtin algorithms. (#1642) Fix accumulated path evaluation in DFS, BFS, WeghtedShortestPath and AllShortestPath algorithm. --- src/query/path.hpp | 15 +- src/query/plan/operator.cpp | 154 ++++++++++++------ .../tests/memgraph_V1/features/match.feature | 20 +++ .../features/memgraph_allshortest.feature | 16 +- .../memgraph_V1/features/memgraph_bfs.feature | 36 ++++ .../features/memgraph_wshortest.feature | 14 ++ .../memgraph_V1/graphs/graph_edges.cypher | 1 + 7 files changed, 204 insertions(+), 52 deletions(-) diff --git a/src/query/path.hpp b/src/query/path.hpp index 43ecb8953..d9685bb4e 100644 --- a/src/query/path.hpp +++ b/src/query/path.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -114,12 +114,17 @@ class Path { /** Expands the path with the given vertex. */ void Expand(const VertexAccessor &vertex) { DMG_ASSERT(vertices_.size() == edges_.size(), "Illegal path construction order"); + DMG_ASSERT(edges_.empty() || (!edges_.empty() && (edges_.back().To().Gid() == vertex.Gid() || + edges_.back().From().Gid() == vertex.Gid())), + "Illegal path construction order"); vertices_.emplace_back(vertex); } /** Expands the path with the given edge. */ void Expand(const EdgeAccessor &edge) { DMG_ASSERT(vertices_.size() - 1 == edges_.size(), "Illegal path construction order"); + DMG_ASSERT(vertices_.back().Gid() == edge.From().Gid() || vertices_.back().Gid() == edge.To().Gid(), + "Illegal path construction order"); edges_.emplace_back(edge); } @@ -130,6 +135,14 @@ class Path { Expand(others...); } + void Shrink() { + DMG_ASSERT(!vertices_.empty(), "Vertices should not be empty in the path before shrink."); + vertices_.pop_back(); + if (!edges_.empty()) { + edges_.pop_back(); + } + } + /** Returns the number of expansions (edges) in this path. */ auto size() const { return edges_.size(); } diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 82269ca27..8dfaac81f 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -1057,7 +1057,8 @@ class ExpandVariableCursor : public Cursor { if (!self_.common_.existing_node) { frame[self_.common_.node_symbol] = start_vertex; return true; - } else if (CheckExistingNode(start_vertex, self_.common_.node_symbol, frame)) { + } + if (CheckExistingNode(start_vertex, self_.common_.node_symbol, frame)) { return true; } } @@ -1243,6 +1244,10 @@ class ExpandVariableCursor : public Cursor { MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(), "Accumulated path must be path"); Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath(); + // Shrink the accumulated path including current level if necessary + while (accumulated_path.size() >= edges_on_frame.size()) { + accumulated_path.Shrink(); + } accumulated_path.Expand(current_edge.first); accumulated_path.Expand(current_vertex); } @@ -1260,10 +1265,9 @@ class ExpandVariableCursor : public Cursor { if (self_.common_.existing_node && !CheckExistingNode(current_vertex, self_.common_.node_symbol, frame)) continue; // We only yield true if we satisfy the lower bound. - if (static_cast<int64_t>(edges_on_frame.size()) >= lower_bound_) + if (static_cast<int64_t>(edges_on_frame.size()) >= lower_bound_) { return true; - else - continue; + } } } }; @@ -1527,8 +1531,8 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { : self_(self), input_cursor_(self_.input()->MakeCursor(mem)), processed_(mem), - to_visit_current_(mem), - to_visit_next_(mem) { + to_visit_next_(mem), + to_visit_current_(mem) { MG_ASSERT(!self_.common_.existing_node, "Single source shortest path algorithm " "should not be used when `existing_node` " @@ -1558,12 +1562,14 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { #endif frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_node_symbol] = vertex; + std::optional<Path> curr_acc_path = std::nullopt; if (self_.filter_lambda_.accumulated_path_symbol) { MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(), "Accumulated path must have Path type"); Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath(); accumulated_path.Expand(edge); accumulated_path.Expand(vertex); + curr_acc_path = accumulated_path; } if (self_.filter_lambda_.expression) { @@ -1578,21 +1584,33 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { throw QueryRuntimeException("Expansion condition must evaluate to boolean or null."); } } - to_visit_next_.emplace_back(edge, vertex); + to_visit_next_.emplace_back(edge, vertex, std::move(curr_acc_path)); processed_.emplace(vertex, edge); }; + auto restore_frame_state_after_expansion = [this, &frame]() { + if (self_.filter_lambda_.accumulated_path_symbol) { + frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink(); + } + }; + // populates the to_visit_next_ structure with expansions // from the given vertex. skips expansions that don't satisfy // the "where" condition. - auto expand_from_vertex = [this, &expand_pair](const auto &vertex) { + auto expand_from_vertex = [this, &expand_pair, &restore_frame_state_after_expansion](const auto &vertex) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; - for (const auto &edge : out_edges) expand_pair(edge, edge.To()); + for (const auto &edge : out_edges) { + expand_pair(edge, edge.To()); + restore_frame_state_after_expansion(); + } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; - for (const auto &edge : in_edges) expand_pair(edge, edge.From()); + for (const auto &edge : in_edges) { + expand_pair(edge, edge.From()); + restore_frame_state_after_expansion(); + } } }; @@ -1638,14 +1656,14 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { } // take the next expansion from the queue - auto expansion = to_visit_current_.back(); + auto [curr_edge, curr_vertex, curr_acc_path] = to_visit_current_.back(); to_visit_current_.pop_back(); // create the frame value for the edges auto *pull_memory = context.evaluation_context.memory; utils::pmr::vector<TypedValue> edge_list(pull_memory); - edge_list.emplace_back(expansion.first); - auto last_vertex = expansion.second; + edge_list.emplace_back(curr_edge); + auto last_vertex = curr_vertex; while (true) { const EdgeAccessor &last_edge = edge_list.back().ValueEdge(); last_vertex = last_edge.From() == last_vertex ? last_edge.To() : last_edge.From(); @@ -1657,11 +1675,17 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { } // expand only if what we've just expanded is less then max depth - if (static_cast<int64_t>(edge_list.size()) < upper_bound_) expand_from_vertex(expansion.second); + if (static_cast<int64_t>(edge_list.size()) < upper_bound_) { + if (self_.filter_lambda_.accumulated_path_symbol) { + MG_ASSERT(curr_acc_path.has_value(), "Expected non-null accumulated path"); + frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(curr_acc_path.value()); + } + expand_from_vertex(curr_vertex); + } if (static_cast<int64_t>(edge_list.size()) < lower_bound_) continue; - frame[self_.common_.node_symbol] = expansion.second; + frame[self_.common_.node_symbol] = curr_vertex; // place edges on the frame in the correct order std::reverse(edge_list.begin(), edge_list.end()); @@ -1693,9 +1717,9 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { // edge because the root does not get expanded from anything. // contains visited vertices as well as those scheduled to be visited. utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>> processed_; - // edge/vertex pairs we have yet to visit, for current and next depth - utils::pmr::vector<std::pair<EdgeAccessor, VertexAccessor>> to_visit_current_; - utils::pmr::vector<std::pair<EdgeAccessor, VertexAccessor>> to_visit_next_; + // edge, vertex we have yet to visit, for current and next depth and their accumulated paths + utils::pmr::vector<std::tuple<EdgeAccessor, VertexAccessor, std::optional<Path>>> to_visit_next_; + utils::pmr::vector<std::tuple<EdgeAccessor, VertexAccessor, std::optional<Path>>> to_visit_current_; }; namespace { @@ -1768,6 +1792,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); + auto create_state = [this](const VertexAccessor &vertex, int64_t depth) { return std::make_pair(vertex, upper_bound_set_ ? depth : 0); }; @@ -1791,6 +1816,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { frame[self_.weight_lambda_->inner_node_symbol] = vertex; TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator); + std::optional<Path> curr_acc_path = std::nullopt; if (self_.filter_lambda_.expression) { frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_node_symbol] = vertex; @@ -1800,6 +1826,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath(); accumulated_path.Expand(edge); accumulated_path.Expand(vertex); + curr_acc_path = accumulated_path; if (self_.filter_lambda_.accumulated_weight_symbol) { frame[self_.filter_lambda_.accumulated_weight_symbol.value()] = next_weight; @@ -1815,24 +1842,32 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool())) return; - pq_.emplace(next_weight, depth + 1, vertex, edge); + pq_.emplace(next_weight, depth + 1, vertex, edge, curr_acc_path); + }; + + auto restore_frame_state_after_expansion = [this, &frame]() { + if (self_.filter_lambda_.accumulated_path_symbol) { + frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink(); + } }; // Populates the priority queue structure with expansions // from the given vertex. skips expansions that don't satisfy // the "where" condition. - auto expand_from_vertex = [this, &expand_pair](const VertexAccessor &vertex, const TypedValue &weight, - int64_t depth) { + auto expand_from_vertex = [this, &expand_pair, &restore_frame_state_after_expansion]( + const VertexAccessor &vertex, const TypedValue &weight, int64_t depth) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { expand_pair(edge, edge.To(), weight, depth); + restore_frame_state_after_expansion(); } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { expand_pair(edge, edge.From(), weight, depth); + restore_frame_state_after_expansion(); } } }; @@ -1850,9 +1885,12 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { // Skip expansion for such nodes. if (node.IsNull()) continue; } + + std::optional<Path> curr_acc_path; if (self_.filter_lambda_.accumulated_path_symbol) { // Add initial vertex of path to the accumulated path - frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(vertex); + curr_acc_path = Path(vertex); + frame[self_.filter_lambda_.accumulated_path_symbol.value()] = curr_acc_path.value(); } if (self_.upper_bound_) { upper_bound_ = EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in weighted shortest path expansion"); @@ -1876,7 +1914,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { total_cost_.clear(); yielded_vertices_.clear(); - pq_.emplace(current_weight, 0, vertex, std::nullopt); + pq_.emplace(current_weight, 0, vertex, std::nullopt, curr_acc_path); // We are adding the starting vertex to the set of yielded vertices // because we don't want to yield paths that end with the starting // vertex. @@ -1885,7 +1923,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { while (!pq_.empty()) { AbortCheck(context); - auto [current_weight, current_depth, current_vertex, current_edge] = pq_.top(); + auto [current_weight, current_depth, current_vertex, current_edge, curr_acc_path] = pq_.top(); pq_.pop(); auto current_state = create_state(current_vertex, current_depth); @@ -1898,7 +1936,12 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { total_cost_.emplace(current_state, current_weight); // Expand only if what we've just expanded is less than max depth. - if (current_depth < upper_bound_) expand_from_vertex(current_vertex, current_weight, current_depth); + if (current_depth < upper_bound_) { + if (self_.filter_lambda_.accumulated_path_symbol) { + frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(curr_acc_path.value()); + } + expand_from_vertex(current_vertex, current_weight, current_depth); + } // If we yielded a path for a vertex already, make the expansion but // don't return the path again. @@ -1921,12 +1964,12 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { // Place destination node on the frame, handle existence flag. if (self_.common_.existing_node) { const auto &node = frame[self_.common_.node_symbol]; - if ((node != TypedValue(current_vertex, pull_memory)).ValueBool()) + if ((node != TypedValue(current_vertex, pull_memory)).ValueBool()) { continue; - else - // Prevent expanding other paths, because we found the - // shortest to existing node. - ClearQueue(); + } + // Prevent expanding other paths, because we found the + // shortest to existing node. + ClearQueue(); } else { frame[self_.common_.node_symbol] = current_vertex; } @@ -1979,8 +2022,9 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { // Priority queue comparator. Keep lowest weight on top of the queue. class PriorityQueueComparator { public: - bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs, - const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) { + bool operator()( + const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>> &lhs, + const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>> &rhs) { const auto &lhs_weight = std::get<0>(lhs); const auto &rhs_weight = std::get<0>(rhs); // Null defines minimum value for all types @@ -1997,8 +2041,9 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { } }; - std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>, - utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>, + std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>>, + utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, + std::optional<Path>>>, PriorityQueueComparator> pq_; @@ -2024,6 +2069,9 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); + + auto *memory = context.evaluation_context.memory; + auto create_state = [this](const VertexAccessor &vertex, int64_t depth) { return std::make_pair(vertex, upper_bound_set_ ? depth : 0); }; @@ -2041,6 +2089,7 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator); // If filter expression exists, evaluate filter + std::optional<Path> curr_acc_path = std::nullopt; if (self_.filter_lambda_.expression) { frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_node_symbol] = next_vertex; @@ -2050,6 +2099,7 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath(); accumulated_path.Expand(edge); accumulated_path.Expand(next_vertex); + curr_acc_path = accumulated_path; if (self_.filter_lambda_.accumulated_weight_symbol) { frame[self_.filter_lambda_.accumulated_weight_symbol.value()] = next_weight; @@ -2076,14 +2126,20 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { } DirectedEdge directed_edge = {edge, direction, next_weight}; - pq_.emplace(next_weight, depth + 1, next_vertex, directed_edge); + pq_.emplace(next_weight, depth + 1, next_vertex, directed_edge, curr_acc_path); + }; + + auto restore_frame_state_after_expansion = [this, &frame]() { + if (self_.filter_lambda_.accumulated_path_symbol) { + frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink(); + } }; // Populates the priority queue structure with expansions // from the given vertex. skips expansions that don't satisfy // the "where" condition. - auto expand_from_vertex = [this, &expand_vertex, &context](const VertexAccessor &vertex, const TypedValue &weight, - int64_t depth) { + auto expand_from_vertex = [this, &expand_vertex, &context, &restore_frame_state_after_expansion]( + const VertexAccessor &vertex, const TypedValue &weight, int64_t depth) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { @@ -2096,6 +2152,7 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { } #endif expand_vertex(edge, EdgeAtom::Direction::OUT, weight, depth); + restore_frame_state_after_expansion(); } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { @@ -2110,12 +2167,12 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { } #endif expand_vertex(edge, EdgeAtom::Direction::IN, weight, depth); + restore_frame_state_after_expansion(); } } }; std::optional<VertexAccessor> start_vertex; - auto *memory = context.evaluation_context.memory; auto create_path = [this, &frame, &memory]() { auto ¤t_level = traversal_stack_.back(); @@ -2167,11 +2224,11 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { return true; }; - auto create_DFS_traversal_tree = [this, &context, &memory, &create_state, &expand_from_vertex]() { + auto create_DFS_traversal_tree = [this, &context, &memory, &frame, &create_state, &expand_from_vertex]() { while (!pq_.empty()) { AbortCheck(context); - const auto [current_weight, current_depth, current_vertex, directed_edge] = pq_.top(); + auto [current_weight, current_depth, current_vertex, directed_edge, acc_path] = pq_.top(); pq_.pop(); const auto &[current_edge, direction, weight] = directed_edge; @@ -2183,6 +2240,10 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { } else { total_cost_.emplace(current_state, current_weight); if (current_depth < upper_bound_) { + if (self_.filter_lambda_.accumulated_path_symbol) { + DMG_ASSERT(acc_path.has_value(), "Path must be already filled in AllShortestPath DFS traversals"); + frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(acc_path.value()); + } expand_from_vertex(current_vertex, current_weight, current_depth); } } @@ -2315,8 +2376,8 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { // Priority queue comparator. Keep lowest weight on top of the queue. class PriorityQueueComparator { public: - bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge> &lhs, - const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge> &rhs) { + bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>> &lhs, + const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>> &rhs) { const auto &lhs_weight = std::get<0>(lhs); const auto &rhs_weight = std::get<0>(rhs); // Null defines minimum value for all types @@ -2335,9 +2396,10 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { // Priority queue - core element of the algorithm. // Stores: {weight, depth, next vertex, edge and direction} - std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge>, - utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge>>, - PriorityQueueComparator> + std::priority_queue< + std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>>, + utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>>>, + PriorityQueueComparator> pq_; void ClearQueue() { diff --git a/tests/gql_behave/tests/memgraph_V1/features/match.feature b/tests/gql_behave/tests/memgraph_V1/features/match.feature index 0d0477ad9..eaf8d3f44 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/match.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/match.feature @@ -762,6 +762,16 @@ Feature: Match | path | | <(:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3})> | + Scenario: Test DFS variable expand using IN edges with filter by edge type1 + Given graph "graph_edges" + When executing query: + """ + MATCH path=(:label3)<-[* (e, n, p | NOT(type(e)='type1' AND type(last(relationships(p))) = 'type1'))]-(:label1) RETURN path; + """ + Then the result should be: + | path | + | <(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})> | + Scenario: Test DFS variable expand with filter by edge type2 Given graph "graph_edges" When executing query: @@ -772,6 +782,16 @@ Feature: Match | path | | <(:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})> | + Scenario: Test DFS variable expand using IN edges with filter by edge type2 + Given graph "graph_edges" + When executing query: + """ + MATCH path=(:label3)<-[* (e, n, p | NOT(type(e)='type2' AND type(last(relationships(p))) = 'type2'))]-(:label1) RETURN path; + """ + Then the result should be: + | path | + | <(:label3 {id: 3})<-[:type1 {id: 2}]-(:label2 {id: 2})<-[:type1 {id: 1}]-(:label1 {id: 1})> | + Scenario: Using path indentifier from CREATE in MERGE Given an empty graph And having executed: diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_allshortest.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_allshortest.feature index 29dc0a5ef..03d041d6e 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_allshortest.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_allshortest.feature @@ -205,11 +205,7 @@ Feature: All Shortest Path | 20.3 | Scenario: Test match AllShortest with accumulated path filtered by order of ids - Given an empty graph - And having executed: - """ - CREATE (:label1 {id: 1})-[:type1 {id:1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})-[:type1 {id: 3}]->(:label4 {id: 4}); - """ + Given graph "graph_edges" When executing query: """ MATCH pth=(:label1)-[*ALLSHORTEST (r, n | r.id) total_weight (e,n,p | e.id > 0 and (nodes(p)[-1]).id > (nodes(p)[-2]).id)]->(:label4) RETURN pth, total_weight; @@ -218,6 +214,16 @@ Feature: All Shortest Path | pth | total_weight | | <(:label1{id:1})-[:type1{id:1}]->(:label2{id:2})-[:type1{id:2}]->(:label3{id:3})-[:type1{id:3}]->(:label4{id:4})> | 6 | + Scenario: Test match AllShortest using IN edges with accumulated path filtered by order of ids + Given graph "graph_edges" + When executing query: + """ + MATCH pth=(:label4)<-[*ALLSHORTEST (r, n | r.id) total_weight (e,n,p | e.id > 0 and (nodes(p)[-1]).id < (nodes(p)[-2]).id)]-(:label1) RETURN pth, total_weight; + """ + Then the result should be: + | pth | total_weight | + | <(:label4{id:4})<-[:type1{id:3}]-(:label3{id:3})<-[:type1{id:2}]-(:label2{id:2})<-[:type1{id:1}]-(:label1{id:1})> | 6 | + Scenario: Test match AllShortest with accumulated path filtered by edge type1 Given graph "graph_edges" When executing query: diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature index 2736a6d71..23edc69cd 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature @@ -136,6 +136,42 @@ Feature: Bfs | pth | | <(:label1{id:1})-[:type1{id:1}]->(:label2{id:2})-[:type1{id:2}]->(:label3{id:3})> | + Scenario: Test BFS variable expand using IN edges with filter by last edge type of accumulated path + Given graph "graph_edges" + When executing query: + """ + MATCH pth=(:label3)<-[*BFS (e,n,p | type(relationships(p)[-1]) = 'type1')]-(:label1) return pth; + """ + Then the result should be: + | pth | + | <(:label3 {id: 3})<-[:type1 {id: 2}]-(:label2 {id: 2})<-[:type1 {id: 1}]-(:label1 {id: 1})> | + + Scenario: Test BFS variable expand using IN edges with filter by number of vertices in the accumulated path + Given graph "graph_edges" + When executing query: + """ + MATCH p=(n)<-[*BFS (r, n, p | size(nodes(p)) > 0)]-(m {id:1}) return p; + """ + Then the result should be: + | p | + | <(:label2 {id: 2})<-[:type1 {id: 1}]-(:label1 {id: 1})> | + | <(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})> | + | <(:label4 {id: 4})<-[:type1 {id: 3}]-(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})> | + | <(:label5 {id: 5})<-[:type3 {id: 20}]-(:label1 {id: 1})> | + + Scenario: Test BFS variable expand using IN edges with filter by id of vertices but accumulated path is not used + Given graph "graph_edges" + When executing query: + """ + MATCH p=(n)<-[*BFS (r, n, p | r.id > 0)]-(m {id:1}) return p; + """ + Then the result should be: + | p | + | <(:label2 {id: 2})<-[:type1 {id: 1}]-(:label1 {id: 1})> | + | <(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})> | + | <(:label4 {id: 4})<-[:type1 {id: 3}]-(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})> | + | <(:label5 {id: 5})<-[:type3 {id: 20}]-(:label1 {id: 1})> | + Scenario: Test BFS variable expand with restict filter by last edge type of accumulated path Given an empty graph And having executed: diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature index 819bc94b3..afd484696 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature @@ -170,6 +170,20 @@ Feature: Weighted Shortest Path | pth | total_weight | | <(:label1{id:1})-[:type1{id:1}]->(:label2{id:2})-[:type1{id:2}]->(:label3{id:3})-[:type1{id:3}]->(:label4{id:4})> | 6 | + Scenario: Test match wShortest using IN edges with accumulated path filtered by order of ids + Given an empty graph + And having executed: + """ + CREATE (:label1 {id: 1})-[:type1 {id:1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})-[:type1 {id: 3}]->(:label4 {id: 4}); + """ + When executing query: + """ + MATCH pth=(:label4)<-[*WSHORTEST (r, n | r.id) total_weight (e,n,p | e.id > 0 and (nodes(p)[-1]).id < (nodes(p)[-2]).id)]-(:label1) RETURN pth, total_weight; + """ + Then the result should be: + | pth | total_weight | + | <(:label4{id:4})<-[:type1{id:3}]-(:label3{id:3})<-[:type1{id:2}]-(:label2{id:2})<-[:type1{id:1}]-(:label1{id:1})> | 6 | + Scenario: Test match wShortest with accumulated path filtered by edge type1 Given graph "graph_edges" When executing query: diff --git a/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher b/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher index 06e7cdb5c..3657b855b 100644 --- a/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher +++ b/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher @@ -1,2 +1,3 @@ CREATE (:label1 {id: 1})-[:type1 {id:1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})-[:type1 {id: 3}]->(:label4 {id: 4}); MATCH (n :label1), (m :label3) CREATE (n)-[:type2 {id: 10}]->(m); +MATCH (n :label1) CREATE (n)-[:type3 {id: 20}]->(:label5 { id: 5 }); From cf80687d1deed51c876883bdaa60d64d0655b5b7 Mon Sep 17 00:00:00 2001 From: Andi <andi8647@gmail.com> Date: Thu, 8 Feb 2024 10:11:33 +0100 Subject: [PATCH 4/4] HA: Organize Raft coordinator group (#1687) --- .clang-tidy | 2 + CMakeLists.txt | 2 +- src/coordination/CMakeLists.txt | 13 +- src/coordination/coordinator_data.cpp | 126 ++++--- src/coordination/coordinator_instance.cpp | 132 +++---- src/coordination/coordinator_log_store.cpp | 331 ++++++++++++++++++ src/coordination/coordinator_state.cpp | 12 +- .../coordinator_state_machine.cpp | 98 ++++++ .../coordinator_state_manager.cpp | 68 ++++ .../coordination/coordinator_client.hpp | 2 - .../include/coordination/coordinator_data.hpp | 21 +- .../coordination/coordinator_exceptions.hpp | 22 ++ .../coordination/coordinator_handlers.hpp | 4 +- .../coordination/coordinator_instance.hpp | 71 ++-- .../coordination/coordinator_state.hpp | 6 +- ...nstance_status.hpp => instance_status.hpp} | 9 +- .../coordination/replication_instance.hpp | 84 +++++ .../include/nuraft/coordinator_log_store.hpp | 128 +++++++ .../nuraft/coordinator_state_machine.hpp | 72 ++++ .../nuraft/coordinator_state_manager.hpp | 66 ++++ src/coordination/replication_instance.cpp | 98 ++++++ src/dbms/coordinator_handler.cpp | 11 +- src/dbms/coordinator_handler.hpp | 6 +- src/flags/replication.cpp | 6 +- src/flags/replication.hpp | 6 +- src/io/network/endpoint.cpp | 2 +- src/io/network/endpoint.hpp | 6 +- src/query/frontend/ast/ast.hpp | 10 +- .../frontend/ast/cypher_main_visitor.cpp | 25 +- .../frontend/ast/cypher_main_visitor.hpp | 7 +- .../opencypher/grammar/MemgraphCypher.g4 | 12 +- .../opencypher/grammar/MemgraphCypherLexer.g4 | 3 +- .../frontend/stripped_lexer_constants.hpp | 3 +- src/query/interpreter.cpp | 72 +++- src/query/interpreter.hpp | 8 +- src/query/metadata.cpp | 2 + src/query/metadata.hpp | 1 + tests/e2e/configuration/default_config.py | 3 +- .../CMakeLists.txt | 1 + .../automatic_failover.py | 167 +++++---- .../coordinator.py | 15 +- .../distributed_coordinators.py | 145 ++++++++ .../workloads.yaml | 6 +- tests/unit/cypher_main_visitor.cpp | 13 + 44 files changed, 1581 insertions(+), 316 deletions(-) create mode 100644 src/coordination/coordinator_log_store.cpp create mode 100644 src/coordination/coordinator_state_machine.cpp create mode 100644 src/coordination/coordinator_state_manager.cpp rename src/coordination/include/coordination/{coordinator_instance_status.hpp => instance_status.hpp} (71%) create mode 100644 src/coordination/include/coordination/replication_instance.hpp create mode 100644 src/coordination/include/nuraft/coordinator_log_store.hpp create mode 100644 src/coordination/include/nuraft/coordinator_state_machine.hpp create mode 100644 src/coordination/include/nuraft/coordinator_state_manager.hpp create mode 100644 src/coordination/replication_instance.cpp create mode 100644 tests/e2e/high_availability_experimental/distributed_coordinators.py diff --git a/.clang-tidy b/.clang-tidy index 5773ea5cd..a30f9e592 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -6,6 +6,7 @@ Checks: '*, -altera-unroll-loops, -android-*, -cert-err58-cpp, + -cppcoreguidelines-avoid-do-while, -cppcoreguidelines-avoid-c-arrays, -cppcoreguidelines-avoid-goto, -cppcoreguidelines-avoid-magic-numbers, @@ -60,6 +61,7 @@ Checks: '*, -readability-implicit-bool-conversion, -readability-magic-numbers, -readability-named-parameter, + -readability-identifier-length, -misc-no-recursion, -concurrency-mt-unsafe, -bugprone-easily-swappable-parameters' diff --git a/CMakeLists.txt b/CMakeLists.txt index 266a3bedb..7245bf9f8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -275,7 +275,7 @@ option(MG_EXPERIMENTAL_HIGH_AVAILABILITY "Feature flag for experimental high ava if (NOT MG_ENTERPRISE AND MG_EXPERIMENTAL_HIGH_AVAILABILITY) set(MG_EXPERIMENTAL_HIGH_AVAILABILITY OFF) - message(FATAL_ERROR "MG_EXPERIMENTAL_HIGH_AVAILABILITY must be used with enterpise version of the code.") + message(FATAL_ERROR "MG_EXPERIMENTAL_HIGH_AVAILABILITY can only be used with enterpise version of the code.") endif () if (MG_EXPERIMENTAL_HIGH_AVAILABILITY) diff --git a/src/coordination/CMakeLists.txt b/src/coordination/CMakeLists.txt index d44cbcd26..d6ab23132 100644 --- a/src/coordination/CMakeLists.txt +++ b/src/coordination/CMakeLists.txt @@ -8,12 +8,18 @@ target_sources(mg-coordination include/coordination/coordinator_server.hpp include/coordination/coordinator_config.hpp include/coordination/coordinator_exceptions.hpp - include/coordination/coordinator_instance.hpp include/coordination/coordinator_slk.hpp include/coordination/coordinator_data.hpp include/coordination/constants.hpp include/coordination/coordinator_cluster_config.hpp include/coordination/coordinator_handlers.hpp + include/coordination/coordinator_instance.hpp + include/coordination/instance_status.hpp + include/coordination/replication_instance.hpp + + include/nuraft/coordinator_log_store.hpp + include/nuraft/coordinator_state_machine.hpp + include/nuraft/coordinator_state_manager.hpp PRIVATE coordinator_client.cpp @@ -23,6 +29,11 @@ target_sources(mg-coordination coordinator_data.cpp coordinator_instance.cpp coordinator_handlers.cpp + replication_instance.cpp + + coordinator_log_store.cpp + coordinator_state_machine.cpp + coordinator_state_manager.cpp ) target_include_directories(mg-coordination PUBLIC include) diff --git a/src/coordination/coordinator_data.cpp b/src/coordination/coordinator_data.cpp index 3eb251003..3732958de 100644 --- a/src/coordination/coordinator_data.cpp +++ b/src/coordination/coordinator_data.cpp @@ -9,27 +9,29 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include "coordination/coordinator_instance.hpp" -#include "coordination/register_main_replica_coordinator_status.hpp" -#include "utils/uuid.hpp" #ifdef MG_ENTERPRISE #include "coordination/coordinator_data.hpp" +#include "coordination/register_main_replica_coordinator_status.hpp" +#include "coordination/replication_instance.hpp" +#include "utils/uuid.hpp" + #include <range/v3/view.hpp> #include <shared_mutex> -#include "libnuraft/nuraft.hxx" namespace memgraph::coordination { -CoordinatorData::CoordinatorData() { - auto find_instance = [](CoordinatorData *coord_data, std::string_view instance_name) -> CoordinatorInstance & { - auto instance = std::ranges::find_if( - coord_data->registered_instances_, - [instance_name](CoordinatorInstance const &instance) { return instance.InstanceName() == instance_name; }); +using nuraft::ptr; +using nuraft::srv_config; - MG_ASSERT(instance != coord_data->registered_instances_.end(), "Instance {} not found during callback!", - instance_name); +CoordinatorData::CoordinatorData() { + auto find_instance = [](CoordinatorData *coord_data, std::string_view instance_name) -> ReplicationInstance & { + auto instance = std::ranges::find_if( + coord_data->repl_instances_, + [instance_name](ReplicationInstance const &instance) { return instance.InstanceName() == instance_name; }); + + MG_ASSERT(instance != coord_data->repl_instances_.end(), "Instance {} not found during callback!", instance_name); return *instance; }; @@ -70,6 +72,11 @@ CoordinatorData::CoordinatorData() { auto &instance = find_instance(coord_data, instance_name); + if (instance.IsAlive()) { + instance.OnSuccessPing(); + return; + } + const auto &instance_uuid = instance.GetMainUUID(); MG_ASSERT(instance_uuid.has_value(), "Instance must have uuid set"); if (main_uuid_ == instance_uuid.value()) { @@ -110,48 +117,40 @@ CoordinatorData::CoordinatorData() { } auto CoordinatorData::TryFailover() -> void { - std::vector<CoordinatorInstance *> alive_registered_replica_instances{}; - std::ranges::transform(registered_instances_ | ranges::views::filter(&CoordinatorInstance::IsReplica) | - ranges::views::filter(&CoordinatorInstance::IsAlive), - std::back_inserter(alive_registered_replica_instances), - [](CoordinatorInstance &instance) { return &instance; }); + auto alive_replicas = repl_instances_ | ranges::views::filter(&ReplicationInstance::IsReplica) | + ranges::views::filter(&ReplicationInstance::IsAlive); - // TODO(antoniof) more complex logic of choosing replica instance - CoordinatorInstance *chosen_replica_instance = - !alive_registered_replica_instances.empty() ? alive_registered_replica_instances[0] : nullptr; - - if (nullptr == chosen_replica_instance) { + if (ranges::empty(alive_replicas)) { spdlog::warn("Failover failed since all replicas are down!"); return; } + // TODO: Smarter choice + auto chosen_replica_instance = ranges::begin(alive_replicas); + chosen_replica_instance->PauseFrequentCheck(); utils::OnScopeExit scope_exit{[&chosen_replica_instance] { chosen_replica_instance->ResumeFrequentCheck(); }}; - utils::UUID potential_new_main_uuid = utils::UUID{}; - spdlog::trace("Generated potential new main uuid"); + auto const potential_new_main_uuid = utils::UUID{}; - auto not_chosen_instance = [chosen_replica_instance](auto *instance) { - return *instance != *chosen_replica_instance; + auto const is_not_chosen_replica_instance = [&chosen_replica_instance](ReplicationInstance &instance) { + return instance != *chosen_replica_instance; }; + // If for some replicas swap fails, for others on successful ping we will revert back on next change // or we will do failover first again and then it will be consistent again - for (auto *other_replica_instance : alive_registered_replica_instances | ranges::views::filter(not_chosen_instance)) { - if (!other_replica_instance->SendSwapAndUpdateUUID(potential_new_main_uuid)) { + for (auto &other_replica_instance : alive_replicas | ranges::views::filter(is_not_chosen_replica_instance)) { + if (!other_replica_instance.SendSwapAndUpdateUUID(potential_new_main_uuid)) { spdlog::error(fmt::format("Failed to swap uuid for instance {} which is alive, aborting failover", - other_replica_instance->InstanceName())); + other_replica_instance.InstanceName())); return; } } std::vector<ReplClientInfo> repl_clients_info; - repl_clients_info.reserve(registered_instances_.size() - 1); - - std::ranges::transform(registered_instances_ | ranges::views::filter([chosen_replica_instance](const auto &instance) { - return *chosen_replica_instance != instance; - }), - std::back_inserter(repl_clients_info), - [](const CoordinatorInstance &instance) { return instance.ReplicationClientInfo(); }); + repl_clients_info.reserve(repl_instances_.size() - 1); + std::ranges::transform(repl_instances_ | ranges::views::filter(is_not_chosen_replica_instance), + std::back_inserter(repl_clients_info), &ReplicationInstance::ReplicationClientInfo); if (!chosen_replica_instance->PromoteToMain(potential_new_main_uuid, std::move(repl_clients_info), main_succ_cb_, main_fail_cb_)) { @@ -164,41 +163,53 @@ auto CoordinatorData::TryFailover() -> void { spdlog::info("Failover successful! Instance {} promoted to main.", chosen_replica_instance->InstanceName()); } -auto CoordinatorData::ShowInstances() const -> std::vector<CoordinatorInstanceStatus> { - std::vector<CoordinatorInstanceStatus> instances_status; - instances_status.reserve(registered_instances_.size()); +auto CoordinatorData::ShowInstances() const -> std::vector<InstanceStatus> { + auto const coord_instances = self_.GetAllCoordinators(); - auto const stringify_repl_role = [](CoordinatorInstance const &instance) -> std::string { + std::vector<InstanceStatus> instances_status; + instances_status.reserve(repl_instances_.size() + coord_instances.size()); + + auto const stringify_repl_role = [](ReplicationInstance const &instance) -> std::string { if (!instance.IsAlive()) return "unknown"; if (instance.IsMain()) return "main"; return "replica"; }; - auto const instance_to_status = - [&stringify_repl_role](CoordinatorInstance const &instance) -> CoordinatorInstanceStatus { + auto const repl_instance_to_status = [&stringify_repl_role](ReplicationInstance const &instance) -> InstanceStatus { return {.instance_name = instance.InstanceName(), - .socket_address = instance.SocketAddress(), - .replication_role = stringify_repl_role(instance), + .coord_socket_address = instance.SocketAddress(), + .cluster_role = stringify_repl_role(instance), .is_alive = instance.IsAlive()}; }; + auto const coord_instance_to_status = [](ptr<srv_config> const &instance) -> InstanceStatus { + return {.instance_name = "coordinator_" + std::to_string(instance->get_id()), + .raft_socket_address = instance->get_endpoint(), + .cluster_role = "coordinator", + .is_alive = true}; // TODO: (andi) Get this info from RAFT and test it or when we will move + // CoordinatorState to every instance, we can be smarter about this using our RPC. + }; + + std::ranges::transform(coord_instances, std::back_inserter(instances_status), coord_instance_to_status); + { auto lock = std::shared_lock{coord_data_lock_}; - std::ranges::transform(registered_instances_, std::back_inserter(instances_status), instance_to_status); + std::ranges::transform(repl_instances_, std::back_inserter(instances_status), repl_instance_to_status); } return instances_status; } +// TODO: (andi) Make sure you cannot put coordinator instance to the main auto CoordinatorData::SetInstanceToMain(std::string instance_name) -> SetInstanceToMainCoordinatorStatus { auto lock = std::lock_guard{coord_data_lock_}; - auto const is_new_main = [&instance_name](CoordinatorInstance const &instance) { + auto const is_new_main = [&instance_name](ReplicationInstance const &instance) { return instance.InstanceName() == instance_name; }; - auto new_main = std::ranges::find_if(registered_instances_, is_new_main); + auto new_main = std::ranges::find_if(repl_instances_, is_new_main); - if (new_main == registered_instances_.end()) { + if (new_main == repl_instances_.end()) { spdlog::error("Instance {} not registered. Please register it using REGISTER INSTANCE {}", instance_name, instance_name); return SetInstanceToMainCoordinatorStatus::NO_INSTANCE_WITH_NAME; @@ -208,16 +219,16 @@ auto CoordinatorData::SetInstanceToMain(std::string instance_name) -> SetInstanc utils::OnScopeExit scope_exit{[&new_main] { new_main->ResumeFrequentCheck(); }}; ReplicationClientsInfo repl_clients_info; - repl_clients_info.reserve(registered_instances_.size() - 1); + repl_clients_info.reserve(repl_instances_.size() - 1); - auto const is_not_new_main = [&instance_name](CoordinatorInstance const &instance) { + auto const is_not_new_main = [&instance_name](ReplicationInstance const &instance) { return instance.InstanceName() != instance_name; }; auto potential_new_main_uuid = utils::UUID{}; spdlog::trace("Generated potential new main uuid"); - for (auto &other_instance : registered_instances_ | ranges::views::filter(is_not_new_main)) { + for (auto &other_instance : repl_instances_ | ranges::views::filter(is_not_new_main)) { if (!other_instance.SendSwapAndUpdateUUID(potential_new_main_uuid)) { spdlog::error( fmt::format("Failed to swap uuid for instance {}, aborting failover", other_instance.InstanceName())); @@ -225,9 +236,9 @@ auto CoordinatorData::SetInstanceToMain(std::string instance_name) -> SetInstanc } } - std::ranges::transform(registered_instances_ | ranges::views::filter(is_not_new_main), + std::ranges::transform(repl_instances_ | ranges::views::filter(is_not_new_main), std::back_inserter(repl_clients_info), - [](const CoordinatorInstance &instance) { return instance.ReplicationClientInfo(); }); + [](const ReplicationInstance &instance) { return instance.ReplicationClientInfo(); }); if (!new_main->PromoteToMain(potential_new_main_uuid, std::move(repl_clients_info), main_succ_cb_, main_fail_cb_)) { return SetInstanceToMainCoordinatorStatus::COULD_NOT_PROMOTE_TO_MAIN; @@ -241,20 +252,20 @@ auto CoordinatorData::SetInstanceToMain(std::string instance_name) -> SetInstanc auto CoordinatorData::RegisterInstance(CoordinatorClientConfig config) -> RegisterInstanceCoordinatorStatus { auto lock = std::lock_guard{coord_data_lock_}; - if (std::ranges::any_of(registered_instances_, [&config](CoordinatorInstance const &instance) { + if (std::ranges::any_of(repl_instances_, [&config](ReplicationInstance const &instance) { return instance.InstanceName() == config.instance_name; })) { return RegisterInstanceCoordinatorStatus::NAME_EXISTS; } - if (std::ranges::any_of(registered_instances_, [&config](CoordinatorInstance const &instance) { + if (std::ranges::any_of(repl_instances_, [&config](ReplicationInstance const &instance) { return instance.SocketAddress() == config.SocketAddress(); })) { return RegisterInstanceCoordinatorStatus::ENDPOINT_EXISTS; } try { - registered_instances_.emplace_back(this, std::move(config), replica_succ_cb_, replica_fail_cb_); + repl_instances_.emplace_back(this, std::move(config), replica_succ_cb_, replica_fail_cb_); return RegisterInstanceCoordinatorStatus::SUCCESS; } catch (CoordinatorRegisterInstanceException const &) { @@ -262,5 +273,10 @@ auto CoordinatorData::RegisterInstance(CoordinatorClientConfig config) -> Regist } } +auto CoordinatorData::AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) + -> void { + self_.AddCoordinatorInstance(raft_server_id, raft_port, std::move(raft_address)); +} + } // namespace memgraph::coordination #endif diff --git a/src/coordination/coordinator_instance.cpp b/src/coordination/coordinator_instance.cpp index a759a2505..7a0b0fbd0 100644 --- a/src/coordination/coordinator_instance.cpp +++ b/src/coordination/coordinator_instance.cpp @@ -13,83 +13,85 @@ #include "coordination/coordinator_instance.hpp" +#include "coordination/coordinator_exceptions.hpp" +#include "nuraft/coordinator_state_machine.hpp" +#include "nuraft/coordinator_state_manager.hpp" +#include "utils/counter.hpp" + namespace memgraph::coordination { -CoordinatorInstance::CoordinatorInstance(CoordinatorData *data, CoordinatorClientConfig config, - HealthCheckCallback succ_cb, HealthCheckCallback fail_cb) - : client_(data, std::move(config), std::move(succ_cb), std::move(fail_cb)), - replication_role_(replication_coordination_glue::ReplicationRole::REPLICA), - is_alive_(true) { - if (!client_.DemoteToReplica()) { - throw CoordinatorRegisterInstanceException("Failed to demote instance {} to replica", client_.InstanceName()); - } - client_.StartFrequentCheck(); -} +using nuraft::asio_service; +using nuraft::cmd_result; +using nuraft::cs_new; +using nuraft::ptr; +using nuraft::raft_params; +using nuraft::srv_config; +using raft_result = cmd_result<ptr<buffer>>; -auto CoordinatorInstance::OnSuccessPing() -> void { - last_response_time_ = std::chrono::system_clock::now(); - is_alive_ = true; -} +CoordinatorInstance::CoordinatorInstance() + : raft_server_id_(FLAGS_raft_server_id), raft_port_(FLAGS_raft_server_port), raft_address_("127.0.0.1") { + auto raft_endpoint = raft_address_ + ":" + std::to_string(raft_port_); + state_manager_ = cs_new<CoordinatorStateManager>(raft_server_id_, raft_endpoint); + state_machine_ = cs_new<CoordinatorStateMachine>(); + logger_ = nullptr; -auto CoordinatorInstance::OnFailPing() -> bool { - is_alive_ = - std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - last_response_time_).count() < - CoordinatorClusterConfig::alive_response_time_difference_sec_; - return is_alive_; -} + // ASIO options + asio_service::options asio_opts; + asio_opts.thread_pool_size_ = 1; // TODO: (andi) Improve this -auto CoordinatorInstance::InstanceName() const -> std::string { return client_.InstanceName(); } -auto CoordinatorInstance::SocketAddress() const -> std::string { return client_.SocketAddress(); } -auto CoordinatorInstance::IsAlive() const -> bool { return is_alive_; } + // RAFT parameters. Heartbeat every 100ms, election timeout between 200ms and 400ms. + raft_params params; + params.heart_beat_interval_ = 100; + params.election_timeout_lower_bound_ = 200; + params.election_timeout_upper_bound_ = 400; + // 5 logs are preserved before the last snapshot + params.reserved_log_items_ = 5; + // Create snapshot for every 5 log appends + params.snapshot_distance_ = 5; + params.client_req_timeout_ = 3000; + params.return_method_ = raft_params::blocking; -auto CoordinatorInstance::IsReplica() const -> bool { - return replication_role_ == replication_coordination_glue::ReplicationRole::REPLICA; -} -auto CoordinatorInstance::IsMain() const -> bool { - return replication_role_ == replication_coordination_glue::ReplicationRole::MAIN; -} + raft_server_ = + launcher_.init(state_machine_, state_manager_, logger_, static_cast<int>(raft_port_), asio_opts, params); -auto CoordinatorInstance::PromoteToMain(utils::UUID uuid, ReplicationClientsInfo repl_clients_info, - HealthCheckCallback main_succ_cb, HealthCheckCallback main_fail_cb) -> bool { - if (!client_.SendPromoteReplicaToMainRpc(uuid, std::move(repl_clients_info))) { - return false; + if (!raft_server_) { + throw RaftServerStartException("Failed to launch raft server on {}", raft_endpoint); } - replication_role_ = replication_coordination_glue::ReplicationRole::MAIN; - client_.SetCallbacks(std::move(main_succ_cb), std::move(main_fail_cb)); - - return true; -} - -auto CoordinatorInstance::DemoteToReplica(HealthCheckCallback replica_succ_cb, HealthCheckCallback replica_fail_cb) - -> bool { - if (!client_.DemoteToReplica()) { - return false; + auto maybe_stop = utils::ResettableCounter<20>(); + while (!raft_server_->is_initialized() && !maybe_stop()) { + std::this_thread::sleep_for(std::chrono::milliseconds(250)); } - replication_role_ = replication_coordination_glue::ReplicationRole::REPLICA; - client_.SetCallbacks(std::move(replica_succ_cb), std::move(replica_fail_cb)); - - return true; -} - -auto CoordinatorInstance::PauseFrequentCheck() -> void { client_.PauseFrequentCheck(); } -auto CoordinatorInstance::ResumeFrequentCheck() -> void { client_.ResumeFrequentCheck(); } - -auto CoordinatorInstance::ReplicationClientInfo() const -> CoordinatorClientConfig::ReplicationClientInfo { - return client_.ReplicationClientInfo(); -} - -auto CoordinatorInstance::GetClient() -> CoordinatorClient & { return client_; } -void CoordinatorInstance::SetNewMainUUID(const std::optional<utils::UUID> &main_uuid) { main_uuid_ = main_uuid; } -auto CoordinatorInstance::GetMainUUID() -> const std::optional<utils::UUID> & { return main_uuid_; } - -auto CoordinatorInstance::SendSwapAndUpdateUUID(const utils::UUID &main_uuid) -> bool { - if (!replication_coordination_glue::SendSwapMainUUIDRpc(client_.RpcClient(), main_uuid)) { - return false; + if (!raft_server_->is_initialized()) { + throw RaftServerStartException("Failed to initialize raft server on {}", raft_endpoint); } - SetNewMainUUID(main_uuid_); - return true; + + spdlog::info("Raft server started on {}", raft_endpoint); +} + +auto CoordinatorInstance::InstanceName() const -> std::string { + return "coordinator_" + std::to_string(raft_server_id_); +} + +auto CoordinatorInstance::RaftSocketAddress() const -> std::string { + return raft_address_ + ":" + std::to_string(raft_port_); +} + +auto CoordinatorInstance::AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) + -> void { + auto const endpoint = raft_address + ":" + std::to_string(raft_port); + srv_config const srv_config_to_add(static_cast<int>(raft_server_id), endpoint); + if (!raft_server_->add_srv(srv_config_to_add)->get_accepted()) { + throw RaftAddServerException("Failed to add server {} to the cluster", endpoint); + } + spdlog::info("Request to add server {} to the cluster accepted", endpoint); +} + +auto CoordinatorInstance::GetAllCoordinators() const -> std::vector<ptr<srv_config>> { + std::vector<ptr<srv_config>> all_srv_configs; + raft_server_->get_srv_config_all(all_srv_configs); + return all_srv_configs; } } // namespace memgraph::coordination diff --git a/src/coordination/coordinator_log_store.cpp b/src/coordination/coordinator_log_store.cpp new file mode 100644 index 000000000..11b7be0dd --- /dev/null +++ b/src/coordination/coordinator_log_store.cpp @@ -0,0 +1,331 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#ifdef MG_ENTERPRISE + +#include "nuraft/coordinator_log_store.hpp" + +namespace memgraph::coordination { + +using nuraft::cs_new; +using nuraft::timer_helper; + +CoordinatorLogStore::CoordinatorLogStore() + : start_idx_(1), + raft_server_bwd_pointer_(nullptr), + disk_emul_delay(0), + disk_emul_thread_(nullptr), + disk_emul_thread_stop_signal_(false), + disk_emul_last_durable_index_(0) { + // Dummy entry for index 0. + ptr<buffer> buf = buffer::alloc(sz_ulong); + logs_[0] = cs_new<log_entry>(0, buf); +} + +CoordinatorLogStore::~CoordinatorLogStore() { + if (disk_emul_thread_) { + disk_emul_thread_stop_signal_ = true; + // disk_emul_ea_.invoke(); + if (disk_emul_thread_->joinable()) { + disk_emul_thread_->join(); + } + } +} + +ptr<log_entry> CoordinatorLogStore::MakeClone(const ptr<log_entry> &entry) { + // NOTE: + // Timestamp is used only when `replicate_log_timestamp_` option is on. + // Otherwise, log store does not need to store or load it. + ptr<log_entry> clone = cs_new<log_entry>(entry->get_term(), buffer::clone(entry->get_buf()), entry->get_val_type(), + entry->get_timestamp()); + return clone; +} + +ulong CoordinatorLogStore::next_slot() const { + std::lock_guard<std::mutex> l(logs_lock_); + // Exclude the dummy entry. + return start_idx_ + logs_.size() - 1; +} + +ulong CoordinatorLogStore::start_index() const { return start_idx_; } + +ptr<log_entry> CoordinatorLogStore::last_entry() const { + ulong next_idx = next_slot(); + std::lock_guard<std::mutex> l(logs_lock_); + auto entry = logs_.find(next_idx - 1); + if (entry == logs_.end()) { + entry = logs_.find(0); + } + + return MakeClone(entry->second); +} + +ulong CoordinatorLogStore::append(ptr<log_entry> &entry) { + ptr<log_entry> clone = MakeClone(entry); + + std::lock_guard<std::mutex> l(logs_lock_); + size_t idx = start_idx_ + logs_.size() - 1; + logs_[idx] = clone; + + if (disk_emul_delay) { + uint64_t cur_time = timer_helper::get_timeofday_us(); + disk_emul_logs_being_written_[cur_time + disk_emul_delay * 1000] = idx; + // disk_emul_ea_.invoke(); + } + + return idx; +} + +void CoordinatorLogStore::write_at(ulong index, ptr<log_entry> &entry) { + ptr<log_entry> clone = MakeClone(entry); + + // Discard all logs equal to or greater than `index. + std::lock_guard<std::mutex> l(logs_lock_); + auto itr = logs_.lower_bound(index); + while (itr != logs_.end()) { + itr = logs_.erase(itr); + } + logs_[index] = clone; + + if (disk_emul_delay) { + uint64_t cur_time = timer_helper::get_timeofday_us(); + disk_emul_logs_being_written_[cur_time + disk_emul_delay * 1000] = index; + + // Remove entries greater than `index`. + auto entry = disk_emul_logs_being_written_.begin(); + while (entry != disk_emul_logs_being_written_.end()) { + if (entry->second > index) { + entry = disk_emul_logs_being_written_.erase(entry); + } else { + entry++; + } + } + // disk_emul_ea_.invoke(); + } +} + +ptr<std::vector<ptr<log_entry>>> CoordinatorLogStore::log_entries(ulong start, ulong end) { + ptr<std::vector<ptr<log_entry>>> ret = cs_new<std::vector<ptr<log_entry>>>(); + + ret->resize(end - start); + ulong cc = 0; + for (ulong ii = start; ii < end; ++ii) { + ptr<log_entry> src = nullptr; + { + std::lock_guard<std::mutex> l(logs_lock_); + auto entry = logs_.find(ii); + if (entry == logs_.end()) { + entry = logs_.find(0); + assert(0); + } + src = entry->second; + } + (*ret)[cc++] = MakeClone(src); + } + return ret; +} + +// NOLINTNEXTLINE(google-default-arguments) +ptr<std::vector<ptr<log_entry>>> CoordinatorLogStore::log_entries_ext(ulong start, ulong end, + int64 batch_size_hint_in_bytes) { + ptr<std::vector<ptr<log_entry>>> ret = cs_new<std::vector<ptr<log_entry>>>(); + + if (batch_size_hint_in_bytes < 0) { + return ret; + } + + size_t accum_size = 0; + for (ulong ii = start; ii < end; ++ii) { + ptr<log_entry> src = nullptr; + { + std::lock_guard<std::mutex> l(logs_lock_); + auto entry = logs_.find(ii); + if (entry == logs_.end()) { + entry = logs_.find(0); + assert(0); + } + src = entry->second; + } + ret->push_back(MakeClone(src)); + accum_size += src->get_buf().size(); + if (batch_size_hint_in_bytes && accum_size >= (ulong)batch_size_hint_in_bytes) break; + } + return ret; +} + +ptr<log_entry> CoordinatorLogStore::entry_at(ulong index) { + ptr<log_entry> src = nullptr; + { + std::lock_guard<std::mutex> l(logs_lock_); + auto entry = logs_.find(index); + if (entry == logs_.end()) { + entry = logs_.find(0); + } + src = entry->second; + } + return MakeClone(src); +} + +ulong CoordinatorLogStore::term_at(ulong index) { + ulong term = 0; + { + std::lock_guard<std::mutex> l(logs_lock_); + auto entry = logs_.find(index); + if (entry == logs_.end()) { + entry = logs_.find(0); + } + term = entry->second->get_term(); + } + return term; +} + +ptr<buffer> CoordinatorLogStore::pack(ulong index, int32 cnt) { + std::vector<ptr<buffer>> logs; + + size_t size_total = 0; + for (ulong ii = index; ii < index + cnt; ++ii) { + ptr<log_entry> le = nullptr; + { + std::lock_guard<std::mutex> l(logs_lock_); + le = logs_[ii]; + } + assert(le.get()); + ptr<buffer> buf = le->serialize(); + size_total += buf->size(); + logs.push_back(buf); + } + + ptr<buffer> buf_out = buffer::alloc(sizeof(int32) + cnt * sizeof(int32) + size_total); + buf_out->pos(0); + buf_out->put((int32)cnt); + + for (auto &entry : logs) { + ptr<buffer> &bb = entry; + buf_out->put((int32)bb->size()); + buf_out->put(*bb); + } + return buf_out; +} + +void CoordinatorLogStore::apply_pack(ulong index, buffer &pack) { + pack.pos(0); + int32 num_logs = pack.get_int(); + + for (int32 ii = 0; ii < num_logs; ++ii) { + ulong cur_idx = index + ii; + int32 buf_size = pack.get_int(); + + ptr<buffer> buf_local = buffer::alloc(buf_size); + pack.get(buf_local); + + ptr<log_entry> le = log_entry::deserialize(*buf_local); + { + std::lock_guard<std::mutex> l(logs_lock_); + logs_[cur_idx] = le; + } + } + + { + std::lock_guard<std::mutex> l(logs_lock_); + auto entry = logs_.upper_bound(0); + if (entry != logs_.end()) { + start_idx_ = entry->first; + } else { + start_idx_ = 1; + } + } +} + +bool CoordinatorLogStore::compact(ulong last_log_index) { + std::lock_guard<std::mutex> l(logs_lock_); + for (ulong ii = start_idx_; ii <= last_log_index; ++ii) { + auto entry = logs_.find(ii); + if (entry != logs_.end()) { + logs_.erase(entry); + } + } + + // WARNING: + // Even though nothing has been erased, + // we should set `start_idx_` to new index. + if (start_idx_ <= last_log_index) { + start_idx_ = last_log_index + 1; + } + return true; +} + +bool CoordinatorLogStore::flush() { + disk_emul_last_durable_index_ = next_slot() - 1; + return true; +} + +ulong CoordinatorLogStore::last_durable_index() { + uint64_t last_log = next_slot() - 1; + if (!disk_emul_delay) { + return last_log; + } + + return disk_emul_last_durable_index_; +} + +void CoordinatorLogStore::DiskEmulLoop() { + // This thread mimics async disk writes. + + // uint32_t next_sleep_us = 100 * 1000; + while (!disk_emul_thread_stop_signal_) { + // disk_emul_ea_.wait_us(next_sleep_us); + // disk_emul_ea_.reset(); + if (disk_emul_thread_stop_signal_) break; + + uint64_t cur_time = timer_helper::get_timeofday_us(); + // next_sleep_us = 100 * 1000; + + bool call_notification = false; + { + std::lock_guard<std::mutex> l(logs_lock_); + // Remove all timestamps equal to or smaller than `cur_time`, + // and pick the greatest one among them. + auto entry = disk_emul_logs_being_written_.begin(); + while (entry != disk_emul_logs_being_written_.end()) { + if (entry->first <= cur_time) { + disk_emul_last_durable_index_ = entry->second; + entry = disk_emul_logs_being_written_.erase(entry); + call_notification = true; + } else { + break; + } + } + + entry = disk_emul_logs_being_written_.begin(); + if (entry != disk_emul_logs_being_written_.end()) { + // next_sleep_us = entry->first - cur_time; + } + } + + if (call_notification) { + raft_server_bwd_pointer_->notify_log_append_completion(true); + } + } +} + +void CoordinatorLogStore::Close() {} + +void CoordinatorLogStore::SetDiskDelay(raft_server *raft, size_t delay_ms) { + disk_emul_delay = delay_ms; + raft_server_bwd_pointer_ = raft; + + if (!disk_emul_thread_) { + disk_emul_thread_ = std::make_unique<std::thread>(&CoordinatorLogStore::DiskEmulLoop, this); + } +} + +} // namespace memgraph::coordination +#endif diff --git a/src/coordination/coordinator_state.cpp b/src/coordination/coordinator_state.cpp index 96ad1902e..8337fa9d8 100644 --- a/src/coordination/coordinator_state.cpp +++ b/src/coordination/coordinator_state.cpp @@ -25,7 +25,7 @@ namespace memgraph::coordination { CoordinatorState::CoordinatorState() { - MG_ASSERT(!(FLAGS_coordinator && FLAGS_coordinator_server_port), + MG_ASSERT(!(FLAGS_raft_server_id && FLAGS_coordinator_server_port), "Instance cannot be a coordinator and have registered coordinator server."); spdlog::info("Executing coordinator constructor"); @@ -68,7 +68,7 @@ auto CoordinatorState::SetInstanceToMain(std::string instance_name) -> SetInstan data_); } -auto CoordinatorState::ShowInstances() const -> std::vector<CoordinatorInstanceStatus> { +auto CoordinatorState::ShowInstances() const -> std::vector<InstanceStatus> { MG_ASSERT(std::holds_alternative<CoordinatorData>(data_), "Can't call show instances on data_, as variant holds wrong alternative"); return std::get<CoordinatorData>(data_).ShowInstances(); @@ -79,5 +79,13 @@ auto CoordinatorState::GetCoordinatorServer() const -> CoordinatorServer & { "Cannot get coordinator server since variant holds wrong alternative"); return *std::get<CoordinatorMainReplicaData>(data_).coordinator_server_; } + +auto CoordinatorState::AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) + -> void { + MG_ASSERT(std::holds_alternative<CoordinatorData>(data_), + "Coordinator cannot register replica since variant holds wrong alternative"); + return std::get<CoordinatorData>(data_).AddCoordinatorInstance(raft_server_id, raft_port, raft_address); +} + } // namespace memgraph::coordination #endif diff --git a/src/coordination/coordinator_state_machine.cpp b/src/coordination/coordinator_state_machine.cpp new file mode 100644 index 000000000..a278ab422 --- /dev/null +++ b/src/coordination/coordinator_state_machine.cpp @@ -0,0 +1,98 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#ifdef MG_ENTERPRISE + +#include "nuraft/coordinator_state_machine.hpp" + +namespace memgraph::coordination { + +auto CoordinatorStateMachine::pre_commit(ulong const log_idx, buffer &data) -> ptr<buffer> { + buffer_serializer bs(data); + std::string str = bs.get_str(); + + spdlog::info("pre_commit {} : {}", log_idx, str); + return nullptr; +} + +auto CoordinatorStateMachine::commit(ulong const log_idx, buffer &data) -> ptr<buffer> { + buffer_serializer bs(data); + std::string str = bs.get_str(); + + spdlog::info("commit {} : {}", log_idx, str); + + last_committed_idx_ = log_idx; + return nullptr; +} + +auto CoordinatorStateMachine::commit_config(ulong const log_idx, ptr<cluster_config> & /*new_conf*/) -> void { + last_committed_idx_ = log_idx; +} + +auto CoordinatorStateMachine::rollback(ulong const log_idx, buffer &data) -> void { + buffer_serializer bs(data); + std::string str = bs.get_str(); + + spdlog::info("rollback {} : {}", log_idx, str); +} + +auto CoordinatorStateMachine::read_logical_snp_obj(snapshot & /*snapshot*/, void *& /*user_snp_ctx*/, ulong /*obj_id*/, + ptr<buffer> &data_out, bool &is_last_obj) -> int { + // Put dummy data. + data_out = buffer::alloc(sizeof(int32)); + buffer_serializer bs(data_out); + bs.put_i32(0); + + is_last_obj = true; + return 0; +} + +auto CoordinatorStateMachine::save_logical_snp_obj(snapshot &s, ulong &obj_id, buffer & /*data*/, bool /*is_first_obj*/, + bool /*is_last_obj*/) -> void { + spdlog::info("save snapshot {} term {} object ID", s.get_last_log_idx(), s.get_last_log_term(), obj_id); + // Request next object. + obj_id++; +} + +auto CoordinatorStateMachine::apply_snapshot(snapshot &s) -> bool { + spdlog::info("apply snapshot {} term {}", s.get_last_log_idx(), s.get_last_log_term()); + { + auto lock = std::lock_guard{last_snapshot_lock_}; + ptr<buffer> snp_buf = s.serialize(); + last_snapshot_ = snapshot::deserialize(*snp_buf); + } + return true; +} + +auto CoordinatorStateMachine::free_user_snp_ctx(void *&user_snp_ctx) -> void {} + +auto CoordinatorStateMachine::last_snapshot() -> ptr<snapshot> { + auto lock = std::lock_guard{last_snapshot_lock_}; + return last_snapshot_; +} + +auto CoordinatorStateMachine::last_commit_index() -> ulong { return last_committed_idx_; } + +auto CoordinatorStateMachine::create_snapshot(snapshot &s, async_result<bool>::handler_type &when_done) -> void { + spdlog::info("create snapshot {} term {}", s.get_last_log_idx(), s.get_last_log_term()); + // Clone snapshot from `s`. + { + auto lock = std::lock_guard{last_snapshot_lock_}; + ptr<buffer> snp_buf = s.serialize(); + last_snapshot_ = snapshot::deserialize(*snp_buf); + } + ptr<std::exception> except(nullptr); + bool ret = true; + when_done(ret, except); +} + +} // namespace memgraph::coordination +#endif diff --git a/src/coordination/coordinator_state_manager.cpp b/src/coordination/coordinator_state_manager.cpp new file mode 100644 index 000000000..b2fb81ea1 --- /dev/null +++ b/src/coordination/coordinator_state_manager.cpp @@ -0,0 +1,68 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#ifdef MG_ENTERPRISE + +#include "nuraft/coordinator_state_manager.hpp" + +namespace memgraph::coordination { + +using nuraft::cluster_config; +using nuraft::cs_new; +using nuraft::srv_config; +using nuraft::srv_state; +using nuraft::state_mgr; + +CoordinatorStateManager::CoordinatorStateManager(int srv_id, std::string const &endpoint) + : my_id_(srv_id), my_endpoint_(endpoint), cur_log_store_(cs_new<CoordinatorLogStore>()) { + my_srv_config_ = cs_new<srv_config>(srv_id, endpoint); + + // Initial cluster config: contains only one server (myself). + cluster_config_ = cs_new<cluster_config>(); + cluster_config_->get_servers().push_back(my_srv_config_); +} + +auto CoordinatorStateManager::load_config() -> ptr<cluster_config> { + // Just return in-memory data in this example. + // May require reading from disk here, if it has been written to disk. + return cluster_config_; +} + +auto CoordinatorStateManager::save_config(cluster_config const &config) -> void { + // Just keep in memory in this example. + // Need to write to disk here, if want to make it durable. + ptr<buffer> buf = config.serialize(); + cluster_config_ = cluster_config::deserialize(*buf); +} + +auto CoordinatorStateManager::save_state(srv_state const &state) -> void { + // Just keep in memory in this example. + // Need to write to disk here, if want to make it durable. + ptr<buffer> buf = state.serialize(); + saved_state_ = srv_state::deserialize(*buf); +} + +auto CoordinatorStateManager::read_state() -> ptr<srv_state> { + // Just return in-memory data in this example. + // May require reading from disk here, if it has been written to disk. + return saved_state_; +} + +auto CoordinatorStateManager::load_log_store() -> ptr<log_store> { return cur_log_store_; } + +auto CoordinatorStateManager::server_id() -> int32 { return my_id_; } + +auto CoordinatorStateManager::system_exit(int const exit_code) -> void {} + +auto CoordinatorStateManager::GetSrvConfig() const -> ptr<srv_config> { return my_srv_config_; } + +} // namespace memgraph::coordination +#endif diff --git a/src/coordination/include/coordination/coordinator_client.hpp b/src/coordination/include/coordination/coordinator_client.hpp index 00695acd7..76ae49a9f 100644 --- a/src/coordination/include/coordination/coordinator_client.hpp +++ b/src/coordination/include/coordination/coordinator_client.hpp @@ -49,12 +49,10 @@ class CoordinatorClient { auto SendPromoteReplicaToMainRpc(const utils::UUID &uuid, ReplicationClientsInfo replication_clients_info) const -> bool; - auto SendSwapMainUUIDRpc(const utils::UUID &uuid) const -> bool; auto ReplicationClientInfo() const -> ReplClientInfo; - auto SetCallbacks(HealthCheckCallback succ_cb, HealthCheckCallback fail_cb) -> void; auto RpcClient() -> rpc::Client & { return rpc_client_; } diff --git a/src/coordination/include/coordination/coordinator_data.hpp b/src/coordination/include/coordination/coordinator_data.hpp index 73bebdf7e..9f4c60297 100644 --- a/src/coordination/include/coordination/coordinator_data.hpp +++ b/src/coordination/include/coordination/coordinator_data.hpp @@ -11,36 +11,45 @@ #pragma once -#include "utils/uuid.hpp" #ifdef MG_ENTERPRISE -#include <list> #include "coordination/coordinator_instance.hpp" -#include "coordination/coordinator_instance_status.hpp" #include "coordination/coordinator_server.hpp" +#include "coordination/instance_status.hpp" #include "coordination/register_main_replica_coordinator_status.hpp" +#include "coordination/replication_instance.hpp" #include "replication_coordination_glue/handler.hpp" #include "utils/rw_lock.hpp" #include "utils/thread_pool.hpp" +#include "utils/uuid.hpp" + +#include <list> namespace memgraph::coordination { class CoordinatorData { public: CoordinatorData(); + // TODO: (andi) Probably rename to RegisterReplicationInstance [[nodiscard]] auto RegisterInstance(CoordinatorClientConfig config) -> RegisterInstanceCoordinatorStatus; [[nodiscard]] auto SetInstanceToMain(std::string instance_name) -> SetInstanceToMainCoordinatorStatus; + auto ShowInstances() const -> std::vector<InstanceStatus>; + auto TryFailover() -> void; - auto ShowInstances() const -> std::vector<CoordinatorInstanceStatus>; + auto AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) -> void; private: - mutable utils::RWLock coord_data_lock_{utils::RWLock::Priority::READ}; HealthCheckCallback main_succ_cb_, main_fail_cb_, replica_succ_cb_, replica_fail_cb_; + // NOTE: Must be std::list because we rely on pointer stability - std::list<CoordinatorInstance> registered_instances_; + std::list<ReplicationInstance> repl_instances_; + mutable utils::RWLock coord_data_lock_{utils::RWLock::Priority::READ}; + + CoordinatorInstance self_; + utils::UUID main_uuid_; }; diff --git a/src/coordination/include/coordination/coordinator_exceptions.hpp b/src/coordination/include/coordination/coordinator_exceptions.hpp index c9e2dff81..5b697e371 100644 --- a/src/coordination/include/coordination/coordinator_exceptions.hpp +++ b/src/coordination/include/coordination/coordinator_exceptions.hpp @@ -28,5 +28,27 @@ class CoordinatorRegisterInstanceException final : public utils::BasicException SPECIALIZE_GET_EXCEPTION_NAME(CoordinatorRegisterInstanceException) }; +class RaftServerStartException final : public utils::BasicException { + public: + explicit RaftServerStartException(std::string_view what) noexcept : BasicException(what) {} + + template <class... Args> + explicit RaftServerStartException(fmt::format_string<Args...> fmt, Args &&...args) noexcept + : RaftServerStartException(fmt::format(fmt, std::forward<Args>(args)...)) {} + + SPECIALIZE_GET_EXCEPTION_NAME(RaftServerStartException) +}; + +class RaftAddServerException final : public utils::BasicException { + public: + explicit RaftAddServerException(std::string_view what) noexcept : BasicException(what) {} + + template <class... Args> + explicit RaftAddServerException(fmt::format_string<Args...> fmt, Args &&...args) noexcept + : RaftAddServerException(fmt::format(fmt, std::forward<Args>(args)...)) {} + + SPECIALIZE_GET_EXCEPTION_NAME(RaftAddServerException) +}; + } // namespace memgraph::coordination #endif diff --git a/src/coordination/include/coordination/coordinator_handlers.hpp b/src/coordination/include/coordination/coordinator_handlers.hpp index 1f170bd61..4aa4656c3 100644 --- a/src/coordination/include/coordination/coordinator_handlers.hpp +++ b/src/coordination/include/coordination/coordinator_handlers.hpp @@ -31,8 +31,8 @@ class CoordinatorHandlers { slk::Builder *res_builder); static void DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, slk::Builder *res_builder); - static void SwapMainUUIDHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, slk::Builder *res_builder); - + static void SwapMainUUIDHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, + slk::Builder *res_builder); }; } // namespace memgraph::dbms diff --git a/src/coordination/include/coordination/coordinator_instance.hpp b/src/coordination/include/coordination/coordinator_instance.hpp index f3fd3deca..1c7af59ae 100644 --- a/src/coordination/include/coordination/coordinator_instance.hpp +++ b/src/coordination/include/coordination/coordinator_instance.hpp @@ -13,70 +13,45 @@ #ifdef MG_ENTERPRISE -#include "coordination/coordinator_client.hpp" -#include "coordination/coordinator_cluster_config.hpp" -#include "coordination/coordinator_exceptions.hpp" -#include "replication_coordination_glue/handler.hpp" -#include "replication_coordination_glue/role.hpp" +#include <flags/replication.hpp> + +#include <libnuraft/nuraft.hxx> namespace memgraph::coordination { -class CoordinatorData; +using nuraft::logger; +using nuraft::ptr; +using nuraft::raft_launcher; +using nuraft::raft_server; +using nuraft::srv_config; +using nuraft::state_machine; +using nuraft::state_mgr; class CoordinatorInstance { public: - CoordinatorInstance(CoordinatorData *data, CoordinatorClientConfig config, HealthCheckCallback succ_cb, - HealthCheckCallback fail_cb); - + CoordinatorInstance(); CoordinatorInstance(CoordinatorInstance const &other) = delete; CoordinatorInstance &operator=(CoordinatorInstance const &other) = delete; CoordinatorInstance(CoordinatorInstance &&other) noexcept = delete; CoordinatorInstance &operator=(CoordinatorInstance &&other) noexcept = delete; ~CoordinatorInstance() = default; - auto OnSuccessPing() -> void; - auto OnFailPing() -> bool; - - auto IsAlive() const -> bool; - auto InstanceName() const -> std::string; - auto SocketAddress() const -> std::string; - - auto IsReplica() const -> bool; - auto IsMain() const -> bool; - - auto PromoteToMain(utils::UUID main_uuid, ReplicationClientsInfo repl_clients_info, HealthCheckCallback main_succ_cb, - HealthCheckCallback main_fail_cb) -> bool; - auto DemoteToReplica(HealthCheckCallback replica_succ_cb, HealthCheckCallback replica_fail_cb) -> bool; - - auto PauseFrequentCheck() -> void; - auto ResumeFrequentCheck() -> void; - - auto ReplicationClientInfo() const -> ReplClientInfo; - - auto GetClient() -> CoordinatorClient &; - - void SetNewMainUUID(const std::optional<utils::UUID> &main_uuid = std::nullopt); - auto GetMainUUID() -> const std::optional<utils::UUID> &; - - auto SendSwapAndUpdateUUID(const utils::UUID &main_uuid) -> bool; + auto RaftSocketAddress() const -> std::string; + auto AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) -> void; + auto GetAllCoordinators() const -> std::vector<ptr<srv_config>>; private: - CoordinatorClient client_; - replication_coordination_glue::ReplicationRole replication_role_; - std::chrono::system_clock::time_point last_response_time_{}; - // TODO this needs to be atomic? What if instance is alive and then we read it and it has changed - bool is_alive_{false}; - // for replica this is main uuid of current main - // for "main" main this same as in CoordinatorData - // it is set to nullopt when replica is down - // TLDR; when replica is down and comes back up we reset uuid of main replica is listening to - // so we need to send swap uuid again - std::optional<utils::UUID> main_uuid_; + ptr<state_machine> state_machine_; + ptr<state_mgr> state_manager_; + ptr<raft_server> raft_server_; + ptr<logger> logger_; + raft_launcher launcher_; - friend bool operator==(CoordinatorInstance const &first, CoordinatorInstance const &second) { - return first.client_ == second.client_ && first.replication_role_ == second.replication_role_; - } + // TODO: (andi) I think variables below can be abstracted + uint32_t raft_server_id_; + uint32_t raft_port_; + std::string raft_address_; }; } // namespace memgraph::coordination diff --git a/src/coordination/include/coordination/coordinator_state.hpp b/src/coordination/include/coordination/coordinator_state.hpp index 5f52f85e5..9ab33a04e 100644 --- a/src/coordination/include/coordination/coordinator_state.hpp +++ b/src/coordination/include/coordination/coordinator_state.hpp @@ -14,8 +14,8 @@ #ifdef MG_ENTERPRISE #include "coordination/coordinator_data.hpp" -#include "coordination/coordinator_instance_status.hpp" #include "coordination/coordinator_server.hpp" +#include "coordination/instance_status.hpp" #include "coordination/register_main_replica_coordinator_status.hpp" #include <variant> @@ -37,7 +37,9 @@ class CoordinatorState { [[nodiscard]] auto SetInstanceToMain(std::string instance_name) -> SetInstanceToMainCoordinatorStatus; - auto ShowInstances() const -> std::vector<CoordinatorInstanceStatus>; + auto ShowInstances() const -> std::vector<InstanceStatus>; + + auto AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) -> void; // The client code must check that the server exists before calling this method. auto GetCoordinatorServer() const -> CoordinatorServer &; diff --git a/src/coordination/include/coordination/coordinator_instance_status.hpp b/src/coordination/include/coordination/instance_status.hpp similarity index 71% rename from src/coordination/include/coordination/coordinator_instance_status.hpp rename to src/coordination/include/coordination/instance_status.hpp index 2a0a3a985..492410061 100644 --- a/src/coordination/include/coordination/coordinator_instance_status.hpp +++ b/src/coordination/include/coordination/instance_status.hpp @@ -19,10 +19,13 @@ namespace memgraph::coordination { -struct CoordinatorInstanceStatus { +// TODO: (andi) For phase IV. Some instances won't have raft_socket_address, coord_socket_address, replication_role and +// cluster role... At the end, all instances will have everything. +struct InstanceStatus { std::string instance_name; - std::string socket_address; - std::string replication_role; + std::string raft_socket_address; + std::string coord_socket_address; + std::string cluster_role; bool is_alive; }; diff --git a/src/coordination/include/coordination/replication_instance.hpp b/src/coordination/include/coordination/replication_instance.hpp new file mode 100644 index 000000000..9d4765b47 --- /dev/null +++ b/src/coordination/include/coordination/replication_instance.hpp @@ -0,0 +1,84 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#ifdef MG_ENTERPRISE + +#include "coordination/coordinator_client.hpp" +#include "coordination/coordinator_cluster_config.hpp" +#include "coordination/coordinator_exceptions.hpp" +#include "replication_coordination_glue/role.hpp" + +#include <libnuraft/nuraft.hxx> +#include "utils/uuid.hpp" + +namespace memgraph::coordination { + +class CoordinatorData; + +class ReplicationInstance { + public: + ReplicationInstance(CoordinatorData *data, CoordinatorClientConfig config, HealthCheckCallback succ_cb, + HealthCheckCallback fail_cb); + + ReplicationInstance(ReplicationInstance const &other) = delete; + ReplicationInstance &operator=(ReplicationInstance const &other) = delete; + ReplicationInstance(ReplicationInstance &&other) noexcept = delete; + ReplicationInstance &operator=(ReplicationInstance &&other) noexcept = delete; + ~ReplicationInstance() = default; + + auto OnSuccessPing() -> void; + auto OnFailPing() -> bool; + + auto IsAlive() const -> bool; + + auto InstanceName() const -> std::string; + auto SocketAddress() const -> std::string; + + auto IsReplica() const -> bool; + auto IsMain() const -> bool; + + auto PromoteToMain(utils::UUID uuid, ReplicationClientsInfo repl_clients_info, HealthCheckCallback main_succ_cb, + HealthCheckCallback main_fail_cb) -> bool; + auto DemoteToReplica(HealthCheckCallback replica_succ_cb, HealthCheckCallback replica_fail_cb) -> bool; + + auto PauseFrequentCheck() -> void; + auto ResumeFrequentCheck() -> void; + + auto ReplicationClientInfo() const -> ReplClientInfo; + + auto SendSwapAndUpdateUUID(const utils::UUID &main_uuid) -> bool; + auto GetClient() -> CoordinatorClient &; + + void SetNewMainUUID(const std::optional<utils::UUID> &main_uuid = std::nullopt); + auto GetMainUUID() -> const std::optional<utils::UUID> &; + + private: + CoordinatorClient client_; + replication_coordination_glue::ReplicationRole replication_role_; + std::chrono::system_clock::time_point last_response_time_{}; + bool is_alive_{false}; + + // for replica this is main uuid of current main + // for "main" main this same as in CoordinatorData + // it is set to nullopt when replica is down + // TLDR; when replica is down and comes back up we reset uuid of main replica is listening to + // so we need to send swap uuid again + std::optional<utils::UUID> main_uuid_; + + friend bool operator==(ReplicationInstance const &first, ReplicationInstance const &second) { + return first.client_ == second.client_ && first.replication_role_ == second.replication_role_; + } +}; + +} // namespace memgraph::coordination +#endif diff --git a/src/coordination/include/nuraft/coordinator_log_store.hpp b/src/coordination/include/nuraft/coordinator_log_store.hpp new file mode 100644 index 000000000..ce1695d2f --- /dev/null +++ b/src/coordination/include/nuraft/coordinator_log_store.hpp @@ -0,0 +1,128 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#ifdef MG_ENTERPRISE + +#include <libnuraft/nuraft.hxx> + +namespace memgraph::coordination { + +using nuraft::buffer; +using nuraft::int32; +using nuraft::int64; +using nuraft::log_entry; +using nuraft::log_store; +using nuraft::ptr; +using nuraft::raft_server; + +class CoordinatorLogStore : public log_store { + public: + CoordinatorLogStore(); + CoordinatorLogStore(CoordinatorLogStore const &) = delete; + CoordinatorLogStore &operator=(CoordinatorLogStore const &) = delete; + CoordinatorLogStore(CoordinatorLogStore &&) = delete; + CoordinatorLogStore &operator=(CoordinatorLogStore &&) = delete; + ~CoordinatorLogStore() override; + + ulong next_slot() const override; + + ulong start_index() const override; + + ptr<log_entry> last_entry() const override; + + ulong append(ptr<log_entry> &entry) override; + + void write_at(ulong index, ptr<log_entry> &entry) override; + + ptr<std::vector<ptr<log_entry>>> log_entries(ulong start, ulong end) override; + + // NOLINTNEXTLINE + ptr<std::vector<ptr<log_entry>>> log_entries_ext(ulong start, ulong end, int64 batch_size_hint_in_bytes = 0) override; + + ptr<log_entry> entry_at(ulong index) override; + + ulong term_at(ulong index) override; + + ptr<buffer> pack(ulong index, int32 cnt) override; + + void apply_pack(ulong index, buffer &pack) override; + + bool compact(ulong last_log_index) override; + + bool flush() override; + + ulong last_durable_index() override; + + void Close(); + + void SetDiskDelay(raft_server *raft, size_t delay_ms); + + private: + static ptr<log_entry> MakeClone(ptr<log_entry> const &entry); + + void DiskEmulLoop(); + + /** + * Map of <log index, log data>. + */ + std::map<ulong, ptr<log_entry>> logs_; + + /** + * Lock for `logs_`. + */ + mutable std::mutex logs_lock_; + + /** + * The index of the first log. + */ + std::atomic<ulong> start_idx_; + + /** + * Backward pointer to Raft server. + */ + raft_server *raft_server_bwd_pointer_; + + // Testing purpose --------------- BEGIN + + /** + * If non-zero, this log store will emulate the disk write delay. + */ + std::atomic<size_t> disk_emul_delay; + + /** + * Map of <timestamp, log index>, emulating logs that is being written to disk. + * Log index will be regarded as "durable" after the corresponding timestamp. + */ + std::map<uint64_t, uint64_t> disk_emul_logs_being_written_; + + /** + * Thread that will update `last_durable_index_` and call + * `notify_log_append_completion` at proper time. + */ + std::unique_ptr<std::thread> disk_emul_thread_; + + /** + * Flag to terminate the thread. + */ + std::atomic<bool> disk_emul_thread_stop_signal_; + + /** + * Last written log index. + */ + std::atomic<uint64_t> disk_emul_last_durable_index_; + + // Testing purpose --------------- END +}; + +} // namespace memgraph::coordination +#endif diff --git a/src/coordination/include/nuraft/coordinator_state_machine.hpp b/src/coordination/include/nuraft/coordinator_state_machine.hpp new file mode 100644 index 000000000..fd7e92401 --- /dev/null +++ b/src/coordination/include/nuraft/coordinator_state_machine.hpp @@ -0,0 +1,72 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#ifdef MG_ENTERPRISE + +#include <spdlog/spdlog.h> +#include <libnuraft/nuraft.hxx> + +namespace memgraph::coordination { + +using nuraft::async_result; +using nuraft::buffer; +using nuraft::buffer_serializer; +using nuraft::cluster_config; +using nuraft::int32; +using nuraft::ptr; +using nuraft::snapshot; +using nuraft::state_machine; + +class CoordinatorStateMachine : public state_machine { + public: + CoordinatorStateMachine() = default; + CoordinatorStateMachine(CoordinatorStateMachine const &) = delete; + CoordinatorStateMachine &operator=(CoordinatorStateMachine const &) = delete; + CoordinatorStateMachine(CoordinatorStateMachine &&) = delete; + CoordinatorStateMachine &operator=(CoordinatorStateMachine &&) = delete; + ~CoordinatorStateMachine() override {} + + auto pre_commit(ulong log_idx, buffer &data) -> ptr<buffer> override; + + auto commit(ulong log_idx, buffer &data) -> ptr<buffer> override; + + auto commit_config(ulong log_idx, ptr<cluster_config> & /*new_conf*/) -> void override; + + auto rollback(ulong log_idx, buffer &data) -> void override; + + auto read_logical_snp_obj(snapshot & /*snapshot*/, void *& /*user_snp_ctx*/, ulong /*obj_id*/, ptr<buffer> &data_out, + bool &is_last_obj) -> int override; + + auto save_logical_snp_obj(snapshot &s, ulong &obj_id, buffer & /*data*/, bool /*is_first_obj*/, bool /*is_last_obj*/) + -> void override; + + auto apply_snapshot(snapshot &s) -> bool override; + + auto free_user_snp_ctx(void *&user_snp_ctx) -> void override; + + auto last_snapshot() -> ptr<snapshot> override; + + auto last_commit_index() -> ulong override; + + auto create_snapshot(snapshot &s, async_result<bool>::handler_type &when_done) -> void override; + + private: + std::atomic<uint64_t> last_committed_idx_{0}; + + ptr<snapshot> last_snapshot_; + + std::mutex last_snapshot_lock_; +}; + +} // namespace memgraph::coordination +#endif diff --git a/src/coordination/include/nuraft/coordinator_state_manager.hpp b/src/coordination/include/nuraft/coordinator_state_manager.hpp new file mode 100644 index 000000000..b6cb6599b --- /dev/null +++ b/src/coordination/include/nuraft/coordinator_state_manager.hpp @@ -0,0 +1,66 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#ifdef MG_ENTERPRISE + +#include "nuraft/coordinator_log_store.hpp" + +#include <spdlog/spdlog.h> +#include <libnuraft/nuraft.hxx> + +namespace memgraph::coordination { + +using nuraft::cluster_config; +using nuraft::cs_new; +using nuraft::srv_config; +using nuraft::srv_state; +using nuraft::state_mgr; + +class CoordinatorStateManager : public state_mgr { + public: + explicit CoordinatorStateManager(int srv_id, std::string const &endpoint); + + CoordinatorStateManager(CoordinatorStateManager const &) = delete; + CoordinatorStateManager &operator=(CoordinatorStateManager const &) = delete; + CoordinatorStateManager(CoordinatorStateManager &&) = delete; + CoordinatorStateManager &operator=(CoordinatorStateManager &&) = delete; + + ~CoordinatorStateManager() override = default; + + auto load_config() -> ptr<cluster_config> override; + + auto save_config(cluster_config const &config) -> void override; + + auto save_state(srv_state const &state) -> void override; + + auto read_state() -> ptr<srv_state> override; + + auto load_log_store() -> ptr<log_store> override; + + auto server_id() -> int32 override; + + auto system_exit(int exit_code) -> void override; + + auto GetSrvConfig() const -> ptr<srv_config>; + + private: + int my_id_; + std::string my_endpoint_; + ptr<CoordinatorLogStore> cur_log_store_; + ptr<srv_config> my_srv_config_; + ptr<cluster_config> cluster_config_; + ptr<srv_state> saved_state_; +}; + +} // namespace memgraph::coordination +#endif diff --git a/src/coordination/replication_instance.cpp b/src/coordination/replication_instance.cpp new file mode 100644 index 000000000..96a5c2a0e --- /dev/null +++ b/src/coordination/replication_instance.cpp @@ -0,0 +1,98 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#ifdef MG_ENTERPRISE + +#include "coordination/replication_instance.hpp" + +#include "replication_coordination_glue/handler.hpp" + +namespace memgraph::coordination { + +ReplicationInstance::ReplicationInstance(CoordinatorData *data, CoordinatorClientConfig config, + HealthCheckCallback succ_cb, HealthCheckCallback fail_cb) + : client_(data, std::move(config), std::move(succ_cb), std::move(fail_cb)), + replication_role_(replication_coordination_glue::ReplicationRole::REPLICA), + is_alive_(true) { + if (!client_.DemoteToReplica()) { + throw CoordinatorRegisterInstanceException("Failed to demote instance {} to replica", client_.InstanceName()); + } + client_.StartFrequentCheck(); +} + +auto ReplicationInstance::OnSuccessPing() -> void { + last_response_time_ = std::chrono::system_clock::now(); + is_alive_ = true; +} + +auto ReplicationInstance::OnFailPing() -> bool { + is_alive_ = + std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - last_response_time_).count() < + CoordinatorClusterConfig::alive_response_time_difference_sec_; + return is_alive_; +} + +auto ReplicationInstance::InstanceName() const -> std::string { return client_.InstanceName(); } +auto ReplicationInstance::SocketAddress() const -> std::string { return client_.SocketAddress(); } +auto ReplicationInstance::IsAlive() const -> bool { return is_alive_; } + +auto ReplicationInstance::IsReplica() const -> bool { + return replication_role_ == replication_coordination_glue::ReplicationRole::REPLICA; +} +auto ReplicationInstance::IsMain() const -> bool { + return replication_role_ == replication_coordination_glue::ReplicationRole::MAIN; +} + +auto ReplicationInstance::PromoteToMain(utils::UUID uuid, ReplicationClientsInfo repl_clients_info, + HealthCheckCallback main_succ_cb, HealthCheckCallback main_fail_cb) -> bool { + if (!client_.SendPromoteReplicaToMainRpc(uuid, std::move(repl_clients_info))) { + return false; + } + + replication_role_ = replication_coordination_glue::ReplicationRole::MAIN; + client_.SetCallbacks(std::move(main_succ_cb), std::move(main_fail_cb)); + + return true; +} + +auto ReplicationInstance::DemoteToReplica(HealthCheckCallback replica_succ_cb, HealthCheckCallback replica_fail_cb) + -> bool { + if (!client_.DemoteToReplica()) { + return false; + } + + replication_role_ = replication_coordination_glue::ReplicationRole::REPLICA; + client_.SetCallbacks(std::move(replica_succ_cb), std::move(replica_fail_cb)); + + return true; +} + +auto ReplicationInstance::PauseFrequentCheck() -> void { client_.PauseFrequentCheck(); } +auto ReplicationInstance::ResumeFrequentCheck() -> void { client_.ResumeFrequentCheck(); } + +auto ReplicationInstance::ReplicationClientInfo() const -> CoordinatorClientConfig::ReplicationClientInfo { + return client_.ReplicationClientInfo(); +} + +auto ReplicationInstance::GetClient() -> CoordinatorClient & { return client_; } +void ReplicationInstance::SetNewMainUUID(const std::optional<utils::UUID> &main_uuid) { main_uuid_ = main_uuid; } +auto ReplicationInstance::GetMainUUID() -> const std::optional<utils::UUID> & { return main_uuid_; } + +auto ReplicationInstance::SendSwapAndUpdateUUID(const utils::UUID &main_uuid) -> bool { + if (!replication_coordination_glue::SendSwapMainUUIDRpc(client_.RpcClient(), main_uuid)) { + return false; + } + SetNewMainUUID(main_uuid_); + return true; +} + +} // namespace memgraph::coordination +#endif diff --git a/src/dbms/coordinator_handler.cpp b/src/dbms/coordinator_handler.cpp index 958de0f91..d1310dee2 100644 --- a/src/dbms/coordinator_handler.cpp +++ b/src/dbms/coordinator_handler.cpp @@ -9,12 +9,11 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include "coordination/register_main_replica_coordinator_status.hpp" #ifdef MG_ENTERPRISE #include "dbms/coordinator_handler.hpp" -#include "dbms/dbms_handler.hpp" +#include "coordination/register_main_replica_coordinator_status.hpp" namespace memgraph::dbms { @@ -31,9 +30,15 @@ auto CoordinatorHandler::SetInstanceToMain(std::string instance_name) return coordinator_state_.SetInstanceToMain(std::move(instance_name)); } -auto CoordinatorHandler::ShowInstances() const -> std::vector<coordination::CoordinatorInstanceStatus> { +auto CoordinatorHandler::ShowInstances() const -> std::vector<coordination::InstanceStatus> { return coordinator_state_.ShowInstances(); } + +auto CoordinatorHandler::AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) + -> void { + coordinator_state_.AddCoordinatorInstance(raft_server_id, raft_port, std::move(raft_address)); +} + } // namespace memgraph::dbms #endif diff --git a/src/dbms/coordinator_handler.hpp b/src/dbms/coordinator_handler.hpp index 04cfe8032..a2a1f19dc 100644 --- a/src/dbms/coordinator_handler.hpp +++ b/src/dbms/coordinator_handler.hpp @@ -14,8 +14,8 @@ #ifdef MG_ENTERPRISE #include "coordination/coordinator_config.hpp" -#include "coordination/coordinator_instance_status.hpp" #include "coordination/coordinator_state.hpp" +#include "coordination/instance_status.hpp" #include "coordination/register_main_replica_coordinator_status.hpp" #include <vector> @@ -33,7 +33,9 @@ class CoordinatorHandler { auto SetInstanceToMain(std::string instance_name) -> coordination::SetInstanceToMainCoordinatorStatus; - auto ShowInstances() const -> std::vector<coordination::CoordinatorInstanceStatus>; + auto ShowInstances() const -> std::vector<coordination::InstanceStatus>; + + auto AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) -> void; private: coordination::CoordinatorState &coordinator_state_; diff --git a/src/flags/replication.cpp b/src/flags/replication.cpp index 3cd5187f3..29c7bfbda 100644 --- a/src/flags/replication.cpp +++ b/src/flags/replication.cpp @@ -13,9 +13,11 @@ #ifdef MG_ENTERPRISE // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -DEFINE_bool(coordinator, false, "Controls whether the instance is a replication coordinator."); -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_uint32(coordinator_server_port, 0, "Port on which coordinator servers will be started."); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_uint32(raft_server_port, 0, "Port on which raft servers will be started."); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_uint32(raft_server_id, 0, "Unique ID of the raft server."); #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/src/flags/replication.hpp b/src/flags/replication.hpp index 16f4c74d2..025079271 100644 --- a/src/flags/replication.hpp +++ b/src/flags/replication.hpp @@ -15,9 +15,11 @@ #ifdef MG_ENTERPRISE // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -DECLARE_bool(coordinator); -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DECLARE_uint32(coordinator_server_port); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_uint32(raft_server_port); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_uint32(raft_server_id); #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/src/io/network/endpoint.cpp b/src/io/network/endpoint.cpp index e9032e42e..44123db6b 100644 --- a/src/io/network/endpoint.cpp +++ b/src/io/network/endpoint.cpp @@ -39,7 +39,7 @@ Endpoint::IpFamily Endpoint::GetIpFamily(const std::string &address) { } std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseSocketOrIpAddress( - const std::string &address, const std::optional<uint16_t> default_port = {}) { + const std::string &address, const std::optional<uint16_t> default_port) { /// expected address format: /// - "ip_address:port_number" /// - "ip_address" diff --git a/src/io/network/endpoint.hpp b/src/io/network/endpoint.hpp index 281be2162..16d70e080 100644 --- a/src/io/network/endpoint.hpp +++ b/src/io/network/endpoint.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -61,8 +61,8 @@ struct Endpoint { * it into an ip address and a port number; even if a default port is given, * it won't be used, as we expect that it is given in the address string. */ - static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrIpAddress(const std::string &address, - std::optional<uint16_t> default_port); + static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrIpAddress( + const std::string &address, std::optional<uint16_t> default_port = {}); /** * Tries to parse given string as either socket address or hostname. diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 0cbb790d0..6fe6b8c9e 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -3071,11 +3071,7 @@ class CoordinatorQuery : public memgraph::query::Query { static const utils::TypeInfo kType; const utils::TypeInfo &GetTypeInfo() const override { return kType; } - enum class Action { - REGISTER_INSTANCE, - SET_INSTANCE_TO_MAIN, - SHOW_REPLICATION_CLUSTER, - }; + enum class Action { REGISTER_INSTANCE, SET_INSTANCE_TO_MAIN, SHOW_INSTANCES, ADD_COORDINATOR_INSTANCE }; enum class SyncMode { SYNC, ASYNC }; @@ -3087,6 +3083,8 @@ class CoordinatorQuery : public memgraph::query::Query { std::string instance_name_; memgraph::query::Expression *replication_socket_address_{nullptr}; memgraph::query::Expression *coordinator_socket_address_{nullptr}; + memgraph::query::Expression *raft_socket_address_{nullptr}; + memgraph::query::Expression *raft_server_id_{nullptr}; memgraph::query::CoordinatorQuery::SyncMode sync_mode_; CoordinatorQuery *Clone(AstStorage *storage) const override { @@ -3098,6 +3096,8 @@ class CoordinatorQuery : public memgraph::query::Query { object->sync_mode_ = sync_mode_; object->coordinator_socket_address_ = coordinator_socket_address_ ? coordinator_socket_address_->Clone(storage) : nullptr; + object->raft_socket_address_ = raft_socket_address_ ? raft_socket_address_->Clone(storage) : nullptr; + object->raft_server_id_ = raft_server_id_ ? raft_server_id_->Clone(storage) : nullptr; return object; } diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 5735326ac..1de5e55ff 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -374,7 +374,6 @@ antlrcpp::Any CypherMainVisitor::visitRegisterReplica(MemgraphCypher::RegisterRe return replication_query; } -// License check is done in the interpreter. antlrcpp::Any CypherMainVisitor::visitRegisterInstanceOnCoordinator( MemgraphCypher::RegisterInstanceOnCoordinatorContext *ctx) { auto *coordinator_query = storage_->Create<CoordinatorQuery>(); @@ -400,10 +399,28 @@ antlrcpp::Any CypherMainVisitor::visitRegisterInstanceOnCoordinator( return coordinator_query; } -// License check is done in the interpreter -antlrcpp::Any CypherMainVisitor::visitShowReplicationCluster(MemgraphCypher::ShowReplicationClusterContext * /*ctx*/) { +antlrcpp::Any CypherMainVisitor::visitAddCoordinatorInstance(MemgraphCypher::AddCoordinatorInstanceContext *ctx) { auto *coordinator_query = storage_->Create<CoordinatorQuery>(); - coordinator_query->action_ = CoordinatorQuery::Action::SHOW_REPLICATION_CLUSTER; + + if (!ctx->raftSocketAddress()->literal()->StringLiteral()) { + throw SemanticException("Raft socket address should be a string literal!"); + } + + if (!ctx->raftServerId()->literal()->numberLiteral()) { + throw SemanticException("Raft server id should be a number literal!"); + } + + coordinator_query->action_ = CoordinatorQuery::Action::ADD_COORDINATOR_INSTANCE; + coordinator_query->raft_socket_address_ = std::any_cast<Expression *>(ctx->raftSocketAddress()->accept(this)); + coordinator_query->raft_server_id_ = std::any_cast<Expression *>(ctx->raftServerId()->accept(this)); + + return coordinator_query; +} + +// License check is done in the interpreter +antlrcpp::Any CypherMainVisitor::visitShowInstances(MemgraphCypher::ShowInstancesContext * /*ctx*/) { + auto *coordinator_query = storage_->Create<CoordinatorQuery>(); + coordinator_query->action_ = CoordinatorQuery::Action::SHOW_INSTANCES; return coordinator_query; } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index e9da98f71..9007ec60a 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -251,7 +251,12 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { /** * @return CoordinatorQuery* */ - antlrcpp::Any visitShowReplicationCluster(MemgraphCypher::ShowReplicationClusterContext *ctx) override; + antlrcpp::Any visitAddCoordinatorInstance(MemgraphCypher::AddCoordinatorInstanceContext *ctx) override; + + /** + * @return CoordinatorQuery* + */ + antlrcpp::Any visitShowInstances(MemgraphCypher::ShowInstancesContext *ctx) override; /** * @return LockPathQuery* diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index a99eda3e9..0597967c7 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -20,6 +20,7 @@ options { tokenVocab=MemgraphCypherLexer; } import Cypher ; memgraphCypherKeyword : cypherKeyword + | ADD | ACTIVE | AFTER | ALTER @@ -64,6 +65,7 @@ memgraphCypherKeyword : cypherKeyword | HEADER | IDENTIFIED | INSTANCE + | INSTANCES | NODE_LABELS | NULLIF | IMPORT @@ -189,7 +191,8 @@ replicationQuery : setReplicationRole coordinatorQuery : registerInstanceOnCoordinator | setInstanceToMain - | showReplicationCluster + | showInstances + | addCoordinatorInstance ; triggerQuery : createTrigger @@ -374,7 +377,7 @@ setReplicationRole : SET REPLICATION ROLE TO ( MAIN | REPLICA ) showReplicationRole : SHOW REPLICATION ROLE ; -showReplicationCluster : SHOW REPLICATION CLUSTER ; +showInstances : SHOW INSTANCES ; instanceName : symbolicName ; @@ -382,6 +385,7 @@ socketAddress : literal ; coordinatorSocketAddress : literal ; replicationSocketAddress : literal ; +raftSocketAddress : literal ; registerReplica : REGISTER REPLICA instanceName ( SYNC | ASYNC ) TO socketAddress ; @@ -390,6 +394,10 @@ registerInstanceOnCoordinator : REGISTER INSTANCE instanceName ON coordinatorSoc setInstanceToMain : SET INSTANCE instanceName TO MAIN ; +raftServerId : literal ; + +addCoordinatorInstance : ADD COORDINATOR raftServerId ON raftSocketAddress ; + dropReplica : DROP REPLICA instanceName ; showReplicas : SHOW REPLICAS ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index b0febc4af..b2d4de661 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -23,6 +23,7 @@ lexer grammar MemgraphCypherLexer ; import CypherLexer ; +ADD : A D D ; ACTIVE : A C T I V E ; AFTER : A F T E R ; ALTER : A L T E R ; @@ -39,7 +40,6 @@ BOOTSTRAP_SERVERS : B O O T S T R A P UNDERSCORE S E R V E R S ; CALL : C A L L ; CHECK : C H E C K ; CLEAR : C L E A R ; -CLUSTER : C L U S T E R ; COMMIT : C O M M I T ; COMMITTED : C O M M I T T E D ; CONFIG : C O N F I G ; @@ -80,6 +80,7 @@ INACTIVE : I N A C T I V E ; IN_MEMORY_ANALYTICAL : I N UNDERSCORE M E M O R Y UNDERSCORE A N A L Y T I C A L ; IN_MEMORY_TRANSACTIONAL : I N UNDERSCORE M E M O R Y UNDERSCORE T R A N S A C T I O N A L ; INSTANCE : I N S T A N C E ; +INSTANCES : I N S T A N C E S ; ISOLATION : I S O L A T I O N ; KAFKA : K A F K A ; LABELS : L A B E L S ; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index bd6ab7971..17583153b 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -219,7 +219,8 @@ const trie::Trie kKeywords = {"union", "lock", "unlock", "build", - "instance"}; + "instance", + "coordinator"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts( diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index ea175a18e..2ddb8dd2a 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -508,6 +508,17 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { } } + auto AddCoordinatorInstance(uint32_t raft_server_id, std::string const &raft_socket_address) -> void override { + auto const maybe_ip_and_port = io::network::Endpoint::ParseSocketOrIpAddress(raft_socket_address); + if (maybe_ip_and_port) { + auto const [ip, port] = *maybe_ip_and_port; + spdlog::info("Adding instance {} with raft socket address {}:{}.", raft_server_id, port, ip); + coordinator_handler_.AddCoordinatorInstance(raft_server_id, port, ip); + } else { + spdlog::error("Invalid raft socket address {}.", raft_socket_address); + } + } + void SetInstanceToMain(const std::string &instance_name) override { auto status = coordinator_handler_.SetInstanceToMain(instance_name); switch (status) { @@ -526,7 +537,7 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { } } - std::vector<coordination::CoordinatorInstanceStatus> ShowInstances() const override { + std::vector<coordination::InstanceStatus> ShowInstances() const override { return coordinator_handler_.ShowInstances(); } @@ -930,7 +941,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & switch (repl_query->action_) { case ReplicationQuery::Action::SET_REPLICATION_ROLE: { #ifdef MG_ENTERPRISE - if (FLAGS_coordinator) { + if (FLAGS_raft_server_id) { throw QueryRuntimeException("Coordinator can't set roles!"); } if (FLAGS_coordinator_server_port) { @@ -960,7 +971,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & } case ReplicationQuery::Action::SHOW_REPLICATION_ROLE: { #ifdef MG_ENTERPRISE - if (FLAGS_coordinator) { + if (FLAGS_raft_server_id) { throw QueryRuntimeException("Coordinator doesn't have a replication role!"); } #endif @@ -1017,8 +1028,8 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & } case ReplicationQuery::Action::SHOW_REPLICAS: { #ifdef MG_ENTERPRISE - if (FLAGS_coordinator) { - throw QueryRuntimeException("Coordinator cannot call SHOW REPLICAS! Use SHOW REPLICATION CLUSTER instead."); + if (FLAGS_raft_server_id) { + throw QueryRuntimeException("Coordinator cannot call SHOW REPLICAS! Use SHOW INSTANCES instead."); } #endif @@ -1079,6 +1090,37 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param const query::InterpreterConfig &config, std::vector<Notification> *notifications) { Callback callback; switch (coordinator_query->action_) { + case CoordinatorQuery::Action::ADD_COORDINATOR_INSTANCE: { + if (!license::global_license_checker.IsEnterpriseValidFast()) { + throw QueryException("Trying to use enterprise feature without a valid license."); + } + if constexpr (!coordination::allow_ha) { + throw QueryRuntimeException( + "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " + "be able to use this functionality."); + } + if (!FLAGS_raft_server_id) { + throw QueryRuntimeException("Only coordinator can add coordinator instance!"); + } + + // TODO: MemoryResource for EvaluationContext, it should probably be passed as + // the argument to Callback. + EvaluationContext evaluation_context{.timestamp = QueryTimestamp(), .parameters = parameters}; + auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; + + auto raft_socket_address_tv = coordinator_query->raft_socket_address_->Accept(evaluator); + auto raft_server_id_tv = coordinator_query->raft_server_id_->Accept(evaluator); + callback.fn = [handler = CoordQueryHandler{*coordinator_state}, raft_socket_address_tv, + raft_server_id_tv]() mutable { + handler.AddCoordinatorInstance(raft_server_id_tv.ValueInt(), std::string(raft_socket_address_tv.ValueString())); + return std::vector<std::vector<TypedValue>>(); + }; + + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::ADD_COORDINATOR_INSTANCE, + fmt::format("Coordinator has added instance {} on coordinator server {}.", + coordinator_query->instance_name_, raft_socket_address_tv.ValueString())); + return callback; + } case CoordinatorQuery::Action::REGISTER_INSTANCE: { if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); @@ -1089,7 +1131,7 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " "be able to use this functionality."); } - if (!FLAGS_coordinator) { + if (!FLAGS_raft_server_id) { throw QueryRuntimeException("Only coordinator can register coordinator server!"); } // TODO: MemoryResource for EvaluationContext, it should probably be passed as @@ -1124,7 +1166,7 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " "be able to use this functionality."); } - if (!FLAGS_coordinator) { + if (!FLAGS_raft_server_id) { throw QueryRuntimeException("Only coordinator can register coordinator server!"); } // TODO: MemoryResource for EvaluationContext, it should probably be passed as @@ -1140,7 +1182,7 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param return callback; } - case CoordinatorQuery::Action::SHOW_REPLICATION_CLUSTER: { + case CoordinatorQuery::Action::SHOW_INSTANCES: { if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); } @@ -1149,11 +1191,11 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " "be able to use this functionality."); } - if (!FLAGS_coordinator) { - throw QueryRuntimeException("Only coordinator can run SHOW REPLICATION CLUSTER."); + if (!FLAGS_raft_server_id) { + throw QueryRuntimeException("Only coordinator can run SHOW INSTANCES."); } - callback.header = {"name", "socket_address", "alive", "role"}; + callback.header = {"name", "raft_socket_address", "coordinator_socket_address", "alive", "role"}; callback.fn = [handler = CoordQueryHandler{*coordinator_state}, replica_nfields = callback.header.size()]() mutable { auto const instances = handler.ShowInstances(); @@ -1162,15 +1204,15 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param std::ranges::transform(instances, std::back_inserter(result), [](const auto &status) -> std::vector<TypedValue> { - return {TypedValue{status.instance_name}, TypedValue{status.socket_address}, - TypedValue{status.is_alive}, TypedValue{status.replication_role}}; + return {TypedValue{status.instance_name}, TypedValue{status.raft_socket_address}, + TypedValue{status.coord_socket_address}, TypedValue{status.is_alive}, + TypedValue{status.cluster_role}}; }); return result; }; return callback; } - return callback; } } #endif @@ -4175,7 +4217,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } #ifdef MG_ENTERPRISE - if (FLAGS_coordinator && !utils::Downcast<CoordinatorQuery>(parsed_query.query) && + if (FLAGS_raft_server_id && !utils::Downcast<CoordinatorQuery>(parsed_query.query) && !utils::Downcast<SettingQuery>(parsed_query.query)) { throw QueryRuntimeException("Coordinator can run only coordinator queries!"); } diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index cf822d8b9..698c639fa 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -53,7 +53,7 @@ #include "utils/tsc.hpp" #ifdef MG_ENTERPRISE -#include "coordination/coordinator_instance_status.hpp" +#include "coordination/instance_status.hpp" #endif namespace memgraph::metrics { @@ -114,7 +114,11 @@ class CoordinatorQueryHandler { virtual void SetInstanceToMain(const std::string &instance_name) = 0; /// @throw QueryRuntimeException if an error ocurred. - virtual std::vector<coordination::CoordinatorInstanceStatus> ShowInstances() const = 0; + virtual std::vector<coordination::InstanceStatus> ShowInstances() const = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual auto AddCoordinatorInstance(uint32_t raft_server_id, std::string const &coordinator_socket_address) + -> void = 0; }; #endif diff --git a/src/query/metadata.cpp b/src/query/metadata.cpp index 56ef57431..59d65e077 100644 --- a/src/query/metadata.cpp +++ b/src/query/metadata.cpp @@ -69,6 +69,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { #ifdef MG_ENTERPRISE case NotificationCode::REGISTER_COORDINATOR_SERVER: return "RegisterCoordinatorServer"sv; + case NotificationCode::ADD_COORDINATOR_INSTANCE: + return "AddCoordinatorInstance"sv; #endif case NotificationCode::REPLICA_PORT_WARNING: return "ReplicaPortWarning"sv; diff --git a/src/query/metadata.hpp b/src/query/metadata.hpp index 8e82ad1e3..2f357a555 100644 --- a/src/query/metadata.hpp +++ b/src/query/metadata.hpp @@ -44,6 +44,7 @@ enum class NotificationCode : uint8_t { REGISTER_REPLICA, #ifdef MG_ENTERPRISE REGISTER_COORDINATOR_SERVER, + ADD_COORDINATOR_INSTANCE, #endif SET_REPLICA, START_STREAM, diff --git a/tests/e2e/configuration/default_config.py b/tests/e2e/configuration/default_config.py index 915a14d14..e0cdc082c 100644 --- a/tests/e2e/configuration/default_config.py +++ b/tests/e2e/configuration/default_config.py @@ -66,8 +66,9 @@ startup_config_dict = { "Time in seconds after which inactive Bolt sessions will be closed.", ), "cartesian_product_enabled": ("true", "true", "Enable cartesian product expansion."), - "coordinator": ("false", "false", "Controls whether the instance is a replication coordinator."), "coordinator_server_port": ("0", "0", "Port on which coordinator servers will be started."), + "raft_server_port": ("0", "0", "Port on which raft servers will be started."), + "raft_server_id": ("0", "0", "Unique ID of the raft server."), "data_directory": ("mg_data", "mg_data", "Path to directory in which to save all permanent data."), "data_recovery_on_startup": ( "false", diff --git a/tests/e2e/high_availability_experimental/CMakeLists.txt b/tests/e2e/high_availability_experimental/CMakeLists.txt index f22e24f43..424ebd08f 100644 --- a/tests/e2e/high_availability_experimental/CMakeLists.txt +++ b/tests/e2e/high_availability_experimental/CMakeLists.txt @@ -2,6 +2,7 @@ find_package(gflags REQUIRED) copy_e2e_python_files(ha_experimental coordinator.py) copy_e2e_python_files(ha_experimental automatic_failover.py) +copy_e2e_python_files(ha_experimental distributed_coordinators.py) copy_e2e_python_files(ha_experimental manual_setting_replicas.py) copy_e2e_python_files(ha_experimental not_replicate_from_old_main.py) copy_e2e_python_files(ha_experimental common.py) diff --git a/tests/e2e/high_availability_experimental/automatic_failover.py b/tests/e2e/high_availability_experimental/automatic_failover.py index 23b462f45..1148075a1 100644 --- a/tests/e2e/high_availability_experimental/automatic_failover.py +++ b/tests/e2e/high_availability_experimental/automatic_failover.py @@ -13,7 +13,6 @@ import os import shutil import sys import tempfile -import time import interactive_mg_runner import pytest @@ -70,7 +69,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { "setup_queries": [], }, "coordinator": { - "args": ["--bolt-port", "7690", "--log-level=TRACE", "--coordinator"], + "args": ["--bolt-port", "7690", "--log-level=TRACE", "--raft-server-id=1", "--raft-server-port=10111"], "log_file": "coordinator.log", "setup_queries": [ "REGISTER INSTANCE instance_1 ON '127.0.0.1:10011' WITH '127.0.0.1:10001';", @@ -111,12 +110,13 @@ def test_replication_works_on_failover(): coord_cursor = connect(host="localhost", port=7690).cursor() def retrieve_data_show_repl_cluster(): - return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW REPLICATION CLUSTER;"))) + return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW INSTANCES;"))) expected_data_on_coord = [ - ("instance_1", "127.0.0.1:10011", True, "main"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", False, "unknown"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "main"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", False, "unknown"), ] mg_sleep_and_assert(expected_data_on_coord, retrieve_data_show_repl_cluster) @@ -132,7 +132,6 @@ def test_replication_works_on_failover(): mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") - expected_data_on_new_main = [ ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), ("instance_3", "127.0.0.1:10003", "sync", 0, 0, "ready"), @@ -143,13 +142,13 @@ def test_replication_works_on_failover(): execute_and_fetch_all(new_main_cursor, "CREATE ();") # 6 - alive_replica_cursor = connect(host="localhost", port=7689).cursor() - res = execute_and_fetch_all(alive_replica_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + alive_replica_cursror = connect(host="localhost", port=7689).cursor() + res = execute_and_fetch_all(alive_replica_cursror, "MATCH (n) RETURN count(n) as count;")[0][0] assert res == 1, "Vertex should be replicated" interactive_mg_runner.stop_all(MEMGRAPH_INSTANCES_DESCRIPTION) -def test_show_replication_cluster(): +def test_show_instances(): safe_execute(shutil.rmtree, TEMP_DIR) interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) @@ -159,12 +158,13 @@ def test_show_replication_cluster(): coord_cursor = connect(host="localhost", port=7690).cursor() def show_repl_cluster(): - return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW REPLICATION CLUSTER;"))) + return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW INSTANCES;"))) expected_data = [ - ("instance_1", "127.0.0.1:10011", True, "replica"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", True, "main"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "replica"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), ] mg_sleep_and_assert(expected_data, show_repl_cluster) @@ -184,18 +184,20 @@ def test_show_replication_cluster(): interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") expected_data = [ - ("instance_1", "127.0.0.1:10011", False, "unknown"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", True, "main"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", False, "unknown"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), ] mg_sleep_and_assert(expected_data, show_repl_cluster) interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_2") expected_data = [ - ("instance_1", "127.0.0.1:10011", False, "unknown"), - ("instance_2", "127.0.0.1:10012", False, "unknown"), - ("instance_3", "127.0.0.1:10013", True, "main"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", False, "unknown"), + ("instance_2", "", "127.0.0.1:10012", False, "unknown"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), ] mg_sleep_and_assert(expected_data, show_repl_cluster) @@ -217,12 +219,13 @@ def test_simple_automatic_failover(): coord_cursor = connect(host="localhost", port=7690).cursor() def retrieve_data_show_repl_cluster(): - return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW REPLICATION CLUSTER;"))) + return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW INSTANCES;"))) expected_data_on_coord = [ - ("instance_1", "127.0.0.1:10011", True, "main"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", False, "unknown"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "main"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", False, "unknown"), ] mg_sleep_and_assert(expected_data_on_coord, retrieve_data_show_repl_cluster) @@ -280,21 +283,23 @@ def test_replica_instance_restarts(): cursor = connect(host="localhost", port=7690).cursor() def show_repl_cluster(): - return sorted(list(execute_and_fetch_all(cursor, "SHOW REPLICATION CLUSTER;"))) + return sorted(list(execute_and_fetch_all(cursor, "SHOW INSTANCES;"))) expected_data_up = [ - ("instance_1", "127.0.0.1:10011", True, "replica"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", True, "main"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "replica"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), ] mg_sleep_and_assert(expected_data_up, show_repl_cluster) interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") expected_data_down = [ - ("instance_1", "127.0.0.1:10011", False, "unknown"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", True, "main"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", False, "unknown"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), ] mg_sleep_and_assert(expected_data_down, show_repl_cluster) @@ -320,19 +325,21 @@ def test_automatic_failover_main_back_as_replica(): coord_cursor = connect(host="localhost", port=7690).cursor() def retrieve_data_show_repl_cluster(): - return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW REPLICATION CLUSTER;"))) + return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW INSTANCES;"))) expected_data_after_failover = [ - ("instance_1", "127.0.0.1:10011", True, "main"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", False, "unknown"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "main"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", False, "unknown"), ] mg_sleep_and_assert(expected_data_after_failover, retrieve_data_show_repl_cluster) expected_data_after_main_coming_back = [ - ("instance_1", "127.0.0.1:10011", True, "main"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", True, "replica"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "main"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "replica"), ] interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") @@ -346,60 +353,68 @@ def test_automatic_failover_main_back_as_replica(): mg_sleep_and_assert([("replica",)], retrieve_data_show_repl_role_instance3) -def test_replica_instance_restarts_replication_works(): +def test_automatic_failover_main_back_as_main(): safe_execute(shutil.rmtree, TEMP_DIR) interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) - cursor = connect(host="localhost", port=7690).cursor() - - def show_repl_cluster(): - return sorted(list(execute_and_fetch_all(cursor, "SHOW REPLICATION CLUSTER;"))) - - expected_data_up = [ - ("instance_1", "127.0.0.1:10011", True, "replica"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", True, "main"), - ] - mg_sleep_and_assert(expected_data_up, show_repl_cluster) - interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_2") + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") - expected_data_down = [ - ("instance_1", "127.0.0.1:10011", False, "unknown"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", True, "main"), + coord_cursor = connect(host="localhost", port=7690).cursor() + + def retrieve_data_show_repl_cluster(): + return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW INSTANCES;"))) + + expected_data_all_down = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", False, "unknown"), + ("instance_2", "", "127.0.0.1:10012", False, "unknown"), + ("instance_3", "", "127.0.0.1:10013", False, "unknown"), ] - mg_sleep_and_assert(expected_data_down, show_repl_cluster) + + mg_sleep_and_assert(expected_data_all_down, retrieve_data_show_repl_cluster) + + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") + expected_data_main_back = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", False, "unknown"), + ("instance_2", "", "127.0.0.1:10012", False, "unknown"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + mg_sleep_and_assert(expected_data_main_back, retrieve_data_show_repl_cluster) + + instance3_cursor = connect(host="localhost", port=7687).cursor() + + def retrieve_data_show_repl_role_instance3(): + return sorted(list(execute_and_fetch_all(instance3_cursor, "SHOW REPLICATION ROLE;"))) + + mg_sleep_and_assert([("main",)], retrieve_data_show_repl_role_instance3) interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_2") - mg_sleep_and_assert(expected_data_up, show_repl_cluster) - - expected_data_on_main_show_replicas = [ - ("instance_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), + expected_data_replicas_back = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "replica"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), ] - instance3_cursor = connect(host="localhost", port=7687).cursor() + + mg_sleep_and_assert(expected_data_replicas_back, retrieve_data_show_repl_cluster) + instance1_cursor = connect(host="localhost", port=7688).cursor() - - def retrieve_data_show_repl_role_instance1(): - return sorted(list(execute_and_fetch_all(instance3_cursor, "SHOW REPLICAS;"))) - - mg_sleep_and_assert(expected_data_on_main_show_replicas, retrieve_data_show_repl_role_instance1) + instance2_cursor = connect(host="localhost", port=7689).cursor() def retrieve_data_show_repl_role_instance1(): return sorted(list(execute_and_fetch_all(instance1_cursor, "SHOW REPLICATION ROLE;"))) - expected_data_replica = [("replica",)] - mg_sleep_and_assert(expected_data_replica, retrieve_data_show_repl_role_instance1) + def retrieve_data_show_repl_role_instance2(): + return sorted(list(execute_and_fetch_all(instance2_cursor, "SHOW REPLICATION ROLE;"))) - execute_and_fetch_all(instance3_cursor, "CREATE ();") - - def retrieve_data_replica(): - return execute_and_fetch_all(instance1_cursor, "MATCH (n) RETURN count(n);")[0][0] - - expected_data_replica = 1 - mg_sleep_and_assert(expected_data_replica, retrieve_data_replica) + mg_sleep_and_assert([("replica",)], retrieve_data_show_repl_role_instance1) + mg_sleep_and_assert([("replica",)], retrieve_data_show_repl_role_instance2) + mg_sleep_and_assert([("main",)], retrieve_data_show_repl_role_instance3) if __name__ == "__main__": diff --git a/tests/e2e/high_availability_experimental/coordinator.py b/tests/e2e/high_availability_experimental/coordinator.py index 9e34a4167..4330c2194 100644 --- a/tests/e2e/high_availability_experimental/coordinator.py +++ b/tests/e2e/high_availability_experimental/coordinator.py @@ -37,16 +37,17 @@ def test_coordinator_cannot_run_show_repl_role(): assert str(e.value) == "Coordinator can run only coordinator queries!" -def test_coordinator_show_replication_cluster(): +def test_coordinator_show_instances(): cursor = connect(host="localhost", port=7690).cursor() def retrieve_data(): - return sorted(list(execute_and_fetch_all(cursor, "SHOW REPLICATION CLUSTER;"))) + return sorted(list(execute_and_fetch_all(cursor, "SHOW INSTANCES;"))) expected_data = [ - ("instance_1", "127.0.0.1:10011", True, "replica"), - ("instance_2", "127.0.0.1:10012", True, "replica"), - ("instance_3", "127.0.0.1:10013", True, "main"), + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "replica"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), ] mg_sleep_and_assert(expected_data, retrieve_data) @@ -65,8 +66,8 @@ def test_coordinator_cannot_call_show_replicas(): def test_main_and_replicas_cannot_call_show_repl_cluster(port): cursor = connect(host="localhost", port=port).cursor() with pytest.raises(Exception) as e: - execute_and_fetch_all(cursor, "SHOW REPLICATION CLUSTER;") - assert str(e.value) == "Only coordinator can run SHOW REPLICATION CLUSTER." + execute_and_fetch_all(cursor, "SHOW INSTANCES;") + assert str(e.value) == "Only coordinator can run SHOW INSTANCES." @pytest.mark.parametrize( diff --git a/tests/e2e/high_availability_experimental/distributed_coordinators.py b/tests/e2e/high_availability_experimental/distributed_coordinators.py new file mode 100644 index 000000000..8a9ebf3c2 --- /dev/null +++ b/tests/e2e/high_availability_experimental/distributed_coordinators.py @@ -0,0 +1,145 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import os +import shutil +import sys +import tempfile + +import interactive_mg_runner +import pytest +from common import connect, execute_and_fetch_all, safe_execute +from mg_utils import mg_sleep_and_assert + +interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +interactive_mg_runner.PROJECT_DIR = os.path.normpath( + os.path.join(interactive_mg_runner.SCRIPT_DIR, "..", "..", "..", "..") +) +interactive_mg_runner.BUILD_DIR = os.path.normpath(os.path.join(interactive_mg_runner.PROJECT_DIR, "build")) +interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactive_mg_runner.BUILD_DIR, "memgraph")) + +TEMP_DIR = tempfile.TemporaryDirectory().name + +MEMGRAPH_INSTANCES_DESCRIPTION = { + "coordinator1": { + "args": [ + "--bolt-port", + "7687", + "--log-level=TRACE", + "--raft-server-id=1", + "--raft-server-port=10111", + ], + "log_file": "coordinator1.log", + "setup_queries": [], + }, + "coordinator2": { + "args": [ + "--bolt-port", + "7688", + "--log-level=TRACE", + "--raft-server-id=2", + "--raft-server-port=10112", + ], + "log_file": "coordinator2.log", + "setup_queries": [], + }, + "coordinator3": { + "args": [ + "--bolt-port", + "7689", + "--log-level=TRACE", + "--raft-server-id=3", + "--raft-server-port=10113", + ], + "log_file": "coordinator3.log", + "setup_queries": [ + "ADD COORDINATOR 1 ON '127.0.0.1:10111'", + "ADD COORDINATOR 2 ON '127.0.0.1:10112'", + ], + }, +} + + +def test_coordinators_communication(): + safe_execute(shutil.rmtree, TEMP_DIR) + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) + + coordinator3_cursor = connect(host="localhost", port=7689).cursor() + + def check_coordinator3(): + return sorted(list(execute_and_fetch_all(coordinator3_cursor, "SHOW INSTANCES"))) + + expected_cluster = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("coordinator_2", "127.0.0.1:10112", "", True, "coordinator"), + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ] + mg_sleep_and_assert(expected_cluster, check_coordinator3) + + coordinator1_cursor = connect(host="localhost", port=7687).cursor() + + def check_coordinator1(): + return sorted(list(execute_and_fetch_all(coordinator1_cursor, "SHOW INSTANCES"))) + + mg_sleep_and_assert(expected_cluster, check_coordinator1) + + coordinator2_cursor = connect(host="localhost", port=7688).cursor() + + def check_coordinator2(): + return sorted(list(execute_and_fetch_all(coordinator2_cursor, "SHOW INSTANCES"))) + + mg_sleep_and_assert(expected_cluster, check_coordinator2) + + +def test_coordinators_communication_with_restarts(): + safe_execute(shutil.rmtree, TEMP_DIR) + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) + + expected_cluster = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("coordinator_2", "127.0.0.1:10112", "", True, "coordinator"), + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ] + + coordinator1_cursor = connect(host="localhost", port=7687).cursor() + + def check_coordinator1(): + return sorted(list(execute_and_fetch_all(coordinator1_cursor, "SHOW INSTANCES"))) + + mg_sleep_and_assert(expected_cluster, check_coordinator1) + + coordinator2_cursor = connect(host="localhost", port=7688).cursor() + + def check_coordinator2(): + return sorted(list(execute_and_fetch_all(coordinator2_cursor, "SHOW INSTANCES"))) + + mg_sleep_and_assert(expected_cluster, check_coordinator2) + + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "coordinator1") + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "coordinator1") + coordinator1_cursor = connect(host="localhost", port=7687).cursor() + + mg_sleep_and_assert(expected_cluster, check_coordinator1) + + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "coordinator1") + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "coordinator2") + + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "coordinator1") + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "coordinator2") + coordinator1_cursor = connect(host="localhost", port=7687).cursor() + coordinator2_cursor = connect(host="localhost", port=7688).cursor() + + mg_sleep_and_assert(expected_cluster, check_coordinator1) + mg_sleep_and_assert(expected_cluster, check_coordinator2) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/high_availability_experimental/workloads.yaml b/tests/e2e/high_availability_experimental/workloads.yaml index 8b617dfb5..23fa3a5db 100644 --- a/tests/e2e/high_availability_experimental/workloads.yaml +++ b/tests/e2e/high_availability_experimental/workloads.yaml @@ -13,7 +13,7 @@ ha_cluster: &ha_cluster log_file: "replication-e2e-main.log" setup_queries: [] coordinator: - args: ["--bolt-port", "7690", "--log-level=TRACE", "--coordinator"] + args: ["--bolt-port", "7690", "--log-level=TRACE", "--raft-server-id=1", "--raft-server-port=10111"] log_file: "replication-e2e-coordinator.log" setup_queries: [ "REGISTER INSTANCE instance_1 ON '127.0.0.1:10011' WITH '127.0.0.1:10001';", @@ -36,6 +36,10 @@ workloads: binary: "tests/e2e/pytest_runner.sh" args: ["high_availability_experimental/manual_setting_replicas.py"] + - name: "Distributed coordinators" + binary: "tests/e2e/pytest_runner.sh" + args: ["high_availability_experimental/distributed_coordinators.py"] + - name: "Not replicate from old main" binary: "tests/e2e/pytest_runner.sh" args: ["high_availability_experimental/not_replicate_from_old_main.py"] diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 1353a56dd..63cca3aa4 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -2632,6 +2632,19 @@ TEST_P(CypherMainVisitorTest, TestRegisterReplicationQuery) { ReplicationQuery::SyncMode::SYNC); } +#ifdef MG_ENTERPRISE +TEST_P(CypherMainVisitorTest, TestAddCoordinatorInstance) { + auto &ast_generator = *GetParam(); + + std::string const correct_query = R"(ADD COORDINATOR 1 ON "127.0.0.1:10111")"; + auto *parsed_query = dynamic_cast<CoordinatorQuery *>(ast_generator.ParseQuery(correct_query)); + + EXPECT_EQ(parsed_query->action_, CoordinatorQuery::Action::ADD_COORDINATOR_INSTANCE); + ast_generator.CheckLiteral(parsed_query->raft_socket_address_, TypedValue("127.0.0.1:10111")); + ast_generator.CheckLiteral(parsed_query->raft_server_id_, TypedValue(1)); +} +#endif + TEST_P(CypherMainVisitorTest, TestDeleteReplica) { auto &ast_generator = *GetParam();