From 0bc298c3adae2b0778441742a89317f66d51a5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A1nos=20Benjamin=20Antal?= Date: Fri, 26 Aug 2022 13:19:27 +0200 Subject: [PATCH] Fix handling of the `ROUTE` Bolt message (#475) The fields of ROUTE message were not read from the input buffer, thus the input buffer got corrupted. Sending a new message to the server would result reading the remaining fields from the buffer, which means reading some values instead of message signature. Because of this unmet expectation, Memgraph closed the connection. With this fix, the fields of the ROUTE message are properly read and ignored. --- src/communication/CMakeLists.txt | 1 + src/communication/bolt/client.cpp | 262 ++++++++++++++++++ src/communication/bolt/client.hpp | 258 ++++------------- src/communication/bolt/v1/codes.hpp | 1 - .../bolt/v1/encoder/client_encoder.hpp | 106 ++++--- src/communication/bolt/v1/encoder/encoder.hpp | 23 -- src/communication/bolt/v1/states/error.hpp | 78 +++--- .../bolt/v1/states/executing.hpp | 2 +- src/communication/bolt/v1/states/handlers.hpp | 46 ++- .../e2e/magic_functions/functions/c_read.cpp | 32 +-- .../e2e/magic_functions/functions/c_write.cpp | 22 +- .../memory/procedures/global_memory_limit.c | 17 +- .../procedures/global_memory_limit_proc.c | 19 +- tests/integration/audit/tester.cpp | 2 +- tests/integration/auth/checker.cpp | 2 +- tests/integration/auth/tester.cpp | 2 +- tests/integration/ldap/tester.cpp | 2 +- tests/integration/mg_import_csv/tester.cpp | 2 +- tests/integration/transactions/tester.cpp | 49 ++-- .../clients/bfs_pokec_client.cpp | 2 +- .../clients/card_fraud_client.cpp | 2 +- .../clients/long_running_common.hpp | 2 +- .../macro_benchmark/clients/pokec_client.cpp | 2 +- .../macro_benchmark/clients/query_client.cpp | 2 +- tests/manual/bolt_client.cpp | 2 +- tests/mgbench/client.cpp | 2 +- tests/stress/long_running.cpp | 4 +- tests/unit/bolt_common.hpp | 16 +- tests/unit/bolt_session.cpp | 182 ++++++++++-- 29 files changed, 708 insertions(+), 434 deletions(-) create mode 100644 src/communication/bolt/client.cpp diff --git a/src/communication/CMakeLists.txt b/src/communication/CMakeLists.txt index ad316aad0..a50e062d1 100644 --- a/src/communication/CMakeLists.txt +++ b/src/communication/CMakeLists.txt @@ -7,6 +7,7 @@ set(communication_src_files websocket/listener.cpp websocket/session.cpp bolt/v1/value.cpp + bolt/client.cpp buffer.cpp client.cpp context.cpp diff --git a/src/communication/bolt/client.cpp b/src/communication/bolt/client.cpp new file mode 100644 index 000000000..024166aa1 --- /dev/null +++ b/src/communication/bolt/client.cpp @@ -0,0 +1,262 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "communication/bolt/client.hpp" + +#include "communication/bolt/v1/codes.hpp" +#include "communication/bolt/v1/value.hpp" +#include "utils/logging.hpp" + +namespace { +constexpr uint8_t kBoltV43Version[4] = {0x00, 0x00, 0x03, 0x04}; +constexpr uint8_t kEmptyBoltVersion[4] = {0x00, 0x00, 0x00, 0x00}; +} // namespace +namespace memgraph::communication::bolt { + +Client::Client(communication::ClientContext &context) : client_{&context} {} + +void Client::Connect(const io::network::Endpoint &endpoint, const std::string &username, const std::string &password, + const std::string &client_name) { + if (!client_.Connect(endpoint)) { + throw ClientFatalException("Couldn't connect to {}!", endpoint); + } + + if (!client_.Write(kPreamble, sizeof(kPreamble), true)) { + spdlog::error("Couldn't send preamble!"); + throw ServerCommunicationException(); + } + + if (!client_.Write(kBoltV43Version, sizeof(kBoltV43Version), true)) { + spdlog::error("Couldn't send protocol version!"); + throw ServerCommunicationException(); + } + + for (int i = 0; i < 3; ++i) { + if (!client_.Write(kEmptyBoltVersion, sizeof(kEmptyBoltVersion), i != 2)) { + spdlog::error("Couldn't send protocol version!"); + throw ServerCommunicationException(); + } + } + + if (!client_.Read(sizeof(kBoltV43Version))) { + spdlog::error("Couldn't get negotiated protocol version!"); + throw ServerCommunicationException(); + } + + if (memcmp(kBoltV43Version, client_.GetData(), sizeof(kBoltV43Version)) != 0) { + spdlog::error("Server negotiated unsupported protocol version!"); + throw ClientFatalException("The server negotiated an usupported protocol version!"); + } + client_.ShiftData(sizeof(kBoltV43Version)); + + if (!encoder_.MessageInit({{"user_agent", client_name}, + {"scheme", "basic"}, + {"principal", username}, + {"credentials", password}, + {"routing", {}}})) { + spdlog::error("Couldn't send init message!"); + throw ServerCommunicationException(); + } + + Signature signature{}; + Value metadata; + if (!ReadMessage(signature, metadata)) { + spdlog::error("Couldn't read init message response!"); + throw ServerCommunicationException(); + } + if (signature != Signature::Success) { + spdlog::error("Handshake failed!"); + throw ClientFatalException("Handshake with the server failed!"); + } + + spdlog::debug("Metadata of init message response: {}", metadata); +} + +QueryData Client::Execute(const std::string &query, const std::map ¶meters) { + if (!client_.IsConnected()) { + throw ClientFatalException("You must first connect to the server before using the client!"); + } + + spdlog::debug("Sending run message with statement: '{}'; parameters: {}", query, parameters); + + // It is super critical from performance point of view to send the pull message right after the run message. Otherwise + // the performance will degrade multiple magnitudes. + encoder_.MessageRun(query, parameters, {}); + encoder_.MessagePull({}); + + spdlog::debug("Reading run message response"); + Signature signature{}; + Value fields; + if (!ReadMessage(signature, fields)) { + throw ServerCommunicationException(); + } + if (fields.type() != Value::Type::Map) { + throw ServerMalformedDataException(); + } + + if (signature == Signature::Failure) { + HandleFailure(fields.ValueMap()); + } + if (signature != Signature::Success) { + throw ServerMalformedDataException(); + } + + spdlog::debug("Reading pull_all message response"); + Marker marker{}; + Value metadata; + std::vector> records; + while (true) { + if (!GetMessage()) { + throw ServerCommunicationException(); + } + if (!decoder_.ReadMessageHeader(&signature, &marker)) { + throw ServerCommunicationException(); + } + if (signature == Signature::Record) { + Value record; + if (!decoder_.ReadValue(&record, Value::Type::List)) { + throw ServerCommunicationException(); + } + records.emplace_back(std::move(record.ValueList())); + } else if (signature == Signature::Success) { + if (!decoder_.ReadValue(&metadata)) { + throw ServerCommunicationException(); + } + break; + } else if (signature == Signature::Failure) { + Value data; + if (!decoder_.ReadValue(&data)) { + throw ServerCommunicationException(); + } + HandleFailure(data.ValueMap()); + } else { + throw ServerMalformedDataException(); + } + } + + if (metadata.type() != Value::Type::Map) { + throw ServerMalformedDataException(); + } + + QueryData ret{{}, std::move(records), std::move(metadata.ValueMap())}; + + auto &header = fields.ValueMap(); + if (header.find("fields") == header.end()) { + throw ServerMalformedDataException(); + } + if (header["fields"].type() != Value::Type::List) { + throw ServerMalformedDataException(); + } + auto &field_vector = header["fields"].ValueList(); + + for (auto &field_item : field_vector) { + if (field_item.type() != Value::Type::String) { + throw ServerMalformedDataException(); + } + ret.fields.emplace_back(std::move(field_item.ValueString())); + } + + return ret; +} + +void Client::Reset() { + if (!client_.IsConnected()) { + throw ClientFatalException("You must first connect to the server before using the client!"); + } + + spdlog::debug("Sending reset message"); + + encoder_.MessageReset(); + + Signature signature{}; + Value fields; + // In Execute the pull message is sent right after the run message without reading the answer for the run message. + // That means some of the messages sent might get ignored. + while (true) { + if (!ReadMessage(signature, fields)) { + throw ServerCommunicationException(); + } + if (signature == Signature::Success) { + break; + } + if (signature != Signature::Ignored) { + throw ServerMalformedDataException(); + } + } +} + +std::optional> Client::Route(const std::map &routing, + const std::vector &bookmarks, + const std::optional &db) { + if (!client_.IsConnected()) { + throw ClientFatalException("You must first connect to the server before using the client!"); + } + + spdlog::debug("Sending route message with routing: {}; bookmarks: {}; db: {}", routing, bookmarks, + db.has_value() ? *db : Value()); + + encoder_.MessageRoute(routing, bookmarks, db); + + spdlog::debug("Reading route message response"); + Signature signature{}; + Value fields; + if (!ReadMessage(signature, fields)) { + throw ServerCommunicationException(); + } + if (signature == Signature::Ignored) { + return std::nullopt; + } + if (signature == Signature::Failure) { + HandleFailure(fields.ValueMap()); + } + if (signature != Signature::Success) { + throw ServerMalformedDataException{}; + } + return fields.ValueMap(); +} + +void Client::Close() { client_.Close(); }; + +bool Client::GetMessage() { + client_.ClearData(); + while (true) { + if (!client_.Read(kChunkHeaderSize)) return false; + + size_t chunk_size = client_.GetData()[0]; + chunk_size <<= 8U; + chunk_size += client_.GetData()[1]; + if (chunk_size == 0) return true; + + if (!client_.Read(chunk_size)) return false; + if (decoder_buffer_.GetChunk() != ChunkState::Whole) return false; + client_.ClearData(); + } + return true; +} + +bool Client::ReadMessage(Signature &signature, Value &ret) { + Marker marker{}; + if (!GetMessage()) return false; + if (!decoder_.ReadMessageHeader(&signature, &marker)) return false; + return ReadMessageData(marker, ret); +} + +bool Client::ReadMessageData(Marker marker, Value &ret) { + if (marker == Marker::TinyStruct) { + ret = Value(); + return true; + } + if (marker == Marker::TinyStruct1) { + return decoder_.ReadValue(&ret); + } + return false; +} +} // namespace memgraph::communication::bolt diff --git a/src/communication/bolt/client.hpp b/src/communication/bolt/client.hpp index c427a3470..90b388299 100644 --- a/src/communication/bolt/client.hpp +++ b/src/communication/bolt/client.hpp @@ -11,6 +11,12 @@ #pragma once +#include +#include +#include +#include + +#include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp" #include "communication/bolt/v1/decoder/decoder.hpp" #include "communication/bolt/v1/encoder/chunked_encoder_buffer.hpp" @@ -19,22 +25,17 @@ #include "communication/context.hpp" #include "io/network/endpoint.hpp" #include "utils/exceptions.hpp" -#include "utils/logging.hpp" namespace memgraph::communication::bolt { -/// This exception is thrown whenever an error occurs during query execution -/// that isn't fatal (eg. mistyped query or some transient error occurred). -/// It should be handled by everyone who uses the client. -class ClientQueryException : public utils::BasicException { +class FailureResponseException : public utils::BasicException { public: - using utils::BasicException::BasicException; + FailureResponseException() : utils::BasicException{"Couldn't execute query!"} {} - ClientQueryException() : utils::BasicException("Couldn't execute query!") {} + explicit FailureResponseException(const std::string &message) : utils::BasicException{message} {} - template - ClientQueryException(const std::string &code, Args &&...args) - : utils::BasicException(std::forward(args)...), code_(code) {} + FailureResponseException(const std::string &code, const std::string &message) + : utils::BasicException{message}, code_{code} {} const std::string &code() const { return code_; } @@ -42,6 +43,14 @@ class ClientQueryException : public utils::BasicException { std::string code_; }; +/// This exception is thrown whenever an error occurs during query execution +/// that isn't fatal (eg. mistyped query or some transient error occurred). +/// It should be handled by everyone who uses the client. +class ClientQueryException : public FailureResponseException { + public: + using FailureResponseException::FailureResponseException; +}; + /// This exception is thrown whenever a fatal error occurs during query /// execution and/or connecting to the server. /// It should be handled by everyone who uses the client. @@ -76,12 +85,13 @@ struct QueryData { /// server. It supports both SSL and plaintext connections. class Client final { public: - explicit Client(communication::ClientContext *context) : client_(context) {} + explicit Client(communication::ClientContext &context); Client(const Client &) = delete; Client(Client &&) = delete; Client &operator=(const Client &) = delete; Client &operator=(Client &&) = delete; + ~Client() = default; /// Method used to connect to the server. Before executing queries this method /// should be called to set-up the connection to the server. After the @@ -89,50 +99,7 @@ class Client final { /// established connection. /// @throws ClientFatalException when we couldn't connect to the server void Connect(const io::network::Endpoint &endpoint, const std::string &username, const std::string &password, - const std::string &client_name = "memgraph-bolt") { - if (!client_.Connect(endpoint)) { - throw ClientFatalException("Couldn't connect to {}!", endpoint); - } - - if (!client_.Write(kPreamble, sizeof(kPreamble), true)) { - SPDLOG_ERROR("Couldn't send preamble!"); - throw ServerCommunicationException(); - } - for (int i = 0; i < 4; ++i) { - if (!client_.Write(kProtocol, sizeof(kProtocol), i != 3)) { - SPDLOG_ERROR("Couldn't send protocol version!"); - throw ServerCommunicationException(); - } - } - - if (!client_.Read(sizeof(kProtocol))) { - SPDLOG_ERROR("Couldn't get negotiated protocol version!"); - throw ServerCommunicationException(); - } - if (memcmp(kProtocol, client_.GetData(), sizeof(kProtocol)) != 0) { - SPDLOG_ERROR("Server negotiated unsupported protocol version!"); - throw ClientFatalException("The server negotiated an usupported protocol version!"); - } - client_.ShiftData(sizeof(kProtocol)); - - if (!encoder_.MessageInit(client_name, {{"scheme", "basic"}, {"principal", username}, {"credentials", password}})) { - SPDLOG_ERROR("Couldn't send init message!"); - throw ServerCommunicationException(); - } - - Signature signature; - Value metadata; - if (!ReadMessage(&signature, &metadata)) { - SPDLOG_ERROR("Couldn't read init message response!"); - throw ServerCommunicationException(); - } - if (signature != Signature::Success) { - SPDLOG_ERROR("Handshake failed!"); - throw ClientFatalException("Handshake with the server failed!"); - } - - SPDLOG_INFO("Metadata of init message response: {}", metadata); - } + const std::string &client_name = "memgraph-bolt"); /// Function used to execute queries against the server. Before you can /// execute queries you must connect the client to the server. @@ -140,168 +107,41 @@ class Client final { /// executing the query (eg. mistyped query, /// etc.) /// @throws ClientFatalException when we couldn't communicate with the server - QueryData Execute(const std::string &query, const std::map ¶meters) { - if (!client_.IsConnected()) { - throw ClientFatalException("You must first connect to the server before using the client!"); - } - - SPDLOG_INFO("Sending run message with statement: '{}'; parameters: {}", query, parameters); - - encoder_.MessageRun(query, parameters); - encoder_.MessagePullAll(); - - SPDLOG_INFO("Reading run message response"); - Signature signature; - Value fields; - if (!ReadMessage(&signature, &fields)) { - throw ServerCommunicationException(); - } - if (fields.type() != Value::Type::Map) { - throw ServerMalformedDataException(); - } - - if (signature == Signature::Failure) { - HandleFailure(); - auto &tmp = fields.ValueMap(); - auto it = tmp.find("message"); - if (it != tmp.end()) { - auto it_code = tmp.find("code"); - if (it_code != tmp.end()) { - throw ClientQueryException(it_code->second.ValueString(), it->second.ValueString()); - } else { - throw ClientQueryException("", it->second.ValueString()); - } - } - throw ClientQueryException(); - } else if (signature != Signature::Success) { - throw ServerMalformedDataException(); - } - - SPDLOG_INFO("Reading pull_all message response"); - Marker marker; - Value metadata; - std::vector> records; - while (true) { - if (!GetMessage()) { - throw ServerCommunicationException(); - } - if (!decoder_.ReadMessageHeader(&signature, &marker)) { - throw ServerCommunicationException(); - } - if (signature == Signature::Record) { - Value record; - if (!decoder_.ReadValue(&record, Value::Type::List)) { - throw ServerCommunicationException(); - } - records.emplace_back(std::move(record.ValueList())); - } else if (signature == Signature::Success) { - if (!decoder_.ReadValue(&metadata)) { - throw ServerCommunicationException(); - } - break; - } else if (signature == Signature::Failure) { - Value data; - if (!decoder_.ReadValue(&data)) { - throw ServerCommunicationException(); - } - HandleFailure(); - auto &tmp = data.ValueMap(); - auto it = tmp.find("message"); - if (it != tmp.end()) { - auto it_code = tmp.find("code"); - if (it_code != tmp.end()) { - throw ClientQueryException(it_code->second.ValueString(), it->second.ValueString()); - } else { - throw ClientQueryException("", it->second.ValueString()); - } - } - throw ClientQueryException(); - } else { - throw ServerMalformedDataException(); - } - } - - if (metadata.type() != Value::Type::Map) { - throw ServerMalformedDataException(); - } - - QueryData ret{{}, std::move(records), std::move(metadata.ValueMap())}; - - auto &header = fields.ValueMap(); - if (header.find("fields") == header.end()) { - throw ServerMalformedDataException(); - } - if (header["fields"].type() != Value::Type::List) { - throw ServerMalformedDataException(); - } - auto &field_vector = header["fields"].ValueList(); - - for (auto &field_item : field_vector) { - if (field_item.type() != Value::Type::String) { - throw ServerMalformedDataException(); - } - ret.fields.emplace_back(std::move(field_item.ValueString())); - } - - return ret; - } + QueryData Execute(const std::string &query, const std::map ¶meters); /// Close the active client connection. - void Close() { client_.Close(); }; + void Close(); + + /// Can be used to reset the active client connection. Reset is automatically sent after receiving a failure message + /// from the server, which result in throwing an FailureResponseException or any exception derived from it. + void Reset(); + + /// Can be used to send a route message. + std::optional> Route(const std::map &routing, + const std::vector &bookmarks, + const std::optional &db); private: - bool GetMessage() { - client_.ClearData(); - while (true) { - if (!client_.Read(kChunkHeaderSize)) return false; + using ClientEncoder = ClientEncoder>; - size_t chunk_size = client_.GetData()[0]; - chunk_size <<= 8; - chunk_size += client_.GetData()[1]; - if (chunk_size == 0) return true; - - if (!client_.Read(chunk_size)) return false; - if (decoder_buffer_.GetChunk() != ChunkState::Whole) return false; - client_.ClearData(); - } - return true; - } - - bool ReadMessage(Signature *signature, Value *ret) { - Marker marker; - if (!GetMessage()) return false; - if (!decoder_.ReadMessageHeader(signature, &marker)) return false; - return ReadMessageData(marker, ret); - } - - bool ReadMessageData(Marker marker, Value *ret) { - if (marker == Marker::TinyStruct) { - *ret = Value(); - return true; - } else if (marker == Marker::TinyStruct1) { - return decoder_.ReadValue(ret); - } - return false; - } - - void HandleFailure() { - if (!encoder_.MessageAckFailure()) { - throw ServerCommunicationException(); - } - while (true) { - Signature signature; - Value data; - if (!ReadMessage(&signature, &data)) { - throw ServerCommunicationException(); - } - if (signature == Signature::Success) { - break; - } else if (signature != Signature::Ignored) { - throw ServerMalformedDataException(); + template + [[noreturn]] void HandleFailure(const std::map &response_map) { + Reset(); + auto it = response_map.find("message"); + if (it != response_map.end()) { + auto it_code = response_map.find("code"); + if (it_code != response_map.end()) { + throw TException(it_code->second.ValueString(), it->second.ValueString()); } + throw TException("", it->second.ValueString()); } + throw TException(); } + bool GetMessage(); + bool ReadMessage(Signature &signature, Value &ret); + bool ReadMessageData(Marker marker, Value &ret); + // client communication::Client client_; communication::ClientInputStream input_stream_{client_}; @@ -313,6 +153,6 @@ class Client final { // encoder objects ChunkedEncoderBuffer encoder_buffer_{output_stream_}; - ClientEncoder> encoder_{encoder_buffer_}; + ClientEncoder encoder_{encoder_buffer_}; }; } // namespace memgraph::communication::bolt diff --git a/src/communication/bolt/v1/codes.hpp b/src/communication/bolt/v1/codes.hpp index adbefbe7f..91c30f831 100644 --- a/src/communication/bolt/v1/codes.hpp +++ b/src/communication/bolt/v1/codes.hpp @@ -16,7 +16,6 @@ namespace memgraph::communication::bolt { inline constexpr uint8_t kPreamble[4] = {0x60, 0x60, 0xB0, 0x17}; -inline constexpr uint8_t kProtocol[4] = {0x00, 0x00, 0x00, 0x01}; enum class Signature : uint8_t { Noop = 0x00, diff --git a/src/communication/bolt/v1/encoder/client_encoder.hpp b/src/communication/bolt/v1/encoder/client_encoder.hpp index 122daf345..3ccf0c910 100644 --- a/src/communication/bolt/v1/encoder/client_encoder.hpp +++ b/src/communication/bolt/v1/encoder/client_encoder.hpp @@ -11,6 +11,11 @@ #pragma once +#include +#include +#include +#include + #include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/encoder/base_encoder.hpp" @@ -30,6 +35,7 @@ class ClientEncoder : private BaseEncoder { using BaseEncoder::WriteList; using BaseEncoder::WriteMap; using BaseEncoder::WriteString; + using BaseEncoder::WriteNull; using BaseEncoder::buffer_; public: @@ -38,10 +44,9 @@ class ClientEncoder : private BaseEncoder { /** * Writes a Init message. * - * From the Bolt v1 documentation: - * InitMessage (signature=0x01) { - * String clientName - * Map authToken + * From the Bolt v4.3 documentation: + * HelloMess (signature=0x01) { + * Map extra * } * * @param client_name the name of the connected client @@ -49,11 +54,10 @@ class ClientEncoder : private BaseEncoder { * @returns true if the data was successfully sent to the client * when flushing, false otherwise */ - bool MessageInit(const std::string client_name, const std::map &auth_token) { - WriteRAW(utils::UnderlyingCast(Marker::TinyStruct2)); + bool MessageInit(const std::map &extra) { + WriteRAW(utils::UnderlyingCast(Marker::TinyStruct1)); WriteRAW(utils::UnderlyingCast(Signature::Init)); - WriteString(client_name); - WriteMap(auth_token); + WriteMap(extra); // Try to flush all remaining data in the buffer, but tell it that we will // send more data (the end of message chunk). if (!buffer_.Flush(true)) return false; @@ -64,10 +68,11 @@ class ClientEncoder : private BaseEncoder { /** * Writes a Run message. * - * From the Bolt v1 documentation: + * From the Bolt v4.3 documentation: * RunMessage (signature=0x10) { - * String statement - * Map parameters + * String statement + * Map parameters + * Map extra * } * * @param statement the statement that should be executed @@ -75,11 +80,13 @@ class ClientEncoder : private BaseEncoder { * @returns true if the data was successfully sent to the client * when flushing, false otherwise */ - bool MessageRun(const std::string &statement, const std::map ¶meters, bool have_more = true) { - WriteRAW(utils::UnderlyingCast(Marker::TinyStruct2)); + bool MessageRun(const std::string &statement, const std::map ¶meters, + const std::map &extra, bool have_more = true) { + WriteRAW(utils::UnderlyingCast(Marker::TinyStruct3)); WriteRAW(utils::UnderlyingCast(Signature::Run)); WriteString(statement); WriteMap(parameters); + WriteMap(extra); // Try to flush all remaining data in the buffer, but tell it that we will // send more data (the end of message chunk). if (!buffer_.Flush(true)) return false; @@ -90,18 +97,20 @@ class ClientEncoder : private BaseEncoder { } /** - * Writes a DiscardAll message. + * Writes a Discard message. * - * From the Bolt v1 documentation: + * From the Bolt v4.3 documentation: * DiscardMessage (signature=0x2F) { + * Map extra * } * * @returns true if the data was successfully sent to the client * when flushing, false otherwise */ - bool MessageDiscardAll() { - WriteRAW(utils::UnderlyingCast(Marker::TinyStruct)); + bool MessageDiscard(const std::map &extra) { + WriteRAW(utils::UnderlyingCast(Marker::TinyStruct1)); WriteRAW(utils::UnderlyingCast(Signature::Discard)); + WriteMap(extra); // Try to flush all remaining data in the buffer, but tell it that we will // send more data (the end of message chunk). if (!buffer_.Flush(true)) return false; @@ -112,36 +121,18 @@ class ClientEncoder : private BaseEncoder { /** * Writes a PullAll message. * - * From the Bolt v1 documentation: - * PullAllMessage (signature=0x3F) { + * From the Bolt v4.3 documentation: + * PullMessage (signature=0x3F) { + * Map extra * } * * @returns true if the data was successfully sent to the client * when flushing, false otherwise */ - bool MessagePullAll() { - WriteRAW(utils::UnderlyingCast(Marker::TinyStruct)); + bool MessagePull(const std::map &extra) { + WriteRAW(utils::UnderlyingCast(Marker::TinyStruct1)); WriteRAW(utils::UnderlyingCast(Signature::Pull)); - // Try to flush all remaining data in the buffer, but tell it that we will - // send more data (the end of message chunk). - if (!buffer_.Flush(true)) return false; - // Flush an empty chunk to indicate that the message is done. - return buffer_.Flush(); - } - - /** - * Writes a AckFailure message. - * - * From the Bolt v1 documentation: - * AckFailureMessage (signature=0x0E) { - * } - * - * @returns true if the data was successfully sent to the client - * when flushing, false otherwise - */ - bool MessageAckFailure() { - WriteRAW(utils::UnderlyingCast(Marker::TinyStruct)); - WriteRAW(utils::UnderlyingCast(Signature::AckFailure)); + WriteMap(extra); // Try to flush all remaining data in the buffer, but tell it that we will // send more data (the end of message chunk). if (!buffer_.Flush(true)) return false; @@ -152,7 +143,7 @@ class ClientEncoder : private BaseEncoder { /** * Writes a Reset message. * - * From the Bolt v1 documentation: + * From the Bolt v4.3 documentation: * ResetMessage (signature=0x0F) { * } * @@ -168,5 +159,36 @@ class ClientEncoder : private BaseEncoder { // Flush an empty chunk to indicate that the message is done. return buffer_.Flush(); } + + /** + * Writes a Route message. + * + * From the Bolt v4.3 documentation: + * RouteMessage (signature=0x0F) { + * Map routing + * List bookmarks + * String db + * } + * + * @returns true if the data was successfully sent to the client + * when flushing, false otherwise + */ + bool MessageRoute(const std::map &routing, const std::vector &bookmarks, + const std::optional &db) { + WriteRAW(utils::UnderlyingCast(Marker::TinyStruct3)); + WriteRAW(utils::UnderlyingCast(Signature::Route)); + WriteMap(routing); + WriteList(bookmarks); + if (db.has_value()) { + WriteString(*db); + } else { + WriteNull(); + } + // Try to flush all remaining data in the buffer, but tell it that we will + // send more data (the end of message chunk). + if (!buffer_.Flush(true)) return false; + // Flush an empty chunk to indicate that the message is done. + return buffer_.Flush(); + } }; } // namespace memgraph::communication::bolt diff --git a/src/communication/bolt/v1/encoder/encoder.hpp b/src/communication/bolt/v1/encoder/encoder.hpp index 33c8ae06b..6ec84051a 100644 --- a/src/communication/bolt/v1/encoder/encoder.hpp +++ b/src/communication/bolt/v1/encoder/encoder.hpp @@ -117,29 +117,6 @@ class Encoder : private BaseEncoder { return buffer_.Flush(); } - /** - * Sends an Ignored message. - * - * From the bolt v1 documentation: - * IgnoredMessage (signature=0x7E) { - * Map metadata - * } - * - * @param metadata the metadata map object that should be sent - * @returns true if the data was successfully sent to the client, - * false otherwise - */ - bool MessageIgnored(const std::map &metadata) { - WriteRAW(utils::UnderlyingCast(Marker::TinyStruct1)); - WriteRAW(utils::UnderlyingCast(Signature::Ignored)); - WriteMap(metadata); - // Try to flush all remaining data in the buffer, but tell it that we will - // send more data (the end of message chunk). - if (!buffer_.Flush(true)) return false; - // Flush an empty chunk to indicate that the message is done. - return buffer_.Flush(); - } - /** * Sends an Ignored message. * diff --git a/src/communication/bolt/v1/states/error.hpp b/src/communication/bolt/v1/states/error.hpp index 3b37923b7..4c5f4006f 100644 --- a/src/communication/bolt/v1/states/error.hpp +++ b/src/communication/bolt/v1/states/error.hpp @@ -15,6 +15,7 @@ #include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/state.hpp" +#include "communication/bolt/v1/states/handlers.hpp" #include "communication/bolt/v1/value.hpp" #include "utils/cast.hpp" #include "utils/likely.hpp" @@ -30,8 +31,8 @@ namespace memgraph::communication::bolt { */ template State StateErrorRun(TSession &session, State state) { - Marker marker; - Signature signature; + Marker marker{}; + Signature signature{}; if (!session.decoder_.ReadMessageHeader(&signature, &marker)) { spdlog::trace("Missing header data!"); return State::Close; @@ -45,54 +46,49 @@ State StateErrorRun(TSession &session, State state) { // Clear the data buffer if it has any leftover data. session.encoder_buffer_.Clear(); - if ((session.version_.major == 1 && signature == Signature::AckFailure) || signature == Signature::Reset) { - if (signature == Signature::AckFailure) { - spdlog::trace("AckFailure received"); - } else { - spdlog::trace("Reset received"); - } + if (session.version_.major == 1 && signature == Signature::AckFailure) { + spdlog::trace("AckFailure received"); if (!session.encoder_.MessageSuccess()) { spdlog::trace("Couldn't send success message!"); return State::Close; } - if (signature == Signature::Reset) { - session.Abort(); - return State::Idle; - } - // We got AckFailure get back to right state. MG_ASSERT(state == State::Error, "Shouldn't happen"); return State::Idle; - } else { - uint8_t value = utils::UnderlyingCast(marker); - - // All bolt client messages have less than 15 parameters so if we receive - // anything than a TinyStruct it's an error. - if ((value & 0xF0) != utils::UnderlyingCast(Marker::TinyStruct)) { - spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", value); - return State::Close; - } - - // We need to clean up all parameters from this command. - value &= 0x0F; // The length is stored in the lower nibble. - Value dv; - for (int i = 0; i < value; ++i) { - if (!session.decoder_.ReadValue(&dv)) { - spdlog::trace("Couldn't clean up parameter {} / {}!", i, value); - return State::Close; - } - } - - // Ignore this message. - if (!session.encoder_.MessageIgnored()) { - spdlog::trace("Couldn't send ignored message!"); - return State::Close; - } - - // Cleanup done, command ignored, stay in error state. - return state; } + if (signature == Signature::Reset) { + spdlog::trace("Reset received"); + return HandleReset(session, marker); + } + + uint8_t value = utils::UnderlyingCast(marker); + + // All bolt client messages have less than 15 parameters so if we receive + // anything than a TinyStruct it's an error. + if ((value & 0xF0U) != utils::UnderlyingCast(Marker::TinyStruct)) { + spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", value); + return State::Close; + } + + // We need to clean up all parameters from this command. + value &= 0x0FU; // The length is stored in the lower nibble. + Value dv; + for (int i = 0; i < value; ++i) { + if (!session.decoder_.ReadValue(&dv)) { + spdlog::trace("Couldn't clean up parameter {} / {}!", i, value); + return State::Close; + } + } + + // Ignore this message. + if (!session.encoder_.MessageIgnored()) { + spdlog::trace("Couldn't send ignored message!"); + return State::Close; + } + + // Cleanup done, command ignored, stay in error state. + return state; } } // namespace memgraph::communication::bolt diff --git a/src/communication/bolt/v1/states/executing.hpp b/src/communication/bolt/v1/states/executing.hpp index 54504985b..ed298deda 100644 --- a/src/communication/bolt/v1/states/executing.hpp +++ b/src/communication/bolt/v1/states/executing.hpp @@ -74,7 +74,7 @@ State RunHandlerV4(Signature signature, TSession &session, State state, Marker m } case Signature::Route: { if constexpr (bolt_minor >= 3) { - if (signature == Signature::Route) return HandleRoute(session); + if (signature == Signature::Route) return HandleRoute(session, marker); } else { spdlog::trace("Supported only in bolt v4.3"); return State::Close; diff --git a/src/communication/bolt/v1/states/handlers.hpp b/src/communication/bolt/v1/states/handlers.hpp index f89ceae74..08b984c28 100644 --- a/src/communication/bolt/v1/states/handlers.hpp +++ b/src/communication/bolt/v1/states/handlers.hpp @@ -18,6 +18,7 @@ #include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/constants.hpp" +#include "communication/bolt/v1/exceptions.hpp" #include "communication/bolt/v1/state.hpp" #include "communication/bolt/v1/value.hpp" #include "communication/exceptions.hpp" @@ -136,7 +137,7 @@ template State HandlePullDiscardV1(TSession &session, const State state, const Marker marker) { const auto expected_marker = Marker::TinyStruct; if (marker != expected_marker) { - spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct", utils::UnderlyingCast(marker)); + spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", utils::UnderlyingCast(marker)); return State::Close; } @@ -157,7 +158,7 @@ template State HandlePullDiscardV4(TSession &session, const State state, const Marker marker) { const auto expected_marker = Marker::TinyStruct1; if (marker != expected_marker) { - spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct1", utils::UnderlyingCast(marker)); + spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker)); return State::Close; } @@ -216,7 +217,8 @@ State HandleRunV1(TSession &session, const State state, const Marker marker) { session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3", utils::UnderlyingCast(marker)); return State::Close; } - Value query, params; + Value query; + Value params; if (!session.decoder_.ReadValue(&query, Value::Type::String)) { spdlog::trace("Couldn't read query string!"); return State::Close; @@ -234,10 +236,12 @@ template State HandleRunV4(TSession &session, const State state, const Marker marker) { const auto expected_marker = Marker::TinyStruct3; if (marker != expected_marker) { - spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct3", utils::UnderlyingCast(marker)); + spdlog::trace("Expected TinyStruct3 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker)); return State::Close; } - Value query, params, extra; + Value query; + Value params; + Value extra; if (!session.decoder_.ReadValue(&query, Value::Type::String)) { spdlog::trace("Couldn't read query string!"); return State::Close; @@ -292,9 +296,6 @@ State HandleReset(TSession &session, const Marker marker) { return State::Close; } - // Clear all pending data and send a success message. - session.encoder_buffer_.Clear(); - if (!session.encoder_.MessageSuccess()) { spdlog::trace("Couldn't send success message!"); return State::Close; @@ -403,12 +404,33 @@ State HandleGoodbye() { } template -State HandleRoute(TSession &session) { - // Route message is not implemented since it is neo4j specific, therefore we - // will receive it an inform user that there is no implementation. +State HandleRoute(TSession &session, const Marker marker) { + // Route message is not implemented since it is Neo4j specific, therefore we will receive it and inform user that + // there is no implementation. Before that, we have to read out the fields from the buffer to leave it in a clean + // state. + if (marker != Marker::TinyStruct3) { + spdlog::trace("Expected TinyStruct3 marker, but received 0x{:02x}!", utils::UnderlyingCast(marker)); + return State::Close; + } + Value routing; + if (!session.decoder_.ReadValue(&routing, Value::Type::Map)) { + spdlog::trace("Couldn't read routing field!"); + return State::Close; + } + + Value bookmarks; + if (!session.decoder_.ReadValue(&bookmarks, Value::Type::List)) { + spdlog::trace("Couldn't read bookmarks field!"); + return State::Close; + } + Value db; + if (!session.decoder_.ReadValue(&db)) { + spdlog::trace("Couldn't read db field!"); + return State::Close; + } session.encoder_buffer_.Clear(); bool fail_sent = - session.encoder_.MessageFailure({{"code", 66}, {"message", "Route message not supported in Memgraph!"}}); + session.encoder_.MessageFailure({{"code", "66"}, {"message", "Route message is not supported in Memgraph!"}}); if (!fail_sent) { spdlog::trace("Couldn't send failure message!"); return State::Close; diff --git a/tests/e2e/magic_functions/functions/c_read.cpp b/tests/e2e/magic_functions/functions/c_read.cpp index 75c023e58..f2842f35f 100644 --- a/tests/e2e/magic_functions/functions/c_read.cpp +++ b/tests/e2e/magic_functions/functions/c_read.cpp @@ -22,13 +22,13 @@ static void ReturnFunctionArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_value *value{nullptr}; auto err_code = mgp_list_at(args, 0, &value); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { - mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory)); return; } err_code = mgp_func_result_set_value(result, value, memory); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { - mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory)); return; } } @@ -38,13 +38,13 @@ static void ReturnOptionalArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_value *value{nullptr}; auto err_code = mgp_list_at(args, 0, &value); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { - mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory)); return; } err_code = mgp_func_result_set_value(result, value, memory); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { - mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory)); return; } } @@ -57,14 +57,14 @@ double GetElementFromArg(struct mgp_list *args, int index) { double result; int is_int; - mgp_value_is_int(value, &is_int); + static_cast(mgp_value_is_int(value, &is_int)); if (is_int) { int64_t result_int; - mgp_value_get_int(value, &result_int); + static_cast(mgp_value_get_int(value, &result_int)); result = static_cast(result_int); } else { - mgp_value_get_double(value, &result); + static_cast(mgp_value_get_double(value, &result)); } return result; } @@ -77,30 +77,30 @@ static void AddTwoNumbers(struct mgp_list *args, mgp_func_context *ctx, mgp_func first = GetElementFromArg(args, 0); second = GetElementFromArg(args, 1); } catch (...) { - mgp_func_result_set_error_msg(result, "Unable to fetch the result!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Unable to fetch the result!", memory)); return; } mgp_value *value{nullptr}; auto summation = first + second; - mgp_value_make_double(summation, memory, &value); + static_cast(mgp_value_make_double(summation, memory, &value)); memgraph::utils::OnScopeExit delete_summation_value([&value] { mgp_value_destroy(value); }); auto err_code = mgp_func_result_set_value(result, value, memory); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { - mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory)); } } static void ReturnNull(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result, struct mgp_memory *memory) { mgp_value *value{nullptr}; - mgp_value_make_null(memory, &value); + static_cast(mgp_value_make_null(memory, &value)); memgraph::utils::OnScopeExit delete_null([&value] { mgp_value_destroy(value); }); auto err_code = mgp_func_result_set_value(result, value, memory); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { - mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory)); } } } // namespace @@ -116,7 +116,7 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem } mgp_type *type_any{nullptr}; - mgp_type_any(&type_any); + static_cast(mgp_type_any(&type_any)); err_code = mgp_func_add_arg(func, "argument", type_any); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { return 1; @@ -131,11 +131,11 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem } mgp_value *default_value{nullptr}; - mgp_value_make_int(42, memory, &default_value); + static_cast(mgp_value_make_int(42, memory, &default_value)); memgraph::utils::OnScopeExit delete_summation_value([&default_value] { mgp_value_destroy(default_value); }); mgp_type *type_int{nullptr}; - mgp_type_int(&type_int); + static_cast(mgp_type_int(&type_int)); err_code = mgp_func_add_opt_arg(func, "opt_argument", type_int, default_value); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { return 1; @@ -150,7 +150,7 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem } mgp_type *type_number{nullptr}; - mgp_type_number(&type_number); + static_cast(mgp_type_number(&type_number)); err_code = mgp_func_add_arg(func, "first", type_number); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { return 1; diff --git a/tests/e2e/magic_functions/functions/c_write.cpp b/tests/e2e/magic_functions/functions/c_write.cpp index 76a2ad08e..0dedde51f 100644 --- a/tests/e2e/magic_functions/functions/c_write.cpp +++ b/tests/e2e/magic_functions/functions/c_write.cpp @@ -15,25 +15,25 @@ static void TryToWrite(struct mgp_list *args, mgp_func_context *ctx, mgp_func_re struct mgp_memory *memory) { mgp_value *value{nullptr}; mgp_vertex *vertex{nullptr}; - mgp_list_at(args, 0, &value); - mgp_value_get_vertex(value, &vertex); + static_cast(mgp_list_at(args, 0, &value)); + static_cast(mgp_value_get_vertex(value, &vertex)); const char *name; - mgp_list_at(args, 1, &value); - mgp_value_get_string(value, &name); + static_cast(mgp_list_at(args, 1, &value)); + static_cast(mgp_value_get_string(value, &name)); - mgp_list_at(args, 2, &value); + static_cast(mgp_list_at(args, 2, &value)); // Setting a property should set an error auto err_code = mgp_vertex_set_property(vertex, name, value); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { - mgp_func_result_set_error_msg(result, "Cannot set property in the function!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Cannot set property in the function!", memory)); return; } err_code = mgp_func_result_set_value(result, value, memory); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { - mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory); + static_cast(mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory)); return; } } @@ -49,23 +49,23 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem } mgp_type *type_vertex{nullptr}; - mgp_type_node(&type_vertex); + static_cast(mgp_type_node(&type_vertex)); err_code = mgp_func_add_arg(func, "argument", type_vertex); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { return 1; } mgp_type *type_string{nullptr}; - mgp_type_string(&type_string); + static_cast(mgp_type_string(&type_string)); err_code = mgp_func_add_arg(func, "name", type_string); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { return 1; } mgp_type *any_type{nullptr}; - mgp_type_any(&any_type); + static_cast(mgp_type_any(&any_type)); mgp_type *nullable_type{nullptr}; - mgp_type_nullable(any_type, &nullable_type); + static_cast(mgp_type_nullable(any_type, &nullable_type)); err_code = mgp_func_add_arg(func, "value", nullable_type); if (err_code != mgp_error::MGP_ERROR_NO_ERROR) { return 1; diff --git a/tests/e2e/memory/procedures/global_memory_limit.c b/tests/e2e/memory/procedures/global_memory_limit.c index 8d4362afc..d4496fd6e 100644 --- a/tests/e2e/memory/procedures/global_memory_limit.c +++ b/tests/e2e/memory/procedures/global_memory_limit.c @@ -1,10 +1,21 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + #include "mg_procedure.h" int *gVal = NULL; void set_error(struct mgp_result *result) { mgp_result_set_error_msg(result, "Something went wrong"); } -static void procedure(const struct mgp_list *args, const struct mgp_graph *graph, struct mgp_result *result, +static void procedure(struct mgp_list *args, struct mgp_graph *graph, struct mgp_result *result, struct mgp_memory *memory) { struct mgp_result_record *record = NULL; const enum mgp_error new_record_err = mgp_result_new_record(result, &record); @@ -21,14 +32,14 @@ static void procedure(const struct mgp_list *args, const struct mgp_graph *graph int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) { const size_t one_gb = 1 << 30; - const enum mgp_error alloc_err = mgp_global_alloc(one_gb, &gVal); + const enum mgp_error alloc_err = mgp_global_alloc(one_gb, (void **)(&gVal)); if (alloc_err != MGP_ERROR_NO_ERROR) return 1; struct mgp_proc *proc = NULL; const enum mgp_error proc_err = mgp_module_add_read_procedure(module, "procedure", procedure, &proc); if (proc_err != MGP_ERROR_NO_ERROR) return 1; - const struct mgp_type *string_type = NULL; + struct mgp_type *string_type = NULL; const enum mgp_error string_type_err = mgp_type_string(&string_type); if (string_type_err != MGP_ERROR_NO_ERROR) return 1; if (mgp_proc_add_result(proc, "result", string_type) != MGP_ERROR_NO_ERROR) return 1; diff --git a/tests/e2e/memory/procedures/global_memory_limit_proc.c b/tests/e2e/memory/procedures/global_memory_limit_proc.c index f54202ead..393ced426 100644 --- a/tests/e2e/memory/procedures/global_memory_limit_proc.c +++ b/tests/e2e/memory/procedures/global_memory_limit_proc.c @@ -1,3 +1,14 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + #include "mg_procedure.h" int *gVal = NULL; @@ -6,7 +17,7 @@ void set_error(struct mgp_result *result) { mgp_result_set_error_msg(result, "So void set_out_of_memory_error(struct mgp_result *result) { mgp_result_set_error_msg(result, "Out of memory"); } -static void error(const struct mgp_list *args, const struct mgp_graph *graph, struct mgp_result *result, +static void error(struct mgp_list *args, struct mgp_graph *graph, struct mgp_result *result, struct mgp_memory *memory) { const size_t one_gb = 1 << 30; if (gVal) { @@ -14,7 +25,7 @@ static void error(const struct mgp_list *args, const struct mgp_graph *graph, st gVal = NULL; } if (!gVal) { - const enum mgp_error err = mgp_global_alloc(one_gb, &gVal); + const enum mgp_error err = mgp_global_alloc(one_gb, (void **)(&gVal)); if (err == MGP_ERROR_UNABLE_TO_ALLOCATE) return set_out_of_memory_error(result); if (err != MGP_ERROR_NO_ERROR) return set_error(result); } @@ -29,11 +40,11 @@ static void error(const struct mgp_list *args, const struct mgp_graph *graph, st if (result_inserted != MGP_ERROR_NO_ERROR) return set_error(result); } -static void success(const struct mgp_list *args, const struct mgp_graph *graph, struct mgp_result *result, +static void success(struct mgp_list *args, struct mgp_graph *graph, struct mgp_result *result, struct mgp_memory *memory) { const size_t bytes = 1024; if (!gVal) { - const enum mgp_error err = mgp_global_alloc(bytes, &gVal); + const enum mgp_error err = mgp_global_alloc(bytes, (void **)(&gVal)); if (err == MGP_ERROR_UNABLE_TO_ALLOCATE) return set_out_of_memory_error(result); if (err != MGP_ERROR_NO_ERROR) return set_error(result); } diff --git a/tests/integration/audit/tester.cpp b/tests/integration/audit/tester.cpp index 7753bd59c..973227253 100644 --- a/tests/integration/audit/tester.cpp +++ b/tests/integration/audit/tester.cpp @@ -86,7 +86,7 @@ int main(int argc, char **argv) { memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port); memgraph::communication::ClientContext context(FLAGS_use_ssl); - memgraph::communication::bolt::Client client(&context); + memgraph::communication::bolt::Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); client.Execute(FLAGS_query, JsonToValue(nlohmann::json::parse(FLAGS_params_json)).ValueMap()); diff --git a/tests/integration/auth/checker.cpp b/tests/integration/auth/checker.cpp index 52cb33c07..946f1f43c 100644 --- a/tests/integration/auth/checker.cpp +++ b/tests/integration/auth/checker.cpp @@ -33,7 +33,7 @@ int main(int argc, char **argv) { memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port); memgraph::communication::ClientContext context(FLAGS_use_ssl); - memgraph::communication::bolt::Client client(&context); + memgraph::communication::bolt::Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); diff --git a/tests/integration/auth/tester.cpp b/tests/integration/auth/tester.cpp index 904ab14a5..3ef7392a9 100644 --- a/tests/integration/auth/tester.cpp +++ b/tests/integration/auth/tester.cpp @@ -38,7 +38,7 @@ int main(int argc, char **argv) { memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port); memgraph::communication::ClientContext context(FLAGS_use_ssl); - memgraph::communication::bolt::Client client(&context); + memgraph::communication::bolt::Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); diff --git a/tests/integration/ldap/tester.cpp b/tests/integration/ldap/tester.cpp index 5fe5ae925..8f79938c7 100644 --- a/tests/integration/ldap/tester.cpp +++ b/tests/integration/ldap/tester.cpp @@ -37,7 +37,7 @@ int main(int argc, char **argv) { memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port); memgraph::communication::ClientContext context(FLAGS_use_ssl); - memgraph::communication::bolt::Client client(&context); + memgraph::communication::bolt::Client client(context); { std::string what; diff --git a/tests/integration/mg_import_csv/tester.cpp b/tests/integration/mg_import_csv/tester.cpp index 8197d5a71..d2ed636d8 100644 --- a/tests/integration/mg_import_csv/tester.cpp +++ b/tests/integration/mg_import_csv/tester.cpp @@ -36,7 +36,7 @@ int main(int argc, char **argv) { memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port); memgraph::communication::ClientContext context(FLAGS_use_ssl); - memgraph::communication::bolt::Client client(&context); + memgraph::communication::bolt::Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); auto ret = client.Execute("DUMP DATABASE", {}); diff --git a/tests/integration/transactions/tester.cpp b/tests/integration/transactions/tester.cpp index 189b8c92b..06f89202c 100644 --- a/tests/integration/transactions/tester.cpp +++ b/tests/integration/transactions/tester.cpp @@ -65,7 +65,7 @@ class BoltClient : public ::testing::Test { memgraph::io::network::Endpoint endpoint_{memgraph::io::network::ResolveHostname(FLAGS_address), static_cast(FLAGS_port)}; memgraph::communication::ClientContext context_{FLAGS_use_ssl}; - Client client_{&context_}; + Client client_{context_}; }; const std::string kNoCurrentTransactionToCommit = "No current transaction to commit."; @@ -100,20 +100,20 @@ TEST_F(BoltClient, DoubleRollbackWithoutTransaction) { TEST_F(BoltClient, DoubleBegin) { EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, DoubleBeginAndCommit) { EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(Execute("commit")); + EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException); EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, DoubleBeginAndRollback) { EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(Execute("rollback")); + EXPECT_THROW(Execute("rollback", kNoCurrentTransactionToRollback), ClientQueryException); EXPECT_FALSE(TransactionActive()); } @@ -157,30 +157,29 @@ TEST_F(BoltClient, BeginAndCorrectQueriesAndBegin) { EXPECT_TRUE(Execute("create (n)")); ASSERT_EQ(GetCount(), count + 1); EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_EQ(GetCount(), count + 1); - EXPECT_TRUE(TransactionActive()); + EXPECT_EQ(GetCount(), count); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, BeginAndWrongQueryAndRollback) { EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("asdasd"), ClientQueryException); - EXPECT_TRUE(Execute("rollback")); + EXPECT_THROW(Execute("rollback", kNoCurrentTransactionToRollback), ClientQueryException); EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, BeginAndWrongQueryAndCommit) { EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("asdasd"), ClientQueryException); - EXPECT_THROW(Execute("commit", kCommitInvalid), ClientQueryException); - EXPECT_TRUE(Execute("rollback")); + EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException); EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, BeginAndWrongQueryAndBegin) { EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("asdasd"), ClientQueryException); - EXPECT_THROW(Execute("commit", kCommitInvalid), ClientQueryException); - EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); + EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException); + EXPECT_TRUE(Execute("begin")); EXPECT_TRUE(TransactionActive()); } @@ -230,7 +229,7 @@ TEST_F(BoltClient, CorrectQueryAndBeginAndBegin) { EXPECT_TRUE(Execute("match (n) return count(n)")); EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, WrongQueryAndBeginAndCommit) { @@ -251,7 +250,7 @@ TEST_F(BoltClient, WrongQueryAndBeginAndBegin) { EXPECT_THROW(Execute("asdasd"), ClientQueryException); EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, CorrectQueriesAndBeginAndCommit) { @@ -278,7 +277,7 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndBegin) { } EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, WrongQueriesAndBeginAndCommit) { @@ -305,7 +304,7 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndBegin) { } EXPECT_TRUE(Execute("begin")); EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, CorrectQueriesAndBeginAndCorrectQueriesAndCommit) { @@ -341,7 +340,7 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndCorrectQueriesAndBegin) { EXPECT_TRUE(Execute("match (n) return count(n)")); } EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, WrongQueriesAndBeginAndCorrectQueriesAndCommit) { @@ -377,7 +376,7 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndCorrectQueriesAndBegin) { EXPECT_TRUE(Execute("match (n) return count(n)")); } EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndCommit) { @@ -388,8 +387,8 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndCommit) { for (int i = 0; i < 3; ++i) { EXPECT_THROW(Execute("asdasd"), ClientQueryException); } - EXPECT_THROW(Execute("commit", kCommitInvalid), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndRollback) { @@ -400,7 +399,7 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndRollback) { for (int i = 0; i < 3; ++i) { EXPECT_THROW(Execute("asdasd"), ClientQueryException); } - EXPECT_TRUE(Execute("rollback")); + EXPECT_THROW(Execute("rollback", kNoCurrentTransactionToRollback), ClientQueryException); EXPECT_FALSE(TransactionActive()); } @@ -412,7 +411,7 @@ TEST_F(BoltClient, CorrectQueriesAndBeginAndWrongQueriesAndBegin) { for (int i = 0; i < 3; ++i) { EXPECT_THROW(Execute("asdasd"), ClientQueryException); } - EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); + EXPECT_TRUE(Execute("begin")); EXPECT_TRUE(TransactionActive()); } @@ -424,8 +423,8 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndWrongQueriesAndCommit) { for (int i = 0; i < 3; ++i) { EXPECT_THROW(Execute("asdasd"), ClientQueryException); } - EXPECT_THROW(Execute("commit", kCommitInvalid), ClientQueryException); - EXPECT_TRUE(TransactionActive()); + EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException); + EXPECT_FALSE(TransactionActive()); } TEST_F(BoltClient, WrongQueriesAndBeginAndWrongQueriesAndRollback) { @@ -436,7 +435,7 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndWrongQueriesAndRollback) { for (int i = 0; i < 3; ++i) { EXPECT_THROW(Execute("asdasd"), ClientQueryException); } - EXPECT_TRUE(Execute("rollback")); + EXPECT_THROW(Execute("rollback", kNoCurrentTransactionToRollback), ClientQueryException); EXPECT_FALSE(TransactionActive()); } @@ -448,7 +447,7 @@ TEST_F(BoltClient, WrongQueriesAndBeginAndWrongQueriesAndBegin) { for (int i = 0; i < 3; ++i) { EXPECT_THROW(Execute("asdasd"), ClientQueryException); } - EXPECT_THROW(Execute("begin", kNestedTransactions), ClientQueryException); + EXPECT_TRUE(Execute("begin")); EXPECT_TRUE(TransactionActive()); } diff --git a/tests/macro_benchmark/clients/bfs_pokec_client.cpp b/tests/macro_benchmark/clients/bfs_pokec_client.cpp index defe9a385..04decf32a 100644 --- a/tests/macro_benchmark/clients/bfs_pokec_client.cpp +++ b/tests/macro_benchmark/clients/bfs_pokec_client.cpp @@ -117,7 +117,7 @@ int main(int argc, char **argv) { Endpoint endpoint(FLAGS_address, FLAGS_port); ClientContext context(FLAGS_use_ssl); - Client client(&context); + Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); std::vector> clients; diff --git a/tests/macro_benchmark/clients/card_fraud_client.cpp b/tests/macro_benchmark/clients/card_fraud_client.cpp index 0cf990933..c04ff20ad 100644 --- a/tests/macro_benchmark/clients/card_fraud_client.cpp +++ b/tests/macro_benchmark/clients/card_fraud_client.cpp @@ -317,7 +317,7 @@ int main(int argc, char **argv) { Endpoint endpoint(FLAGS_address, FLAGS_port); ClientContext context(FLAGS_use_ssl); - Client client(&context); + Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); num_pos.store(NumNodesWithLabel(client, "Pos")); diff --git a/tests/macro_benchmark/clients/long_running_common.hpp b/tests/macro_benchmark/clients/long_running_common.hpp index 4509fa566..f76857dc1 100644 --- a/tests/macro_benchmark/clients/long_running_common.hpp +++ b/tests/macro_benchmark/clients/long_running_common.hpp @@ -118,7 +118,7 @@ class TestClient { private: memgraph::communication::ClientContext context_{FLAGS_use_ssl}; - Client client_{&context_}; + Client client_{context_}; }; void RunMultithreadedTest(std::vector> &clients) { diff --git a/tests/macro_benchmark/clients/pokec_client.cpp b/tests/macro_benchmark/clients/pokec_client.cpp index a6f7419a8..ba6f96941 100644 --- a/tests/macro_benchmark/clients/pokec_client.cpp +++ b/tests/macro_benchmark/clients/pokec_client.cpp @@ -261,7 +261,7 @@ int main(int argc, char **argv) { auto independent_nodes_ids = [&] { Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port); ClientContext context(FLAGS_use_ssl); - Client client(&context); + Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); return IndependentSet(client, INDEPENDENT_LABEL); }(); diff --git a/tests/macro_benchmark/clients/query_client.cpp b/tests/macro_benchmark/clients/query_client.cpp index 51e32f9bb..9bc65d238 100644 --- a/tests/macro_benchmark/clients/query_client.cpp +++ b/tests/macro_benchmark/clients/query_client.cpp @@ -67,7 +67,7 @@ void ExecuteQueries(const std::vector &queries, std::ostream &ostre threads.push_back(std::thread([&]() { Endpoint endpoint(FLAGS_address, FLAGS_port); ClientContext context(FLAGS_use_ssl); - Client client(&context); + Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); std::string str; diff --git a/tests/manual/bolt_client.cpp b/tests/manual/bolt_client.cpp index d978cfca5..7948c1b95 100644 --- a/tests/manual/bolt_client.cpp +++ b/tests/manual/bolt_client.cpp @@ -32,7 +32,7 @@ int main(int argc, char **argv) { memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port); memgraph::communication::ClientContext context(FLAGS_use_ssl); - memgraph::communication::bolt::Client client(&context); + memgraph::communication::bolt::Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); diff --git a/tests/mgbench/client.cpp b/tests/mgbench/client.cpp index e4b63d477..87bdda6c9 100644 --- a/tests/mgbench/client.cpp +++ b/tests/mgbench/client.cpp @@ -177,7 +177,7 @@ void Execute( threads.push_back(std::thread([&, worker]() { memgraph::io::network::Endpoint endpoint(FLAGS_address, FLAGS_port); memgraph::communication::ClientContext context(FLAGS_use_ssl); - memgraph::communication::bolt::Client client(&context); + memgraph::communication::bolt::Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); ready.fetch_add(1, std::memory_order_acq_rel); diff --git a/tests/stress/long_running.cpp b/tests/stress/long_running.cpp index 11f5c5feb..ddf93549d 100644 --- a/tests/stress/long_running.cpp +++ b/tests/stress/long_running.cpp @@ -66,7 +66,7 @@ class GraphSession { } EndpointT endpoint(FLAGS_address, FLAGS_port); - client_ = std::make_unique(&context_); + client_ = std::make_unique(context_); client_->Connect(endpoint, FLAGS_username, FLAGS_password); } @@ -387,7 +387,7 @@ int main(int argc, char **argv) { // create client EndpointT endpoint(FLAGS_address, FLAGS_port); ClientContextT context(FLAGS_use_ssl); - ClientT client(&context); + ClientT client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); // cleanup and create indexes diff --git a/tests/unit/bolt_common.hpp b/tests/unit/bolt_common.hpp index c6495ee4a..5770fcbfa 100644 --- a/tests/unit/bolt_common.hpp +++ b/tests/unit/bolt_common.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -98,15 +98,19 @@ void PrintOutput(std::vector &output) { * TODO (mferencevic): document */ void CheckOutput(std::vector &output, const uint8_t *data, uint64_t len, bool clear = true) { - if (clear) + if (clear) { ASSERT_EQ(len, output.size()); - else + } else { ASSERT_LE(len, output.size()); - for (size_t i = 0; i < len; ++i) EXPECT_EQ(output[i], data[i]); - if (clear) + } + for (size_t i = 0; i < len; ++i) { + EXPECT_EQ(output[i], data[i]) << i; + } + if (clear) { output.clear(); - else + } else { output.erase(output.begin(), output.begin() + len); + } } /** diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 9b18d807f..7d8de88b4 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -9,6 +9,8 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include + #include #include "bolt_common.hpp" @@ -151,7 +153,10 @@ inline constexpr uint8_t noop[] = {0x00, 0x00}; } // namespace v4_1 namespace v4_3 { -inline constexpr uint8_t route[]{0xb0, 0x60}; +inline constexpr uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x03, 0x04, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; +inline constexpr uint8_t handshake_resp[] = {0x00, 0x00, 0x03, 0x04}; +inline constexpr uint8_t route[]{0xb3, 0x66, 0xa0, 0x90, 0xc0}; } // namespace v4_3 // Write bolt chunk header (length) @@ -705,67 +710,89 @@ TEST(BoltSession, ErrorWrongMarker) { } TEST(BoltSession, ErrorOK) { - // v1 { + SCOPED_TRACE("v1"); // test ACK_FAILURE and RESET const uint8_t *dataset[] = {ackfailure_req, reset_req}; for (int i = 0; i < 2; ++i) { + SCOPED_TRACE("i: " + std::to_string(i)); // first test with socket write success, then with socket write fail for (int j = 0; j < 2; ++j) { + SCOPED_TRACE("j: " + std::to_string(j)); + const auto write_success = j == 0; INIT_VARS; ExecuteHandshake(input_stream, session, output); - ExecuteInit(input_stream, session, output); + ASSERT_EQ(session.version_.major, 1U); + ExecuteInit(input_stream, session, output); WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); - output_stream.SetWriteSuccess(j == 0); - if (j == 0) { + output_stream.SetWriteSuccess(write_success); + if (write_success) { ExecuteCommand(input_stream, session, dataset[i], 2); } else { ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2), SessionException); } // assert that all data from the init message was cleaned up - ASSERT_EQ(session.decoder_buffer_.Size(), 0); + EXPECT_EQ(session.decoder_buffer_.Size(), 0); - if (j == 0) { - ASSERT_EQ(session.state_, State::Idle); + if (write_success) { + EXPECT_EQ(session.state_, State::Idle); CheckOutput(output, success_resp, sizeof(success_resp)); } else { - ASSERT_EQ(session.state_, State::Close); - ASSERT_EQ(output.size(), 0); + EXPECT_EQ(session.state_, State::Close); + EXPECT_EQ(output.size(), 0); } } } } - // v4+ { + SCOPED_TRACE("v4"); const uint8_t *dataset[] = {ackfailure_req, v4::reset_req}; for (int i = 0; i < 2; ++i) { - INIT_VARS; + SCOPED_TRACE("i: " + std::to_string(i)); + // first test with socket write success, then with socket write fail + for (int j = 0; j < 2; ++j) { + SCOPED_TRACE("j: " + std::to_string(j)); + const auto write_success = j == 0; + const auto is_reset = i == 1; + INIT_VARS; - ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp); - ExecuteInit(input_stream, session, output, true); + ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp); + ASSERT_EQ(session.version_.major, 4U); + ExecuteInit(input_stream, session, output, true); - WriteRunRequest(input_stream, kInvalidQuery, true); - session.Execute(); + WriteRunRequest(input_stream, kInvalidQuery, true); + session.Execute(); - output.clear(); + output.clear(); + output_stream.SetWriteSuccess(write_success); - ExecuteCommand(input_stream, session, dataset[i], 2); + // ACK_FAILURE does not exist in v3+, ingored message is sent + if (write_success) { + ExecuteCommand(input_stream, session, dataset[i], 2); + } else { + ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2), SessionException); + } - // ACK_FAILURE does not exist in v4+ - if (i == 0) { - ASSERT_EQ(session.state_, State::Error); - } else { - ASSERT_EQ(session.state_, State::Idle); - CheckOutput(output, success_resp, sizeof(success_resp)); + if (write_success) { + if (is_reset) { + EXPECT_EQ(session.state_, State::Idle); + CheckOutput(output, success_resp, sizeof(success_resp)); + } else { + ASSERT_EQ(session.state_, State::Error); + CheckOutput(output, ignored_resp, sizeof(ignored_resp)); + } + } else { + EXPECT_EQ(session.state_, State::Close); + } } } } @@ -950,18 +977,100 @@ TEST(BoltSession, Noop) { TEST(BoltSession, Route) { // Memgraph does not support route message, but it handles it { + SCOPED_TRACE("v1"); INIT_VARS; ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); ASSERT_THROW(ExecuteCommand(input_stream, session, v4_3::route, sizeof(v4_3::route)), SessionException); + EXPECT_EQ(session.state_, State::Close); } { + SCOPED_TRACE("v4"); INIT_VARS; - ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp); + ExecuteHandshake(input_stream, session, output, v4_3::handshake_req, v4_3::handshake_resp); ExecuteInit(input_stream, session, output, true); - ASSERT_THROW(ExecuteCommand(input_stream, session, v4_3::route, sizeof(v4_3::route)), SessionException); + ASSERT_NO_THROW(ExecuteCommand(input_stream, session, v4_3::route, sizeof(v4_3::route))); + static constexpr uint8_t expected_resp[] = { + 0x00 /*two bytes of chunk header, chunk contains 64 bytes of data*/, + 0x40, + 0xb1 /*TinyStruct1*/, + 0x7f /*Failure*/, + 0xa2 /*TinyMap with 2 items*/, + 0x84 /*TinyString with 4 chars*/, + 'c', + 'o', + 'd', + 'e', + 0x82 /*TinyString with 2 chars*/, + '6', + '6', + 0x87 /*TinyString with 7 chars*/, + 'm', + 'e', + 's', + 's', + 'a', + 'g', + 'e', + 0xd0 /*String*/, + 0x2b /*With 43 chars*/, + 'R', + 'o', + 'u', + 't', + 'e', + ' ', + 'm', + 'e', + 's', + 's', + 'a', + 'g', + 'e', + ' ', + 'i', + 's', + ' ', + 'n', + 'o', + 't', + ' ', + 's', + 'u', + 'p', + 'p', + 'o', + 'r', + 't', + 'e', + 'd', + ' ', + 'i', + 'n', + ' ', + 'M', + 'e', + 'm', + 'g', + 'r', + 'a', + 'p', + 'h', + '!', + 0x00 /*Terminating zeros*/, + 0x00, + }; + EXPECT_EQ(input_stream.size(), 0U); + CheckOutput(output, expected_resp, sizeof(expected_resp)); + EXPECT_EQ(session.state_, State::Error); + + SCOPED_TRACE("Try to reset connection after ROUTE failed"); + ASSERT_NO_THROW(ExecuteCommand(input_stream, session, v4::reset_req, sizeof(v4::reset_req))); + EXPECT_EQ(input_stream.size(), 0U); + CheckOutput(output, success_resp, sizeof(success_resp)); + EXPECT_EQ(session.state_, State::Idle); } } @@ -992,3 +1101,24 @@ TEST(BoltSession, Rollback) { ASSERT_THROW(ExecuteCommand(input_stream, session, v4::rollback, sizeof(v4::rollback)), SessionException); } } + +TEST(BoltSession, ResetInIdle) { + { + SCOPED_TRACE("v1"); + INIT_VARS; + + ExecuteHandshake(input_stream, session, output); + ExecuteInit(input_stream, session, output); + ASSERT_NO_THROW(ExecuteCommand(input_stream, session, reset_req, sizeof(reset_req))); + EXPECT_EQ(session.state_, State::Idle); + } + { + SCOPED_TRACE("v4"); + INIT_VARS; + + ExecuteHandshake(input_stream, session, output, v4_3::handshake_req, v4_3::handshake_resp); + ExecuteInit(input_stream, session, output, true); + ASSERT_NO_THROW(ExecuteCommand(input_stream, session, v4::reset_req, sizeof(v4::reset_req))); + EXPECT_EQ(session.state_, State::Idle); + } +}