From f05fcd91c3ee3cae2fb196b442fd7dda9c6fbfe9 Mon Sep 17 00:00:00 2001 From: Matej Ferencevic <matej.ferencevic@memgraph.io> Date: Sat, 15 Apr 2017 15:14:12 +0200 Subject: [PATCH] Refactored bolt session to use new decoder. Summary: Bolt buffer is now a template. Communication worker now has a new interface. Fixed network tests to use new interface. Fixed bolt tests to use new interface. Added more functions to bolt decoder. Reviewers: dgleich, buda Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D256 --- CMakeLists.txt | 2 - src/communication/bolt/v1/bolt_exception.hpp | 11 - src/communication/bolt/v1/codes.hpp | 23 +- src/communication/bolt/v1/constants.hpp | 5 + src/communication/bolt/v1/decoder/buffer.hpp | 14 +- .../v1/decoder/chunked_decoder_buffer.hpp | 42 +- src/communication/bolt/v1/decoder/decoder.hpp | 56 +- .../bolt/v1/encoder/base_encoder.hpp | 1 - .../v1/encoder/chunked_encoder_buffer.hpp | 79 ++- src/communication/bolt/v1/encoder/encoder.hpp | 65 +- .../bolt/v1/encoder/result_stream.hpp | 6 +- src/communication/bolt/v1/messaging/codes.hpp | 38 -- src/communication/bolt/v1/packing/codes.hpp | 55 -- src/communication/bolt/v1/packing/types.hpp | 39 -- src/communication/bolt/v1/session.hpp | 118 +++- src/communication/bolt/v1/state.hpp | 39 +- src/communication/bolt/v1/states/error.hpp | 78 ++- src/communication/bolt/v1/states/executor.hpp | 240 ++++--- .../bolt/v1/states/handshake.hpp | 27 +- src/communication/bolt/v1/states/init.hpp | 63 +- .../bolt/v1/transport/bolt_decoder.cpp | 108 --- .../bolt/v1/transport/bolt_decoder.hpp | 41 -- .../bolt/v1/transport/buffer.cpp | 10 - .../bolt/v1/transport/buffer.hpp | 26 - .../bolt/v1/transport/chunked_decoder.hpp | 63 -- .../bolt/v1/transport/stream_error.hpp | 11 - .../v1/transport/streamed_bolt_decoder.hpp | 307 --------- src/communication/worker.hpp | 13 +- src/io/network/stream_reader.hpp | 13 +- src/utils/exceptions/not_yet_implemented.hpp | 2 +- src/utils/exceptions/stacktrace_exception.hpp | 5 +- tests/concurrent/network_common.hpp | 37 +- tests/concurrent/network_read_hang.cpp | 14 +- tests/unit/bolt_buffer.cpp | 2 +- tests/unit/bolt_chunked_decoder.cpp | 59 -- tests/unit/bolt_chunked_decoder_buffer.cpp | 21 +- tests/unit/bolt_common.hpp | 16 +- tests/unit/bolt_result_stream.cpp | 3 +- tests/unit/bolt_session.cpp | 620 ++++++++++++++++-- 39 files changed, 1260 insertions(+), 1112 deletions(-) delete mode 100644 src/communication/bolt/v1/bolt_exception.hpp delete mode 100644 src/communication/bolt/v1/messaging/codes.hpp delete mode 100644 src/communication/bolt/v1/packing/codes.hpp delete mode 100644 src/communication/bolt/v1/packing/types.hpp delete mode 100644 src/communication/bolt/v1/transport/bolt_decoder.cpp delete mode 100644 src/communication/bolt/v1/transport/bolt_decoder.hpp delete mode 100644 src/communication/bolt/v1/transport/buffer.cpp delete mode 100644 src/communication/bolt/v1/transport/buffer.hpp delete mode 100644 src/communication/bolt/v1/transport/chunked_decoder.hpp delete mode 100644 src/communication/bolt/v1/transport/stream_error.hpp delete mode 100644 src/communication/bolt/v1/transport/streamed_bolt_decoder.hpp delete mode 100644 tests/unit/bolt_chunked_decoder.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 553f1e9f8..e8361d968 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -328,8 +328,6 @@ set(memgraph_src_files ${src_dir}/utils/string/join.cpp ${src_dir}/utils/string/file.cpp ${src_dir}/utils/numerics/saturate.cpp - ${src_dir}/communication/bolt/v1/transport/bolt_decoder.cpp - ${src_dir}/communication/bolt/v1/transport/buffer.cpp ${src_dir}/io/network/addrinfo.cpp ${src_dir}/io/network/network_endpoint.cpp ${src_dir}/io/network/socket.cpp diff --git a/src/communication/bolt/v1/bolt_exception.hpp b/src/communication/bolt/v1/bolt_exception.hpp deleted file mode 100644 index 883c4986d..000000000 --- a/src/communication/bolt/v1/bolt_exception.hpp +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include "utils/exceptions/basic_exception.hpp" - -namespace communication::bolt { - -class BoltException : public BasicException { - public: - using BasicException::BasicException; -}; -} diff --git a/src/communication/bolt/v1/codes.hpp b/src/communication/bolt/v1/codes.hpp index 6fdae975c..e36117954 100644 --- a/src/communication/bolt/v1/codes.hpp +++ b/src/communication/bolt/v1/codes.hpp @@ -1,6 +1,7 @@ #pragma once #include <cstdint> +#include "utils/underlying_cast.hpp" namespace communication::bolt { @@ -30,6 +31,15 @@ enum class Marker : uint8_t { TinyMap = 0xA0, TinyStruct = 0xB0, + // TinyStructX represents the value of TinyStruct + X + // This is defined to make decoding easier. To check if a marker is equal + // to TinyStruct + 1 you should use something like: + // underyling_cast(marker) == underyling_cast(Marker::TinyStruct) + 1 + // This way you can just use: + // marker == Marker::TinyStruct1 + TinyStruct1 = 0xB1, + TinyStruct2 = 0xB2, + Null = 0xC0, Float64 = 0xC1, @@ -58,9 +68,12 @@ enum class Marker : uint8_t { }; static constexpr uint8_t MarkerString = 0, MarkerList = 1, MarkerMap = 2; -static constexpr Marker MarkerTiny[3] = {Marker::TinyString, Marker::TinyList, Marker::TinyMap}; -static constexpr Marker Marker8[3] = {Marker::String8, Marker::List8, Marker::Map8}; -static constexpr Marker Marker16[3] = {Marker::String16, Marker::List16, Marker::Map16}; -static constexpr Marker Marker32[3] = {Marker::String32, Marker::List32, Marker::Map32}; - +static constexpr Marker MarkerTiny[3] = {Marker::TinyString, Marker::TinyList, + Marker::TinyMap}; +static constexpr Marker Marker8[3] = {Marker::String8, Marker::List8, + Marker::Map8}; +static constexpr Marker Marker16[3] = {Marker::String16, Marker::List16, + Marker::Map16}; +static constexpr Marker Marker32[3] = {Marker::String32, Marker::List32, + Marker::Map32}; } diff --git a/src/communication/bolt/v1/constants.hpp b/src/communication/bolt/v1/constants.hpp index 866de247a..3d8cb3369 100644 --- a/src/communication/bolt/v1/constants.hpp +++ b/src/communication/bolt/v1/constants.hpp @@ -10,4 +10,9 @@ static constexpr size_t MAX_CHUNK_SIZE = 65535; static constexpr size_t CHUNK_END_MARKER_SIZE = 2; static constexpr size_t WHOLE_CHUNK_SIZE = CHUNK_HEADER_SIZE + MAX_CHUNK_SIZE + CHUNK_END_MARKER_SIZE; + +/** + * Handshake size defined in the Bolt protocol. + */ +static constexpr size_t HANDSHAKE_SIZE = 20; } diff --git a/src/communication/bolt/v1/decoder/buffer.hpp b/src/communication/bolt/v1/decoder/buffer.hpp index 90b70c4f4..72b8e6d18 100644 --- a/src/communication/bolt/v1/decoder/buffer.hpp +++ b/src/communication/bolt/v1/decoder/buffer.hpp @@ -21,7 +21,11 @@ namespace communication::bolt { * Allocating, writing and written stores data in the buffer. The stored * data can then be read using the pointer returned with the data function. * The current implementation stores data in a single fixed length buffer. + * + * @tparam Size the size of the internal byte array, defaults to the maximum + * size of a chunk in the Bolt protocol */ +template <size_t Size = WHOLE_CHUNK_SIZE> class Buffer : public Loggable { private: using StreamBufferT = io::network::StreamBuffer; @@ -36,7 +40,7 @@ class Buffer : public Loggable { * available memory. */ StreamBufferT Allocate() { - return StreamBufferT{&data_[size_], WHOLE_CHUNK_SIZE - size_}; + return StreamBufferT{&data_[size_], Size - size_}; } /** @@ -51,7 +55,7 @@ class Buffer : public Loggable { */ void Written(size_t len) { size_ += len; - debug_assert(size_ <= WHOLE_CHUNK_SIZE, "Written more than storage has space!"); + debug_assert(size_ <= Size, "Written more than storage has space!"); } /** @@ -70,9 +74,7 @@ class Buffer : public Loggable { /** * This method clears the buffer. */ - void Clear() { - size_ = 0; - } + void Clear() { size_ = 0; } /** * This function returns a pointer to the internal buffer. It is used for @@ -86,7 +88,7 @@ class Buffer : public Loggable { size_t size() { return size_; } private: - uint8_t data_[WHOLE_CHUNK_SIZE]; + uint8_t data_[Size]; size_t size_{0}; }; } diff --git a/src/communication/bolt/v1/decoder/chunked_decoder_buffer.hpp b/src/communication/bolt/v1/decoder/chunked_decoder_buffer.hpp index 460e08291..27cc66812 100644 --- a/src/communication/bolt/v1/decoder/chunked_decoder_buffer.hpp +++ b/src/communication/bolt/v1/decoder/chunked_decoder_buffer.hpp @@ -12,6 +12,22 @@ namespace communication::bolt { +/** + * This class is used as the return value of the GetChunk function of the + * ChunkedDecoderBuffer. It represents the 3 situations that can happen when + * reading a chunk. + */ +enum class ChunkState : uint8_t { + // The chunk isn't complete, we have to read more data + Partial, + + // The chunk is invalid, it's tail isn't 0x00 0x00 + Invalid, + + // The chunk is whole and correct and has been loaded into the buffer + Whole +}; + /** * @brief ChunkedDecoderBuffer * @@ -27,7 +43,8 @@ class ChunkedDecoderBuffer : public Loggable { using StreamBufferT = io::network::StreamBuffer; public: - ChunkedDecoderBuffer(Buffer &buffer) : Loggable("ChunkedDecoderBuffer"), buffer_(buffer) {} + ChunkedDecoderBuffer(Buffer<> &buffer) + : Loggable("ChunkedDecoderBuffer"), buffer_(buffer) {} /** * Reads data from the internal buffer. @@ -55,27 +72,27 @@ class ChunkedDecoderBuffer : public Loggable { * @returns true if a chunk was successfully copied into the internal * buffer, false otherwise */ - bool GetChunk() { + ChunkState GetChunk() { uint8_t *data = buffer_.data(); size_t size = buffer_.size(); if (size < 2) { logger.trace("Size < 2"); - return false; + return ChunkState::Partial; } size_t chunk_size = data[0]; chunk_size <<= 8; chunk_size += data[1]; if (size < chunk_size + 4) { - logger.trace("Chunk size is {} but only have {} data bytes.", chunk_size, size); - return false; + logger.trace("Chunk size is {} but only have {} data bytes.", chunk_size, + size); + return ChunkState::Partial; } if (data[chunk_size + 2] != 0 || data[chunk_size + 3] != 0) { logger.trace("Invalid chunk!"); buffer_.Shift(chunk_size + 4); - // TODO: raise an exception! - return false; + return ChunkState::Invalid; } pos_ = 0; @@ -83,11 +100,18 @@ class ChunkedDecoderBuffer : public Loggable { memcpy(data_, data + 2, size - 4); buffer_.Shift(chunk_size + 4); - return true; + return ChunkState::Whole; } + /** + * Gets the size of currently available data in the loaded chunk. + * + * @returns size of available data + */ + size_t Size() { return size_ - pos_; } + private: - Buffer &buffer_; + Buffer<> &buffer_; uint8_t data_[MAX_CHUNK_SIZE]; size_t size_{0}; size_t pos_{0}; diff --git a/src/communication/bolt/v1/decoder/decoder.hpp b/src/communication/bolt/v1/decoder/decoder.hpp index 5bb1bf92e..60d200413 100644 --- a/src/communication/bolt/v1/decoder/decoder.hpp +++ b/src/communication/bolt/v1/decoder/decoder.hpp @@ -46,8 +46,7 @@ template <typename Buffer> class Decoder : public Loggable { public: Decoder(Buffer &buffer) - : Loggable("communication::bolt::Decoder"), - buffer_(buffer) {} + : Loggable("communication::bolt::Decoder"), buffer_(buffer) {} /** * Reads a TypedValue from the available data in the buffer. @@ -136,6 +135,31 @@ class Decoder : public Loggable { return true; } + /** + * Reads a Message header from the available data in the buffer. + * + * @param signature pointer to a Signature where the signature should be + * stored + * @param marker pointer to a Signature where the marker should be stored + * @returns true if data has been written into the data pointers, + * false otherwise + */ + bool ReadMessageHeader(Signature *signature, Marker *marker) { + uint8_t values[2]; + + logger.trace("[ReadMessageHeader] Start"); + + if (!buffer_.Read(values, 2)) { + logger.debug("[ReadMessageHeader] Marker data missing!"); + return false; + } + + *marker = (Marker)values[0]; + *signature = (Signature)values[1]; + logger.trace("[ReadMessageHeader] Success"); + return true; + } + /** * Reads a Vertex from the available data in the buffer. * This function tries to read a Vertex from the available data. @@ -306,7 +330,7 @@ class Decoder : public Loggable { logger.trace("[ReadInt] Found an Int8"); int8_t tmp; if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) { - logger.debug( "[ReadInt] Int8 missing data!"); + logger.debug("[ReadInt] Int8 missing data!"); return false; } ret = tmp; @@ -314,7 +338,7 @@ class Decoder : public Loggable { logger.trace("[ReadInt] Found an Int16"); int16_t tmp; if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) { - logger.debug( "[ReadInt] Int16 missing data!"); + logger.debug("[ReadInt] Int16 missing data!"); return false; } ret = bswap(tmp); @@ -322,19 +346,20 @@ class Decoder : public Loggable { logger.trace("[ReadInt] Found an Int32"); int32_t tmp; if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) { - logger.debug( "[ReadInt] Int32 missing data!"); + logger.debug("[ReadInt] Int32 missing data!"); return false; } ret = bswap(tmp); } else if (marker == Marker::Int64) { logger.trace("[ReadInt] Found an Int64"); if (!buffer_.Read(reinterpret_cast<uint8_t *>(&ret), sizeof(ret))) { - logger.debug( "[ReadInt] Int64 missing data!"); + logger.debug("[ReadInt] Int64 missing data!"); return false; } ret = bswap(ret); } else { - logger.debug("[ReadInt] Received invalid marker ({})!", underlying_cast(marker)); + logger.debug("[ReadInt] Received invalid marker ({})!", + underlying_cast(marker)); return false; } if (success) { @@ -350,7 +375,7 @@ class Decoder : public Loggable { logger.trace("[ReadDouble] Start"); debug_assert(marker == Marker::Float64, "Received invalid marker!"); if (!buffer_.Read(reinterpret_cast<uint8_t *>(&value), sizeof(value))) { - logger.debug( "[ReadDouble] Missing data!"); + logger.debug("[ReadDouble] Missing data!"); return false; } value = bswap(value); @@ -369,7 +394,7 @@ class Decoder : public Loggable { logger.trace("[ReadTypeSize] Found a Type8"); uint8_t tmp; if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) { - logger.debug( "[ReadTypeSize] Type8 missing data!"); + logger.debug("[ReadTypeSize] Type8 missing data!"); return -1; } return tmp; @@ -377,7 +402,7 @@ class Decoder : public Loggable { logger.trace("[ReadTypeSize] Found a Type16"); uint16_t tmp; if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) { - logger.debug( "[ReadTypeSize] Type16 missing data!"); + logger.debug("[ReadTypeSize] Type16 missing data!"); return -1; } tmp = bswap(tmp); @@ -386,13 +411,14 @@ class Decoder : public Loggable { logger.trace("[ReadTypeSize] Found a Type32"); uint32_t tmp; if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) { - logger.debug( "[ReadTypeSize] Type32 missing data!"); + logger.debug("[ReadTypeSize] Type32 missing data!"); return -1; } tmp = bswap(tmp); return tmp; } else { - logger.debug("[ReadTypeSize] Received invalid marker ({})!", underlying_cast(marker)); + logger.debug("[ReadTypeSize] Received invalid marker ({})!", + underlying_cast(marker)); return -1; } } @@ -409,7 +435,8 @@ class Decoder : public Loggable { logger.debug("[ReadString] Missing data!"); return false; } - *data = query::TypedValue(std::string(reinterpret_cast<char *>(ret.get()), size)); + *data = query::TypedValue( + std::string(reinterpret_cast<char *>(ret.get()), size)); logger.trace("[ReadString] Success"); return true; } @@ -462,7 +489,8 @@ class Decoder : public Loggable { ret.insert(std::make_pair(str, tv)); } if (ret.size() != size) { - logger.debug("[ReadMap] The client sent multiple objects with same indexes!"); + logger.debug( + "[ReadMap] The client sent multiple objects with same indexes!"); return false; } diff --git a/src/communication/bolt/v1/encoder/base_encoder.hpp b/src/communication/bolt/v1/encoder/base_encoder.hpp index 824e6ff53..312cbcc04 100644 --- a/src/communication/bolt/v1/encoder/base_encoder.hpp +++ b/src/communication/bolt/v1/encoder/base_encoder.hpp @@ -6,7 +6,6 @@ #include "logging/logger.hpp" #include "query/typed_value.hpp" #include "utils/bswap.hpp" -#include "utils/underlying_cast.hpp" #include <string> diff --git a/src/communication/bolt/v1/encoder/chunked_encoder_buffer.hpp b/src/communication/bolt/v1/encoder/chunked_encoder_buffer.hpp index c12554171..1a0be2493 100644 --- a/src/communication/bolt/v1/encoder/chunked_encoder_buffer.hpp +++ b/src/communication/bolt/v1/encoder/chunked_encoder_buffer.hpp @@ -5,7 +5,6 @@ #include <memory> #include <vector> -#include "communication/bolt/v1/bolt_exception.hpp" #include "communication/bolt/v1/constants.hpp" #include "logging/loggable.hpp" #include "utils/bswap.hpp" @@ -39,7 +38,8 @@ namespace communication::bolt { template <class Socket> class ChunkedEncoderBuffer : public Loggable { public: - ChunkedEncoderBuffer(Socket &socket) : Loggable("Chunked Encoder Buffer"), socket_(socket) {} + ChunkedEncoderBuffer(Socket &socket) + : Loggable("Chunked Encoder Buffer"), socket_(socket) {} /** * Writes n values into the buffer. If n is bigger than whole chunk size @@ -51,7 +51,7 @@ class ChunkedEncoderBuffer : public Loggable { void Write(const uint8_t *values, size_t n) { while (n > 0) { // Define number of bytes which will be copied into chunk because - // chunk is a fixed lenght array. + // chunk is a fixed length array. auto size = n < MAX_CHUNK_SIZE + CHUNK_HEADER_SIZE - pos_ ? n : MAX_CHUNK_SIZE + CHUNK_HEADER_SIZE - pos_; @@ -90,13 +90,16 @@ class ChunkedEncoderBuffer : public Loggable { debug_assert(pos_ <= WHOLE_CHUNK_SIZE, "Internal variable pos_ is bigger than the whole chunk size."); - // 3. Copy whole chunk into the buffer. + // 3. Remember first chunk size. + if (first_chunk_size_ == -1) first_chunk_size_ = pos_; + + // 4. Copy whole chunk into the buffer. size_ += pos_; buffer_.reserve(size_); std::copy(chunk_.begin(), chunk_.begin() + pos_, std::back_inserter(buffer_)); - // 4. Cleanup. + // 5. Cleanup. // * pos_ has to be reset to the size of chunk header (reserved // space for the chunk size) pos_ = CHUNK_HEADER_SIZE; @@ -104,24 +107,72 @@ class ChunkedEncoderBuffer : public Loggable { /** * Sends the whole buffer(message) to the client. + * @returns true if the data was successfully sent to the client + * false otherwise */ - void Flush() { + bool Flush() { // Call chunk if is hasn't been called. if (pos_ > CHUNK_HEADER_SIZE) Chunk(); // Early return if buffer is empty because there is nothing to write. - if (size_ == 0) return; + if (size_ == 0) return true; // Flush the whole buffer. - bool written = socket_.Write(buffer_.data(), size_); - if (!written) throw BoltException("Socket write failed!"); + if (!socket_.Write(buffer_.data() + offset_, size_ - offset_)) return false; logger.trace("Flushed {} bytes.", size_); // Cleanup. + Clear(); + return true; + } + + /** + * Sends only the first message chunk in the buffer to the client. + * @returns true if the data was successfully sent to the client + * false otherwise + */ + bool FlushFirstChunk() { + // Call chunk if is hasn't been called. + if (pos_ > CHUNK_HEADER_SIZE) Chunk(); + + // Early return if buffer is empty because there is nothing to write. + if (size_ == 0) return false; + + // Early return if there is no first chunk + if (first_chunk_size_ == -1) return false; + + // Flush the first chunk + if (!socket_.Write(buffer_.data(), first_chunk_size_)) return false; + logger.trace("Flushed {} bytes.", first_chunk_size_); + + // Cleanup. + // Here we use offset as a method of deleting from the front of the + // data vector. Because the first chunk will always be relatively + // small comparing to the rest of the data it is more optimal just to + // skip the first part of the data than to shift everything in the + // vector buffer. + offset_ = first_chunk_size_; + first_chunk_size_ = -1; + return true; + } + + /** + * Clears the internal buffers. + */ + void Clear() { buffer_.clear(); size_ = 0; + first_chunk_size_ = -1; + offset_ = 0; } + /** + * Returns a boolean indicating whether there is data in the buffer. + * @returns true if there is data in the buffer, + * false otherwise + */ + bool HasData() { return buffer_.size() > 0 || size_ > 0; } + private: /** * A client socket. @@ -143,6 +194,16 @@ class ChunkedEncoderBuffer : public Loggable { */ size_t size_{0}; + /** + * Size of first chunk in the buffer. + */ + int32_t first_chunk_size_{-1}; + + /** + * Offset from the start of the buffer. + */ + size_t offset_{0}; + /** * Current position in chunk array. */ diff --git a/src/communication/bolt/v1/encoder/encoder.hpp b/src/communication/bolt/v1/encoder/encoder.hpp index 44b68dc9d..80f1d8d75 100644 --- a/src/communication/bolt/v1/encoder/encoder.hpp +++ b/src/communication/bolt/v1/encoder/encoder.hpp @@ -1,5 +1,6 @@ #pragma once +#include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/encoder/base_encoder.hpp" namespace communication::bolt { @@ -21,8 +22,7 @@ class Encoder : private BaseEncoder<Buffer> { using BaseEncoder<Buffer>::buffer_; public: - Encoder(Buffer &buffer) - : BaseEncoder<Buffer>(buffer) { + Encoder(Buffer &buffer) : BaseEncoder<Buffer>(buffer) { logger = logging::log->logger("communication::bolt::Encoder"); } @@ -40,8 +40,8 @@ class Encoder : private BaseEncoder<Buffer> { * @param values the fields list object that should be sent */ void MessageRecord(const std::vector<query::TypedValue> &values) { - // 0xB1 = struct 1; 0x71 = record signature - WriteRAW("\xB1\x71", 2); + WriteRAW(underlying_cast(Marker::TinyStruct1)); + WriteRAW(underlying_cast(Signature::Record)); WriteList(values); buffer_.Chunk(); } @@ -56,26 +56,34 @@ class Encoder : private BaseEncoder<Buffer> { * * @param metadata the metadata map object that should be sent * @param flush should method flush the socket + * @returns true if the data was successfully sent to the client + * when flushing, false otherwise */ - void MessageSuccess(const std::map<std::string, query::TypedValue> &metadata, + bool MessageSuccess(const std::map<std::string, query::TypedValue> &metadata, bool flush = true) { - // 0xB1 = struct 1; 0x70 = success signature - WriteRAW("\xB1\x70", 2); + WriteRAW(underlying_cast(Marker::TinyStruct1)); + WriteRAW(underlying_cast(Signature::Success)); WriteMap(metadata); - if (flush) - buffer_.Flush(); - else + if (flush) { + return buffer_.Flush(); + } else { buffer_.Chunk(); + // Chunk always succeeds, so return true + return true; + } } /** * Sends a Success message. * * This function sends a success message without additional metadata. + * + * @returns true if the data was successfully sent to the client, + * false otherwise */ - void MessageSuccess() { + bool MessageSuccess() { std::map<std::string, query::TypedValue> metadata; - MessageSuccess(metadata); + return MessageSuccess(metadata); } /** @@ -87,12 +95,15 @@ class Encoder : private BaseEncoder<Buffer> { * } * * @param metadata the metadata map object that should be sent + * @returns true if the data was successfully sent to the client, + * false otherwise */ - void MessageFailure(const std::map<std::string, query::TypedValue> &metadata) { - // 0xB1 = struct 1; 0x7F = failure signature - WriteRAW("\xB1\x7F", 2); + bool MessageFailure( + const std::map<std::string, query::TypedValue> &metadata) { + WriteRAW(underlying_cast(Marker::TinyStruct1)); + WriteRAW(underlying_cast(Signature::Failure)); WriteMap(metadata); - buffer_.Flush(); + return buffer_.Flush(); } /** @@ -104,23 +115,29 @@ class Encoder : private BaseEncoder<Buffer> { * } * * @param metadata the metadata map object that should be sent + * @returns true if the data was successfully sent to the client, + * false otherwise */ - void MessageIgnored(const std::map<std::string, query::TypedValue> &metadata) { - // 0xB1 = struct 1; 0x7E = ignored signature - WriteRAW("\xB1\x7E", 2); + bool MessageIgnored( + const std::map<std::string, query::TypedValue> &metadata) { + WriteRAW(underlying_cast(Marker::TinyStruct1)); + WriteRAW(underlying_cast(Signature::Ignored)); WriteMap(metadata); - buffer_.Flush(); + return buffer_.Flush(); } /** * Sends an Ignored message. * * This function sends an ignored message without additional metadata. + * + * @returns true if the data was successfully sent to the client, + * false otherwise */ - void MessageIgnored() { - // 0xB0 = struct 0; 0x7E = ignored signature - WriteRAW("\xB0\x7E", 2); - buffer_.Flush(); + bool MessageIgnored() { + WriteRAW(underlying_cast(Marker::TinyStruct)); + WriteRAW(underlying_cast(Signature::Ignored)); + return buffer_.Flush(); } }; } diff --git a/src/communication/bolt/v1/encoder/result_stream.hpp b/src/communication/bolt/v1/encoder/result_stream.hpp index d48e7c619..0650e816f 100644 --- a/src/communication/bolt/v1/encoder/result_stream.hpp +++ b/src/communication/bolt/v1/encoder/result_stream.hpp @@ -31,8 +31,10 @@ class ResultStream { std::map<std::string, query::TypedValue> data; for (auto &i : fields) vec.push_back(query::TypedValue(i)); data.insert(std::make_pair(std::string("fields"), query::TypedValue(vec))); - // this call will automaticaly send the data to the client - encoder_.MessageSuccess(data); + // this message shouldn't send directly to the client because if an error + // happened the client will receive two messages (success and failure) + // instead of only one + encoder_.MessageSuccess(data, false); } /** diff --git a/src/communication/bolt/v1/messaging/codes.hpp b/src/communication/bolt/v1/messaging/codes.hpp deleted file mode 100644 index 04bc0bff4..000000000 --- a/src/communication/bolt/v1/messaging/codes.hpp +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -#include "utils/types/byte.hpp" -#include "utils/underlying_cast.hpp" - -namespace communication::bolt { - -enum class MessageCode : byte { - Init = 0x01, - AckFailure = 0x0E, - Reset = 0x0F, - - Run = 0x10, - DiscardAll = 0x2F, - PullAll = 0x3F, - - Record = 0x71, - Success = 0x70, - Ignored = 0x7E, - Failure = 0x7F -}; - -inline bool operator==(byte value, MessageCode code) { - return value == underlying_cast(code); -} - -inline bool operator==(MessageCode code, byte value) { - return operator==(value, code); -} - -inline bool operator!=(byte value, MessageCode code) { - return !operator==(value, code); -} - -inline bool operator!=(MessageCode code, byte value) { - return operator!=(value, code); -} -} diff --git a/src/communication/bolt/v1/packing/codes.hpp b/src/communication/bolt/v1/packing/codes.hpp deleted file mode 100644 index b21aa3b98..000000000 --- a/src/communication/bolt/v1/packing/codes.hpp +++ /dev/null @@ -1,55 +0,0 @@ -#pragma once - -#include <cstdint> - -namespace communication::bolt::pack { - -enum Code : uint8_t { - TinyString = 0x80, - TinyList = 0x90, - TinyMap = 0xA0, - - TinyStruct = 0xB0, - StructOne = 0xB1, - StructTwo = 0xB2, - - Null = 0xC0, - - Float64 = 0xC1, - - False = 0xC2, - True = 0xC3, - - Int8 = 0xC8, - Int16 = 0xC9, - Int32 = 0xCA, - Int64 = 0xCB, - - Bytes8 = 0xCC, - Bytes16 = 0xCD, - Bytes32 = 0xCE, - - String8 = 0xD0, - String16 = 0xD1, - String32 = 0xD2, - - List8 = 0xD4, - List16 = 0xD5, - List32 = 0xD6, - - Map8 = 0xD8, - Map16 = 0xD9, - Map32 = 0xDA, - MapStream = 0xDB, - - Node = 0x4E, - Relationship = 0x52, - Path = 0x50, - - Struct8 = 0xDC, - Struct16 = 0xDD, - EndOfStream = 0xDF, -}; - -enum Rule : uint8_t { MaxInitStructSize = 0x02 }; -} diff --git a/src/communication/bolt/v1/packing/types.hpp b/src/communication/bolt/v1/packing/types.hpp deleted file mode 100644 index 9838af1c0..000000000 --- a/src/communication/bolt/v1/packing/types.hpp +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -namespace communication::bolt { - -enum class PackType { - /** denotes absence of a value */ - Null, - - /** denotes a type with two possible values (t/f) */ - Boolean, - - /** 64-bit signed integral number */ - Integer, - - /** 64-bit floating point number */ - Float, - - /** binary data */ - Bytes, - - /** unicode string */ - String, - - /** collection of values */ - List, - - /** collection of zero or more key/value pairs */ - Map, - - /** zero or more packstream values */ - Struct, - - /** denotes stream value end */ - EndOfStream, - - /** reserved for future use */ - Reserved -}; -} diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index d8461ee34..9b545b6bc 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -6,15 +6,19 @@ #include "dbms/dbms.hpp" #include "query/engine.hpp" +#include "communication/bolt/v1/constants.hpp" #include "communication/bolt/v1/state.hpp" #include "communication/bolt/v1/states/error.hpp" #include "communication/bolt/v1/states/executor.hpp" #include "communication/bolt/v1/states/handshake.hpp" #include "communication/bolt/v1/states/init.hpp" +#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp" +#include "communication/bolt/v1/decoder/decoder.hpp" #include "communication/bolt/v1/encoder/encoder.hpp" #include "communication/bolt/v1/encoder/result_stream.hpp" -#include "communication/bolt/v1/transport/bolt_decoder.hpp" + +#include "io/network/stream_buffer.hpp" #include "logging/loggable.hpp" @@ -29,10 +33,11 @@ namespace communication::bolt { */ template <typename Socket> class Session : public Loggable { - public: - using Decoder = BoltDecoder; + private: using OutputStream = ResultStream<Encoder<ChunkedEncoderBuffer<Socket>>>; + using StreamBuffer = io::network::StreamBuffer; + public: Session(Socket &&socket, Dbms &dbms, QueryEngine<OutputStream> &query_engine) : Loggable("communication::bolt::Session"), socket_(std::move(socket)), @@ -40,16 +45,17 @@ class Session : public Loggable { query_engine_(query_engine), encoder_buffer_(socket_), encoder_(encoder_buffer_), - output_stream_(encoder_) { + output_stream_(encoder_), + decoder_buffer_(buffer_), + decoder_(decoder_buffer_), + state_(State::Handshake) { event_.data.ptr = this; - // start with a handshake state - state_ = HANDSHAKE; } /** * @return is the session in a valid state */ - bool Alive() const { return state_ != NULLSTATE; } + bool Alive() const { return state_ != State::Close; } /** * @return the socket id @@ -57,48 +63,83 @@ class Session : public Loggable { int Id() const { return socket_.id(); } /** - * Reads the data from a client and goes through the bolt states in - * order to execute command from the client. - * - * @param data pointer on bytes received from a client - * @param len length of data received from a client + * Executes the session after data has been read into the buffer. + * Goes through the bolt states in order to execute commands from the client. */ - void Execute(const uint8_t *data, size_t len) { - // mark the end of the message - auto end = data + len; - - while (true) { - auto size = end - data; - + void Execute() { + // while there is data in the buffers + while (buffer_.size() > 0 || decoder_buffer_.Size() > 0) { if (LIKELY(connected_)) { - logger.debug("Decoding chunk of size {}", size); - if (!decoder_.decode(data, size)) return; + logger.debug("Decoding chunk of size {}", buffer_.size()); + auto chunk_state = decoder_buffer_.GetChunk(); + if (chunk_state == ChunkState::Partial) { + logger.trace("Chunk isn't complete!"); + return; + } else if (chunk_state == ChunkState::Invalid) { + logger.trace("Chunk is invalid!"); + ClientFailureInvalidData(); + return; + } + // if chunk_state == ChunkState::Whole then we continue with + // execution of the select below + } else if (buffer_.size() < HANDSHAKE_SIZE) { + logger.debug("Received partial handshake of size {}", buffer_.size()); + return; + } else if (buffer_.size() > HANDSHAKE_SIZE) { + logger.debug("Received too large handshake of size {}", buffer_.size()); + ClientFailureInvalidData(); + return; } else { - logger.debug("Decoding handshake of size {}", size); - decoder_.handshake(data, size); + logger.debug("Decoding handshake of size {}", buffer_.size()); } switch (state_) { - case HANDSHAKE: + case State::Handshake: state_ = StateHandshakeRun<Session<Socket>>(*this); break; - case INIT: + case State::Init: state_ = StateInitRun<Session<Socket>>(*this); break; - case EXECUTOR: + case State::Executor: state_ = StateExecutorRun<Session<Socket>>(*this); break; - case ERROR: + case State::Error: state_ = StateErrorRun<Session<Socket>>(*this); break; - case NULLSTATE: + case State::Close: + // This state is handled below break; } - decoder_.reset(); + // State::Close is handled here because we always want to check for + // it after the above select. If any of the states above return a + // State::Close then the connection should be terminated immediately. + if (state_ == State::Close) { + ClientFailureInvalidData(); + return; + } + + logger.trace("Buffer size: {}", buffer_.size()); + logger.trace("Decoder buffer size: {}", decoder_buffer_.Size()); } } + /** + * Allocates data from the internal buffer. + * Used in the underlying network stack to asynchronously read data + * from the client. + * @returns a StreamBuffer to the allocated internal data buffer + */ + StreamBuffer Allocate() { return buffer_.Allocate(); } + + /** + * Notifies the internal buffer of written data. + * Used in the underlying network stack to notify the internal buffer + * how many bytes of data have been written. + * @param len how many data was written to the buffer + */ + void Written(size_t len) { buffer_.Written(len); } + /** * Closes the session (client socket). */ @@ -112,12 +153,29 @@ class Session : public Loggable { Socket socket_; Dbms &dbms_; QueryEngine<OutputStream> &query_engine_; + ChunkedEncoderBuffer<Socket> encoder_buffer_; Encoder<ChunkedEncoderBuffer<Socket>> encoder_; OutputStream output_stream_; - Decoder decoder_; + + Buffer<> buffer_; + ChunkedDecoderBuffer decoder_buffer_; + Decoder<ChunkedDecoderBuffer> decoder_; + io::network::Epoll::Event event_; bool connected_{false}; State state_; + + private: + void ClientFailureInvalidData() { + // set the state to Close + state_ = State::Close; + // don't care about the return status because this is always + // called when we are about to close the connection to the client + encoder_.MessageFailure({{"code", "Memgraph.InvalidData"}, + {"message", "The client has sent invalid data!"}}); + // close the connection + Close(); + } }; } diff --git a/src/communication/bolt/v1/state.hpp b/src/communication/bolt/v1/state.hpp index aed4bd281..8dcb20c27 100644 --- a/src/communication/bolt/v1/state.hpp +++ b/src/communication/bolt/v1/state.hpp @@ -3,15 +3,36 @@ namespace communication::bolt { /** - * TODO (mferencevic): change to a class enum & document (explain states in - * more details) + * This class represents states in execution of the Bolt protocol. + * It is used only internally in the Session. All functions that run + * these states can be found in the states/ subdirectory. */ -enum State { - HANDSHAKE, - INIT, - EXECUTOR, - ERROR, - NULLSTATE -}; +enum class State : uint8_t { + /** + * This state negotiates a handshake with the client. + */ + Handshake, + /** + * This state initializes the Bolt session. + */ + Init, + + /** + * This state executes commands from the Bolt protocol. + */ + Executor, + + /** + * This state handles errors. + */ + Error, + + /** + * This is a 'virtual' state (it doesn't have a run function) which tells + * the session that the client has sent malformed data and that the + * session should be closed. + */ + Close +}; } diff --git a/src/communication/bolt/v1/states/error.hpp b/src/communication/bolt/v1/states/error.hpp index 5879d3400..bfb44e58f 100644 --- a/src/communication/bolt/v1/states/error.hpp +++ b/src/communication/bolt/v1/states/error.hpp @@ -1,38 +1,72 @@ #pragma once -#include "communication/bolt/v1/messaging/codes.hpp" +#include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/state.hpp" #include "logging/default.hpp" namespace communication::bolt { /** - * TODO (mferencevic): finish & document + * Error state run function + * This function handles a Bolt session when it is in an error state. + * The error state is exited upon receiving an ACK_FAILURE or RESET message. + * @param session the session that should be used for the run */ template <typename Session> State StateErrorRun(Session &session) { static Logger logger = logging::log->logger("State ERROR"); - session.decoder_.read_byte(); - auto message_type = session.decoder_.read_byte(); - - logger.trace("Message type byte is: {:02X}", message_type); - - if (message_type == MessageCode::PullAll) { - session.encoder_.MessageIgnored(); - return ERROR; - } else if (message_type == MessageCode::AckFailure) { - // TODO reset current statement? is it even necessary? - logger.trace("AckFailure received"); - session.encoder_.MessageSuccess(); - return EXECUTOR; - } else if (message_type == MessageCode::Reset) { - // TODO rollback current transaction - // discard all records waiting to be sent - session.encoder_.MessageSuccess(); - return EXECUTOR; + Marker marker; + Signature signature; + if (!session.decoder_.ReadMessageHeader(&signature, &marker)) { + logger.debug("Missing header data!"); + return State::Close; + } + + logger.trace("Message signature is: 0x{:02X}", underlying_cast(signature)); + + // clear the data buffer if it has any leftover data + session.encoder_buffer_.Clear(); + + if (signature == Signature::AckFailure || signature == Signature::Reset) { + if (signature == Signature::AckFailure) + logger.trace("AckFailure received"); + else + logger.trace("Reset received"); + + if (!session.encoder_.MessageSuccess()) { + logger.debug("Couldn't send success message!"); + return State::Close; + } + return State::Executor; + } else { + uint8_t value = underlying_cast(marker); + + // all bolt client messages have less than 15 parameters + // so if we receive anything than a TinyStruct it's an error + if ((value & 0xF0) != underlying_cast(Marker::TinyStruct)) { + logger.debug("Expected TinyStruct marker, but received 0x{:02X}!", value); + return State::Close; + } + + // we need to clean up all parameters from this command + value &= 0x0F; // the length is stored in the lower nibble + query::TypedValue tv; + for (int i = 0; i < value; ++i) { + if (!session.decoder_.ReadTypedValue(&tv)) { + logger.debug("Couldn't clean up parameter {} / {}!", i, value); + return State::Close; + } + } + + // ignore this message + if (!session.encoder_.MessageIgnored()) { + logger.debug("Couldn't send ignored message!"); + return State::Close; + } + + // cleanup done, command ignored, stay in error state + return State::Error; } - session.encoder_.MessageIgnored(); - return ERROR; } } diff --git a/src/communication/bolt/v1/states/executor.hpp b/src/communication/bolt/v1/states/executor.hpp index 9be74160d..3e084bca0 100644 --- a/src/communication/bolt/v1/states/executor.hpp +++ b/src/communication/bolt/v1/states/executor.hpp @@ -2,131 +2,209 @@ #include <string> -#include "communication/bolt/v1/bolt_exception.hpp" -#include "communication/bolt/v1/messaging/codes.hpp" +#include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/state.hpp" #include "logging/default.hpp" #include "query/exceptions.hpp" namespace communication::bolt { -struct Query { - Query(std::string &&statement) : statement(statement) {} - std::string statement; -}; - -template <typename Session> -void StateExecutorFailure( - Session &session, Logger &logger, - const std::map<std::string, query::TypedValue> &metadata) { - try { - session.encoder_.MessageFailure(metadata); - } catch (const BoltException &e) { - logger.debug("MessageFailure failed because: {}", e.what()); - session.Close(); - } -} - /** - * TODO (mferencevic): finish & document + * Executor state run function + * This function executes an initialized Bolt session. + * It executes: RUN, PULL_ALL, DISCARD_ALL & RESET. + * @param session the session that should be used for the run */ template <typename Session> State StateExecutorRun(Session &session) { // initialize logger static Logger logger = logging::log->logger("State EXECUTOR"); - // just read one byte that represents the struct type, we can skip the - // information contained in this byte - session.decoder_.read_byte(); - auto message_type = session.decoder_.read_byte(); + Marker marker; + Signature signature; + if (!session.decoder_.ReadMessageHeader(&signature, &marker)) { + logger.debug("Missing header data!"); + return State::Close; + } - if (message_type == MessageCode::Run) { - Query query(session.decoder_.read_string()); + if (signature == Signature::Run) { + if (marker != Marker::TinyStruct2) { + logger.debug("Expected TinyStruct2 marker, but received 0x{:02X}!", + underlying_cast(marker)); + return State::Close; + } + + query::TypedValue query, params; + if (!session.decoder_.ReadTypedValue(&query, + query::TypedValue::Type::String)) { + logger.debug("Couldn't read query string!"); + return State::Close; + } + + if (!session.decoder_.ReadTypedValue(¶ms, + query::TypedValue::Type::Map)) { + logger.debug("Couldn't read parameters!"); + return State::Close; + } - // TODO (mferencevic): implement proper exception handling auto db_accessor = session.dbms_.active(); logger.debug("[ActiveDB] '{}'", db_accessor->name()); try { - logger.trace("[Run] '{}'", query.statement); + logger.trace("[Run] '{}'", query.Value<std::string>()); auto is_successfully_executed = session.query_engine_.Run( - query.statement, *db_accessor, session.output_stream_); + query.Value<std::string>(), *db_accessor, session.output_stream_); if (!is_successfully_executed) { + // abort transaction db_accessor->abort(); - StateExecutorFailure<Session>( - session, logger, + + // clear any leftover messages in the buffer + session.encoder_buffer_.Clear(); + + // send failure message + bool exec_fail_sent = session.encoder_.MessageFailure( {{"code", "Memgraph.QueryExecutionFail"}, {"message", "Query execution has failed (probably there is no " "element or there are some problems with concurrent " "access -> client has to resolve problems with " "concurrent access)"}}); - return ERROR; + + if (!exec_fail_sent) { + logger.debug("Couldn't send failure message!"); + return State::Close; + } else { + logger.debug("Query execution failed!"); + return State::Error; + } } else { db_accessor->commit(); + // The query engine has already stored all query data in the buffer. + // We should only send the first chunk now which is the success + // message which contains header data. The rest of this data (records + // and summary) will be sent after a PULL_ALL command from the client. + if (!session.encoder_buffer_.FlushFirstChunk()) { + logger.debug("Couldn't flush header data from the buffer!"); + return State::Close; + } + return State::Executor; } - return EXECUTOR; - // !! QUERY ENGINE -> RUN METHOD -> EXCEPTION HANDLING !! - } catch (const query::SyntaxException &e) { + } catch (const BasicException &e) { + // clear header success message + session.encoder_buffer_.Clear(); db_accessor->abort(); - StateExecutorFailure<Session>( - session, logger, - {{"code", "Memgraph.SyntaxException"}, {"message", "Syntax error"}}); - return ERROR; - } catch (const query::QueryEngineException &e) { - db_accessor->abort(); - StateExecutorFailure<Session>( - session, logger, - {{"code", "Memgraph.QueryEngineException"}, - {"message", "Query engine was unable to execute the query"}}); - return ERROR; + bool fail_sent = session.encoder_.MessageFailure( + {{"code", "Memgraph.Exception"}, {"message", e.what()}}); + logger.debug("Error message: {}", e.what()); + if (!fail_sent) { + logger.debug("Couldn't send failure message!"); + return State::Close; + } + return State::Error; + } catch (const StacktraceException &e) { + // clear header success message + session.encoder_buffer_.Clear(); db_accessor->abort(); - StateExecutorFailure<Session>(session, logger, - {{"code", "Memgraph.StacktraceException"}, - {"message", "Unknown exception"}}); - return ERROR; - } catch (const BoltException &e) { - db_accessor->abort(); - logger.debug("Failed because: {}", e.what()); - session.Close(); + bool fail_sent = session.encoder_.MessageFailure( + {{"code", "Memgraph.Exception"}, {"message", e.what()}}); + logger.debug("Error message: {}", e.what()); + logger.debug("Error trace: {}", e.trace()); + if (!fail_sent) { + logger.debug("Couldn't send failure message!"); + return State::Close; + } + return State::Error; + } catch (std::exception &e) { + // clear header success message + session.encoder_buffer_.Clear(); db_accessor->abort(); - StateExecutorFailure<Session>( - session, logger, - {{"code", "Memgraph.Exception"}, {"message", "Unknown exception"}}); - return ERROR; + bool fail_sent = session.encoder_.MessageFailure( + {{"code", "Memgraph.Exception"}, + {"message", + "An unknown exception occured, please contact your database " + "administrator."}}); + logger.debug("Unknown exception!!!"); + if (!fail_sent) { + logger.debug("Couldn't send failure message!"); + return State::Close; + } + return State::Error; } - // TODO (mferencevic): finish the error handling, cover all exceptions - // which can be raised from query engine - // * [abort, MessageFailure, return ERROR] should be extracted into - // separate function (or something equivalent) - // - // !! QUERY ENGINE -> RUN METHOD -> EXCEPTION HANDLING !! - } else if (message_type == MessageCode::PullAll) { + } else if (signature == Signature::PullAll) { logger.trace("[PullAll]"); - session.encoder_buffer_.Flush(); - } else if (message_type == MessageCode::DiscardAll) { + if (marker != Marker::TinyStruct) { + logger.debug("Expected TinyStruct marker, but received 0x{:02X}!", + underlying_cast(marker)); + return State::Close; + } + if (!session.encoder_buffer_.HasData()) { + // the buffer doesn't have data, return a failure message + bool data_fail_sent = session.encoder_.MessageFailure( + {{"code", "Memgraph.Exception"}, + {"message", + "There is no data to " + "send, you have to execute a RUN command before a PULL_ALL!"}}); + if (!data_fail_sent) { + logger.debug("Couldn't send failure message!"); + return State::Close; + } + return State::Error; + } + // flush pending data to the client, the success message is streamed + // from the query engine, it contains statistics from the query run + if (!session.encoder_buffer_.Flush()) { + logger.debug("Couldn't flush data from the buffer!"); + return State::Close; + } + return State::Executor; + + } else if (signature == Signature::DiscardAll) { logger.trace("[DiscardAll]"); + if (marker != Marker::TinyStruct) { + logger.debug("Expected TinyStruct marker, but received 0x{:02X}!", + underlying_cast(marker)); + return State::Close; + } + // clear all pending data and send a success message + session.encoder_buffer_.Clear(); + if (!session.encoder_.MessageSuccess()) { + logger.debug("Couldn't send success message!"); + return State::Close; + } + return State::Executor; + + } else if (signature == Signature::Reset) { + // IMPORTANT: This implementation of the Bolt RESET command isn't fully + // compliant to the protocol definition. In the protocol it is defined + // that this command should immediately stop any running commands and + // reset the session to a clean state. That means that we should always + // make a look-ahead for the RESET command before processing anything. + // Our implementation, for now, does everything in a blocking fashion + // so we cannot simply "kill" a transaction while it is running. So + // now this command only resets the session to a clean state. It + // does not IGNORE running and pending commands as it should. + if (marker != Marker::TinyStruct) { + logger.debug("Expected TinyStruct marker, but received 0x{:02X}!", + underlying_cast(marker)); + return State::Close; + } + // clear all pending data and send a success message + session.encoder_buffer_.Clear(); + if (!session.encoder_.MessageSuccess()) { + logger.debug("Couldn't send success message!"); + return State::Close; + } + return State::Executor; - // TODO: discard state - // TODO: write_success, send - session.encoder_.MessageSuccess(); - } else if (message_type == MessageCode::Reset) { - // TODO: rollback current transaction - // discard all records waiting to be sent - session.encoder_.MessageSuccess(); - return EXECUTOR; } else { - logger.error("Unrecognized message recieved"); - logger.debug("Invalid message type 0x{:02X}", message_type); - - return ERROR; + logger.debug("Unrecognized signature recieved (0x{:02X})!", + underlying_cast(signature)); + return State::Close; } - - return EXECUTOR; } } diff --git a/src/communication/bolt/v1/states/handshake.hpp b/src/communication/bolt/v1/states/handshake.hpp index 5667f1af9..900e9f73e 100644 --- a/src/communication/bolt/v1/states/handshake.hpp +++ b/src/communication/bolt/v1/states/handshake.hpp @@ -1,31 +1,42 @@ #pragma once #include "communication/bolt/v1/state.hpp" -#include "communication/bolt/v1/transport/bolt_decoder.hpp" #include "logging/default.hpp" namespace communication::bolt { -static constexpr uint32_t preamble = 0x6060B017; -static constexpr byte protocol[4] = {0x00, 0x00, 0x00, 0x01}; +static constexpr uint8_t preamble[4] = {0x60, 0x60, 0xB0, 0x17}; +static constexpr uint8_t protocol[4] = {0x00, 0x00, 0x00, 0x01}; /** - * TODO (mferencevic): finish & document + * Handshake state run function + * This function runs everything to make a Bolt handshake with the client. + * @param session the session that should be used for the run */ template <typename Session> State StateHandshakeRun(Session &session) { static Logger logger = logging::log->logger("State HANDSHAKE"); - if (UNLIKELY(session.decoder_.read_uint32() != preamble)) return NULLSTATE; + auto precmp = memcmp(session.buffer_.data(), preamble, sizeof(preamble)); + if (UNLIKELY(precmp != 0)) { + logger.debug("Received a wrong preamble!"); + return State::Close; + } // TODO so far we only support version 1 of the protocol so it doesn't // make sense to check which version the client prefers // this will change in the future + if (!session.socket_.Write(protocol, sizeof(protocol))) { + logger.debug("Couldn't write handshake response!"); + return State::Close; + } session.connected_ = true; - // TODO: check for success - session.socket_.Write(protocol, sizeof protocol); - return INIT; + // Delete data from buffer. It is guaranteed that there will be exactly + // 20 bytes in the buffer so we can use buffer_.size() here. + session.buffer_.Shift(session.buffer_.size()); + + return State::Init; } } diff --git a/src/communication/bolt/v1/states/init.hpp b/src/communication/bolt/v1/states/init.hpp index 6025c3d0a..8b04ec30f 100644 --- a/src/communication/bolt/v1/states/init.hpp +++ b/src/communication/bolt/v1/states/init.hpp @@ -1,53 +1,62 @@ #pragma once +#include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/encoder/result_stream.hpp" -#include "communication/bolt/v1/messaging/codes.hpp" -#include "communication/bolt/v1/packing/codes.hpp" #include "communication/bolt/v1/state.hpp" -#include "communication/bolt/v1/transport/bolt_decoder.hpp" #include "logging/default.hpp" #include "utils/likely.hpp" namespace communication::bolt { /** - * TODO (mferencevic): finish & document + * Init state run function + * This function runs everything to initialize a Bolt session with the client. + * @param session the session that should be used for the run */ template <typename Session> State StateInitRun(Session &session) { static Logger logger = logging::log->logger("State INIT"); logger.debug("Parsing message"); - auto struct_type = session.decoder_.read_byte(); - - if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) { - logger.debug("{}", struct_type); - logger.debug( - "Expected struct marker of max size 0x{:02} instead of 0x{:02X}", - (unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type); - return NULLSTATE; + Marker marker; + Signature signature; + if (!session.decoder_.ReadMessageHeader(&signature, &marker)) { + logger.debug("Missing header data!"); + return State::Close; } - auto message_type = session.decoder_.read_byte(); - - if (UNLIKELY(message_type != MessageCode::Init)) { - logger.debug("Expected Init (0x01) instead of (0x{:02X})", - (unsigned)message_type); - return NULLSTATE; + if (UNLIKELY(signature != Signature::Init)) { + logger.debug("Expected Init signature, but received 0x{:02X}!", + underlying_cast(signature)); + return State::Close; + } + if (UNLIKELY(marker != Marker::TinyStruct2)) { + logger.debug("Expected TinyStruct2 marker, but received 0x{:02X}!", + underlying_cast(marker)); + return State::Close; } - auto client_name = session.decoder_.read_string(); - - if (struct_type == pack::Code::StructTwo) { - // TODO process authentication tokens + query::TypedValue client_name; + if (!session.decoder_.ReadTypedValue(&client_name, + query::TypedValue::Type::String)) { + logger.debug("Couldn't read client name!"); + return State::Close; } - logger.debug("Executing state"); - logger.debug("Client connected '{}'", client_name); + query::TypedValue metadata; + if (!session.decoder_.ReadTypedValue(&metadata, + query::TypedValue::Type::Map)) { + logger.debug("Couldn't read metadata!"); + return State::Close; + } - // TODO: write_success, chunk, send - session.encoder_.MessageSuccess(); + logger.debug("Client connected '{}'", client_name.Value<std::string>()); - return EXECUTOR; + if (!session.encoder_.MessageSuccess()) { + logger.debug("Couldn't send success message to the client!"); + return State::Close; + } + + return State::Executor; } } diff --git a/src/communication/bolt/v1/transport/bolt_decoder.cpp b/src/communication/bolt/v1/transport/bolt_decoder.cpp deleted file mode 100644 index a94892096..000000000 --- a/src/communication/bolt/v1/transport/bolt_decoder.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "communication/bolt/v1/transport/bolt_decoder.hpp" -#include "communication/bolt/v1/packing/codes.hpp" - -#include "logging/default.hpp" -#include "utils/bswap.hpp" - -namespace communication::bolt { - -void BoltDecoder::handshake(const byte *&data, size_t len) { - buffer.write(data, len); - data += len; -} - -bool BoltDecoder::decode(const byte *&data, size_t len) { - return decoder(data, len); -} - -bool BoltDecoder::empty() const { return pos == buffer.size(); } - -void BoltDecoder::reset() { - buffer.clear(); - pos = 0; -} - -byte BoltDecoder::peek() const { return buffer[pos]; } - -byte BoltDecoder::read_byte() { return buffer[pos++]; } - -void BoltDecoder::read_bytes(void *dest, size_t n) { - std::memcpy(dest, buffer.data() + pos, n); - pos += n; -} - -template <class T> -T parse(const void *data) { - // reinterpret bytes as the target value - auto value = reinterpret_cast<const T *>(data); - - // swap values to little endian - return bswap(*value); -} - -template <class T> -T parse(Buffer &buffer, size_t &pos) { - // get a pointer to the data we're converting - auto ptr = buffer.data() + pos; - - // skip sizeof bytes that we're going to read - pos += sizeof(T); - - // read and convert the value - return parse<T>(ptr); -} - -int16_t BoltDecoder::read_int16() { return parse<int16_t>(buffer, pos); } - -uint16_t BoltDecoder::read_uint16() { return parse<uint16_t>(buffer, pos); } - -int32_t BoltDecoder::read_int32() { return parse<int32_t>(buffer, pos); } - -uint32_t BoltDecoder::read_uint32() { return parse<uint32_t>(buffer, pos); } - -int64_t BoltDecoder::read_int64() { return parse<int64_t>(buffer, pos); } - -uint64_t BoltDecoder::read_uint64() { return parse<uint64_t>(buffer, pos); } - -double BoltDecoder::read_float64() { - auto v = parse<int64_t>(buffer, pos); - return *reinterpret_cast<const double *>(&v); -} - -std::string BoltDecoder::read_string() { - auto marker = read_byte(); - - std::string res; - uint32_t size; - - // if the first 4 bits equal to 1000 (0x8), this is a tiny string - if ((marker & 0xF0) == pack::TinyString) { - // size is stored in the lower 4 bits of the marker byte - size = marker & 0x0F; - } - // if the marker is 0xD0, size is an 8-bit unsigned integer - else if (marker == pack::String8) { - size = read_byte(); - } - // if the marker is 0xD1, size is a 16-bit big-endian unsigned integer - else if (marker == pack::String16) { - size = read_uint16(); - } - // if the marker is 0xD2, size is a 32-bit big-endian unsigned integer - else if (marker == pack::String32) { - size = read_uint32(); - } else { - // TODO error? - return res; - } - - if (size == 0) return res; - - res.append(reinterpret_cast<const char *>(raw()), size); - pos += size; - - return res; -} - -const byte *BoltDecoder::raw() const { return buffer.data() + pos; } -} diff --git a/src/communication/bolt/v1/transport/bolt_decoder.hpp b/src/communication/bolt/v1/transport/bolt_decoder.hpp deleted file mode 100644 index d8401bc96..000000000 --- a/src/communication/bolt/v1/transport/bolt_decoder.hpp +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include "communication/bolt/v1/transport/buffer.hpp" -#include "communication/bolt/v1/transport/chunked_decoder.hpp" -#include "utils/types/byte.hpp" - -namespace communication::bolt { - -class BoltDecoder { - public: - void handshake(const byte *&data, size_t len); - bool decode(const byte *&data, size_t len); - - bool empty() const; - void reset(); - - byte peek() const; - byte read_byte(); - void read_bytes(void *dest, size_t n); - - int16_t read_int16(); - uint16_t read_uint16(); - - int32_t read_int32(); - uint32_t read_uint32(); - - int64_t read_int64(); - uint64_t read_uint64(); - - double read_float64(); - - std::string read_string(); - - private: - Buffer buffer; - ChunkedDecoder<Buffer> decoder{buffer}; - size_t pos{0}; - - const byte *raw() const; -}; -} diff --git a/src/communication/bolt/v1/transport/buffer.cpp b/src/communication/bolt/v1/transport/buffer.cpp deleted file mode 100644 index b47bec8eb..000000000 --- a/src/communication/bolt/v1/transport/buffer.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include "communication/bolt/v1/transport/buffer.hpp" - -namespace communication::bolt { - -void Buffer::write(const byte* data, size_t len) { - buffer.insert(buffer.end(), data, data + len); -} - -void Buffer::clear() { buffer.clear(); } -} diff --git a/src/communication/bolt/v1/transport/buffer.hpp b/src/communication/bolt/v1/transport/buffer.hpp deleted file mode 100644 index 62d3828c4..000000000 --- a/src/communication/bolt/v1/transport/buffer.hpp +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include <cstdint> -#include <cstdlib> -#include <vector> - -#include "utils/types/byte.hpp" - -namespace communication::bolt { - -class Buffer { - public: - void write(const byte* data, size_t len); - - void clear(); - - size_t size() const { return buffer.size(); } - - byte operator[](size_t idx) const { return buffer[idx]; } - - const byte* data() const { return buffer.data(); } - - private: - std::vector<byte> buffer; -}; -} diff --git a/src/communication/bolt/v1/transport/chunked_decoder.hpp b/src/communication/bolt/v1/transport/chunked_decoder.hpp deleted file mode 100644 index 2d02f16f2..000000000 --- a/src/communication/bolt/v1/transport/chunked_decoder.hpp +++ /dev/null @@ -1,63 +0,0 @@ -#pragma once - -#include <cstring> -#include <functional> - -#include "logging/default.hpp" -#include "utils/exceptions/stacktrace_exception.hpp" -#include "utils/likely.hpp" -#include "utils/types/byte.hpp" - -namespace communication::bolt { - -template <class Stream> -class ChunkedDecoder { - public: - class DecoderError : public StacktraceException { - public: - using StacktraceException::StacktraceException; - }; - - ChunkedDecoder(Stream &stream) : stream(stream) {} - - /* Decode chunked data - * - * Chunk format looks like: - * - * |Header| Data ||Header| Data || ... || End | - * | 2B | size bytes || 2B | size bytes || ... ||00 00| - */ - bool decode(const byte *&chunk, size_t n) { - while (n > 0) { - // get size from first two bytes in the chunk - auto size = get_size(chunk); - - if (UNLIKELY(size + 2 > n)) - throw DecoderError("Chunk size larger than available data."); - - // advance chunk to pass those two bytes - chunk += 2; - n -= 2; - - // if chunk size is 0, we're done! - if (size == 0) return true; - - stream.get().write(chunk, size); - - chunk += size; - n -= size; - } - - return false; - } - - bool operator()(const byte *&chunk, size_t n) { return decode(chunk, n); } - - private: - std::reference_wrapper<Stream> stream; - - size_t get_size(const byte *chunk) { - return size_t(chunk[0]) << 8 | chunk[1]; - } -}; -} diff --git a/src/communication/bolt/v1/transport/stream_error.hpp b/src/communication/bolt/v1/transport/stream_error.hpp deleted file mode 100644 index 2f5e987e8..000000000 --- a/src/communication/bolt/v1/transport/stream_error.hpp +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include "utils/exceptions/stacktrace_exception.hpp" - -namespace communication::bolt { - -class StreamError : StacktraceException { - public: - using StacktraceException::StacktraceException; -}; -} diff --git a/src/communication/bolt/v1/transport/streamed_bolt_decoder.hpp b/src/communication/bolt/v1/transport/streamed_bolt_decoder.hpp deleted file mode 100644 index d39928be8..000000000 --- a/src/communication/bolt/v1/transport/streamed_bolt_decoder.hpp +++ /dev/null @@ -1,307 +0,0 @@ -#pragma once - -#include <string> - -#include "communication/bolt/v1/packing/codes.hpp" -#include "query/exception/decoder_exception.hpp" -#include "utils/bswap.hpp" -#include "utils/types/byte.hpp" - -namespace communication::bolt { - -// BoltDecoder for streams. Meant for use in SnapshotDecoder. -// This should be recoded to recieve the current caller so that decoder can -// based on a current type call it. -template <class STREAM> -class StreamedBoltDecoder { - static constexpr int64_t plus_2_to_the_31 = 2147483648L; - static constexpr int64_t plus_2_to_the_15 = 32768L; - static constexpr int64_t plus_2_to_the_7 = 128L; - static constexpr int64_t minus_2_to_the_4 = -16L; - static constexpr int64_t minus_2_to_the_7 = -128L; - static constexpr int64_t minus_2_to_the_15 = -32768L; - static constexpr int64_t minus_2_to_the_31 = -2147483648L; - - public: - StreamedBoltDecoder(STREAM &stream) : stream(stream) {} - - // Returns mark of a data. - size_t mark() { return peek_byte(); } - - // Calls handle with current primitive data. Throws DecoderException if it - // isn't a primitive. - template <class H, class T> - T accept_primitive(H &handle) { - switch (byte()) { - case pack::False: { - return handle.handle(false); - } - case pack::True: { - return handle.handle(true); - } - case pack::Float64: { - return handle.handle(read_double()); - } - default: { return handle.handle(integer()); } - }; - } - - // Reads map header. Throws DecoderException if it isn't map header. - size_t map_header() { - auto marker = byte(); - - size_t size; - - if ((marker & 0xF0) == pack::TinyMap) { - size = marker & 0x0F; - - } else if (marker == pack::Map8) { - size = byte(); - - } else if (marker == pack::Map16) { - size = read<uint16_t>(); - - } else if (marker == pack::Map32) { - size = read<uint32_t>(); - - } else { - // Error - throw DecoderException( - "StreamedBoltDecoder: Tryed to read map header but found ", marker); - } - - return size; - } - - bool is_list() { - auto marker = peek_byte(); - - if ((marker & 0xF0) == pack::TinyList) { - return true; - - } else if (marker == pack::List8) { - return true; - - } else if (marker == pack::List16) { - return true; - - } else if (marker == pack::List32) { - return true; - } else { - return false; - } - } - - // Reads list header. Throws DecoderException if it isn't list header. - size_t list_header() { - auto marker = byte(); - - if ((marker & 0xF0) == pack::TinyList) { - return marker & 0x0F; - - } else if (marker == pack::List8) { - return byte(); - - } else if (marker == pack::List16) { - return read<uint16_t>(); - - } else if (marker == pack::List32) { - return read<uint32_t>(); - - } else { - // Error - throw DecoderException( - "StreamedBoltDecoder: Tryed to read list header but found ", marker); - } - } - - bool is_bool() { - auto marker = peek_byte(); - - if (marker == pack::True) { - return true; - } else if (marker == pack::False) { - return true; - } else { - return false; - } - } - - // Reads bool.Throws DecoderException if it isn't bool. - bool read_bool() { - auto marker = byte(); - - if (marker == pack::True) { - return true; - } else if (marker == pack::False) { - return false; - } else { - throw DecoderException( - "StreamedBoltDecoder: Tryed to read bool header but found ", marker); - } - } - - bool is_integer() { - auto marker = peek_byte(); - - if (marker >= minus_2_to_the_4 && marker < plus_2_to_the_7) { - return true; - - } else if (marker == pack::Int8) { - return true; - - } else if (marker == pack::Int16) { - return true; - - } else if (marker == pack::Int32) { - return true; - - } else if (marker == pack::Int64) { - return true; - - } else { - return false; - } - } - - // Reads integer.Throws DecoderException if it isn't integer. - int64_t integer() { - auto marker = byte(); - - if (marker >= minus_2_to_the_4 && marker < plus_2_to_the_7) { - return marker; - - } else if (marker == pack::Int8) { - return byte(); - - } else if (marker == pack::Int16) { - return read<int16_t>(); - - } else if (marker == pack::Int32) { - return read<int32_t>(); - - } else if (marker == pack::Int64) { - return read<int64_t>(); - - } else { - throw DecoderException( - "StreamedBoltDecoder: Tryed to read integer but found ", marker); - } - } - - bool is_double() { - auto marker = peek_byte(); - - return marker == pack::Float64; - } - - // Reads double.Throws DecoderException if it isn't double. - double read_double() { - auto marker = byte(); - if (marker == pack::Float64) { - auto tmp = read<int64_t>(); - return *reinterpret_cast<const double *>(&tmp); - } else { - throw DecoderException( - "StreamedBoltDecoder: Tryed to read double but found ", marker); - } - } - - bool is_string() { - auto marker = peek_byte(); - - // if the first 4 bits equal to 1000 (0x8), this is a tiny string - if ((marker & 0xF0) == pack::TinyString) { - return true; - } - // if the marker is 0xD0, size is an 8-bit unsigned integer - else if (marker == pack::String8) { - return true; - } - // if the marker is 0xD1, size is a 16-bit big-endian unsigned integer - else if (marker == pack::String16) { - return true; - } - // if the marker is 0xD2, size is a 32-bit big-endian unsigned integer - else if (marker == pack::String32) { - return true; - - } else { - return false; - } - } - - // Reads string into res. Throws DecoderException if it isn't string. - void string(std::string &res) { - if (!string_try(res)) { - throw DecoderException( - "StreamedBoltDecoder: Tryed to read string but found ", - std::to_string(peek_byte())); - } - } - // Try-s to read string. Retunrns true on success. If it didn't succed - // stream remains unchanged - bool string_try(std::string &res) { - auto marker = peek_byte(); - - uint32_t size; - - // if the first 4 bits equal to 1000 (0x8), this is a tiny string - if ((marker & 0xF0) == pack::TinyString) { - byte(); - // size is stored in the lower 4 bits of the marker byte - size = marker & 0x0F; - } - // if the marker is 0xD0, size is an 8-bit unsigned integer - else if (marker == pack::String8) { - byte(); - size = byte(); - } - // if the marker is 0xD1, size is a 16-bit big-endian unsigned integer - else if (marker == pack::String16) { - byte(); - size = read<uint16_t>(); - } - // if the marker is 0xD2, size is a 32-bit big-endian unsigned integer - else if (marker == pack::String32) { - byte(); - size = read<uint32_t>(); - } else { - // Error - return false; - } - - if (size > 0) { - res.resize(size); - stream.read(&res.front(), size); - } else { - res.clear(); - } - - return true; - } - - private: - // Reads T from stream. It doens't care for alligment so this is valid only - // for primitives. - template <class T> - T read() { - buffer.resize(sizeof(T)); - - // Load value - stream.read(&buffer.front(), sizeof(T)); - - // reinterpret bytes as the target value - auto value = reinterpret_cast<const T *>(&buffer.front()); - - // swap values to little endian - return bswap(*value); - } - - ::byte byte() { return stream.get(); } - ::byte peek_byte() { return stream.peek(); } - - STREAM &stream; - std::string buffer; -}; -}; diff --git a/src/communication/worker.hpp b/src/communication/worker.hpp index 2998f174f..ff7ea6b61 100644 --- a/src/communication/worker.hpp +++ b/src/communication/worker.hpp @@ -63,17 +63,11 @@ class Worker void OnWaitTimeout() {} - StreamBuffer OnAlloc(Session &) { - /* logger.trace("[on_alloc] Allocating {}B", sizeof buf); */ - - return StreamBuffer{buf_, sizeof buf_}; - } - - void OnRead(Session &session, StreamBuffer &buf) { - logger_.trace("[on_read] Received {}B", buf.len); + void OnRead(Session &session) { + logger_.trace("OnRead"); try { - session.Execute(buf.data, buf.len); + session.Execute(); } catch (const std::exception &e) { logger_.error("Error occured while executing statement."); logger_.error("{}", e.what()); @@ -96,7 +90,6 @@ class Worker // TODO: Do something about it } - uint8_t buf_[65536]; std::thread thread_; void Start(std::atomic<bool> &alive) { diff --git a/src/io/network/stream_reader.hpp b/src/io/network/stream_reader.hpp index ec1a4c05f..2910039ca 100644 --- a/src/io/network/stream_reader.hpp +++ b/src/io/network/stream_reader.hpp @@ -53,13 +53,13 @@ class StreamReader : public StreamListener<Derived, Stream> { } // allocate the buffer to fill the data - auto buf = this->derived().OnAlloc(stream); + auto buf = stream.Allocate(); // read from the buffer at most buf.len bytes - buf.len = stream.socket_.Read(buf.data, buf.len); + int len = stream.socket_.Read(buf.data, buf.len); // check for read errors - if (buf.len == -1) { + if (len == -1) { // this means we have read all available data if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) { return; @@ -71,13 +71,16 @@ class StreamReader : public StreamListener<Derived, Stream> { } // end of file, the client has closed the connection - if (UNLIKELY(buf.len == 0)) { + if (UNLIKELY(len == 0)) { logger_.trace("Calling OnClose because the socket is closed!"); this->derived().OnClose(stream); return; } - this->derived().OnRead(stream, buf); + // notify the stream that it has new data + stream.Written(len); + + this->derived().OnRead(stream); } private: diff --git a/src/utils/exceptions/not_yet_implemented.hpp b/src/utils/exceptions/not_yet_implemented.hpp index 23391462b..954f8a95d 100644 --- a/src/utils/exceptions/not_yet_implemented.hpp +++ b/src/utils/exceptions/not_yet_implemented.hpp @@ -6,5 +6,5 @@ class NotYetImplemented : public StacktraceException { public: using StacktraceException::StacktraceException; - NotYetImplemented() : StacktraceException("") {} + NotYetImplemented() : StacktraceException("Not yet implemented!") {} }; diff --git a/src/utils/exceptions/stacktrace_exception.hpp b/src/utils/exceptions/stacktrace_exception.hpp index 48c425cbf..7491f792a 100644 --- a/src/utils/exceptions/stacktrace_exception.hpp +++ b/src/utils/exceptions/stacktrace_exception.hpp @@ -11,7 +11,7 @@ class StacktraceException : public std::exception { public: StacktraceException(const std::string &message) noexcept : message_(message) { Stacktrace stacktrace; - message_.append(stacktrace.dump()); + stacktrace_ = stacktrace.dump(); } template <class... Args> @@ -25,6 +25,9 @@ class StacktraceException : public std::exception { const char *what() const noexcept override { return message_.c_str(); } + const char *trace() const noexcept { return stacktrace_.c_str(); } + private: std::string message_; + std::string stacktrace_; }; diff --git a/tests/concurrent/network_common.hpp b/tests/concurrent/network_common.hpp index ac399fff4..e6367829f 100644 --- a/tests/concurrent/network_common.hpp +++ b/tests/concurrent/network_common.hpp @@ -11,6 +11,7 @@ #include "logging/default.hpp" #include "logging/streams/stdout.hpp" +#include "communication/bolt/v1/decoder/buffer.hpp" #include "communication/server.hpp" #include "dbms/dbms.hpp" #include "io/network/epoll.hpp" @@ -38,23 +39,26 @@ class TestSession { int Id() const { return socket_.id(); } - void Execute(const byte* data, size_t len) { - if (size_ == 0) { - size_ = data[0]; - size_ <<= 8; - size_ += data[1]; - data += 2; - len -= 2; - } - memcpy(buffer_ + have_, data, len); - have_ += len; - if (have_ < size_) return; + void Execute() { + if (buffer_.size() < 2) return; + const uint8_t *data = buffer_.data(); + size_t size = data[0]; + size <<= 8; + size += data[1]; + if (buffer_.size() < size + 2) return; for (int i = 0; i < REPLY; ++i) - ASSERT_TRUE(this->socket_.Write(buffer_, size_)); + ASSERT_TRUE(this->socket_.Write(data + 2, size)); - have_ = 0; - size_ = 0; + buffer_.Shift(size + 2); + } + + io::network::StreamBuffer Allocate() { + return buffer_.Allocate(); + } + + void Written(size_t len) { + buffer_.Written(len); } void Close() { @@ -62,9 +66,7 @@ class TestSession { this->socket_.Close(); } - char buffer_[SIZE * 2]; - uint32_t have_, size_; - + communication::bolt::Buffer<SIZE * 2> buffer_; Logger logger_; socket_t socket_; io::network::Epoll::Event event_; @@ -87,6 +89,7 @@ void client_run(int num, const char* interface, const char* port, endpoint_t endpoint(interface, port); socket_t socket; ASSERT_TRUE(socket.Connect(endpoint)); + ASSERT_TRUE(socket.SetTimeout(2, 0)); logger.trace("Socket create: {}", socket.id()); for (int len = lo; len <= hi; len += 100) { have = 0; diff --git a/tests/concurrent/network_read_hang.cpp b/tests/concurrent/network_read_hang.cpp index 1de6c00c3..53ed0c5f5 100644 --- a/tests/concurrent/network_read_hang.cpp +++ b/tests/concurrent/network_read_hang.cpp @@ -15,6 +15,7 @@ #include "logging/default.hpp" #include "logging/streams/stdout.hpp" +#include "communication/bolt/v1/decoder/buffer.hpp" #include "communication/server.hpp" #include "dbms/dbms.hpp" #include "io/network/epoll.hpp" @@ -40,8 +41,16 @@ class TestSession { int Id() const { return socket_.id(); } - void Execute(const byte* data, size_t len) { - this->socket_.Write(data, len); + void Execute() { + this->socket_.Write(buffer_.data(), buffer_.size()); + } + + io::network::StreamBuffer Allocate() { + return buffer_.Allocate(); + } + + void Written(size_t len) { + buffer_.Written(len); } void Close() { @@ -49,6 +58,7 @@ class TestSession { } socket_t socket_; + communication::bolt::Buffer<> buffer_; io::network::Epoll::Event event_; }; diff --git a/tests/unit/bolt_buffer.cpp b/tests/unit/bolt_buffer.cpp index c303996b1..120fa25dd 100644 --- a/tests/unit/bolt_buffer.cpp +++ b/tests/unit/bolt_buffer.cpp @@ -4,7 +4,7 @@ constexpr const int SIZE = 4096; uint8_t data[SIZE]; -using BufferT = communication::bolt::Buffer; +using BufferT = communication::bolt::Buffer<>; using StreamBufferT = io::network::StreamBuffer; TEST(BoltBuffer, AllocateAndWritten) { diff --git a/tests/unit/bolt_chunked_decoder.cpp b/tests/unit/bolt_chunked_decoder.cpp deleted file mode 100644 index ec9b903ac..000000000 --- a/tests/unit/bolt_chunked_decoder.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include <array> -#include <cassert> -#include <cstring> -#include <deque> -#include <iostream> -#include <vector> - -#include "communication/bolt/v1/transport/chunked_decoder.hpp" -#include "gtest/gtest.h" - -/** - * DummyStream which is going to be used to test output data. - */ -struct DummyStream { - /** - * TODO (mferencevic): apply google style guide once decoder will be - * refactored + document - */ - void write(const uint8_t *values, size_t n) { - data.insert(data.end(), values, values + n); - } - std::vector<uint8_t> data; -}; -using DecoderT = communication::bolt::ChunkedDecoder<DummyStream>; - -TEST(ChunkedDecoderTest, WriteString) { - DummyStream stream; - DecoderT decoder(stream); - - std::vector<uint8_t> chunks[] = { - {0x00, 0x08, 'A', ' ', 'q', 'u', 'i', 'c', 'k', ' ', 0x00, 0x06, 'b', 'r', - 'o', 'w', 'n', ' '}, - {0x00, 0x0A, 'f', 'o', 'x', ' ', 'j', 'u', 'm', 'p', 's', ' '}, - {0x00, 0x07, 'o', 'v', 'e', 'r', ' ', 'a', ' '}, - {0x00, 0x08, 'l', 'a', 'z', 'y', ' ', 'd', 'o', 'g', 0x00, 0x00}}; - static constexpr size_t N = std::extent<decltype(chunks)>::value; - - for (size_t i = 0; i < N; ++i) { - auto &chunk = chunks[i]; - logging::info("Chunk size: {}", chunk.size()); - - const uint8_t *start = chunk.data(); - auto finished = decoder.decode(start, chunk.size()); - - // break early if finished - if (finished) break; - } - - // check validity - std::string decoded = "A quick brown fox jumps over a lazy dog"; - ASSERT_EQ(decoded.size(), stream.data.size()); - for (size_t i = 0; i < decoded.size(); ++i) - ASSERT_EQ(decoded[i], stream.data[i]); -} - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tests/unit/bolt_chunked_decoder_buffer.cpp b/tests/unit/bolt_chunked_decoder_buffer.cpp index 8e827b6f5..c5abfdffc 100644 --- a/tests/unit/bolt_chunked_decoder_buffer.cpp +++ b/tests/unit/bolt_chunked_decoder_buffer.cpp @@ -5,9 +5,10 @@ constexpr const int SIZE = 131072; uint8_t data[SIZE]; -using BufferT = communication::bolt::Buffer; +using BufferT = communication::bolt::Buffer<>; using StreamBufferT = io::network::StreamBuffer; using DecoderBufferT = communication::bolt::ChunkedDecoderBuffer; +using ChunkStateT = communication::bolt::ChunkState; TEST(BoltBuffer, CorrectChunk) { uint8_t tmp[2000]; @@ -20,7 +21,7 @@ TEST(BoltBuffer, CorrectChunk) { sb.data[1002] = 0; sb.data[1003] = 0; buffer.Written(1004); - ASSERT_EQ(decoder_buffer.GetChunk(), true); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole); ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true); for (int i = 0; i < 1000; ++i) @@ -40,7 +41,7 @@ TEST(BoltBuffer, CorrectChunkTrailingData) { sb.data[1002] = 0; sb.data[1003] = 0; buffer.Written(2004); - ASSERT_EQ(decoder_buffer.GetChunk(), true); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole); ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true); for (int i = 0; i < 1000; ++i) @@ -62,7 +63,7 @@ TEST(BoltBuffer, InvalidChunk) { sb.data[1002] = 1; sb.data[1003] = 1; buffer.Written(2004); - ASSERT_EQ(decoder_buffer.GetChunk(), false); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Invalid); ASSERT_EQ(buffer.size(), 1000); @@ -79,19 +80,19 @@ TEST(BoltBuffer, GraduallyPopulatedChunk) { sb.data[0] = 0x03; sb.data[1] = 0xe8; buffer.Written(2); - ASSERT_EQ(decoder_buffer.GetChunk(), false); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial); for (int i = 0; i < 5; ++i) { sb = buffer.Allocate(); memcpy(sb.data, data + 200 * i, 200); buffer.Written(200); - ASSERT_EQ(decoder_buffer.GetChunk(), false); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial); } sb = buffer.Allocate(); sb.data[0] = 0; sb.data[1] = 0; buffer.Written(2); - ASSERT_EQ(decoder_buffer.GetChunk(), true); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole); ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true); for (int i = 0; i < 1000; ++i) @@ -108,13 +109,13 @@ TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) { sb.data[0] = 0x03; sb.data[1] = 0xe8; buffer.Written(2); - ASSERT_EQ(decoder_buffer.GetChunk(), false); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial); for (int i = 0; i < 5; ++i) { sb = buffer.Allocate(); memcpy(sb.data, data + 200 * i, 200); buffer.Written(200); - ASSERT_EQ(decoder_buffer.GetChunk(), false); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial); } sb = buffer.Allocate(); @@ -125,7 +126,7 @@ TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) { memcpy(sb.data, data, 1000); buffer.Written(1000); - ASSERT_EQ(decoder_buffer.GetChunk(), true); + ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole); ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true); for (int i = 0; i < 1000; ++i) diff --git a/tests/unit/bolt_common.hpp b/tests/unit/bolt_common.hpp index 2d400e833..4c382489e 100644 --- a/tests/unit/bolt_common.hpp +++ b/tests/unit/bolt_common.hpp @@ -29,19 +29,25 @@ class TestSocket { int id() const { return socket; } - int Write(const std::string &str) { return Write(str.c_str(), str.size()); } - int Write(const char *data, size_t len) { + bool Write(const std::string &str) { return Write(str.c_str(), str.size()); } + bool Write(const char *data, size_t len) { return Write(reinterpret_cast<const uint8_t *>(data), len); } - int Write(const uint8_t *data, size_t len) { + bool Write(const uint8_t *data, size_t len) { + if (!write_success_) return false; for (size_t i = 0; i < len; ++i) output.push_back(data[i]); - return len; + return true; + } + + void SetWriteSuccess(bool success) { + write_success_ = success; } std::vector<uint8_t> output; protected: int socket; + bool write_success_{true}; }; /** @@ -53,7 +59,7 @@ class TestBuffer { void Write(const uint8_t *data, size_t n) { socket_.Write(data, n); } void Chunk() {} - void Flush() {} + bool Flush() { return true; } private: TestSocket &socket_; diff --git a/tests/unit/bolt_result_stream.cpp b/tests/unit/bolt_result_stream.cpp index 330207429..6dbc40e7b 100644 --- a/tests/unit/bolt_result_stream.cpp +++ b/tests/unit/bolt_result_stream.cpp @@ -31,7 +31,8 @@ TEST(Bolt, ResultStream) { for (int i = 0; i < 10; ++i) headers.push_back(std::string(2, (char)('a' + i))); - result_stream.Header(headers); // this method automatically calls Flush + result_stream.Header(headers); + buffer.FlushFirstChunk(); PrintOutput(output); CheckOutput(output, header_output, 45); diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 1dfea8f76..fc878a66e 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -2,68 +2,604 @@ #include "communication/bolt/v1/encoder/result_stream.hpp" #include "communication/bolt/v1/session.hpp" +#include "config/config.hpp" #include "query/engine.hpp" + +// Shortcuts for writing variable initializations in tests +#define INIT_VARS Dbms dbms;\ + TestSocket socket(10);\ + QueryEngine<ResultStreamT> query_engine;\ + SessionT session(std::move(socket), dbms, query_engine);\ + std::vector<uint8_t> &output = session.socket_.output; + + using ResultStreamT = communication::bolt::ResultStream<communication::bolt::Encoder< communication::bolt::ChunkedEncoderBuffer<TestSocket>>>; using SessionT = communication::bolt::Session<TestSocket>; +using StateT = communication::bolt::State; -/** - * TODO (mferencevic): document - */ -const uint8_t handshake_req[] = - "\x60\x60\xb0\x17\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - "\x00\x00"; -const uint8_t handshake_resp[] = "\x00\x00\x00\x01"; -const uint8_t init_req[] = - "\x00\x3f\xb2\x01\xd0\x15\x6c\x69\x62\x6e\x65\x6f\x34\x6a\x2d\x63\x6c\x69" - "\x65\x6e\x74\x2f\x31\x2e\x32\x2e\x31\xa3\x86\x73\x63\x68\x65\x6d\x65\x85" - "\x62\x61\x73\x69\x63\x89\x70\x72\x69\x6e\x63\x69\x70\x61\x6c\x80\x8b\x63" - "\x72\x65\x64\x65\x6e\x74\x69\x61\x6c\x73\x80\x00\x00"; -const uint8_t init_resp[] = "\x00\x03\xb1\x70\xa0\x00\x00"; -const uint8_t run_req[] = - "\x00\x26\xb2\x10\xd0\x21\x43\x52\x45\x41\x54\x45\x20\x28\x6e\x20\x7b\x6e" - "\x61\x6d\x65\x3a\x20\x32\x39\x33\x38\x33\x7d\x29\x20\x52\x45\x54\x55\x52" - "\x4e\x20\x6e\xa0\x00\x00"; +// Sample testdata that has correct inputs and outputs. +const uint8_t handshake_req[] = { + 0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; +const uint8_t handshake_resp[] = {0x00, 0x00, 0x00, 0x01}; +const uint8_t init_req[] = { + 0xb2, 0x01, 0xd0, 0x15, 0x6c, 0x69, 0x62, 0x6e, 0x65, 0x6f, 0x34, 0x6a, + 0x2d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2f, 0x31, 0x2e, 0x32, 0x2e, + 0x31, 0xa3, 0x86, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x65, 0x85, 0x62, 0x61, + 0x73, 0x69, 0x63, 0x89, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, + 0x6c, 0x80, 0x8b, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, + 0x6c, 0x73, 0x80}; +const uint8_t init_resp[] = {0x00, 0x03, 0xb1, 0x70, 0xa0, 0x00, 0x00}; +const uint8_t run_req_header[] = {0xb2, 0x10, 0xd1}; +const uint8_t pullall_req[] = {0xb0, 0x3f}; +const uint8_t discardall_req[] = {0xb0, 0x2f}; +const uint8_t reset_req[] = {0xb0, 0x0f}; +const uint8_t ackfailure_req[] = {0xb0, 0x0e}; +const uint8_t success_resp[] = {0x00, 0x03, 0xb1, 0x70, 0xa0, 0x00, 0x00}; +const uint8_t ignored_resp[] = {0x00, 0x02, 0xb0, 0x7e, 0x00, 0x00}; -TEST(Bolt, Session) { - Dbms dbms; - TestSocket socket(10); - QueryEngine<ResultStreamT> query_engine; - SessionT session(std::move(socket), dbms, query_engine); - std::vector<uint8_t> &output = session.socket_.output; - // execute handshake - session.Execute(handshake_req, 20); - ASSERT_EQ(session.state_, communication::bolt::INIT); +// Write bolt chunk header (length) +void WriteChunkHeader(SessionT &session, uint16_t len) { + len = bswap(len); + auto buff = session.Allocate(); + memcpy(buff.data, reinterpret_cast<uint8_t *>(&len), sizeof(len)); + session.Written(sizeof(len)); +} + +// Write bolt chunk tail (two zeros) +void WriteChunkTail(SessionT &session) { + WriteChunkHeader(session, 0); +} + +// Check that the server responded with a failure message +void CheckFailureMessage(std::vector<uint8_t> &output) { + ASSERT_GE(output.size(), 6); + // skip the first two bytes because they are the chunk header + ASSERT_EQ(output[2], 0xB1); // tiny struct 1 + ASSERT_EQ(output[3], 0x7F); // signature failure +} + +// Execute and check a correct handshake +void ExecuteHandshake(SessionT &session, std::vector<uint8_t> &output) { + auto buff = session.Allocate(); + memcpy(buff.data, handshake_req, 20); + session.Written(20); + session.Execute(); + + ASSERT_EQ(session.state_, StateT::Init); + ASSERT_TRUE(session.socket_.IsOpen()); PrintOutput(output); CheckOutput(output, handshake_resp, 4); +} - // execute init - session.Execute(init_req, 67); - ASSERT_EQ(session.state_, communication::bolt::EXECUTOR); +// Write bolt chunk and execute command +void ExecuteCommand(SessionT &session, const uint8_t *data, size_t len, bool chunk = true) { + if (chunk) WriteChunkHeader(session, len); + auto buff = session.Allocate(); + memcpy(buff.data, data, len); + session.Written(len); + if (chunk) WriteChunkTail(session); + session.Execute(); +} + +// Execute and check a correct init +void ExecuteInit(SessionT &session, std::vector<uint8_t> &output) { + ExecuteCommand(session, init_req, sizeof(init_req)); + ASSERT_EQ(session.state_, StateT::Executor); + ASSERT_TRUE(session.socket_.IsOpen()); PrintOutput(output); CheckOutput(output, init_resp, 7); - - // execute run - session.Execute(run_req, 42); - - // TODO (mferencevic): query engine doesn't currently work, - // we should test the query output here and the next state - // ASSERT_EQ(session.state, bolt::EXECUTOR); - // PrintOutput(output); - // CheckOutput(output, run_resp, len); - - // TODO (mferencevic): add more tests - - session.Close(); } +// Write bolt encoded run request +void WriteRunRequest(SessionT &session, const char *str) { + // write chunk header + auto len = strlen(str); + WriteChunkHeader(session, 3 + 2 + len + 1); + + // write string header + auto buff = session.Allocate(); + memcpy(buff.data, run_req_header, 3); + session.Written(3); + + // write string length + WriteChunkHeader(session, len); + + // write string + buff = session.Allocate(); + memcpy(buff.data, str, len); + session.Written(len); + + // write empty map for parameters + buff = session.Allocate(); + buff.data[0] = 0xA0; // TinyMap0 + session.Written(1); + + // write chunk tail + WriteChunkTail(session); +} + + +TEST(BoltSession, HandshakeWrongPreamble) { + INIT_VARS; + + auto buff = session.Allocate(); + // copy 0x00000001 four times + for (int i = 0; i < 4; ++i) + memcpy(buff.data + i * 4, handshake_req + 4, 4); + session.Written(20); + session.Execute(); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + PrintOutput(output); + CheckFailureMessage(output); +} + +TEST(BoltSession, HandshakeInTwoPackets) { + INIT_VARS; + + auto buff = session.Allocate(); + memcpy(buff.data, handshake_req, 10); + session.Written(10); + session.Execute(); + + ASSERT_EQ(session.state_, StateT::Handshake); + ASSERT_TRUE(session.socket_.IsOpen()); + + memcpy(buff.data + 10, handshake_req + 10, 10); + session.Written(10); + session.Execute(); + + ASSERT_EQ(session.state_, StateT::Init); + ASSERT_TRUE(session.socket_.IsOpen()); + PrintOutput(output); + CheckOutput(output, handshake_resp, 4); +} + +TEST(BoltSession, HandshakeTooLarge) { + INIT_VARS; + + auto buff = session.Allocate(); + memcpy(buff.data, handshake_req, 20); + memcpy(buff.data + 20, handshake_req, 20); + session.Written(40); + session.Execute(); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + PrintOutput(output); + CheckFailureMessage(output); +} + +TEST(BoltSession, HandshakeWriteFail) { + INIT_VARS; + session.socket_.SetWriteSuccess(false); + ExecuteCommand(session, handshake_req, sizeof(handshake_req), false); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + ASSERT_EQ(output.size(), 0); +} + +TEST(BoltSession, HandshakeOK) { + INIT_VARS; + ExecuteHandshake(session, output); +} + + +TEST(BoltSession, InitWrongSignature) { + INIT_VARS; + ExecuteHandshake(session, output); + ExecuteCommand(session, run_req_header, sizeof(run_req_header)); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); +} + +TEST(BoltSession, InitWrongMarker) { + INIT_VARS; + ExecuteHandshake(session, output); + + // wrong marker, good signature + uint8_t data[2] = {0x00, init_req[1]}; + ExecuteCommand(session, data, 2); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); +} + +TEST(BoltSession, InitMissingData) { + // test lengths, they test the following situations: + // missing header data, missing client name, missing metadata + int len[] = {1, 2, 25}; + + for (int i = 0; i < 3; ++i) { + INIT_VARS; + ExecuteHandshake(session, output); + ExecuteCommand(session, init_req, len[i]); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); + } +} + +TEST(BoltSession, InitWriteFail) { + INIT_VARS; + ExecuteHandshake(session, output); + session.socket_.SetWriteSuccess(false); + ExecuteCommand(session, init_req, sizeof(init_req)); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + ASSERT_EQ(output.size(), 0); +} + +TEST(BoltSession, InitOK) { + INIT_VARS; + ExecuteHandshake(session, output); + ExecuteInit(session, output); +} + + +TEST(BoltSession, ExecuteRunWrongMarker) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + // wrong marker, good signature + uint8_t data[2] = {0x00, run_req_header[1]}; + ExecuteCommand(session, data, sizeof(data)); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); +} + +TEST(BoltSession, ExecuteRunMissingData) { + // test lengths, they test the following situations: + // missing header data, missing query data, missing parameters + int len[] = {1, 2, 37}; + + for (int i = 0; i < 3; ++i) { + INIT_VARS; + ExecuteHandshake(session, output); + ExecuteInit(session, output); + ExecuteCommand(session, run_req_header, len[i]); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); + } +} + +TEST(BoltSession, ExecuteRunBasicException) { + // first test with socket write success, then with socket write fail + for (int i = 0; i < 2; ++i) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + session.socket_.SetWriteSuccess(i == 0); + WriteRunRequest(session, "MATCH (omnom"); + session.Execute(); + + if (i == 0) { + ASSERT_EQ(session.state_, StateT::Error); + ASSERT_TRUE(session.socket_.IsOpen()); + CheckFailureMessage(output); + } else { + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + ASSERT_EQ(output.size(), 0); + } + } +} + +TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) { + // This test first tests PULL_ALL then DISCARD_ALL and then RESET + // It tests for missing data in the message header + const uint8_t *dataset[3] = {pullall_req, discardall_req, reset_req}; + + for (int i = 0; i < 3; ++i) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + // wrong marker, good signature + uint8_t data[2] = {0x00, dataset[i][1]}; + ExecuteCommand(session, data, sizeof(data)); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); + } +} + +TEST(BoltSession, ExecutePullAllBufferEmpty) { + // first test with socket write success, then with socket write fail + for (int i = 0; i < 2; ++i) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + session.socket_.SetWriteSuccess(i == 0); + ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + + if (i == 0) { + ASSERT_EQ(session.state_, StateT::Error); + ASSERT_TRUE(session.socket_.IsOpen()); + CheckFailureMessage(output); + } else { + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + ASSERT_EQ(output.size(), 0); + } + } +} + +TEST(BoltSession, ExecutePullAllDiscardAllReset) { + // This test first tests PULL_ALL then DISCARD_ALL and then RESET + // It tests a good message + const uint8_t *dataset[3] = {pullall_req, discardall_req, reset_req}; + + for (int i = 0; i < 3; ++i) { + // first test with socket write success, then with socket write fail + for (int j = 0; j < 2; ++j) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + WriteRunRequest(session, "CREATE (n) RETURN n"); + session.Execute(); + + if (j == 1) output.clear(); + + session.socket_.SetWriteSuccess(j == 0); + ExecuteCommand(session, dataset[i], 2); + + if (j == 0) { + ASSERT_EQ(session.state_, StateT::Executor); + ASSERT_TRUE(session.socket_.IsOpen()); + ASSERT_FALSE(session.encoder_buffer_.HasData()); + PrintOutput(output); + } else { + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + ASSERT_EQ(output.size(), 0); + } + } + } +} + +TEST(BoltSession, ExecuteInvalidMessage) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + ExecuteCommand(session, init_req, sizeof(init_req)); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); +} + +TEST(BoltSession, ErrorIgnoreMessage) { + // first test with socket write success, then with socket write fail + for (int i = 0; i < 2; ++i) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteRunRequest(session, "MATCH (omnom"); + session.Execute(); + + output.clear(); + + session.socket_.SetWriteSuccess(i == 0); + ExecuteCommand(session, init_req, sizeof(init_req)); + + // assert that all data from the init message was cleaned up + ASSERT_EQ(session.decoder_buffer_.Size(), 0); + + if (i == 0) { + ASSERT_EQ(session.state_, StateT::Error); + ASSERT_TRUE(session.socket_.IsOpen()); + CheckOutput(output, ignored_resp, sizeof(ignored_resp)); + } else { + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + ASSERT_EQ(output.size(), 0); + } + } +} + +TEST(BoltSession, ErrorCantCleanup) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteRunRequest(session, "MATCH (omnom"); + session.Execute(); + + output.clear(); + + // there is data missing in the request, cleanup should fail + ExecuteCommand(session, init_req, sizeof(init_req) - 10); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); +} + +TEST(BoltSession, ErrorWrongMarker) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteRunRequest(session, "MATCH (omnom"); + session.Execute(); + + output.clear(); + + // wrong marker, good signature + uint8_t data[2] = {0x00, init_req[1]}; + ExecuteCommand(session, data, sizeof(data)); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); +} + +TEST(BoltSession, ErrorOK) { + // test ACK_FAILURE and RESET + const uint8_t *dataset[] = {ackfailure_req, reset_req}; + + for (int i = 0; i < 2; ++i) { + // first test with socket write success, then with socket write fail + for (int j = 0; j < 2; ++j) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteRunRequest(session, "MATCH (omnom"); + session.Execute(); + + output.clear(); + + session.socket_.SetWriteSuccess(j == 0); + ExecuteCommand(session, dataset[i], 2); + + // assert that all data from the init message was cleaned up + ASSERT_EQ(session.decoder_buffer_.Size(), 0); + + if (j == 0) { + ASSERT_EQ(session.state_, StateT::Executor); + ASSERT_TRUE(session.socket_.IsOpen()); + CheckOutput(output, success_resp, sizeof(success_resp)); + } else { + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + ASSERT_EQ(output.size(), 0); + } + } + } +} + +TEST(BoltSession, ErrorMissingData) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteRunRequest(session, "MATCH (omnom"); + session.Execute(); + + output.clear(); + + // some marker, missing signature + uint8_t data[1] = {0x00}; + ExecuteCommand(session, data, sizeof(data)); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); +} + +TEST(BoltSession, MultipleChunksInOneExecute) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteRunRequest(session, "CREATE (n) RETURN n"); + ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + + ASSERT_EQ(session.state_, StateT::Executor); + ASSERT_TRUE(session.socket_.IsOpen()); + PrintOutput(output); + + // Count chunks in output + int len, num = 0; + while(output.size() > 0) { + len = (output[0] << 8) + output[1]; + output.erase(output.begin(), output.begin() + len + 4); + ++num; + } + + // there should be 3 chunks in the output + // the first is a success with the query headers + // the second is a record message + // and the last is a success message with query run metadata + ASSERT_EQ(num, 3); +} + +TEST(BoltSession, PartialChunk) { + INIT_VARS; + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteChunkHeader(session, sizeof(discardall_req)); + auto buff = session.Allocate(); + memcpy(buff.data, discardall_req, sizeof(discardall_req)); + session.Written(2); + + // missing chunk tail + session.Execute(); + + ASSERT_EQ(session.state_, StateT::Executor); + ASSERT_TRUE(session.socket_.IsOpen()); + ASSERT_EQ(output.size(), 0); + + WriteChunkTail(session); + + session.Execute(); + + ASSERT_EQ(session.state_, StateT::Executor); + ASSERT_TRUE(session.socket_.IsOpen()); + ASSERT_GT(output.size(), 0); + PrintOutput(output); +} + +TEST(BoltSession, InvalidChunk) { + INIT_VARS; + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + // this will write 0x00 0x02 0x00 0x02 0x00 0x02 + // that is a chunk of good size, but it's invalid because the last + // two bytes are 0x00 0x02 and they should be 0x00 0x00 + for (int i = 0; i < 3; ++i) WriteChunkHeader(session, 2); + session.Execute(); + + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); +} + + int main(int argc, char **argv) { logging::init_sync(); logging::log->pipe(std::make_unique<Stdout>()); + // Set the interpret to true to avoid calling the compiler which only + // supports a limited set of queries. + CONFIG(config::INTERPRET) = "true"; ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }