diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index ea38192df..e90d88234 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include "glog/logging.h" @@ -64,6 +65,10 @@ class Session { virtual bool Authenticate(const std::string &username, const std::string &password) = 0; + /** Return the name of the server that should be used for the Bolt INIT + * message. */ + virtual std::optional GetServerNameForInit() = 0; + /** * Executes the session after data has been read into the buffer. * Goes through the bolt states in order to execute commands from the client. diff --git a/src/communication/bolt/v1/states/init.hpp b/src/communication/bolt/v1/states/init.hpp index 473c953f5..c2a7aa7bd 100644 --- a/src/communication/bolt/v1/states/init.hpp +++ b/src/communication/bolt/v1/states/init.hpp @@ -93,9 +93,19 @@ State StateInitRun(Session &session) { } // Return success. - if (!session.encoder_.MessageSuccess()) { - DLOG(WARNING) << "Couldn't send success message to the client!"; - return State::Close; + { + bool success_sent = false; + auto server_name = session.GetServerNameForInit(); + if (server_name) { + success_sent = + session.encoder_.MessageSuccess({{"server", *server_name}}); + } else { + success_sent = session.encoder_.MessageSuccess(); + } + if (!success_sent) { + DLOG(WARNING) << "Couldn't send success message to the client!"; + return State::Close; + } } return State::Idle; diff --git a/src/memgraph_init.cpp b/src/memgraph_init.cpp index 2ae611a82..e5746d6aa 100644 --- a/src/memgraph_init.cpp +++ b/src/memgraph_init.cpp @@ -21,6 +21,9 @@ DEFINE_uint64(memory_warning_threshold, 1024, "Memory warning threshold, in MB. If Memgraph detects there is " "less available RAM it will log a warning. Set to 0 to " "disable."); +DEFINE_string(bolt_server_name_for_init, "", + "Server name which the database should send to the client in the " + "Bolt INIT message."); BoltSession::BoltSession(SessionData *data, const io::network::Endpoint &endpoint, @@ -130,6 +133,11 @@ bool BoltSession::Authenticate(const std::string &username, #endif } +std::optional BoltSession::GetServerNameForInit() { + if (FLAGS_bolt_server_name_for_init.empty()) return std::nullopt; + return FLAGS_bolt_server_name_for_init; +} + #ifdef MG_SINGLE_NODE_V2 BoltSession::TypedValueResultStream::TypedValueResultStream( TEncoder *encoder, const storage::Storage *db) diff --git a/src/memgraph_init.hpp b/src/memgraph_init.hpp index c8ac4de90..7d5540726 100644 --- a/src/memgraph_init.hpp +++ b/src/memgraph_init.hpp @@ -66,6 +66,8 @@ class BoltSession final bool Authenticate(const std::string &username, const std::string &password) override; + std::optional GetServerNameForInit() override; + private: /// Wrapper around TEncoder which converts TypedValue to Value /// before forwarding the calls to original TEncoder. diff --git a/tests/manual/ha_proxy.cpp b/tests/manual/ha_proxy.cpp index e986ee432..0e045809d 100644 --- a/tests/manual/ha_proxy.cpp +++ b/tests/manual/ha_proxy.cpp @@ -107,6 +107,10 @@ class BoltSession final return true; } + std::optional GetServerNameForInit() override { + return std::nullopt; + } + private: SessionData *session_data_; io::network::Endpoint endpoint_; diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 187459495..a8203a149 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -55,6 +55,10 @@ class TestSession : public Session { return true; } + std::optional GetServerNameForInit() override { + return std::nullopt; + } + private: std::string query_; };