Fix handling of the ROUTE Bolt message (#475)

The fields of ROUTE message were not read from the input buffer, thus the
input buffer got corrupted. Sending a new message to the server would result
reading the remaining fields from the buffer, which means reading some values
instead of message signature. Because of this unmet expectation, Memgraph closed
the connection. With this fix, the fields of the ROUTE message are properly
read and ignored.
This commit is contained in:
János Benjamin Antal 2022-08-26 13:19:27 +02:00 committed by GitHub
parent d73d153978
commit 0bc298c3ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 708 additions and 434 deletions

View File

@ -7,6 +7,7 @@ set(communication_src_files
websocket/listener.cpp
websocket/session.cpp
bolt/v1/value.cpp
bolt/client.cpp
buffer.cpp
client.cpp
context.cpp

View File

@ -0,0 +1,262 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "communication/bolt/client.hpp"
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/value.hpp"
#include "utils/logging.hpp"
namespace {
constexpr uint8_t kBoltV43Version[4] = {0x00, 0x00, 0x03, 0x04};
constexpr uint8_t kEmptyBoltVersion[4] = {0x00, 0x00, 0x00, 0x00};
} // namespace
namespace memgraph::communication::bolt {
Client::Client(communication::ClientContext &context) : client_{&context} {}
void Client::Connect(const io::network::Endpoint &endpoint, const std::string &username, const std::string &password,
const std::string &client_name) {
if (!client_.Connect(endpoint)) {
throw ClientFatalException("Couldn't connect to {}!", endpoint);
}
if (!client_.Write(kPreamble, sizeof(kPreamble), true)) {
spdlog::error("Couldn't send preamble!");
throw ServerCommunicationException();
}
if (!client_.Write(kBoltV43Version, sizeof(kBoltV43Version), true)) {
spdlog::error("Couldn't send protocol version!");
throw ServerCommunicationException();
}
for (int i = 0; i < 3; ++i) {
if (!client_.Write(kEmptyBoltVersion, sizeof(kEmptyBoltVersion), i != 2)) {
spdlog::error("Couldn't send protocol version!");
throw ServerCommunicationException();
}
}
if (!client_.Read(sizeof(kBoltV43Version))) {
spdlog::error("Couldn't get negotiated protocol version!");
throw ServerCommunicationException();
}
if (memcmp(kBoltV43Version, client_.GetData(), sizeof(kBoltV43Version)) != 0) {
spdlog::error("Server negotiated unsupported protocol version!");
throw ClientFatalException("The server negotiated an usupported protocol version!");
}
client_.ShiftData(sizeof(kBoltV43Version));
if (!encoder_.MessageInit({{"user_agent", client_name},
{"scheme", "basic"},
{"principal", username},
{"credentials", password},
{"routing", {}}})) {
spdlog::error("Couldn't send init message!");
throw ServerCommunicationException();
}
Signature signature{};
Value metadata;
if (!ReadMessage(signature, metadata)) {
spdlog::error("Couldn't read init message response!");
throw ServerCommunicationException();
}
if (signature != Signature::Success) {
spdlog::error("Handshake failed!");
throw ClientFatalException("Handshake with the server failed!");
}
spdlog::debug("Metadata of init message response: {}", metadata);
}
QueryData Client::Execute(const std::string &query, const std::map<std::string, Value> &parameters) {
if (!client_.IsConnected()) {
throw ClientFatalException("You must first connect to the server before using the client!");
}
spdlog::debug("Sending run message with statement: '{}'; parameters: {}", query, parameters);
// It is super critical from performance point of view to send the pull message right after the run message. Otherwise
// the performance will degrade multiple magnitudes.
encoder_.MessageRun(query, parameters, {});
encoder_.MessagePull({});
spdlog::debug("Reading run message response");
Signature signature{};
Value fields;
if (!ReadMessage(signature, fields)) {
throw ServerCommunicationException();
}
if (fields.type() != Value::Type::Map) {
throw ServerMalformedDataException();
}
if (signature == Signature::Failure) {
HandleFailure<ClientQueryException>(fields.ValueMap());
}
if (signature != Signature::Success) {
throw ServerMalformedDataException();
}
spdlog::debug("Reading pull_all message response");
Marker marker{};
Value metadata;
std::vector<std::vector<Value>> records;
while (true) {
if (!GetMessage()) {
throw ServerCommunicationException();
}
if (!decoder_.ReadMessageHeader(&signature, &marker)) {
throw ServerCommunicationException();
}
if (signature == Signature::Record) {
Value record;
if (!decoder_.ReadValue(&record, Value::Type::List)) {
throw ServerCommunicationException();
}
records.emplace_back(std::move(record.ValueList()));
} else if (signature == Signature::Success) {
if (!decoder_.ReadValue(&metadata)) {
throw ServerCommunicationException();
}
break;
} else if (signature == Signature::Failure) {
Value data;
if (!decoder_.ReadValue(&data)) {
throw ServerCommunicationException();
}
HandleFailure<ClientQueryException>(data.ValueMap());
} else {
throw ServerMalformedDataException();
}
}
if (metadata.type() != Value::Type::Map) {
throw ServerMalformedDataException();
}
QueryData ret{{}, std::move(records), std::move(metadata.ValueMap())};
auto &header = fields.ValueMap();
if (header.find("fields") == header.end()) {
throw ServerMalformedDataException();
}
if (header["fields"].type() != Value::Type::List) {
throw ServerMalformedDataException();
}
auto &field_vector = header["fields"].ValueList();
for (auto &field_item : field_vector) {
if (field_item.type() != Value::Type::String) {
throw ServerMalformedDataException();
}
ret.fields.emplace_back(std::move(field_item.ValueString()));
}
return ret;
}
void Client::Reset() {
if (!client_.IsConnected()) {
throw ClientFatalException("You must first connect to the server before using the client!");
}
spdlog::debug("Sending reset message");
encoder_.MessageReset();
Signature signature{};
Value fields;
// In Execute the pull message is sent right after the run message without reading the answer for the run message.
// That means some of the messages sent might get ignored.
while (true) {
if (!ReadMessage(signature, fields)) {
throw ServerCommunicationException();
}
if (signature == Signature::Success) {
break;
}
if (signature != Signature::Ignored) {
throw ServerMalformedDataException();
}
}
}
std::optional<std::map<std::string, Value>> Client::Route(const std::map<std::string, Value> &routing,
const std::vector<Value> &bookmarks,
const std::optional<std::string> &db) {
if (!client_.IsConnected()) {
throw ClientFatalException("You must first connect to the server before using the client!");
}
spdlog::debug("Sending route message with routing: {}; bookmarks: {}; db: {}", routing, bookmarks,
db.has_value() ? *db : Value());
encoder_.MessageRoute(routing, bookmarks, db);
spdlog::debug("Reading route message response");
Signature signature{};
Value fields;
if (!ReadMessage(signature, fields)) {
throw ServerCommunicationException();
}
if (signature == Signature::Ignored) {
return std::nullopt;
}
if (signature == Signature::Failure) {
HandleFailure(fields.ValueMap());
}
if (signature != Signature::Success) {
throw ServerMalformedDataException{};
}
return fields.ValueMap();
}
void Client::Close() { client_.Close(); };
bool Client::GetMessage() {
client_.ClearData();
while (true) {
if (!client_.Read(kChunkHeaderSize)) return false;
size_t chunk_size = client_.GetData()[0];
chunk_size <<= 8U;
chunk_size += client_.GetData()[1];
if (chunk_size == 0) return true;
if (!client_.Read(chunk_size)) return false;
if (decoder_buffer_.GetChunk() != ChunkState::Whole) return false;
client_.ClearData();
}
return true;
}
bool Client::ReadMessage(Signature &signature, Value &ret) {
Marker marker{};
if (!GetMessage()) return false;
if (!decoder_.ReadMessageHeader(&signature, &marker)) return false;
return ReadMessageData(marker, ret);
}
bool Client::ReadMessageData(Marker marker, Value &ret) {
if (marker == Marker::TinyStruct) {
ret = Value();
return true;
}
if (marker == Marker::TinyStruct1) {
return decoder_.ReadValue(&ret);
}
return false;
}
} // namespace memgraph::communication::bolt

View File

@ -11,6 +11,12 @@
#pragma once
#include <map>
#include <optional>
#include <string>
#include <vector>
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp"
#include "communication/bolt/v1/decoder/decoder.hpp"
#include "communication/bolt/v1/encoder/chunked_encoder_buffer.hpp"
@ -19,22 +25,17 @@
#include "communication/context.hpp"
#include "io/network/endpoint.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
namespace memgraph::communication::bolt {
/// This exception is thrown whenever an error occurs during query execution
/// that isn't fatal (eg. mistyped query or some transient error occurred).
/// It should be handled by everyone who uses the client.
class ClientQueryException : public utils::BasicException {
class FailureResponseException : public utils::BasicException {
public:
using utils::BasicException::BasicException;
FailureResponseException() : utils::BasicException{"Couldn't execute query!"} {}
ClientQueryException() : utils::BasicException("Couldn't execute query!") {}
explicit FailureResponseException(const std::string &message) : utils::BasicException{message} {}
template <class... Args>
ClientQueryException(const std::string &code, Args &&...args)
: utils::BasicException(std::forward<Args>(args)...), code_(code) {}
FailureResponseException(const std::string &code, const std::string &message)
: utils::BasicException{message}, code_{code} {}
const std::string &code() const { return code_; }
@ -42,6 +43,14 @@ class ClientQueryException : public utils::BasicException {
std::string code_;
};
/// This exception is thrown whenever an error occurs during query execution
/// that isn't fatal (eg. mistyped query or some transient error occurred).
/// It should be handled by everyone who uses the client.
class ClientQueryException : public FailureResponseException {
public:
using FailureResponseException::FailureResponseException;
};
/// This exception is thrown whenever a fatal error occurs during query
/// execution and/or connecting to the server.
/// It should be handled by everyone who uses the client.
@ -76,12 +85,13 @@ struct QueryData {
/// server. It supports both SSL and plaintext connections.
class Client final {
public:
explicit Client(communication::ClientContext *context) : client_(context) {}
explicit Client(communication::ClientContext &context);
Client(const Client &) = delete;
Client(Client &&) = delete;
Client &operator=(const Client &) = delete;
Client &operator=(Client &&) = delete;
~Client() = default;
/// Method used to connect to the server. Before executing queries this method
/// should be called to set-up the connection to the server. After the
@ -89,50 +99,7 @@ class Client final {
/// established connection.
/// @throws ClientFatalException when we couldn't connect to the server
void Connect(const io::network::Endpoint &endpoint, const std::string &username, const std::string &password,
const std::string &client_name = "memgraph-bolt") {
if (!client_.Connect(endpoint)) {
throw ClientFatalException("Couldn't connect to {}!", endpoint);
}
if (!client_.Write(kPreamble, sizeof(kPreamble), true)) {
SPDLOG_ERROR("Couldn't send preamble!");
throw ServerCommunicationException();
}
for (int i = 0; i < 4; ++i) {
if (!client_.Write(kProtocol, sizeof(kProtocol), i != 3)) {
SPDLOG_ERROR("Couldn't send protocol version!");
throw ServerCommunicationException();
}
}
if (!client_.Read(sizeof(kProtocol))) {
SPDLOG_ERROR("Couldn't get negotiated protocol version!");
throw ServerCommunicationException();
}
if (memcmp(kProtocol, client_.GetData(), sizeof(kProtocol)) != 0) {
SPDLOG_ERROR("Server negotiated unsupported protocol version!");
throw ClientFatalException("The server negotiated an usupported protocol version!");
}
client_.ShiftData(sizeof(kProtocol));
if (!encoder_.MessageInit(client_name, {{"scheme", "basic"}, {"principal", username}, {"credentials", password}})) {
SPDLOG_ERROR("Couldn't send init message!");
throw ServerCommunicationException();
}
Signature signature;
Value metadata;
if (!ReadMessage(&signature, &metadata)) {
SPDLOG_ERROR("Couldn't read init message response!");
throw ServerCommunicationException();
}
if (signature != Signature::Success) {
SPDLOG_ERROR("Handshake failed!");
throw ClientFatalException("Handshake with the server failed!");
}
SPDLOG_INFO("Metadata of init message response: {}", metadata);
}
const std::string &client_name = "memgraph-bolt");
/// Function used to execute queries against the server. Before you can
/// execute queries you must connect the client to the server.
@ -140,168 +107,41 @@ class Client final {
/// executing the query (eg. mistyped query,
/// etc.)
/// @throws ClientFatalException when we couldn't communicate with the server
QueryData Execute(const std::string &query, const std::map<std::string, Value> &parameters) {
if (!client_.IsConnected()) {
throw ClientFatalException("You must first connect to the server before using the client!");
}
SPDLOG_INFO("Sending run message with statement: '{}'; parameters: {}", query, parameters);
encoder_.MessageRun(query, parameters);
encoder_.MessagePullAll();
SPDLOG_INFO("Reading run message response");
Signature signature;
Value fields;
if (!ReadMessage(&signature, &fields)) {
throw ServerCommunicationException();
}
if (fields.type() != Value::Type::Map) {
throw ServerMalformedDataException();
}
if (signature == Signature::Failure) {
HandleFailure();
auto &tmp = fields.ValueMap();
auto it = tmp.find("message");
if (it != tmp.end()) {
auto it_code = tmp.find("code");
if (it_code != tmp.end()) {
throw ClientQueryException(it_code->second.ValueString(), it->second.ValueString());
} else {
throw ClientQueryException("", it->second.ValueString());
}
}
throw ClientQueryException();
} else if (signature != Signature::Success) {
throw ServerMalformedDataException();
}
SPDLOG_INFO("Reading pull_all message response");
Marker marker;
Value metadata;
std::vector<std::vector<Value>> records;
while (true) {
if (!GetMessage()) {
throw ServerCommunicationException();
}
if (!decoder_.ReadMessageHeader(&signature, &marker)) {
throw ServerCommunicationException();
}
if (signature == Signature::Record) {
Value record;
if (!decoder_.ReadValue(&record, Value::Type::List)) {
throw ServerCommunicationException();
}
records.emplace_back(std::move(record.ValueList()));
} else if (signature == Signature::Success) {
if (!decoder_.ReadValue(&metadata)) {
throw ServerCommunicationException();
}
break;
} else if (signature == Signature::Failure) {
Value data;
if (!decoder_.ReadValue(&data)) {
throw ServerCommunicationException();
}
HandleFailure();
auto &tmp = data.ValueMap();
auto it = tmp.find("message");
if (it != tmp.end()) {
auto it_code = tmp.find("code");
if (it_code != tmp.end()) {
throw ClientQueryException(it_code->second.ValueString(), it->second.ValueString());
} else {
throw ClientQueryException("", it->second.ValueString());
}
}
throw ClientQueryException();
} else {
throw ServerMalformedDataException();
}
}
if (metadata.type() != Value::Type::Map) {
throw ServerMalformedDataException();
}
QueryData ret{{}, std::move(records), std::move(metadata.ValueMap())};
auto &header = fields.ValueMap();
if (header.find("fields") == header.end()) {
throw ServerMalformedDataException();
}
if (header["fields"].type() != Value::Type::List) {
throw ServerMalformedDataException();
}
auto &field_vector = header["fields"].ValueList();
for (auto &field_item : field_vector) {
if (field_item.type() != Value::Type::String) {
throw ServerMalformedDataException();
}
ret.fields.emplace_back(std::move(field_item.ValueString()));
}
return ret;
}
QueryData Execute(const std::string &query, const std::map<std::string, Value> &parameters);
/// Close the active client connection.
void Close() { client_.Close(); };
void Close();
/// Can be used to reset the active client connection. Reset is automatically sent after receiving a failure message
/// from the server, which result in throwing an FailureResponseException or any exception derived from it.
void Reset();
/// Can be used to send a route message.
std::optional<std::map<std::string, Value>> Route(const std::map<std::string, Value> &routing,
const std::vector<Value> &bookmarks,
const std::optional<std::string> &db);
private:
bool GetMessage() {
client_.ClearData();
while (true) {
if (!client_.Read(kChunkHeaderSize)) return false;
using ClientEncoder = ClientEncoder<ChunkedEncoderBuffer<communication::ClientOutputStream>>;
size_t chunk_size = client_.GetData()[0];
chunk_size <<= 8;
chunk_size += client_.GetData()[1];
if (chunk_size == 0) return true;
if (!client_.Read(chunk_size)) return false;
if (decoder_buffer_.GetChunk() != ChunkState::Whole) return false;
client_.ClearData();
}
return true;
}
bool ReadMessage(Signature *signature, Value *ret) {
Marker marker;
if (!GetMessage()) return false;
if (!decoder_.ReadMessageHeader(signature, &marker)) return false;
return ReadMessageData(marker, ret);
}
bool ReadMessageData(Marker marker, Value *ret) {
if (marker == Marker::TinyStruct) {
*ret = Value();
return true;
} else if (marker == Marker::TinyStruct1) {
return decoder_.ReadValue(ret);
}
return false;
}
void HandleFailure() {
if (!encoder_.MessageAckFailure()) {
throw ServerCommunicationException();
}
while (true) {
Signature signature;
Value data;
if (!ReadMessage(&signature, &data)) {
throw ServerCommunicationException();
}
if (signature == Signature::Success) {
break;
} else if (signature != Signature::Ignored) {
throw ServerMalformedDataException();
template <typename TException = FailureResponseException>
[[noreturn]] void HandleFailure(const std::map<std::string, Value> &response_map) {
Reset();
auto it = response_map.find("message");
if (it != response_map.end()) {
auto it_code = response_map.find("code");
if (it_code != response_map.end()) {
throw TException(it_code->second.ValueString(), it->second.ValueString());
}
throw TException("", it->second.ValueString());
}
throw TException();
}
bool GetMessage();
bool ReadMessage(Signature &signature, Value &ret);
bool ReadMessageData(Marker marker, Value &ret);
// client
communication::Client client_;
communication::ClientInputStream input_stream_{client_};
@ -313,6 +153,6 @@ class Client final {
// encoder objects
ChunkedEncoderBuffer<communication::ClientOutputStream> encoder_buffer_{output_stream_};
ClientEncoder<ChunkedEncoderBuffer<communication::ClientOutputStream>> encoder_{encoder_buffer_};
ClientEncoder encoder_{encoder_buffer_};
};
} // namespace memgraph::communication::bolt

View File

@ -16,7 +16,6 @@
namespace memgraph::communication::bolt {
inline constexpr uint8_t kPreamble[4] = {0x60, 0x60, 0xB0, 0x17};
inline constexpr uint8_t kProtocol[4] = {0x00, 0x00, 0x00, 0x01};
enum class Signature : uint8_t {
Noop = 0x00,

View File

@ -11,6 +11,11 @@
#pragma once
#include <map>
#include <optional>
#include <string>
#include <vector>
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/encoder/base_encoder.hpp"
@ -30,6 +35,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
using BaseEncoder<Buffer>::WriteList;
using BaseEncoder<Buffer>::WriteMap;
using BaseEncoder<Buffer>::WriteString;
using BaseEncoder<Buffer>::WriteNull;
using BaseEncoder<Buffer>::buffer_;
public:
@ -38,10 +44,9 @@ class ClientEncoder : private BaseEncoder<Buffer> {
/**
* Writes a Init message.
*
* From the Bolt v1 documentation:
* InitMessage (signature=0x01) {
* String clientName
* Map<String,Value> authToken
* From the Bolt v4.3 documentation:
* HelloMess (signature=0x01) {
* Map<String,Value> extra
* }
*
* @param client_name the name of the connected client
@ -49,11 +54,10 @@ class ClientEncoder : private BaseEncoder<Buffer> {
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
bool MessageInit(const std::string client_name, const std::map<std::string, Value> &auth_token) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct2));
bool MessageInit(const std::map<std::string, Value> &extra) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct1));
WriteRAW(utils::UnderlyingCast(Signature::Init));
WriteString(client_name);
WriteMap(auth_token);
WriteMap(extra);
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;
@ -64,10 +68,11 @@ class ClientEncoder : private BaseEncoder<Buffer> {
/**
* Writes a Run message.
*
* From the Bolt v1 documentation:
* From the Bolt v4.3 documentation:
* RunMessage (signature=0x10) {
* String statement
* Map<String,Value> parameters
* String statement
* Map<String,Value> parameters
* Map<String,Value> extra
* }
*
* @param statement the statement that should be executed
@ -75,11 +80,13 @@ class ClientEncoder : private BaseEncoder<Buffer> {
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
bool MessageRun(const std::string &statement, const std::map<std::string, Value> &parameters, bool have_more = true) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct2));
bool MessageRun(const std::string &statement, const std::map<std::string, Value> &parameters,
const std::map<std::string, Value> &extra, bool have_more = true) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct3));
WriteRAW(utils::UnderlyingCast(Signature::Run));
WriteString(statement);
WriteMap(parameters);
WriteMap(extra);
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;
@ -90,18 +97,20 @@ class ClientEncoder : private BaseEncoder<Buffer> {
}
/**
* Writes a DiscardAll message.
* Writes a Discard message.
*
* From the Bolt v1 documentation:
* From the Bolt v4.3 documentation:
* DiscardMessage (signature=0x2F) {
* Map<String,Value> extra
* }
*
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
bool MessageDiscardAll() {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct));
bool MessageDiscard(const std::map<std::string, Value> &extra) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct1));
WriteRAW(utils::UnderlyingCast(Signature::Discard));
WriteMap(extra);
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;
@ -112,36 +121,18 @@ class ClientEncoder : private BaseEncoder<Buffer> {
/**
* Writes a PullAll message.
*
* From the Bolt v1 documentation:
* PullAllMessage (signature=0x3F) {
* From the Bolt v4.3 documentation:
* PullMessage (signature=0x3F) {
* Map<String,Value> extra
* }
*
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
bool MessagePullAll() {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct));
bool MessagePull(const std::map<std::string, Value> &extra) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct1));
WriteRAW(utils::UnderlyingCast(Signature::Pull));
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;
// Flush an empty chunk to indicate that the message is done.
return buffer_.Flush();
}
/**
* Writes a AckFailure message.
*
* From the Bolt v1 documentation:
* AckFailureMessage (signature=0x0E) {
* }
*
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
bool MessageAckFailure() {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct));
WriteRAW(utils::UnderlyingCast(Signature::AckFailure));
WriteMap(extra);
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;
@ -152,7 +143,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
/**
* Writes a Reset message.
*
* From the Bolt v1 documentation:
* From the Bolt v4.3 documentation:
* ResetMessage (signature=0x0F) {
* }
*
@ -168,5 +159,36 @@ class ClientEncoder : private BaseEncoder<Buffer> {
// Flush an empty chunk to indicate that the message is done.
return buffer_.Flush();
}
/**
* Writes a Route message.
*
* From the Bolt v4.3 documentation:
* RouteMessage (signature=0x0F) {
* Map<String,Value> routing
* List<String> bookmarks
* String db
* }
*
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
bool MessageRoute(const std::map<std::string, Value> &routing, const std::vector<Value> &bookmarks,
const std::optional<std::string> &db) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct3));
WriteRAW(utils::UnderlyingCast(Signature::Route));
WriteMap(routing);
WriteList(bookmarks);
if (db.has_value()) {
WriteString(*db);
} else {
WriteNull();
}
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;
// Flush an empty chunk to indicate that the message is done.
return buffer_.Flush();
}
};
} // namespace memgraph::communication::bolt

View File

@ -117,29 +117,6 @@ class Encoder : private BaseEncoder<Buffer> {
return buffer_.Flush();
}
/**
* Sends an Ignored message.
*
* From the bolt v1 documentation:
* IgnoredMessage (signature=0x7E) {
* Map<String,Value> metadata
* }
*
* @param metadata the metadata map object that should be sent
* @returns true if the data was successfully sent to the client,
* false otherwise
*/
bool MessageIgnored(const std::map<std::string, Value> &metadata) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct1));
WriteRAW(utils::UnderlyingCast(Signature::Ignored));
WriteMap(metadata);
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;
// Flush an empty chunk to indicate that the message is done.
return buffer_.Flush();
}
/**
* Sends an Ignored message.
*

View File

@ -15,6 +15,7 @@
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/states/handlers.hpp"
#include "communication/bolt/v1/value.hpp"
#include "utils/cast.hpp"
#include "utils/likely.hpp"
@ -30,8 +31,8 @@ namespace memgraph::communication::bolt {
*/
template <typename TSession>
State StateErrorRun(TSession &session, State state) {
Marker marker;
Signature signature;
Marker marker{};
Signature signature{};
if (!session.decoder_.ReadMessageHeader(&signature, &marker)) {
spdlog::trace("Missing header data!");
return State::Close;
@ -45,54 +46,49 @@ State StateErrorRun(TSession &session, State state) {
// Clear the data buffer if it has any leftover data.
session.encoder_buffer_.Clear();
if ((session.version_.major == 1 && signature == Signature::AckFailure) || signature == Signature::Reset) {
if (signature == Signature::AckFailure) {
spdlog::trace("AckFailure received");
} else {
spdlog::trace("Reset received");
}
if (session.version_.major == 1 && signature == Signature::AckFailure) {
spdlog::trace("AckFailure received");
if (!session.encoder_.MessageSuccess()) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
if (signature == Signature::Reset) {
session.Abort();
return State::Idle;
}
// We got AckFailure get back to right state.
MG_ASSERT(state == State::Error, "Shouldn't happen");
return State::Idle;
} else {
uint8_t value = utils::UnderlyingCast(marker);
// All bolt client messages have less than 15 parameters so if we receive
// anything than a TinyStruct it's an error.
if ((value & 0xF0) != utils::UnderlyingCast(Marker::TinyStruct)) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", value);
return State::Close;
}
// We need to clean up all parameters from this command.
value &= 0x0F; // The length is stored in the lower nibble.
Value dv;
for (int i = 0; i < value; ++i) {
if (!session.decoder_.ReadValue(&dv)) {
spdlog::trace("Couldn't clean up parameter {} / {}!", i, value);
return State::Close;
}
}
// Ignore this message.
if (!session.encoder_.MessageIgnored()) {
spdlog::trace("Couldn't send ignored message!");
return State::Close;
}
// Cleanup done, command ignored, stay in error state.
return state;
}
if (signature == Signature::Reset) {
spdlog::trace("Reset received");
return HandleReset(session, marker);
}
uint8_t value = utils::UnderlyingCast(marker);
// All bolt client messages have less than 15 parameters so if we receive
// anything than a TinyStruct it's an error.
if ((value & 0xF0U) != utils::UnderlyingCast(Marker::TinyStruct)) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", value);
return State::Close;
}
// We need to clean up all parameters from this command.
value &= 0x0FU; // The length is stored in the lower nibble.
Value dv;
for (int i = 0; i < value; ++i) {
if (!session.decoder_.ReadValue(&dv)) {
spdlog::trace("Couldn't clean up parameter {} / {}!", i, value);
return State::Close;
}
}
// Ignore this message.
if (!session.encoder_.MessageIgnored()) {
spdlog::trace("Couldn't send ignored message!");
return State::Close;
}
// Cleanup done, command ignored, stay in error state.
return state;
}
} // namespace memgraph::communication::bolt

View File

@ -74,7 +74,7 @@ State RunHandlerV4(Signature signature, TSession &session, State state, Marker m
}
case Signature::Route: {
if constexpr (bolt_minor >= 3) {
if (signature == Signature::Route) return HandleRoute<TSession>(session);
if (signature == Signature::Route) return HandleRoute<TSession>(session, marker);
} else {
spdlog::trace("Supported only in bolt v4.3");
return State::Close;

View File

@ -18,6 +18,7 @@
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/exceptions.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/value.hpp"
#include "communication/exceptions.hpp"
@ -136,7 +137,7 @@ template <bool is_pull, typename TSession>
State HandlePullDiscardV1(TSession &session, const State state, const Marker marker) {
const auto expected_marker = Marker::TinyStruct;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct", utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
return State::Close;
}
@ -157,7 +158,7 @@ template <bool is_pull, typename TSession>
State HandlePullDiscardV4(TSession &session, const State state, const Marker marker) {
const auto expected_marker = Marker::TinyStruct1;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct1", utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
return State::Close;
}
@ -216,7 +217,8 @@ State HandleRunV1(TSession &session, const State state, const Marker marker) {
session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3", utils::UnderlyingCast(marker));
return State::Close;
}
Value query, params;
Value query;
Value params;
if (!session.decoder_.ReadValue(&query, Value::Type::String)) {
spdlog::trace("Couldn't read query string!");
return State::Close;
@ -234,10 +236,12 @@ template <typename TSession>
State HandleRunV4(TSession &session, const State state, const Marker marker) {
const auto expected_marker = Marker::TinyStruct3;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct3", utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct3 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
return State::Close;
}
Value query, params, extra;
Value query;
Value params;
Value extra;
if (!session.decoder_.ReadValue(&query, Value::Type::String)) {
spdlog::trace("Couldn't read query string!");
return State::Close;
@ -292,9 +296,6 @@ State HandleReset(TSession &session, const Marker marker) {
return State::Close;
}
// Clear all pending data and send a success message.
session.encoder_buffer_.Clear();
if (!session.encoder_.MessageSuccess()) {
spdlog::trace("Couldn't send success message!");
return State::Close;
@ -403,12 +404,33 @@ State HandleGoodbye() {
}
template <typename TSession>
State HandleRoute(TSession &session) {
// Route message is not implemented since it is neo4j specific, therefore we
// will receive it an inform user that there is no implementation.
State HandleRoute(TSession &session, const Marker marker) {
// Route message is not implemented since it is Neo4j specific, therefore we will receive it and inform user that
// there is no implementation. Before that, we have to read out the fields from the buffer to leave it in a clean
// state.
if (marker != Marker::TinyStruct3) {
spdlog::trace("Expected TinyStruct3 marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
Value routing;
if (!session.decoder_.ReadValue(&routing, Value::Type::Map)) {
spdlog::trace("Couldn't read routing field!");
return State::Close;
}
Value bookmarks;
if (!session.decoder_.ReadValue(&bookmarks, Value::Type::List)) {
spdlog::trace("Couldn't read bookmarks field!");
return State::Close;
}
Value db;
if (!session.decoder_.ReadValue(&db)) {
spdlog::trace("Couldn't read db field!");
return State::Close;
}
session.encoder_buffer_.Clear();
bool fail_sent =
session.encoder_.MessageFailure({{"code", 66}, {"message", "Route message not supported in Memgraph!"}});
session.encoder_.MessageFailure({{"code", "66"}, {"message", "Route message is not supported in Memgraph!"}});
if (!fail_sent) {
spdlog::trace("Couldn't send failure message!");
return State::Close;

View File

@ -22,13 +22,13 @@ static void ReturnFunctionArgument(struct mgp_list *args, mgp_func_context *ctx,
mgp_value *value{nullptr};
auto err_code = mgp_list_at(args, 0, &value);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory));
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory));
return;
}
}
@ -38,13 +38,13 @@ static void ReturnOptionalArgument(struct mgp_list *args, mgp_func_context *ctx,
mgp_value *value{nullptr};
auto err_code = mgp_list_at(args, 0, &value);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory));
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory));
return;
}
}
@ -57,14 +57,14 @@ double GetElementFromArg(struct mgp_list *args, int index) {
double result;
int is_int;
mgp_value_is_int(value, &is_int);
static_cast<void>(mgp_value_is_int(value, &is_int));
if (is_int) {
int64_t result_int;
mgp_value_get_int(value, &result_int);
static_cast<void>(mgp_value_get_int(value, &result_int));
result = static_cast<double>(result_int);
} else {
mgp_value_get_double(value, &result);
static_cast<void>(mgp_value_get_double(value, &result));
}
return result;
}
@ -77,30 +77,30 @@ static void AddTwoNumbers(struct mgp_list *args, mgp_func_context *ctx, mgp_func
first = GetElementFromArg(args, 0);
second = GetElementFromArg(args, 1);
} catch (...) {
mgp_func_result_set_error_msg(result, "Unable to fetch the result!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Unable to fetch the result!", memory));
return;
}
mgp_value *value{nullptr};
auto summation = first + second;
mgp_value_make_double(summation, memory, &value);
static_cast<void>(mgp_value_make_double(summation, memory, &value));
memgraph::utils::OnScopeExit delete_summation_value([&value] { mgp_value_destroy(value); });
auto err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory));
}
}
static void ReturnNull(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
mgp_value_make_null(memory, &value);
static_cast<void>(mgp_value_make_null(memory, &value));
memgraph::utils::OnScopeExit delete_null([&value] { mgp_value_destroy(value); });
auto err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory));
}
}
} // namespace
@ -116,7 +116,7 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem
}
mgp_type *type_any{nullptr};
mgp_type_any(&type_any);
static_cast<void>(mgp_type_any(&type_any));
err_code = mgp_func_add_arg(func, "argument", type_any);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;
@ -131,11 +131,11 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem
}
mgp_value *default_value{nullptr};
mgp_value_make_int(42, memory, &default_value);
static_cast<void>(mgp_value_make_int(42, memory, &default_value));
memgraph::utils::OnScopeExit delete_summation_value([&default_value] { mgp_value_destroy(default_value); });
mgp_type *type_int{nullptr};
mgp_type_int(&type_int);
static_cast<void>(mgp_type_int(&type_int));
err_code = mgp_func_add_opt_arg(func, "opt_argument", type_int, default_value);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;
@ -150,7 +150,7 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem
}
mgp_type *type_number{nullptr};
mgp_type_number(&type_number);
static_cast<void>(mgp_type_number(&type_number));
err_code = mgp_func_add_arg(func, "first", type_number);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;

