diff --git a/apollo_archives.yaml b/apollo_archives.yaml index aa9720e8a..c84b47ec4 100644 --- a/apollo_archives.yaml +++ b/apollo_archives.yaml @@ -4,6 +4,7 @@ - build_debug/memgraph_distributed - build_release/memgraph - build_release/memgraph_distributed + - build_release/tools/src/mg_client - build_release/tools/src/mg_import_csv - build_release/tools/src/mg_statsd - config diff --git a/src/communication/bolt/client.hpp b/src/communication/bolt/client.hpp index a5bfc06c0..c2cfe47ea 100644 --- a/src/communication/bolt/client.hpp +++ b/src/communication/bolt/client.hpp @@ -13,26 +13,49 @@ namespace communication::bolt { -class ClientFatalException : public utils::BasicException { - public: - using utils::BasicException::BasicException; - ClientFatalException() - : utils::BasicException( - "Something went wrong while communicating with the server!") {} -}; - +/// This exception is thrown whenever an error occurs during query execution +/// that isn't fatal (eg. mistyped query or some transient error occurred). +/// It should be handled by everyone who uses the client. class ClientQueryException : public utils::BasicException { public: using utils::BasicException::BasicException; ClientQueryException() : utils::BasicException("Couldn't execute query!") {} }; +/// This exception is thrown whenever a fatal error occurs during query +/// execution and/or connecting to the server. +/// It should be handled by everyone who uses the client. +class ClientFatalException : public utils::BasicException { + public: + using utils::BasicException::BasicException; +}; + +// Internal exception used whenever a communication error occurs. You should +// only handle the `ClientFatalException`. +class ServerCommunicationException : public ClientFatalException { + public: + ServerCommunicationException() + : ClientFatalException("Couldn't communicate with the server!") {} +}; + +// Internal exception used whenever a malformed data error occurs. You should +// only handle the `ClientFatalException`. +class ServerMalformedDataException : public ClientFatalException { + public: + ServerMalformedDataException() + : ClientFatalException("The server sent malformed data!") {} +}; + +/// Structure that is used to return results from an executed query. struct QueryData { std::vector<std::string> fields; std::vector<std::vector<Value>> records; std::map<std::string, Value> metadata; }; +/// Bolt client. +/// It has methods used to connect to the server and execute queries against the +/// server. It supports both SSL and plaintext connections. class Client final { public: explicit Client(communication::ClientContext *context) : client_(context) {} @@ -42,60 +65,74 @@ class Client final { Client &operator=(const Client &) = delete; Client &operator=(Client &&) = delete; - bool Connect(const io::network::Endpoint &endpoint, + /// Method used to connect to the server. Before executing queries this method + /// should be called to set-up the connection to the server. After the + /// connection is set-up, multiple queries may be executed through a single + /// established connection. + /// @throws ClientFatalException when we couldn't connect to the server + void Connect(const io::network::Endpoint &endpoint, const std::string &username, const std::string &password, - const std::string &client_name = "memgraph-bolt/0.0.1") { + const std::string &client_name = "memgraph-bolt") { if (!client_.Connect(endpoint)) { - LOG(ERROR) << "Couldn't connect to " << endpoint; - return false; + throw ClientFatalException("Couldn't connect to {}!", endpoint); } if (!client_.Write(kPreamble, sizeof(kPreamble), true)) { - LOG(ERROR) << "Couldn't send preamble!"; - return false; + DLOG(ERROR) << "Couldn't send preamble!"; + throw ServerCommunicationException(); } for (int i = 0; i < 4; ++i) { if (!client_.Write(kProtocol, sizeof(kProtocol), i != 3)) { - LOG(ERROR) << "Couldn't send protocol version!"; - return false; + DLOG(ERROR) << "Couldn't send protocol version!"; + throw ServerCommunicationException(); } } if (!client_.Read(sizeof(kProtocol))) { - LOG(ERROR) << "Couldn't get negotiated protocol version!"; - return false; + DLOG(ERROR) << "Couldn't get negotiated protocol version!"; + throw ServerCommunicationException(); } if (memcmp(kProtocol, client_.GetData(), sizeof(kProtocol)) != 0) { - LOG(ERROR) << "Server negotiated unsupported protocol version!"; - return false; + DLOG(ERROR) << "Server negotiated unsupported protocol version!"; + throw ClientFatalException( + "The server negotiated an usupported protocol version!"); } client_.ShiftData(sizeof(kProtocol)); if (!encoder_.MessageInit(client_name, {{"scheme", "basic"}, {"principal", username}, {"credentials", password}})) { - LOG(ERROR) << "Couldn't send init message!"; - return false; + DLOG(ERROR) << "Couldn't send init message!"; + throw ServerCommunicationException(); } Signature signature; Value metadata; if (!ReadMessage(&signature, &metadata)) { - LOG(ERROR) << "Couldn't read init message response!"; - return false; + DLOG(ERROR) << "Couldn't read init message response!"; + throw ServerCommunicationException(); } if (signature != Signature::Success) { - LOG(ERROR) << "Handshake failed!"; - return false; + DLOG(ERROR) << "Handshake failed!"; + throw ClientFatalException("Handshake with the server failed!"); } DLOG(INFO) << "Metadata of init message response: " << metadata; - - return true; } + /// Function used to execute queries against the server. Before you can + /// execute queries you must connect the client to the server. + /// @throws ClientQueryException when there is some transient error while + /// executing the query (eg. mistyped query, + /// etc.) + /// @throws ClientFatalException when we couldn't communicate with the server QueryData Execute(const std::string &query, const std::map<std::string, Value> ¶meters) { + if (!client_.IsConnected()) { + throw ClientFatalException( + "You must first connect to the server before using the client!"); + } + DLOG(INFO) << "Sending run message with statement: '" << query << "'; parameters: " << parameters; @@ -106,10 +143,10 @@ class Client final { Signature signature; Value fields; if (!ReadMessage(&signature, &fields)) { - throw ClientFatalException(); + throw ServerCommunicationException(); } if (fields.type() != Value::Type::Map) { - throw ClientFatalException(); + throw ServerMalformedDataException(); } if (signature == Signature::Failure) { @@ -121,7 +158,7 @@ class Client final { } throw ClientQueryException(); } else if (signature != Signature::Success) { - throw ClientFatalException(); + throw ServerMalformedDataException(); } DLOG(INFO) << "Reading pull_all message response"; @@ -130,26 +167,26 @@ class Client final { std::vector<std::vector<Value>> records; while (true) { if (!GetMessage()) { - throw ClientFatalException(); + throw ServerCommunicationException(); } if (!decoder_.ReadMessageHeader(&signature, &marker)) { - throw ClientFatalException(); + throw ServerCommunicationException(); } if (signature == Signature::Record) { Value record; if (!decoder_.ReadValue(&record, Value::Type::List)) { - throw ClientFatalException(); + throw ServerCommunicationException(); } records.emplace_back(std::move(record.ValueList())); } else if (signature == Signature::Success) { if (!decoder_.ReadValue(&metadata)) { - throw ClientFatalException(); + throw ServerCommunicationException(); } break; } else if (signature == Signature::Failure) { Value data; if (!decoder_.ReadValue(&data)) { - throw ClientFatalException(); + throw ServerCommunicationException(); } HandleFailure(); auto &tmp = data.ValueMap(); @@ -159,28 +196,28 @@ class Client final { } throw ClientQueryException(); } else { - throw ClientFatalException(); + throw ServerMalformedDataException(); } } if (metadata.type() != Value::Type::Map) { - throw ClientFatalException(); + throw ServerMalformedDataException(); } QueryData ret{{}, std::move(records), std::move(metadata.ValueMap())}; auto &header = fields.ValueMap(); if (header.find("fields") == header.end()) { - throw ClientFatalException(); + throw ServerMalformedDataException(); } if (header["fields"].type() != Value::Type::List) { - throw ClientFatalException(); + throw ServerMalformedDataException(); } auto &field_vector = header["fields"].ValueList(); for (auto &field_item : field_vector) { if (field_item.type() != Value::Type::String) { - throw ClientFatalException(); + throw ServerMalformedDataException(); } ret.fields.emplace_back(std::move(field_item.ValueString())); } @@ -188,6 +225,7 @@ class Client final { return ret; } + /// Close the active client connection. void Close() { client_.Close(); }; private: @@ -227,18 +265,18 @@ class Client final { void HandleFailure() { if (!encoder_.MessageAckFailure()) { - throw ClientFatalException(); + throw ServerCommunicationException(); } while (true) { Signature signature; Value data; if (!ReadMessage(&signature, &data)) { - throw ClientFatalException(); + throw ServerCommunicationException(); } if (signature == Signature::Success) { break; } else if (signature != Signature::Ignored) { - throw ClientFatalException(); + throw ServerMalformedDataException(); } } } diff --git a/src/communication/client.cpp b/src/communication/client.cpp index cb137f731..9f5c50530 100644 --- a/src/communication/client.cpp +++ b/src/communication/client.cpp @@ -32,7 +32,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) { // Create a new SSL object that will be used for SSL communication. ssl_ = SSL_new(context_->context()); if (ssl_ == nullptr) { - LOG(ERROR) << "Couldn't create client SSL object!"; + DLOG(ERROR) << "Couldn't create client SSL object!"; socket_.Close(); return false; } @@ -43,7 +43,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) { // handle that in our socket destructor). bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE); if (bio_ == nullptr) { - LOG(ERROR) << "Couldn't create client BIO object!"; + DLOG(ERROR) << "Couldn't create client BIO object!"; socket_.Close(); return false; } @@ -59,7 +59,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) { // Perform the TLS handshake. auto ret = SSL_connect(ssl_); if (ret != 1) { - LOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError(); + DLOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError(); socket_.Close(); return false; } @@ -70,6 +70,8 @@ bool Client::Connect(const io::network::Endpoint &endpoint) { bool Client::ErrorStatus() { return socket_.ErrorStatus(); } +bool Client::IsConnected() { return socket_.IsOpen(); } + void Client::Shutdown() { socket_.Shutdown(); } void Client::Close() { @@ -111,7 +113,7 @@ bool Client::Read(size_t len) { continue; } else { // This is a fatal error. - LOG(ERROR) << "Received an unexpected SSL error: " << err; + DLOG(ERROR) << "Received an unexpected SSL error: " << err; return false; } } else if (got == 0) { diff --git a/src/communication/client.hpp b/src/communication/client.hpp index 5f6e35906..b7412cb78 100644 --- a/src/communication/client.hpp +++ b/src/communication/client.hpp @@ -42,6 +42,11 @@ class Client final { */ bool ErrorStatus(); + /** + * This function returns `true` if the socket is connected to a remote host. + */ + bool IsConnected(); + /** * This function shuts down the socket. */ diff --git a/tests/integration/auth/checker.cpp b/tests/integration/auth/checker.cpp index 8d6294d89..8f6f22e15 100644 --- a/tests/integration/auth/checker.cpp +++ b/tests/integration/auth/checker.cpp @@ -27,10 +27,7 @@ int main(int argc, char **argv) { communication::ClientContext context(FLAGS_use_ssl); communication::bolt::Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":" - << FLAGS_port; - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); try { auto ret = client.Execute("SHOW PRIVILEGES FOR user", {}); diff --git a/tests/integration/auth/tester.cpp b/tests/integration/auth/tester.cpp index e7ee501df..408f88258 100644 --- a/tests/integration/auth/tester.cpp +++ b/tests/integration/auth/tester.cpp @@ -32,10 +32,7 @@ int main(int argc, char **argv) { communication::ClientContext context(FLAGS_use_ssl); communication::bolt::Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":" - << FLAGS_port; - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); for (int i = 1; i < argc; ++i) { std::string query(argv[i]); diff --git a/tests/integration/distributed/tester.cpp b/tests/integration/distributed/tester.cpp index 0a9d2f246..026ed91a7 100644 --- a/tests/integration/distributed/tester.cpp +++ b/tests/integration/distributed/tester.cpp @@ -28,10 +28,7 @@ int main(int argc, char **argv) { communication::ClientContext context(FLAGS_use_ssl); communication::bolt::Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":" - << FLAGS_port; - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); client.Execute("UNWIND range(0, 10000) AS x CREATE ()", {}); diff --git a/tests/integration/kafka/tester.cpp b/tests/integration/kafka/tester.cpp index 70f6f95b9..d96e47256 100644 --- a/tests/integration/kafka/tester.cpp +++ b/tests/integration/kafka/tester.cpp @@ -54,10 +54,7 @@ int main(int argc, char **argv) { communication::ClientContext context(FLAGS_use_ssl); communication::bolt::Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":" - << FLAGS_port; - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); if (FLAGS_step == "start") { ExecuteQuery(client, diff --git a/tests/integration/transactions/tester.cpp b/tests/integration/transactions/tester.cpp index d26b461e6..478977840 100644 --- a/tests/integration/transactions/tester.cpp +++ b/tests/integration/transactions/tester.cpp @@ -20,9 +20,7 @@ using namespace communication::bolt; class BoltClient : public ::testing::Test { protected: virtual void SetUp() { - if (!client_.Connect(endpoint_, FLAGS_username, FLAGS_password)) { - throw utils::BasicException("Couldn't connect to database!"); - } + client_.Connect(endpoint_, FLAGS_username, FLAGS_password); } virtual void TearDown() {} diff --git a/tests/macro_benchmark/clients/bfs_pokec_client.cpp b/tests/macro_benchmark/clients/bfs_pokec_client.cpp index 8c54f2085..dba43ac22 100644 --- a/tests/macro_benchmark/clients/bfs_pokec_client.cpp +++ b/tests/macro_benchmark/clients/bfs_pokec_client.cpp @@ -109,9 +109,7 @@ int main(int argc, char **argv) { Endpoint endpoint(FLAGS_address, FLAGS_port); ClientContext context(FLAGS_use_ssl); Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to " << endpoint; - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); std::vector<std::unique_ptr<TestClient>> clients; for (auto i = 0; i < FLAGS_num_workers; ++i) { diff --git a/tests/macro_benchmark/clients/card_fraud_client.cpp b/tests/macro_benchmark/clients/card_fraud_client.cpp index 1de65a906..7559b700e 100644 --- a/tests/macro_benchmark/clients/card_fraud_client.cpp +++ b/tests/macro_benchmark/clients/card_fraud_client.cpp @@ -340,9 +340,7 @@ int main(int argc, char **argv) { Endpoint endpoint(FLAGS_address, FLAGS_port); ClientContext context(FLAGS_use_ssl); Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to " << endpoint; - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); num_pos.store(NumNodesWithLabel(client, "Pos")); num_cards.store(NumNodesWithLabel(client, "Card")); diff --git a/tests/macro_benchmark/clients/long_running_common.hpp b/tests/macro_benchmark/clients/long_running_common.hpp index d090a3827..078d3cca7 100644 --- a/tests/macro_benchmark/clients/long_running_common.hpp +++ b/tests/macro_benchmark/clients/long_running_common.hpp @@ -43,9 +43,7 @@ class TestClient { public: TestClient() { Endpoint endpoint(FLAGS_address, FLAGS_port); - if (!client_.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to " << endpoint; - } + client_.Connect(endpoint, FLAGS_username, FLAGS_password); } virtual ~TestClient() {} diff --git a/tests/macro_benchmark/clients/pokec_client.cpp b/tests/macro_benchmark/clients/pokec_client.cpp index 8b84c728c..822219645 100644 --- a/tests/macro_benchmark/clients/pokec_client.cpp +++ b/tests/macro_benchmark/clients/pokec_client.cpp @@ -281,9 +281,7 @@ int main(int argc, char **argv) { Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port); ClientContext context(FLAGS_use_ssl); Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to " << endpoint; - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); return IndependentSet(client, INDEPENDENT_LABEL); }(); diff --git a/tests/macro_benchmark/clients/query_client.cpp b/tests/macro_benchmark/clients/query_client.cpp index 9aab2e2e2..1081eb597 100644 --- a/tests/macro_benchmark/clients/query_client.cpp +++ b/tests/macro_benchmark/clients/query_client.cpp @@ -61,9 +61,7 @@ void ExecuteQueries(const std::vector<std::string> &queries, Endpoint endpoint(FLAGS_address, FLAGS_port); ClientContext context(FLAGS_use_ssl); Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - LOG(FATAL) << "Couldn't connect to " << endpoint; - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); std::string str; while (true) { diff --git a/tests/manual/bolt_client.cpp b/tests/manual/bolt_client.cpp index 57f62847a..f8597a679 100644 --- a/tests/manual/bolt_client.cpp +++ b/tests/manual/bolt_client.cpp @@ -25,7 +25,7 @@ int main(int argc, char **argv) { communication::ClientContext context(FLAGS_use_ssl); communication::bolt::Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1; + client.Connect(endpoint, FLAGS_username, FLAGS_password); std::cout << "Memgraph bolt client is connected and running." << std::endl; diff --git a/tests/stress/long_running.cpp b/tests/stress/long_running.cpp index 37953035d..ad88cfe81 100644 --- a/tests/stress/long_running.cpp +++ b/tests/stress/long_running.cpp @@ -59,10 +59,7 @@ class GraphSession { EndpointT endpoint(FLAGS_address, FLAGS_port); client_ = std::make_unique<ClientT>(&context_); - - if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) { - throw utils::BasicException("Couldn't connect to server!"); - } + client_->Connect(endpoint, FLAGS_username, FLAGS_password); } private: @@ -381,9 +378,7 @@ int main(int argc, char **argv) { EndpointT endpoint(FLAGS_address, FLAGS_port); ClientContextT context(FLAGS_use_ssl); ClientT client(&context); - if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { - throw utils::BasicException("Couldn't connect to server!"); - } + client.Connect(endpoint, FLAGS_username, FLAGS_password); // cleanup and create indexes client.Execute("MATCH (n) DETACH DELETE n", {}); diff --git a/tools/src/mg_client/main.cpp b/tools/src/mg_client/main.cpp index 293e7fc18..7895ebaea 100644 --- a/tools/src/mg_client/main.cpp +++ b/tools/src/mg_client/main.cpp @@ -629,12 +629,16 @@ int main(int argc, char **argv) { communication::ClientContext context(FLAGS_use_ssl); communication::bolt::Client client(&context); - if (!client.Connect(endpoint, FLAGS_username, password)) { - // Error message is logged in client.Connect method + std::string bolt_client_version = + fmt::format("mg_client/{}", gflags::VersionString()); + try { + client.Connect(endpoint, FLAGS_username, password, bolt_client_version); + } catch (const communication::bolt::ClientFatalException &e) { + EchoFailure("Connection failure", e.what()); return 1; } - EchoInfo("mg-client"); + EchoInfo(fmt::format("mg_client {}", gflags::VersionString())); EchoInfo("Type :help for shell usage"); EchoInfo("Quit the shell by typing Ctrl-D(eof) or :quit"); EchoInfo(fmt::format("Connected to 'memgraph://{}'", endpoint)); @@ -682,9 +686,13 @@ int main(int argc, char **argv) { client.Close(); while (num_retries > 0) { --num_retries; - if (client.Connect(endpoint, FLAGS_username, FLAGS_password)) { + try { + client.Connect(endpoint, FLAGS_username, FLAGS_password, + bolt_client_version); is_connected = true; break; + } catch (const communication::bolt::ClientFatalException &e) { + EchoFailure("Connection failure", e.what()); } std::this_thread::sleep_for(std::chrono::seconds(1)); }