Polish Bolt client and mg_client
Reviewers: mculinovic, teon.banek Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1681
This commit is contained in:
parent
078ab75145
commit
0013cbb173
@ -4,6 +4,7 @@
|
||||
- build_debug/memgraph_distributed
|
||||
- build_release/memgraph
|
||||
- build_release/memgraph_distributed
|
||||
- build_release/tools/src/mg_client
|
||||
- build_release/tools/src/mg_import_csv
|
||||
- build_release/tools/src/mg_statsd
|
||||
- config
|
||||
|
@ -13,26 +13,49 @@
|
||||
|
||||
namespace communication::bolt {
|
||||
|
||||
class ClientFatalException : public utils::BasicException {
|
||||
public:
|
||||
using utils::BasicException::BasicException;
|
||||
ClientFatalException()
|
||||
: utils::BasicException(
|
||||
"Something went wrong while communicating with the server!") {}
|
||||
};
|
||||
|
||||
/// This exception is thrown whenever an error occurs during query execution
|
||||
/// that isn't fatal (eg. mistyped query or some transient error occurred).
|
||||
/// It should be handled by everyone who uses the client.
|
||||
class ClientQueryException : public utils::BasicException {
|
||||
public:
|
||||
using utils::BasicException::BasicException;
|
||||
ClientQueryException() : utils::BasicException("Couldn't execute query!") {}
|
||||
};
|
||||
|
||||
/// This exception is thrown whenever a fatal error occurs during query
|
||||
/// execution and/or connecting to the server.
|
||||
/// It should be handled by everyone who uses the client.
|
||||
class ClientFatalException : public utils::BasicException {
|
||||
public:
|
||||
using utils::BasicException::BasicException;
|
||||
};
|
||||
|
||||
// Internal exception used whenever a communication error occurs. You should
|
||||
// only handle the `ClientFatalException`.
|
||||
class ServerCommunicationException : public ClientFatalException {
|
||||
public:
|
||||
ServerCommunicationException()
|
||||
: ClientFatalException("Couldn't communicate with the server!") {}
|
||||
};
|
||||
|
||||
// Internal exception used whenever a malformed data error occurs. You should
|
||||
// only handle the `ClientFatalException`.
|
||||
class ServerMalformedDataException : public ClientFatalException {
|
||||
public:
|
||||
ServerMalformedDataException()
|
||||
: ClientFatalException("The server sent malformed data!") {}
|
||||
};
|
||||
|
||||
/// Structure that is used to return results from an executed query.
|
||||
struct QueryData {
|
||||
std::vector<std::string> fields;
|
||||
std::vector<std::vector<Value>> records;
|
||||
std::map<std::string, Value> metadata;
|
||||
};
|
||||
|
||||
/// Bolt client.
|
||||
/// It has methods used to connect to the server and execute queries against the
|
||||
/// server. It supports both SSL and plaintext connections.
|
||||
class Client final {
|
||||
public:
|
||||
explicit Client(communication::ClientContext *context) : client_(context) {}
|
||||
@ -42,60 +65,74 @@ class Client final {
|
||||
Client &operator=(const Client &) = delete;
|
||||
Client &operator=(Client &&) = delete;
|
||||
|
||||
bool Connect(const io::network::Endpoint &endpoint,
|
||||
/// Method used to connect to the server. Before executing queries this method
|
||||
/// should be called to set-up the connection to the server. After the
|
||||
/// connection is set-up, multiple queries may be executed through a single
|
||||
/// established connection.
|
||||
/// @throws ClientFatalException when we couldn't connect to the server
|
||||
void Connect(const io::network::Endpoint &endpoint,
|
||||
const std::string &username, const std::string &password,
|
||||
const std::string &client_name = "memgraph-bolt/0.0.1") {
|
||||
const std::string &client_name = "memgraph-bolt") {
|
||||
if (!client_.Connect(endpoint)) {
|
||||
LOG(ERROR) << "Couldn't connect to " << endpoint;
|
||||
return false;
|
||||
throw ClientFatalException("Couldn't connect to {}!", endpoint);
|
||||
}
|
||||
|
||||
if (!client_.Write(kPreamble, sizeof(kPreamble), true)) {
|
||||
LOG(ERROR) << "Couldn't send preamble!";
|
||||
return false;
|
||||
DLOG(ERROR) << "Couldn't send preamble!";
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (!client_.Write(kProtocol, sizeof(kProtocol), i != 3)) {
|
||||
LOG(ERROR) << "Couldn't send protocol version!";
|
||||
return false;
|
||||
DLOG(ERROR) << "Couldn't send protocol version!";
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
}
|
||||
|
||||
if (!client_.Read(sizeof(kProtocol))) {
|
||||
LOG(ERROR) << "Couldn't get negotiated protocol version!";
|
||||
return false;
|
||||
DLOG(ERROR) << "Couldn't get negotiated protocol version!";
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
if (memcmp(kProtocol, client_.GetData(), sizeof(kProtocol)) != 0) {
|
||||
LOG(ERROR) << "Server negotiated unsupported protocol version!";
|
||||
return false;
|
||||
DLOG(ERROR) << "Server negotiated unsupported protocol version!";
|
||||
throw ClientFatalException(
|
||||
"The server negotiated an usupported protocol version!");
|
||||
}
|
||||
client_.ShiftData(sizeof(kProtocol));
|
||||
|
||||
if (!encoder_.MessageInit(client_name, {{"scheme", "basic"},
|
||||
{"principal", username},
|
||||
{"credentials", password}})) {
|
||||
LOG(ERROR) << "Couldn't send init message!";
|
||||
return false;
|
||||
DLOG(ERROR) << "Couldn't send init message!";
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
|
||||
Signature signature;
|
||||
Value metadata;
|
||||
if (!ReadMessage(&signature, &metadata)) {
|
||||
LOG(ERROR) << "Couldn't read init message response!";
|
||||
return false;
|
||||
DLOG(ERROR) << "Couldn't read init message response!";
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
if (signature != Signature::Success) {
|
||||
LOG(ERROR) << "Handshake failed!";
|
||||
return false;
|
||||
DLOG(ERROR) << "Handshake failed!";
|
||||
throw ClientFatalException("Handshake with the server failed!");
|
||||
}
|
||||
|
||||
DLOG(INFO) << "Metadata of init message response: " << metadata;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Function used to execute queries against the server. Before you can
|
||||
/// execute queries you must connect the client to the server.
|
||||
/// @throws ClientQueryException when there is some transient error while
|
||||
/// executing the query (eg. mistyped query,
|
||||
/// etc.)
|
||||
/// @throws ClientFatalException when we couldn't communicate with the server
|
||||
QueryData Execute(const std::string &query,
|
||||
const std::map<std::string, Value> ¶meters) {
|
||||
if (!client_.IsConnected()) {
|
||||
throw ClientFatalException(
|
||||
"You must first connect to the server before using the client!");
|
||||
}
|
||||
|
||||
DLOG(INFO) << "Sending run message with statement: '" << query
|
||||
<< "'; parameters: " << parameters;
|
||||
|
||||
@ -106,10 +143,10 @@ class Client final {
|
||||
Signature signature;
|
||||
Value fields;
|
||||
if (!ReadMessage(&signature, &fields)) {
|
||||
throw ClientFatalException();
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
if (fields.type() != Value::Type::Map) {
|
||||
throw ClientFatalException();
|
||||
throw ServerMalformedDataException();
|
||||
}
|
||||
|
||||
if (signature == Signature::Failure) {
|
||||
@ -121,7 +158,7 @@ class Client final {
|
||||
}
|
||||
throw ClientQueryException();
|
||||
} else if (signature != Signature::Success) {
|
||||
throw ClientFatalException();
|
||||
throw ServerMalformedDataException();
|
||||
}
|
||||
|
||||
DLOG(INFO) << "Reading pull_all message response";
|
||||
@ -130,26 +167,26 @@ class Client final {
|
||||
std::vector<std::vector<Value>> records;
|
||||
while (true) {
|
||||
if (!GetMessage()) {
|
||||
throw ClientFatalException();
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
if (!decoder_.ReadMessageHeader(&signature, &marker)) {
|
||||
throw ClientFatalException();
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
if (signature == Signature::Record) {
|
||||
Value record;
|
||||
if (!decoder_.ReadValue(&record, Value::Type::List)) {
|
||||
throw ClientFatalException();
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
records.emplace_back(std::move(record.ValueList()));
|
||||
} else if (signature == Signature::Success) {
|
||||
if (!decoder_.ReadValue(&metadata)) {
|
||||
throw ClientFatalException();
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
break;
|
||||
} else if (signature == Signature::Failure) {
|
||||
Value data;
|
||||
if (!decoder_.ReadValue(&data)) {
|
||||
throw ClientFatalException();
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
HandleFailure();
|
||||
auto &tmp = data.ValueMap();
|
||||
@ -159,28 +196,28 @@ class Client final {
|
||||
}
|
||||
throw ClientQueryException();
|
||||
} else {
|
||||
throw ClientFatalException();
|
||||
throw ServerMalformedDataException();
|
||||
}
|
||||
}
|
||||
|
||||
if (metadata.type() != Value::Type::Map) {
|
||||
throw ClientFatalException();
|
||||
throw ServerMalformedDataException();
|
||||
}
|
||||
|
||||
QueryData ret{{}, std::move(records), std::move(metadata.ValueMap())};
|
||||
|
||||
auto &header = fields.ValueMap();
|
||||
if (header.find("fields") == header.end()) {
|
||||
throw ClientFatalException();
|
||||
throw ServerMalformedDataException();
|
||||
}
|
||||
if (header["fields"].type() != Value::Type::List) {
|
||||
throw ClientFatalException();
|
||||
throw ServerMalformedDataException();
|
||||
}
|
||||
auto &field_vector = header["fields"].ValueList();
|
||||
|
||||
for (auto &field_item : field_vector) {
|
||||
if (field_item.type() != Value::Type::String) {
|
||||
throw ClientFatalException();
|
||||
throw ServerMalformedDataException();
|
||||
}
|
||||
ret.fields.emplace_back(std::move(field_item.ValueString()));
|
||||
}
|
||||
@ -188,6 +225,7 @@ class Client final {
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Close the active client connection.
|
||||
void Close() { client_.Close(); };
|
||||
|
||||
private:
|
||||
@ -227,18 +265,18 @@ class Client final {
|
||||
|
||||
void HandleFailure() {
|
||||
if (!encoder_.MessageAckFailure()) {
|
||||
throw ClientFatalException();
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
while (true) {
|
||||
Signature signature;
|
||||
Value data;
|
||||
if (!ReadMessage(&signature, &data)) {
|
||||
throw ClientFatalException();
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
if (signature == Signature::Success) {
|
||||
break;
|
||||
} else if (signature != Signature::Ignored) {
|
||||
throw ClientFatalException();
|
||||
throw ServerMalformedDataException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
|
||||
// Create a new SSL object that will be used for SSL communication.
|
||||
ssl_ = SSL_new(context_->context());
|
||||
if (ssl_ == nullptr) {
|
||||
LOG(ERROR) << "Couldn't create client SSL object!";
|
||||
DLOG(ERROR) << "Couldn't create client SSL object!";
|
||||
socket_.Close();
|
||||
return false;
|
||||
}
|
||||
@ -43,7 +43,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
|
||||
// handle that in our socket destructor).
|
||||
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
|
||||
if (bio_ == nullptr) {
|
||||
LOG(ERROR) << "Couldn't create client BIO object!";
|
||||
DLOG(ERROR) << "Couldn't create client BIO object!";
|
||||
socket_.Close();
|
||||
return false;
|
||||
}
|
||||
@ -59,7 +59,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
|
||||
// Perform the TLS handshake.
|
||||
auto ret = SSL_connect(ssl_);
|
||||
if (ret != 1) {
|
||||
LOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError();
|
||||
DLOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError();
|
||||
socket_.Close();
|
||||
return false;
|
||||
}
|
||||
@ -70,6 +70,8 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
|
||||
|
||||
bool Client::ErrorStatus() { return socket_.ErrorStatus(); }
|
||||
|
||||
bool Client::IsConnected() { return socket_.IsOpen(); }
|
||||
|
||||
void Client::Shutdown() { socket_.Shutdown(); }
|
||||
|
||||
void Client::Close() {
|
||||
@ -111,7 +113,7 @@ bool Client::Read(size_t len) {
|
||||
continue;
|
||||
} else {
|
||||
// This is a fatal error.
|
||||
LOG(ERROR) << "Received an unexpected SSL error: " << err;
|
||||
DLOG(ERROR) << "Received an unexpected SSL error: " << err;
|
||||
return false;
|
||||
}
|
||||
} else if (got == 0) {
|
||||
|
@ -42,6 +42,11 @@ class Client final {
|
||||
*/
|
||||
bool ErrorStatus();
|
||||
|
||||
/**
|
||||
* This function returns `true` if the socket is connected to a remote host.
|
||||
*/
|
||||
bool IsConnected();
|
||||
|
||||
/**
|
||||
* This function shuts down the socket.
|
||||
*/
|
||||
|
@ -27,10 +27,7 @@ int main(int argc, char **argv) {
|
||||
communication::ClientContext context(FLAGS_use_ssl);
|
||||
communication::bolt::Client client(&context);
|
||||
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
|
||||
<< FLAGS_port;
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
try {
|
||||
auto ret = client.Execute("SHOW PRIVILEGES FOR user", {});
|
||||
|
@ -32,10 +32,7 @@ int main(int argc, char **argv) {
|
||||
communication::ClientContext context(FLAGS_use_ssl);
|
||||
communication::bolt::Client client(&context);
|
||||
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
|
||||
<< FLAGS_port;
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
std::string query(argv[i]);
|
||||
|
@ -28,10 +28,7 @@ int main(int argc, char **argv) {
|
||||
communication::ClientContext context(FLAGS_use_ssl);
|
||||
communication::bolt::Client client(&context);
|
||||
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
|
||||
<< FLAGS_port;
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
client.Execute("UNWIND range(0, 10000) AS x CREATE ()", {});
|
||||
|
||||
|
@ -54,10 +54,7 @@ int main(int argc, char **argv) {
|
||||
communication::ClientContext context(FLAGS_use_ssl);
|
||||
communication::bolt::Client client(&context);
|
||||
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
|
||||
<< FLAGS_port;
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
if (FLAGS_step == "start") {
|
||||
ExecuteQuery(client,
|
||||
|
@ -20,9 +20,7 @@ using namespace communication::bolt;
|
||||
class BoltClient : public ::testing::Test {
|
||||
protected:
|
||||
virtual void SetUp() {
|
||||
if (!client_.Connect(endpoint_, FLAGS_username, FLAGS_password)) {
|
||||
throw utils::BasicException("Couldn't connect to database!");
|
||||
}
|
||||
client_.Connect(endpoint_, FLAGS_username, FLAGS_password);
|
||||
}
|
||||
|
||||
virtual void TearDown() {}
|
||||
|
@ -109,9 +109,7 @@ int main(int argc, char **argv) {
|
||||
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
ClientContext context(FLAGS_use_ssl);
|
||||
Client client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
std::vector<std::unique_ptr<TestClient>> clients;
|
||||
for (auto i = 0; i < FLAGS_num_workers; ++i) {
|
||||
|
@ -340,9 +340,7 @@ int main(int argc, char **argv) {
|
||||
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
ClientContext context(FLAGS_use_ssl);
|
||||
Client client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
num_pos.store(NumNodesWithLabel(client, "Pos"));
|
||||
num_cards.store(NumNodesWithLabel(client, "Card"));
|
||||
|
@ -43,9 +43,7 @@ class TestClient {
|
||||
public:
|
||||
TestClient() {
|
||||
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
if (!client_.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||
}
|
||||
client_.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
}
|
||||
|
||||
virtual ~TestClient() {}
|
||||
|
@ -281,9 +281,7 @@ int main(int argc, char **argv) {
|
||||
Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port);
|
||||
ClientContext context(FLAGS_use_ssl);
|
||||
Client client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
return IndependentSet(client, INDEPENDENT_LABEL);
|
||||
}();
|
||||
|
||||
|
@ -61,9 +61,7 @@ void ExecuteQueries(const std::vector<std::string> &queries,
|
||||
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
ClientContext context(FLAGS_use_ssl);
|
||||
Client client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
std::string str;
|
||||
while (true) {
|
||||
|
@ -25,7 +25,7 @@ int main(int argc, char **argv) {
|
||||
communication::ClientContext context(FLAGS_use_ssl);
|
||||
communication::bolt::Client client(&context);
|
||||
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1;
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
std::cout << "Memgraph bolt client is connected and running." << std::endl;
|
||||
|
||||
|
@ -59,10 +59,7 @@ class GraphSession {
|
||||
|
||||
EndpointT endpoint(FLAGS_address, FLAGS_port);
|
||||
client_ = std::make_unique<ClientT>(&context_);
|
||||
|
||||
if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
throw utils::BasicException("Couldn't connect to server!");
|
||||
}
|
||||
client_->Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -381,9 +378,7 @@ int main(int argc, char **argv) {
|
||||
EndpointT endpoint(FLAGS_address, FLAGS_port);
|
||||
ClientContextT context(FLAGS_use_ssl);
|
||||
ClientT client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
throw utils::BasicException("Couldn't connect to server!");
|
||||
}
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
// cleanup and create indexes
|
||||
client.Execute("MATCH (n) DETACH DELETE n", {});
|
||||
|
@ -629,12 +629,16 @@ int main(int argc, char **argv) {
|
||||
communication::ClientContext context(FLAGS_use_ssl);
|
||||
communication::bolt::Client client(&context);
|
||||
|
||||
if (!client.Connect(endpoint, FLAGS_username, password)) {
|
||||
// Error message is logged in client.Connect method
|
||||
std::string bolt_client_version =
|
||||
fmt::format("mg_client/{}", gflags::VersionString());
|
||||
try {
|
||||
client.Connect(endpoint, FLAGS_username, password, bolt_client_version);
|
||||
} catch (const communication::bolt::ClientFatalException &e) {
|
||||
EchoFailure("Connection failure", e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
EchoInfo("mg-client");
|
||||
EchoInfo(fmt::format("mg_client {}", gflags::VersionString()));
|
||||
EchoInfo("Type :help for shell usage");
|
||||
EchoInfo("Quit the shell by typing Ctrl-D(eof) or :quit");
|
||||
EchoInfo(fmt::format("Connected to 'memgraph://{}'", endpoint));
|
||||
@ -682,9 +686,13 @@ int main(int argc, char **argv) {
|
||||
client.Close();
|
||||
while (num_retries > 0) {
|
||||
--num_retries;
|
||||
if (client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
try {
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password,
|
||||
bolt_client_version);
|
||||
is_connected = true;
|
||||
break;
|
||||
} catch (const communication::bolt::ClientFatalException &e) {
|
||||
EchoFailure("Connection failure", e.what());
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user