View File

@ -15,25 +15,25 @@ static void TryToWrite(struct mgp_list *args, mgp_func_context *ctx, mgp_func_re
struct mgp_memory *memory) {
mgp_value *value{nullptr};
mgp_vertex *vertex{nullptr};
mgp_list_at(args, 0, &value);
mgp_value_get_vertex(value, &vertex);
static_cast<void>(mgp_list_at(args, 0, &value));
static_cast<void>(mgp_value_get_vertex(value, &vertex));
const char *name;
mgp_list_at(args, 1, &value);
mgp_value_get_string(value, &name);
static_cast<void>(mgp_list_at(args, 1, &value));
static_cast<void>(mgp_value_get_string(value, &name));
mgp_list_at(args, 2, &value);
static_cast<void>(mgp_list_at(args, 2, &value));
// Setting a property should set an error
auto err_code = mgp_vertex_set_property(vertex, name, value);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Cannot set property in the function!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Cannot set property in the function!", memory));
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
static_cast<void>(mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory));
return;
}
}
@ -49,23 +49,23 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem
}
mgp_type *type_vertex{nullptr};
mgp_type_node(&type_vertex);
static_cast<void>(mgp_type_node(&type_vertex));
err_code = mgp_func_add_arg(func, "argument", type_vertex);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_string{nullptr};
mgp_type_string(&type_string);
static_cast<void>(mgp_type_string(&type_string));
err_code = mgp_func_add_arg(func, "name", type_string);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *any_type{nullptr};
mgp_type_any(&any_type);
static_cast<void>(mgp_type_any(&any_type));
mgp_type *nullable_type{nullptr};
mgp_type_nullable(any_type, &nullable_type);
static_cast<void>(mgp_type_nullable(any_type, &nullable_type));
err_code = mgp_func_add_arg(func, "value", nullable_type);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;

