memgraph/src/communication/bolt/client.hpp
Matej Ferencevic c507e74384 First version of bolt cpp client.
Reviewers: buda, mislav.bradac

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D689
2017-08-24 17:44:44 +02:00

280 lines
8.0 KiB
C++

#pragma once
#include <glog/logging.h>
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp"
#include "communication/bolt/v1/decoder/decoder.hpp"
#include "communication/bolt/v1/encoder/chunked_encoder_buffer.hpp"
#include "communication/bolt/v1/encoder/client_encoder.hpp"
#include "query/typed_value.hpp"
#include "utils/exceptions.hpp"
namespace communication::bolt {
class ClientException : public utils::BasicException {
using utils::BasicException::BasicException;
};
class ClientSocketException : public ClientException {
public:
using ClientException::ClientException;
ClientSocketException()
: ClientException("Couldn't write/read data to/from the socket!") {}
};
class ClientInvalidDataException : public ClientException {
public:
using ClientException::ClientException;
ClientInvalidDataException()
: ClientException("The server sent invalid data!") {}
};
class ClientQueryException : public ClientException {
public:
using ClientException::ClientException;
ClientQueryException() : ClientException("Couldn't execute query!") {}
};
struct QueryData {
std::vector<std::string> fields;
std::vector<std::vector<DecodedValue>> records;
std::map<std::string, DecodedValue> metadata;
};
template <typename Socket>
class Client {
public:
Client(Socket &&socket, std::string &username, std::string &password,
std::string client_name = "")
: socket_(std::move(socket)) {
DLOG(INFO) << "Sending handshake";
if (!socket_.Write(kPreamble, sizeof(kPreamble))) {
throw ClientSocketException();
}
for (int i = 0; i < 4; ++i) {
if (!socket_.Write(kProtocol, sizeof(kProtocol))) {
throw ClientSocketException();
}
}
DLOG(INFO) << "Reading handshake response";
if (!GetDataByLen(4)) {
throw ClientSocketException();
}
if (memcmp(kProtocol, buffer_.data(), sizeof(kProtocol)) != 0) {
throw ClientException("Server negotiated unsupported protocol version!");
}
buffer_.Shift(sizeof(kProtocol));
if (client_name == "") {
client_name = "memgraph-bolt/0.0.1";
}
DLOG(INFO) << "Sending init message";
if (!encoder_.MessageInit(client_name, {{"scheme", "basic"},
{"principal", username},
{"secret", password}})) {
throw ClientSocketException();
}
DLOG(INFO) << "Reading init message response";
if (!GetDataByChunk()) {
throw ClientSocketException();
}
Signature signature;
DecodedValue metadata;
if (!ReadMessage(&signature, &metadata)) {
throw ClientException("Couldn't read init message response!");
}
if (signature != Signature::Success) {
throw ClientInvalidDataException();
}
}
QueryData Execute(const std::string &query,
const std::map<std::string, DecodedValue> &parameters) {
DLOG(INFO) << "Sending run message with statement: '" << query
<< "'; parameters: " << parameters;
std::map<std::string, query::TypedValue> params_tv(parameters.begin(),
parameters.end());
encoder_.MessageRun(query, params_tv, false);
encoder_.MessagePullAll();
DLOG(INFO) << "Reading run message response";
if (!GetDataByChunk()) {
throw ClientSocketException();
}
Signature signature;
DecodedValue fields;
if (!ReadMessage(&signature, &fields)) {
throw ClientInvalidDataException();
}
if (fields.type() != DecodedValue::Type::Map) {
throw ClientInvalidDataException();
}
if (signature == Signature::Failure) {
HandleFailure();
auto &tmp = fields.ValueMap();
auto it = tmp.find("message");
if (it != tmp.end()) {
throw ClientQueryException(it->second.ValueString());
}
throw ClientQueryException();
} else if (signature != Signature::Success) {
throw ClientInvalidDataException();
}
DLOG(INFO) << "Reading pull_all message response";
Marker marker;
DecodedValue metadata;
std::vector<std::vector<DecodedValue>> records;
while (true) {
if (!GetDataByChunk()) {
throw ClientInvalidDataException();
}
if (!decoder_.ReadMessageHeader(&signature, &marker)) {
throw ClientInvalidDataException();
}
if (signature == Signature::Record) {
DecodedValue record;
if (!decoder_.ReadValue(&record, DecodedValue::Type::List)) {
throw ClientInvalidDataException();
}
records.push_back(record.ValueList());
} else if (signature == Signature::Success) {
if (!decoder_.ReadValue(&metadata)) {
throw ClientInvalidDataException();
}
break;
} else if (signature == Signature::Failure) {
DecodedValue data;
if (!decoder_.ReadValue(&data)) {
throw ClientInvalidDataException();
}
HandleFailure();
auto &tmp = data.ValueMap();
auto it = tmp.find("message");
if (it != tmp.end()) {
throw ClientQueryException(it->second.ValueString());
}
throw ClientQueryException();
} else {
throw ClientInvalidDataException();
}
}
if (metadata.type() != DecodedValue::Type::Map) {
throw ClientInvalidDataException();
}
QueryData ret{{}, records, metadata.ValueMap()};
auto &header = fields.ValueMap();
if (header.find("fields") == header.end()) {
throw ClientInvalidDataException();
}
if (header["fields"].type() != DecodedValue::Type::List) {
throw ClientInvalidDataException();
}
auto &field_vector = header["fields"].ValueList();
for (auto &field_item : field_vector) {
if (field_item.type() != DecodedValue::Type::String) {
throw ClientInvalidDataException();
}
ret.fields.push_back(field_item.ValueString());
}
return ret;
}
void Close() { socket_.Close(); };
~Client() { Close(); }
private:
bool GetDataByLen(uint64_t len) {
while (buffer_.size() < len) {
auto buff = buffer_.Allocate();
int ret = socket_.Read(buff.data, buff.len);
if (ret == -1) return false;
buffer_.Written(ret);
}
return true;
}
bool GetDataByChunk() {
// If there is more data in the buffer then don't read data.
if (decoder_buffer_.Size() > 0) return true;
ChunkState state;
while ((state = decoder_buffer_.GetChunk()) == ChunkState::Partial) {
auto buff = buffer_.Allocate();
int ret = socket_.Read(buff.data, buff.len);
if (ret == -1) return false;
buffer_.Written(ret);
}
if (state == ChunkState::Whole) {
return true;
}
return false;
}
bool ReadMessage(Signature *signature, DecodedValue *ret) {
Marker marker;
if (!decoder_.ReadMessageHeader(signature, &marker)) {
return false;
}
return ReadMessageData(marker, ret);
}
bool ReadMessageData(Marker marker, DecodedValue *ret) {
if (marker == Marker::TinyStruct) {
*ret = DecodedValue();
return true;
} else if (marker == Marker::TinyStruct1) {
return decoder_.ReadValue(ret);
}
return false;
}
void HandleFailure() {
if (!encoder_.MessageAckFailure()) {
throw ClientSocketException();
}
while (true) {
Signature signature;
DecodedValue data;
if (!GetDataByChunk()) {
throw ClientInvalidDataException();
}
if (!ReadMessage(&signature, &data)) {
throw ClientInvalidDataException();
}
if (signature == Signature::Success) {
break;
} else if (signature != Signature::Ignored) {
throw ClientInvalidDataException();
}
}
}
// socket
Socket socket_;
// decoder objects
Buffer<> buffer_;
ChunkedDecoderBuffer decoder_buffer_{buffer_};
Decoder<ChunkedDecoderBuffer> decoder_{decoder_buffer_};
// encoder objects
ChunkedEncoderBuffer<Socket> encoder_buffer_{socket_};
ClientEncoder<ChunkedEncoderBuffer<Socket>> encoder_{encoder_buffer_};
};
}