Polish Bolt client and mg_client

Reviewers: mculinovic, teon.banek

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1681
This commit is contained in:
Matej Ferencevic 2018-10-19 12:28:38 +02:00
parent 078ab75145
commit 0013cbb173
17 changed files with 119 additions and 94 deletions

View File

@ -4,6 +4,7 @@
- build_debug/memgraph_distributed
- build_release/memgraph
- build_release/memgraph_distributed
- build_release/tools/src/mg_client
- build_release/tools/src/mg_import_csv
- build_release/tools/src/mg_statsd
- config

View File

@ -13,26 +13,49 @@
namespace communication::bolt {
class ClientFatalException : public utils::BasicException {
public:
using utils::BasicException::BasicException;
ClientFatalException()
: utils::BasicException(
"Something went wrong while communicating with the server!") {}
};
/// This exception is thrown whenever an error occurs during query execution
/// that isn't fatal (eg. mistyped query or some transient error occurred).
/// It should be handled by everyone who uses the client.
class ClientQueryException : public utils::BasicException {
public:
using utils::BasicException::BasicException;
ClientQueryException() : utils::BasicException("Couldn't execute query!") {}
};
/// This exception is thrown whenever a fatal error occurs during query
/// execution and/or connecting to the server.
/// It should be handled by everyone who uses the client.
class ClientFatalException : public utils::BasicException {
public:
using utils::BasicException::BasicException;
};
// Internal exception used whenever a communication error occurs. You should
// only handle the `ClientFatalException`.
class ServerCommunicationException : public ClientFatalException {
public:
ServerCommunicationException()
: ClientFatalException("Couldn't communicate with the server!") {}
};
// Internal exception used whenever a malformed data error occurs. You should
// only handle the `ClientFatalException`.
class ServerMalformedDataException : public ClientFatalException {
public:
ServerMalformedDataException()
: ClientFatalException("The server sent malformed data!") {}
};
/// Structure that is used to return results from an executed query.
struct QueryData {
std::vector<std::string> fields;
std::vector<std::vector<Value>> records;
std::map<std::string, Value> metadata;
};
/// Bolt client.
/// It has methods used to connect to the server and execute queries against the
/// server. It supports both SSL and plaintext connections.
class Client final {
public:
explicit Client(communication::ClientContext *context) : client_(context) {}
@ -42,60 +65,74 @@ class Client final {
Client &operator=(const Client &) = delete;
Client &operator=(Client &&) = delete;
bool Connect(const io::network::Endpoint &endpoint,
/// Method used to connect to the server. Before executing queries this method
/// should be called to set-up the connection to the server. After the
/// connection is set-up, multiple queries may be executed through a single
/// established connection.
/// @throws ClientFatalException when we couldn't connect to the server
void Connect(const io::network::Endpoint &endpoint,
const std::string &username, const std::string &password,
const std::string &client_name = "memgraph-bolt/0.0.1") {
const std::string &client_name = "memgraph-bolt") {
if (!client_.Connect(endpoint)) {
LOG(ERROR) << "Couldn't connect to " << endpoint;
return false;
throw ClientFatalException("Couldn't connect to {}!", endpoint);
}
if (!client_.Write(kPreamble, sizeof(kPreamble), true)) {
LOG(ERROR) << "Couldn't send preamble!";
return false;
DLOG(ERROR) << "Couldn't send preamble!";
throw ServerCommunicationException();
}
for (int i = 0; i < 4; ++i) {
if (!client_.Write(kProtocol, sizeof(kProtocol), i != 3)) {
LOG(ERROR) << "Couldn't send protocol version!";
return false;
DLOG(ERROR) << "Couldn't send protocol version!";
throw ServerCommunicationException();
}
}
if (!client_.Read(sizeof(kProtocol))) {
LOG(ERROR) << "Couldn't get negotiated protocol version!";
return false;
DLOG(ERROR) << "Couldn't get negotiated protocol version!";
throw ServerCommunicationException();
}
if (memcmp(kProtocol, client_.GetData(), sizeof(kProtocol)) != 0) {
LOG(ERROR) << "Server negotiated unsupported protocol version!";
return false;
DLOG(ERROR) << "Server negotiated unsupported protocol version!";
throw ClientFatalException(
"The server negotiated an usupported protocol version!");
}
client_.ShiftData(sizeof(kProtocol));
if (!encoder_.MessageInit(client_name, {{"scheme", "basic"},
{"principal", username},
{"credentials", password}})) {
LOG(ERROR) << "Couldn't send init message!";
return false;
DLOG(ERROR) << "Couldn't send init message!";
throw ServerCommunicationException();
}
Signature signature;
Value metadata;
if (!ReadMessage(&signature, &metadata)) {
LOG(ERROR) << "Couldn't read init message response!";
return false;
DLOG(ERROR) << "Couldn't read init message response!";
throw ServerCommunicationException();
}
if (signature != Signature::Success) {
LOG(ERROR) << "Handshake failed!";
return false;
DLOG(ERROR) << "Handshake failed!";
throw ClientFatalException("Handshake with the server failed!");
}
DLOG(INFO) << "Metadata of init message response: " << metadata;
return true;
}
/// Function used to execute queries against the server. Before you can
/// execute queries you must connect the client to the server.
/// @throws ClientQueryException when there is some transient error while
/// executing the query (eg. mistyped query,
/// etc.)
/// @throws ClientFatalException when we couldn't communicate with the server
QueryData Execute(const std::string &query,
const std::map<std::string, Value> &parameters) {
if (!client_.IsConnected()) {
throw ClientFatalException(
"You must first connect to the server before using the client!");
}
DLOG(INFO) << "Sending run message with statement: '" << query
<< "'; parameters: " << parameters;
@ -106,10 +143,10 @@ class Client final {
Signature signature;
Value fields;
if (!ReadMessage(&signature, &fields)) {
throw ClientFatalException();
throw ServerCommunicationException();
}
if (fields.type() != Value::Type::Map) {
throw ClientFatalException();
throw ServerMalformedDataException();
}
if (signature == Signature::Failure) {
@ -121,7 +158,7 @@ class Client final {
}
throw ClientQueryException();
} else if (signature != Signature::Success) {
throw ClientFatalException();
throw ServerMalformedDataException();
}
DLOG(INFO) << "Reading pull_all message response";
@ -130,26 +167,26 @@ class Client final {
std::vector<std::vector<Value>> records;
while (true) {
if (!GetMessage()) {
throw ClientFatalException();
throw ServerCommunicationException();
}
if (!decoder_.ReadMessageHeader(&signature, &marker)) {
throw ClientFatalException();
throw ServerCommunicationException();
}
if (signature == Signature::Record) {
Value record;
if (!decoder_.ReadValue(&record, Value::Type::List)) {
throw ClientFatalException();
throw ServerCommunicationException();
}
records.emplace_back(std::move(record.ValueList()));
} else if (signature == Signature::Success) {
if (!decoder_.ReadValue(&metadata)) {
throw ClientFatalException();
throw ServerCommunicationException();
}
break;
} else if (signature == Signature::Failure) {
Value data;
if (!decoder_.ReadValue(&data)) {
throw ClientFatalException();
throw ServerCommunicationException();
}
HandleFailure();
auto &tmp = data.ValueMap();
@ -159,28 +196,28 @@ class Client final {
}
throw ClientQueryException();
} else {
throw ClientFatalException();
throw ServerMalformedDataException();
}
}
if (metadata.type() != Value::Type::Map) {
throw ClientFatalException();
throw ServerMalformedDataException();
}
QueryData ret{{}, std::move(records), std::move(metadata.ValueMap())};
auto &header = fields.ValueMap();
if (header.find("fields") == header.end()) {
throw ClientFatalException();
throw ServerMalformedDataException();
}
if (header["fields"].type() != Value::Type::List) {
throw ClientFatalException();
throw ServerMalformedDataException();
}
auto &field_vector = header["fields"].ValueList();
for (auto &field_item : field_vector) {
if (field_item.type() != Value::Type::String) {
throw ClientFatalException();
throw ServerMalformedDataException();
}
ret.fields.emplace_back(std::move(field_item.ValueString()));
}
@ -188,6 +225,7 @@ class Client final {
return ret;
}
/// Close the active client connection.
void Close() { client_.Close(); };
private:
@ -227,18 +265,18 @@ class Client final {
void HandleFailure() {
if (!encoder_.MessageAckFailure()) {
throw ClientFatalException();
throw ServerCommunicationException();
}
while (true) {
Signature signature;
Value data;
if (!ReadMessage(&signature, &data)) {
throw ClientFatalException();
throw ServerCommunicationException();
}
if (signature == Signature::Success) {
break;
} else if (signature != Signature::Ignored) {
throw ClientFatalException();
throw ServerMalformedDataException();
}
}
}

View File

@ -32,7 +32,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
// Create a new SSL object that will be used for SSL communication.
ssl_ = SSL_new(context_->context());
if (ssl_ == nullptr) {
LOG(ERROR) << "Couldn't create client SSL object!";
DLOG(ERROR) << "Couldn't create client SSL object!";
socket_.Close();
return false;
}
@ -43,7 +43,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
// handle that in our socket destructor).
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
if (bio_ == nullptr) {
LOG(ERROR) << "Couldn't create client BIO object!";
DLOG(ERROR) << "Couldn't create client BIO object!";
socket_.Close();
return false;
}
@ -59,7 +59,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
// Perform the TLS handshake.
auto ret = SSL_connect(ssl_);
if (ret != 1) {
LOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError();
DLOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError();
socket_.Close();
return false;
}
@ -70,6 +70,8 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
bool Client::ErrorStatus() { return socket_.ErrorStatus(); }
bool Client::IsConnected() { return socket_.IsOpen(); }
void Client::Shutdown() { socket_.Shutdown(); }
void Client::Close() {
@ -111,7 +113,7 @@ bool Client::Read(size_t len) {
continue;
} else {
// This is a fatal error.
LOG(ERROR) << "Received an unexpected SSL error: " << err;
DLOG(ERROR) << "Received an unexpected SSL error: " << err;
return false;
}
} else if (got == 0) {

View File

@ -42,6 +42,11 @@ class Client final {
*/
bool ErrorStatus();
/**
* This function returns `true` if the socket is connected to a remote host.
*/
bool IsConnected();
/**
* This function shuts down the socket.
*/

View File

@ -27,10 +27,7 @@ int main(int argc, char **argv) {
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
<< FLAGS_port;
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
try {
auto ret = client.Execute("SHOW PRIVILEGES FOR user", {});

View File

@ -32,10 +32,7 @@ int main(int argc, char **argv) {
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
<< FLAGS_port;
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
for (int i = 1; i < argc; ++i) {
std::string query(argv[i]);

View File

@ -28,10 +28,7 @@ int main(int argc, char **argv) {
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
<< FLAGS_port;
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
client.Execute("UNWIND range(0, 10000) AS x CREATE ()", {});

View File

@ -54,10 +54,7 @@ int main(int argc, char **argv) {
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
<< FLAGS_port;
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
if (FLAGS_step == "start") {
ExecuteQuery(client,

View File

@ -20,9 +20,7 @@ using namespace communication::bolt;
class BoltClient : public ::testing::Test {
protected:
virtual void SetUp() {
if (!client_.Connect(endpoint_, FLAGS_username, FLAGS_password)) {
throw utils::BasicException("Couldn't connect to database!");
}
client_.Connect(endpoint_, FLAGS_username, FLAGS_password);
}
virtual void TearDown() {}

View File

@ -109,9 +109,7 @@ int main(int argc, char **argv) {
Endpoint endpoint(FLAGS_address, FLAGS_port);
ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint;
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
std::vector<std::unique_ptr<TestClient>> clients;
for (auto i = 0; i < FLAGS_num_workers; ++i) {

View File

@ -340,9 +340,7 @@ int main(int argc, char **argv) {
Endpoint endpoint(FLAGS_address, FLAGS_port);
ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint;
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
num_pos.store(NumNodesWithLabel(client, "Pos"));
num_cards.store(NumNodesWithLabel(client, "Card"));

View File

@ -43,9 +43,7 @@ class TestClient {
public:
TestClient() {
Endpoint endpoint(FLAGS_address, FLAGS_port);
if (!client_.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint;
}
client_.Connect(endpoint, FLAGS_username, FLAGS_password);
}
virtual ~TestClient() {}

View File

@ -281,9 +281,7 @@ int main(int argc, char **argv) {
Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port);
ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint;
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
return IndependentSet(client, INDEPENDENT_LABEL);
}();

View File

@ -61,9 +61,7 @@ void ExecuteQueries(const std::vector<std::string> &queries,
Endpoint endpoint(FLAGS_address, FLAGS_port);
ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint;
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
std::string str;
while (true) {

View File

@ -25,7 +25,7 @@ int main(int argc, char **argv) {
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1;
client.Connect(endpoint, FLAGS_username, FLAGS_password);
std::cout << "Memgraph bolt client is connected and running." << std::endl;

View File

@ -59,10 +59,7 @@ class GraphSession {
EndpointT endpoint(FLAGS_address, FLAGS_port);
client_ = std::make_unique<ClientT>(&context_);
if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) {
throw utils::BasicException("Couldn't connect to server!");
}
client_->Connect(endpoint, FLAGS_username, FLAGS_password);
}
private:
@ -381,9 +378,7 @@ int main(int argc, char **argv) {
EndpointT endpoint(FLAGS_address, FLAGS_port);
ClientContextT context(FLAGS_use_ssl);
ClientT client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
throw utils::BasicException("Couldn't connect to server!");
}
client.Connect(endpoint, FLAGS_username, FLAGS_password);
// cleanup and create indexes
client.Execute("MATCH (n) DETACH DELETE n", {});

View File

@ -629,12 +629,16 @@ int main(int argc, char **argv) {
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, password)) {
// Error message is logged in client.Connect method
std::string bolt_client_version =
fmt::format("mg_client/{}", gflags::VersionString());
try {
client.Connect(endpoint, FLAGS_username, password, bolt_client_version);
} catch (const communication::bolt::ClientFatalException &e) {
EchoFailure("Connection failure", e.what());
return 1;
}
EchoInfo("mg-client");
EchoInfo(fmt::format("mg_client {}", gflags::VersionString()));
EchoInfo("Type :help for shell usage");
EchoInfo("Quit the shell by typing Ctrl-D(eof) or :quit");
EchoInfo(fmt::format("Connected to 'memgraph://{}'", endpoint));
@ -682,9 +686,13 @@ int main(int argc, char **argv) {
client.Close();
while (num_retries > 0) {
--num_retries;
if (client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
try {
client.Connect(endpoint, FLAGS_username, FLAGS_password,
bolt_client_version);
is_connected = true;
break;
} catch (const communication::bolt::ClientFatalException &e) {
EchoFailure("Connection failure", e.what());
}
std::this_thread::sleep_for(std::chrono::seconds(1));
}