View File

@ -1,10 +1,21 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "mg_procedure.h"
int *gVal = NULL;
void set_error(struct mgp_result *result) { mgp_result_set_error_msg(result, "Something went wrong"); }
static void procedure(const struct mgp_list *args, const struct mgp_graph *graph, struct mgp_result *result,
static void procedure(struct mgp_list *args, struct mgp_graph *graph, struct mgp_result *result,
struct mgp_memory *memory) {
struct mgp_result_record *record = NULL;
const enum mgp_error new_record_err = mgp_result_new_record(result, &record);
@ -21,14 +32,14 @@ static void procedure(const struct mgp_list *args, const struct mgp_graph *graph
int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
const size_t one_gb = 1 << 30;
const enum mgp_error alloc_err = mgp_global_alloc(one_gb, &gVal);
const enum mgp_error alloc_err = mgp_global_alloc(one_gb, (void **)(&gVal));
if (alloc_err != MGP_ERROR_NO_ERROR) return 1;
struct mgp_proc *proc = NULL;
const enum mgp_error proc_err = mgp_module_add_read_procedure(module, "procedure", procedure, &proc);
if (proc_err != MGP_ERROR_NO_ERROR) return 1;
const struct mgp_type *string_type = NULL;
struct mgp_type *string_type = NULL;
const enum mgp_error string_type_err = mgp_type_string(&string_type);
if (string_type_err != MGP_ERROR_NO_ERROR) return 1;
if (mgp_proc_add_result(proc, "result", string_type) != MGP_ERROR_NO_ERROR) return 1;

View File

@ -1,3 +1,14 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "mg_procedure.h"
int *gVal = NULL;
@ -6,7 +17,7 @@ void set_error(struct mgp_result *result) { mgp_result_set_error_msg(result, "So
void set_out_of_memory_error(struct mgp_result *result) { mgp_result_set_error_msg(result, "Out of memory"); }
static void error(const struct mgp_list *args, const struct mgp_graph *graph, struct mgp_result *result,
static void error(struct mgp_list *args, struct mgp_graph *graph, struct mgp_result *result,
struct mgp_memory *memory) {
const size_t one_gb = 1 << 30;
if (gVal) {
@ -14,7 +25,7 @@ static void error(const struct mgp_list *args, const struct mgp_graph *graph, st
gVal = NULL;
}
if (!gVal) {
const enum mgp_error err = mgp_global_alloc(one_gb, &gVal);
const enum mgp_error err = mgp_global_alloc(one_gb, (void **)(&gVal));
if (err == MGP_ERROR_UNABLE_TO_ALLOCATE) return set_out_of_memory_error(result);
if (err != MGP_ERROR_NO_ERROR) return set_error(result);
}
@ -29,11 +40,11 @@ static void error(const struct mgp_list *args, const struct mgp_graph *graph, st
if (result_inserted != MGP_ERROR_NO_ERROR) return set_error(result);
}
static void success(const struct mgp_list *args, const struct mgp_graph *graph, struct mgp_result *result,
static void success(struct mgp_list *args, struct mgp_graph *graph, struct mgp_result *result,
struct mgp_memory *memory) {
const size_t bytes = 1024;
if (!gVal) {
const enum mgp_error err = mgp_global_alloc(bytes, &gVal);
const enum mgp_error err = mgp_global_alloc(bytes, (void **)(&gVal));
if (err == MGP_ERROR_UNABLE_TO_ALLOCATE) return set_out_of_memory_error(result);
if (err != MGP_ERROR_NO_ERROR) return set_error(result);
}

View File

@ -86,7 +86,7 @@ int main(int argc, char **argv) {
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(&context);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
client.Execute(FLAGS_query, JsonToValue(nlohmann::json::parse(FLAGS_params_json)).ValueMap());

View File

@ -33,7 +33,7 @@ int main(int argc, char **argv) {
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(&context);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);

View File

@ -38,7 +38,7 @@ int main(int argc, char **argv) {
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(&context);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);

View File

@ -37,7 +37,7 @@ int main(int argc, char **argv) {
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(&context);
memgraph::communication::bolt::Client client(context);
{
std::string what;

View File

@ -36,7 +36,7 @@ int main(int argc, char **argv) {
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(&context);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
auto ret = client.Execute("DUMP DATABASE", {});

View File

@ -65,7 +65,7 @@ class BoltClient : public ::testing::Test {
memgraph::io::network::Endpoint endpoint_{memgraph::io::network::ResolveHostname(FLAGS_address),
static_cast<uint16_t>(FLAGS_port)};
memgraph::communication::ClientContext context_{FLAGS_use_ssl};
Client client_{&context_};
Client client_{context_};
};
const std::string kNoCurrentTransactionToCommit = "No current transaction to commit.";
@ -100,20 +100,20 @@ TEST_F(BoltClient, DoubleRollbackWithoutTransaction) {
TEST_F(BoltClient, DoubleBegin) {
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, DoubleBeginAndCommit) {
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(Execute("commit"));
EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, DoubleBeginAndRollback) {
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(Execute("rollback"));
EXPECT_THROW(Execute("rollback", kNoCurrentTransactionToRollback), ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
@ -157,30 +157,29 @@ TEST_F(BoltClient, BeginAndCorrectQueriesAndBegin) {
EXPECT_TRUE(Execute("create (n)"));
ASSERT_EQ(GetCount(), count + 1);
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_EQ(GetCount(), count + 1);
EXPECT_TRUE(TransactionActive());
EXPECT_EQ(GetCount(), count);
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, BeginAndWrongQueryAndRollback) {
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
EXPECT_TRUE(Execute("rollback"));
EXPECT_THROW(Execute("rollback", kNoCurrentTransactionToRollback), ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, BeginAndWrongQueryAndCommit) {
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
EXPECT_THROW(Execute("commit", kCommitInvalid), ClientQueryException);
EXPECT_TRUE(Execute("rollback"));
EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, BeginAndWrongQueryAndBegin) {
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
EXPECT_THROW(Execute("commit", kCommitInvalid), ClientQueryException);
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException);
EXPECT_TRUE(Execute("begin"));
EXPECT_TRUE(TransactionActive());
}
@ -230,7 +229,7 @@ TEST_F(BoltClient, CorrectQueryAndBeginAndBegin) {
EXPECT_TRUE(Execute("match (n) return count(n)"));
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, WrongQueryAndBeginAndCommit) {
@ -251,7 +250,7 @@ TEST_F(BoltClient, WrongQueryAndBeginAndBegin) {
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, CorrectQueriesAndBeginAndCommit) {
@ -278,7 +277,7 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndBegin) {
}
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, WrongQueriesAndBeginAndCommit) {
@ -305,7 +304,7 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndBegin) {
}
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, CorrectQueriesAndBeginAndCorrectQueriesAndCommit) {
@ -341,7 +340,7 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndCorrectQueriesAndBegin) {
EXPECT_TRUE(Execute("match (n) return count(n)"));
}
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, WrongQueriesAndBeginAndCorrectQueriesAndCommit) {
@ -377,7 +376,7 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndCorrectQueriesAndBegin) {
EXPECT_TRUE(Execute("match (n) return count(n)"));
}
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndCommit) {
@ -388,8 +387,8 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndCommit) {
for (int i = 0; i < 3; ++i) {
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
}
EXPECT_THROW(Execute("commit", kCommitInvalid), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndRollback) {
@ -400,7 +399,7 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndRollback) {
for (int i = 0; i < 3; ++i) {
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
}
EXPECT_TRUE(Execute("rollback"));
EXPECT_THROW(Execute("rollback", kNoCurrentTransactionToRollback), ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
@ -412,7 +411,7 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndBegin) {
for (int i = 0; i < 3; ++i) {
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
}
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(Execute("begin"));
EXPECT_TRUE(TransactionActive());
}
@ -424,8 +423,8 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndWrongQueriesAndCommit) {
for (int i = 0; i < 3; ++i) {
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
}
EXPECT_THROW(Execute("commit", kCommitInvalid), ClientQueryException);
EXPECT_TRUE(TransactionActive());
EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, WrongQueriesAndBeginAndWrongQueriesAndRollback) {
@ -436,7 +435,7 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndWrongQueriesAndRollback) {
for (int i = 0; i < 3; ++i) {
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
}
EXPECT_TRUE(Execute("rollback"));
EXPECT_THROW(Execute("rollback", kNoCurrentTransactionToRollback), ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
@ -448,7 +447,7 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndWrongQueriesAndBegin) {
for (int i = 0; i < 3; ++i) {
EXPECT_THROW(Execute("asdasd"), ClientQueryException);
}
EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException);
EXPECT_TRUE(Execute("begin"));
EXPECT_TRUE(TransactionActive());
}

View File

@ -117,7 +117,7 @@ int main(int argc, char **argv) {
Endpoint endpoint(FLAGS_address, FLAGS_port);
ClientContext context(FLAGS_use_ssl);
Client client(&context);
Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
std::vector<std::unique_ptr<TestClient>> clients;

View File

@ -317,7 +317,7 @@ int main(int argc, char **argv) {
Endpoint endpoint(FLAGS_address, FLAGS_port);
ClientContext context(FLAGS_use_ssl);
Client client(&context);
Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
num_pos.store(NumNodesWithLabel(client, "Pos"));

View File

@ -118,7 +118,7 @@ class TestClient {
private:
memgraph::communication::ClientContext context_{FLAGS_use_ssl};
Client client_{&context_};
Client client_{context_};
};
void RunMultithreadedTest(std::vector<std::unique_ptr<TestClient>> &clients) {

View File

@ -261,7 +261,7 @@ int main(int argc, char **argv) {
auto independent_nodes_ids = [&] {
Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
ClientContext context(FLAGS_use_ssl);
Client client(&context);
Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
return IndependentSet(client, INDEPENDENT_LABEL);
}();

View File

@ -67,7 +67,7 @@ void ExecuteQueries(const std::vector<std::string> &queries, std::ostream &ostre
threads.push_back(std::thread([&]() {
Endpoint endpoint(FLAGS_address, FLAGS_port);
ClientContext context(FLAGS_use_ssl);
Client client(&context);
Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
std::string str;

View File

@ -32,7 +32,7 @@ int main(int argc, char **argv) {
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(&context);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);

View File

@ -177,7 +177,7 @@ void Execute(
threads.push_back(std::thread([&, worker]() {
memgraph::io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(&context);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
ready.fetch_add(1, std::memory_order_acq_rel);

View File

@ -66,7 +66,7 @@ class GraphSession {
}
EndpointT endpoint(FLAGS_address, FLAGS_port);
client_ = std::make_unique<ClientT>(&context_);
client_ = std::make_unique<ClientT>(context_);
client_->Connect(endpoint, FLAGS_username, FLAGS_password);
}
@ -387,7 +387,7 @@ int main(int argc, char **argv) {
// create client
EndpointT endpoint(FLAGS_address, FLAGS_port);
ClientContextT context(FLAGS_use_ssl);
ClientT client(&context);
ClientT client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
// cleanup and create indexes

View File

@ -1,4 +1,4 @@
// Copyright 2021 Memgraph Ltd.
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -98,15 +98,19 @@ void PrintOutput(std::vector<uint8_t> &output) {
* TODO (mferencevic): document
*/
void CheckOutput(std::vector<uint8_t> &output, const uint8_t *data, uint64_t len, bool clear = true) {
if (clear)
if (clear) {
ASSERT_EQ(len, output.size());
else
} else {
ASSERT_LE(len, output.size());
for (size_t i = 0; i < len; ++i) EXPECT_EQ(output[i], data[i]);
if (clear)
}
for (size_t i = 0; i < len; ++i) {
EXPECT_EQ(output[i], data[i]) << i;
}
if (clear) {
output.clear();
else
} else {
output.erase(output.begin(), output.begin() + len);
}
}
/**

View File

@ -9,6 +9,8 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <string>
#include <gflags/gflags.h>
#include "bolt_common.hpp"
@ -151,7 +153,10 @@ inline constexpr uint8_t noop[] = {0x00, 0x00};
} // namespace v4_1
namespace v4_3 {
inline constexpr uint8_t route[]{0xb0, 0x60};
inline constexpr uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x03, 0x04, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
inline constexpr uint8_t handshake_resp[] = {0x00, 0x00, 0x03, 0x04};
inline constexpr uint8_t route[]{0xb3, 0x66, 0xa0, 0x90, 0xc0};
} // namespace v4_3
// Write bolt chunk header (length)
@ -705,67 +710,89 @@ TEST(BoltSession, ErrorWrongMarker) {
}
TEST(BoltSession, ErrorOK) {
// v1
{
SCOPED_TRACE("v1");
// test ACK_FAILURE and RESET
const uint8_t *dataset[] = {ackfailure_req, reset_req};
for (int i = 0; i < 2; ++i) {
SCOPED_TRACE("i: " + std::to_string(i));
// first test with socket write success, then with socket write fail
for (int j = 0; j < 2; ++j) {
SCOPED_TRACE("j: " + std::to_string(j));
const auto write_success = j == 0;
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
ASSERT_EQ(session.version_.major, 1U);
ExecuteInit(input_stream, session, output);
WriteRunRequest(input_stream, kInvalidQuery);
session.Execute();
output.clear();
output_stream.SetWriteSuccess(j == 0);
if (j == 0) {
output_stream.SetWriteSuccess(write_success);
if (write_success) {
ExecuteCommand(input_stream, session, dataset[i], 2);
} else {
ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2), SessionException);
}
// assert that all data from the init message was cleaned up
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
EXPECT_EQ(session.decoder_buffer_.Size(), 0);
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
if (write_success) {
EXPECT_EQ(session.state_, State::Idle);
CheckOutput(output, success_resp, sizeof(success_resp));
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_EQ(output.size(), 0);
EXPECT_EQ(session.state_, State::Close);
EXPECT_EQ(output.size(), 0);
}
}
}
}
// v4+
{
SCOPED_TRACE("v4");
const uint8_t *dataset[] = {ackfailure_req, v4::reset_req};
for (int i = 0; i < 2; ++i) {
INIT_VARS;
SCOPED_TRACE("i: " + std::to_string(i));
// first test with socket write success, then with socket write fail
for (int j = 0; j < 2; ++j) {
SCOPED_TRACE("j: " + std::to_string(j));
const auto write_success = j == 0;
const auto is_reset = i == 1;
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp);
ExecuteInit(input_stream, session, output, true);
ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp);
ASSERT_EQ(session.version_.major, 4U);
ExecuteInit(input_stream, session, output, true);
WriteRunRequest(input_stream, kInvalidQuery, true);
session.Execute();
WriteRunRequest(input_stream, kInvalidQuery, true);
session.Execute();
output.clear();
output.clear();
output_stream.SetWriteSuccess(write_success);
ExecuteCommand(input_stream, session, dataset[i], 2);
// ACK_FAILURE does not exist in v3+, ingored message is sent
if (write_success) {
ExecuteCommand(input_stream, session, dataset[i], 2);
} else {
ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2), SessionException);
}
// ACK_FAILURE does not exist in v4+
if (i == 0) {
ASSERT_EQ(session.state_, State::Error);
} else {
ASSERT_EQ(session.state_, State::Idle);
CheckOutput(output, success_resp, sizeof(success_resp));
if (write_success) {
if (is_reset) {
EXPECT_EQ(session.state_, State::Idle);
CheckOutput(output, success_resp, sizeof(success_resp));
} else {
ASSERT_EQ(session.state_, State::Error);
CheckOutput(output, ignored_resp, sizeof(ignored_resp));
}
} else {
EXPECT_EQ(session.state_, State::Close);
}
}
}
}
@ -950,18 +977,100 @@ TEST(BoltSession, Noop) {
TEST(BoltSession, Route) {
// Memgraph does not support route message, but it handles it
{
SCOPED_TRACE("v1");
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
ASSERT_THROW(ExecuteCommand(input_stream, session, v4_3::route, sizeof(v4_3::route)), SessionException);
EXPECT_EQ(session.state_, State::Close);
}
{
SCOPED_TRACE("v4");
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp);
ExecuteHandshake(input_stream, session, output, v4_3::handshake_req, v4_3::handshake_resp);
ExecuteInit(input_stream, session, output, true);
ASSERT_THROW(ExecuteCommand(input_stream, session, v4_3::route, sizeof(v4_3::route)), SessionException);
ASSERT_NO_THROW(ExecuteCommand(input_stream, session, v4_3::route, sizeof(v4_3::route)));
static constexpr uint8_t expected_resp[] = {
0x00 /*two bytes of chunk header, chunk contains 64 bytes of data*/,
0x40,
0xb1 /*TinyStruct1*/,
0x7f /*Failure*/,
0xa2 /*TinyMap with 2 items*/,
0x84 /*TinyString with 4 chars*/,
'c',
'o',
'd',
'e',
0x82 /*TinyString with 2 chars*/,
'6',
'6',
0x87 /*TinyString with 7 chars*/,
'm',
'e',
's',
's',
'a',
'g',
'e',
0xd0 /*String*/,
0x2b /*With 43 chars*/,
'R',
'o',
'u',
't',
'e',
' ',
'm',
'e',
's',
's',
'a',
'g',
'e',
' ',
'i',
's',
' ',
'n',
'o',
't',
' ',
's',
'u',
'p',
'p',
'o',
'r',
't',
'e',
'd',
' ',
'i',
'n',
' ',
'M',
'e',
'm',
'g',
'r',
'a',
'p',
'h',
'!',
0x00 /*Terminating zeros*/,
0x00,
};
EXPECT_EQ(input_stream.size(), 0U);
CheckOutput(output, expected_resp, sizeof(expected_resp));
EXPECT_EQ(session.state_, State::Error);
SCOPED_TRACE("Try to reset connection after ROUTE failed");
ASSERT_NO_THROW(ExecuteCommand(input_stream, session, v4::reset_req, sizeof(v4::reset_req)));
EXPECT_EQ(input_stream.size(), 0U);
CheckOutput(output, success_resp, sizeof(success_resp));
EXPECT_EQ(session.state_, State::Idle);
}
}
@ -992,3 +1101,24 @@ TEST(BoltSession, Rollback) {
ASSERT_THROW(ExecuteCommand(input_stream, session, v4::rollback, sizeof(v4::rollback)), SessionException);
}
}
TEST(BoltSession, ResetInIdle) {
{
SCOPED_TRACE("v1");
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
ASSERT_NO_THROW(ExecuteCommand(input_stream, session, reset_req, sizeof(reset_req)));
EXPECT_EQ(session.state_, State::Idle);
}
{
SCOPED_TRACE("v4");
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4_3::handshake_req, v4_3::handshake_resp);
ExecuteInit(input_stream, session, output, true);
ASSERT_NO_THROW(ExecuteCommand(input_stream, session, v4::reset_req, sizeof(v4::reset_req)));
EXPECT_EQ(session.state_, State::Idle);
}
}