Refactored bolt session to use new decoder.

Summary:
Bolt buffer is now a template.

Communication worker now has a new interface.

Fixed network tests to use new interface.

Fixed bolt tests to use new interface.

Added more functions to bolt decoder.

Reviewers: dgleich, buda

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D256
This commit is contained in:
Matej Ferencevic 2017-04-15 15:14:12 +02:00
parent db30e72069
commit f05fcd91c3
39 changed files with 1260 additions and 1112 deletions

View File

@ -328,8 +328,6 @@ set(memgraph_src_files
${src_dir}/utils/string/join.cpp
${src_dir}/utils/string/file.cpp
${src_dir}/utils/numerics/saturate.cpp
${src_dir}/communication/bolt/v1/transport/bolt_decoder.cpp
${src_dir}/communication/bolt/v1/transport/buffer.cpp
${src_dir}/io/network/addrinfo.cpp
${src_dir}/io/network/network_endpoint.cpp
${src_dir}/io/network/socket.cpp

View File

@ -1,11 +0,0 @@
#pragma once
#include "utils/exceptions/basic_exception.hpp"
namespace communication::bolt {
class BoltException : public BasicException {
public:
using BasicException::BasicException;
};
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <cstdint>
#include "utils/underlying_cast.hpp"
namespace communication::bolt {
@ -30,6 +31,15 @@ enum class Marker : uint8_t {
TinyMap = 0xA0,
TinyStruct = 0xB0,
// TinyStructX represents the value of TinyStruct + X
// This is defined to make decoding easier. To check if a marker is equal
// to TinyStruct + 1 you should use something like:
// underyling_cast(marker) == underyling_cast(Marker::TinyStruct) + 1
// This way you can just use:
// marker == Marker::TinyStruct1
TinyStruct1 = 0xB1,
TinyStruct2 = 0xB2,
Null = 0xC0,
Float64 = 0xC1,
@ -58,9 +68,12 @@ enum class Marker : uint8_t {
};
static constexpr uint8_t MarkerString = 0, MarkerList = 1, MarkerMap = 2;
static constexpr Marker MarkerTiny[3] = {Marker::TinyString, Marker::TinyList, Marker::TinyMap};
static constexpr Marker Marker8[3] = {Marker::String8, Marker::List8, Marker::Map8};
static constexpr Marker Marker16[3] = {Marker::String16, Marker::List16, Marker::Map16};
static constexpr Marker Marker32[3] = {Marker::String32, Marker::List32, Marker::Map32};
static constexpr Marker MarkerTiny[3] = {Marker::TinyString, Marker::TinyList,
Marker::TinyMap};
static constexpr Marker Marker8[3] = {Marker::String8, Marker::List8,
Marker::Map8};
static constexpr Marker Marker16[3] = {Marker::String16, Marker::List16,
Marker::Map16};
static constexpr Marker Marker32[3] = {Marker::String32, Marker::List32,
Marker::Map32};
}

View File

@ -10,4 +10,9 @@ static constexpr size_t MAX_CHUNK_SIZE = 65535;
static constexpr size_t CHUNK_END_MARKER_SIZE = 2;
static constexpr size_t WHOLE_CHUNK_SIZE =
CHUNK_HEADER_SIZE + MAX_CHUNK_SIZE + CHUNK_END_MARKER_SIZE;
/**
* Handshake size defined in the Bolt protocol.
*/
static constexpr size_t HANDSHAKE_SIZE = 20;
}

View File

@ -21,7 +21,11 @@ namespace communication::bolt {
* Allocating, writing and written stores data in the buffer. The stored
* data can then be read using the pointer returned with the data function.
* The current implementation stores data in a single fixed length buffer.
*
* @tparam Size the size of the internal byte array, defaults to the maximum
* size of a chunk in the Bolt protocol
*/
template <size_t Size = WHOLE_CHUNK_SIZE>
class Buffer : public Loggable {
private:
using StreamBufferT = io::network::StreamBuffer;
@ -36,7 +40,7 @@ class Buffer : public Loggable {
* available memory.
*/
StreamBufferT Allocate() {
return StreamBufferT{&data_[size_], WHOLE_CHUNK_SIZE - size_};
return StreamBufferT{&data_[size_], Size - size_};
}
/**
@ -51,7 +55,7 @@ class Buffer : public Loggable {
*/
void Written(size_t len) {
size_ += len;
debug_assert(size_ <= WHOLE_CHUNK_SIZE, "Written more than storage has space!");
debug_assert(size_ <= Size, "Written more than storage has space!");
}
/**
@ -70,9 +74,7 @@ class Buffer : public Loggable {
/**
* This method clears the buffer.
*/
void Clear() {
size_ = 0;
}
void Clear() { size_ = 0; }
/**
* This function returns a pointer to the internal buffer. It is used for
@ -86,7 +88,7 @@ class Buffer : public Loggable {
size_t size() { return size_; }
private:
uint8_t data_[WHOLE_CHUNK_SIZE];
uint8_t data_[Size];
size_t size_{0};
};
}

View File

@ -12,6 +12,22 @@
namespace communication::bolt {
/**
* This class is used as the return value of the GetChunk function of the
* ChunkedDecoderBuffer. It represents the 3 situations that can happen when
* reading a chunk.
*/
enum class ChunkState : uint8_t {
// The chunk isn't complete, we have to read more data
Partial,
// The chunk is invalid, it's tail isn't 0x00 0x00
Invalid,
// The chunk is whole and correct and has been loaded into the buffer
Whole
};
/**
* @brief ChunkedDecoderBuffer
*
@ -27,7 +43,8 @@ class ChunkedDecoderBuffer : public Loggable {
using StreamBufferT = io::network::StreamBuffer;
public:
ChunkedDecoderBuffer(Buffer &buffer) : Loggable("ChunkedDecoderBuffer"), buffer_(buffer) {}
ChunkedDecoderBuffer(Buffer<> &buffer)
: Loggable("ChunkedDecoderBuffer"), buffer_(buffer) {}
/**
* Reads data from the internal buffer.
@ -55,27 +72,27 @@ class ChunkedDecoderBuffer : public Loggable {
* @returns true if a chunk was successfully copied into the internal
* buffer, false otherwise
*/
bool GetChunk() {
ChunkState GetChunk() {
uint8_t *data = buffer_.data();
size_t size = buffer_.size();
if (size < 2) {
logger.trace("Size < 2");
return false;
return ChunkState::Partial;
}
size_t chunk_size = data[0];
chunk_size <<= 8;
chunk_size += data[1];
if (size < chunk_size + 4) {
logger.trace("Chunk size is {} but only have {} data bytes.", chunk_size, size);
return false;
logger.trace("Chunk size is {} but only have {} data bytes.", chunk_size,
size);
return ChunkState::Partial;
}
if (data[chunk_size + 2] != 0 || data[chunk_size + 3] != 0) {
logger.trace("Invalid chunk!");
buffer_.Shift(chunk_size + 4);
// TODO: raise an exception!
return false;
return ChunkState::Invalid;
}
pos_ = 0;
@ -83,11 +100,18 @@ class ChunkedDecoderBuffer : public Loggable {
memcpy(data_, data + 2, size - 4);
buffer_.Shift(chunk_size + 4);
return true;
return ChunkState::Whole;
}
/**
* Gets the size of currently available data in the loaded chunk.
*
* @returns size of available data
*/
size_t Size() { return size_ - pos_; }
private:
Buffer &buffer_;
Buffer<> &buffer_;
uint8_t data_[MAX_CHUNK_SIZE];
size_t size_{0};
size_t pos_{0};

View File

@ -46,8 +46,7 @@ template <typename Buffer>
class Decoder : public Loggable {
public:
Decoder(Buffer &buffer)
: Loggable("communication::bolt::Decoder"),
buffer_(buffer) {}
: Loggable("communication::bolt::Decoder"), buffer_(buffer) {}
/**
* Reads a TypedValue from the available data in the buffer.
@ -136,6 +135,31 @@ class Decoder : public Loggable {
return true;
}
/**
* Reads a Message header from the available data in the buffer.
*
* @param signature pointer to a Signature where the signature should be
* stored
* @param marker pointer to a Signature where the marker should be stored
* @returns true if data has been written into the data pointers,
* false otherwise
*/
bool ReadMessageHeader(Signature *signature, Marker *marker) {
uint8_t values[2];
logger.trace("[ReadMessageHeader] Start");
if (!buffer_.Read(values, 2)) {
logger.debug("[ReadMessageHeader] Marker data missing!");
return false;
}
*marker = (Marker)values[0];
*signature = (Signature)values[1];
logger.trace("[ReadMessageHeader] Success");
return true;
}
/**
* Reads a Vertex from the available data in the buffer.
* This function tries to read a Vertex from the available data.
@ -306,7 +330,7 @@ class Decoder : public Loggable {
logger.trace("[ReadInt] Found an Int8");
int8_t tmp;
if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) {
logger.debug( "[ReadInt] Int8 missing data!");
logger.debug("[ReadInt] Int8 missing data!");
return false;
}
ret = tmp;
@ -314,7 +338,7 @@ class Decoder : public Loggable {
logger.trace("[ReadInt] Found an Int16");
int16_t tmp;
if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) {
logger.debug( "[ReadInt] Int16 missing data!");
logger.debug("[ReadInt] Int16 missing data!");
return false;
}
ret = bswap(tmp);
@ -322,19 +346,20 @@ class Decoder : public Loggable {
logger.trace("[ReadInt] Found an Int32");
int32_t tmp;
if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) {
logger.debug( "[ReadInt] Int32 missing data!");
logger.debug("[ReadInt] Int32 missing data!");
return false;
}
ret = bswap(tmp);
} else if (marker == Marker::Int64) {
logger.trace("[ReadInt] Found an Int64");
if (!buffer_.Read(reinterpret_cast<uint8_t *>(&ret), sizeof(ret))) {
logger.debug( "[ReadInt] Int64 missing data!");
logger.debug("[ReadInt] Int64 missing data!");
return false;
}
ret = bswap(ret);
} else {
logger.debug("[ReadInt] Received invalid marker ({})!", underlying_cast(marker));
logger.debug("[ReadInt] Received invalid marker ({})!",
underlying_cast(marker));
return false;
}
if (success) {
@ -350,7 +375,7 @@ class Decoder : public Loggable {
logger.trace("[ReadDouble] Start");
debug_assert(marker == Marker::Float64, "Received invalid marker!");
if (!buffer_.Read(reinterpret_cast<uint8_t *>(&value), sizeof(value))) {
logger.debug( "[ReadDouble] Missing data!");
logger.debug("[ReadDouble] Missing data!");
return false;
}
value = bswap(value);
@ -369,7 +394,7 @@ class Decoder : public Loggable {
logger.trace("[ReadTypeSize] Found a Type8");
uint8_t tmp;
if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) {
logger.debug( "[ReadTypeSize] Type8 missing data!");
logger.debug("[ReadTypeSize] Type8 missing data!");
return -1;
}
return tmp;
@ -377,7 +402,7 @@ class Decoder : public Loggable {
logger.trace("[ReadTypeSize] Found a Type16");
uint16_t tmp;
if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) {
logger.debug( "[ReadTypeSize] Type16 missing data!");
logger.debug("[ReadTypeSize] Type16 missing data!");
return -1;
}
tmp = bswap(tmp);
@ -386,13 +411,14 @@ class Decoder : public Loggable {
logger.trace("[ReadTypeSize] Found a Type32");
uint32_t tmp;
if (!buffer_.Read(reinterpret_cast<uint8_t *>(&tmp), sizeof(tmp))) {
logger.debug( "[ReadTypeSize] Type32 missing data!");
logger.debug("[ReadTypeSize] Type32 missing data!");
return -1;
}
tmp = bswap(tmp);
return tmp;
} else {
logger.debug("[ReadTypeSize] Received invalid marker ({})!", underlying_cast(marker));
logger.debug("[ReadTypeSize] Received invalid marker ({})!",
underlying_cast(marker));
return -1;
}
}
@ -409,7 +435,8 @@ class Decoder : public Loggable {
logger.debug("[ReadString] Missing data!");
return false;
}
*data = query::TypedValue(std::string(reinterpret_cast<char *>(ret.get()), size));
*data = query::TypedValue(
std::string(reinterpret_cast<char *>(ret.get()), size));
logger.trace("[ReadString] Success");
return true;
}
@ -462,7 +489,8 @@ class Decoder : public Loggable {
ret.insert(std::make_pair(str, tv));
}
if (ret.size() != size) {
logger.debug("[ReadMap] The client sent multiple objects with same indexes!");
logger.debug(
"[ReadMap] The client sent multiple objects with same indexes!");
return false;
}

View File

@ -6,7 +6,6 @@
#include "logging/logger.hpp"
#include "query/typed_value.hpp"
#include "utils/bswap.hpp"
#include "utils/underlying_cast.hpp"
#include <string>

View File

@ -5,7 +5,6 @@
#include <memory>
#include <vector>
#include "communication/bolt/v1/bolt_exception.hpp"
#include "communication/bolt/v1/constants.hpp"
#include "logging/loggable.hpp"
#include "utils/bswap.hpp"
@ -39,7 +38,8 @@ namespace communication::bolt {
template <class Socket>
class ChunkedEncoderBuffer : public Loggable {
public:
ChunkedEncoderBuffer(Socket &socket) : Loggable("Chunked Encoder Buffer"), socket_(socket) {}
ChunkedEncoderBuffer(Socket &socket)
: Loggable("Chunked Encoder Buffer"), socket_(socket) {}
/**
* Writes n values into the buffer. If n is bigger than whole chunk size
@ -51,7 +51,7 @@ class ChunkedEncoderBuffer : public Loggable {
void Write(const uint8_t *values, size_t n) {
while (n > 0) {
// Define number of bytes which will be copied into chunk because
// chunk is a fixed lenght array.
// chunk is a fixed length array.
auto size = n < MAX_CHUNK_SIZE + CHUNK_HEADER_SIZE - pos_
? n
: MAX_CHUNK_SIZE + CHUNK_HEADER_SIZE - pos_;
@ -90,13 +90,16 @@ class ChunkedEncoderBuffer : public Loggable {
debug_assert(pos_ <= WHOLE_CHUNK_SIZE,
"Internal variable pos_ is bigger than the whole chunk size.");
// 3. Copy whole chunk into the buffer.
// 3. Remember first chunk size.
if (first_chunk_size_ == -1) first_chunk_size_ = pos_;
// 4. Copy whole chunk into the buffer.
size_ += pos_;
buffer_.reserve(size_);
std::copy(chunk_.begin(), chunk_.begin() + pos_,
std::back_inserter(buffer_));
// 4. Cleanup.
// 5. Cleanup.
// * pos_ has to be reset to the size of chunk header (reserved
// space for the chunk size)
pos_ = CHUNK_HEADER_SIZE;
@ -104,24 +107,72 @@ class ChunkedEncoderBuffer : public Loggable {
/**
* Sends the whole buffer(message) to the client.
* @returns true if the data was successfully sent to the client
* false otherwise
*/
void Flush() {
bool Flush() {
// Call chunk if is hasn't been called.
if (pos_ > CHUNK_HEADER_SIZE) Chunk();
// Early return if buffer is empty because there is nothing to write.
if (size_ == 0) return;
if (size_ == 0) return true;
// Flush the whole buffer.
bool written = socket_.Write(buffer_.data(), size_);
if (!written) throw BoltException("Socket write failed!");
if (!socket_.Write(buffer_.data() + offset_, size_ - offset_)) return false;
logger.trace("Flushed {} bytes.", size_);
// Cleanup.
Clear();
return true;
}
/**
* Sends only the first message chunk in the buffer to the client.
* @returns true if the data was successfully sent to the client
* false otherwise
*/
bool FlushFirstChunk() {
// Call chunk if is hasn't been called.
if (pos_ > CHUNK_HEADER_SIZE) Chunk();
// Early return if buffer is empty because there is nothing to write.
if (size_ == 0) return false;
// Early return if there is no first chunk
if (first_chunk_size_ == -1) return false;
// Flush the first chunk
if (!socket_.Write(buffer_.data(), first_chunk_size_)) return false;
logger.trace("Flushed {} bytes.", first_chunk_size_);
// Cleanup.
// Here we use offset as a method of deleting from the front of the
// data vector. Because the first chunk will always be relatively
// small comparing to the rest of the data it is more optimal just to
// skip the first part of the data than to shift everything in the
// vector buffer.
offset_ = first_chunk_size_;
first_chunk_size_ = -1;
return true;
}
/**
* Clears the internal buffers.
*/
void Clear() {
buffer_.clear();
size_ = 0;
first_chunk_size_ = -1;
offset_ = 0;
}
/**
* Returns a boolean indicating whether there is data in the buffer.
* @returns true if there is data in the buffer,
* false otherwise
*/
bool HasData() { return buffer_.size() > 0 || size_ > 0; }
private:
/**
* A client socket.
@ -143,6 +194,16 @@ class ChunkedEncoderBuffer : public Loggable {
*/
size_t size_{0};
/**
* Size of first chunk in the buffer.
*/
int32_t first_chunk_size_{-1};
/**
* Offset from the start of the buffer.
*/
size_t offset_{0};
/**
* Current position in chunk array.
*/

View File

@ -1,5 +1,6 @@
#pragma once
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/encoder/base_encoder.hpp"
namespace communication::bolt {
@ -21,8 +22,7 @@ class Encoder : private BaseEncoder<Buffer> {
using BaseEncoder<Buffer>::buffer_;
public:
Encoder(Buffer &buffer)
: BaseEncoder<Buffer>(buffer) {
Encoder(Buffer &buffer) : BaseEncoder<Buffer>(buffer) {
logger = logging::log->logger("communication::bolt::Encoder");
}
@ -40,8 +40,8 @@ class Encoder : private BaseEncoder<Buffer> {
* @param values the fields list object that should be sent
*/
void MessageRecord(const std::vector<query::TypedValue> &values) {
// 0xB1 = struct 1; 0x71 = record signature
WriteRAW("\xB1\x71", 2);
WriteRAW(underlying_cast(Marker::TinyStruct1));
WriteRAW(underlying_cast(Signature::Record));
WriteList(values);
buffer_.Chunk();
}
@ -56,26 +56,34 @@ class Encoder : private BaseEncoder<Buffer> {
*
* @param metadata the metadata map object that should be sent
* @param flush should method flush the socket
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
void MessageSuccess(const std::map<std::string, query::TypedValue> &metadata,
bool MessageSuccess(const std::map<std::string, query::TypedValue> &metadata,
bool flush = true) {
// 0xB1 = struct 1; 0x70 = success signature
WriteRAW("\xB1\x70", 2);
WriteRAW(underlying_cast(Marker::TinyStruct1));
WriteRAW(underlying_cast(Signature::Success));
WriteMap(metadata);
if (flush)
buffer_.Flush();
else
if (flush) {
return buffer_.Flush();
} else {
buffer_.Chunk();
// Chunk always succeeds, so return true
return true;
}
}
/**
* Sends a Success message.
*
* This function sends a success message without additional metadata.
*
* @returns true if the data was successfully sent to the client,
* false otherwise
*/
void MessageSuccess() {
bool MessageSuccess() {
std::map<std::string, query::TypedValue> metadata;
MessageSuccess(metadata);
return MessageSuccess(metadata);
}
/**
@ -87,12 +95,15 @@ class Encoder : private BaseEncoder<Buffer> {
* }
*
* @param metadata the metadata map object that should be sent
* @returns true if the data was successfully sent to the client,
* false otherwise
*/
void MessageFailure(const std::map<std::string, query::TypedValue> &metadata) {
// 0xB1 = struct 1; 0x7F = failure signature
WriteRAW("\xB1\x7F", 2);
bool MessageFailure(
const std::map<std::string, query::TypedValue> &metadata) {
WriteRAW(underlying_cast(Marker::TinyStruct1));
WriteRAW(underlying_cast(Signature::Failure));
WriteMap(metadata);
buffer_.Flush();
return buffer_.Flush();
}
/**
@ -104,23 +115,29 @@ class Encoder : private BaseEncoder<Buffer> {
* }
*
* @param metadata the metadata map object that should be sent
* @returns true if the data was successfully sent to the client,
* false otherwise
*/
void MessageIgnored(const std::map<std::string, query::TypedValue> &metadata) {
// 0xB1 = struct 1; 0x7E = ignored signature
WriteRAW("\xB1\x7E", 2);
bool MessageIgnored(
const std::map<std::string, query::TypedValue> &metadata) {
WriteRAW(underlying_cast(Marker::TinyStruct1));
WriteRAW(underlying_cast(Signature::Ignored));
WriteMap(metadata);
buffer_.Flush();
return buffer_.Flush();
}
/**
* Sends an Ignored message.
*
* This function sends an ignored message without additional metadata.
*
* @returns true if the data was successfully sent to the client,
* false otherwise
*/
void MessageIgnored() {
// 0xB0 = struct 0; 0x7E = ignored signature
WriteRAW("\xB0\x7E", 2);
buffer_.Flush();
bool MessageIgnored() {
WriteRAW(underlying_cast(Marker::TinyStruct));
WriteRAW(underlying_cast(Signature::Ignored));
return buffer_.Flush();
}
};
}

View File

@ -31,8 +31,10 @@ class ResultStream {
std::map<std::string, query::TypedValue> data;
for (auto &i : fields) vec.push_back(query::TypedValue(i));
data.insert(std::make_pair(std::string("fields"), query::TypedValue(vec)));
// this call will automaticaly send the data to the client
encoder_.MessageSuccess(data);
// this message shouldn't send directly to the client because if an error
// happened the client will receive two messages (success and failure)
// instead of only one
encoder_.MessageSuccess(data, false);
}
/**

View File

@ -1,38 +0,0 @@
#pragma once
#include "utils/types/byte.hpp"
#include "utils/underlying_cast.hpp"
namespace communication::bolt {
enum class MessageCode : byte {
Init = 0x01,
AckFailure = 0x0E,
Reset = 0x0F,
Run = 0x10,
DiscardAll = 0x2F,
PullAll = 0x3F,
Record = 0x71,
Success = 0x70,
Ignored = 0x7E,
Failure = 0x7F
};
inline bool operator==(byte value, MessageCode code) {
return value == underlying_cast(code);
}
inline bool operator==(MessageCode code, byte value) {
return operator==(value, code);
}
inline bool operator!=(byte value, MessageCode code) {
return !operator==(value, code);
}
inline bool operator!=(MessageCode code, byte value) {
return operator!=(value, code);
}
}

View File

@ -1,55 +0,0 @@
#pragma once
#include <cstdint>
namespace communication::bolt::pack {
enum Code : uint8_t {
TinyString = 0x80,
TinyList = 0x90,
TinyMap = 0xA0,
TinyStruct = 0xB0,
StructOne = 0xB1,
StructTwo = 0xB2,
Null = 0xC0,
Float64 = 0xC1,
False = 0xC2,
True = 0xC3,
Int8 = 0xC8,
Int16 = 0xC9,
Int32 = 0xCA,
Int64 = 0xCB,
Bytes8 = 0xCC,
Bytes16 = 0xCD,
Bytes32 = 0xCE,
String8 = 0xD0,
String16 = 0xD1,
String32 = 0xD2,
List8 = 0xD4,
List16 = 0xD5,
List32 = 0xD6,
Map8 = 0xD8,
Map16 = 0xD9,
Map32 = 0xDA,
MapStream = 0xDB,
Node = 0x4E,
Relationship = 0x52,
Path = 0x50,
Struct8 = 0xDC,
Struct16 = 0xDD,
EndOfStream = 0xDF,
};
enum Rule : uint8_t { MaxInitStructSize = 0x02 };
}

View File

@ -1,39 +0,0 @@
#pragma once
namespace communication::bolt {
enum class PackType {
/** denotes absence of a value */
Null,
/** denotes a type with two possible values (t/f) */
Boolean,
/** 64-bit signed integral number */
Integer,
/** 64-bit floating point number */
Float,
/** binary data */
Bytes,
/** unicode string */
String,
/** collection of values */
List,
/** collection of zero or more key/value pairs */
Map,
/** zero or more packstream values */
Struct,
/** denotes stream value end */
EndOfStream,
/** reserved for future use */
Reserved
};
}

View File

@ -6,15 +6,19 @@
#include "dbms/dbms.hpp"
#include "query/engine.hpp"
#include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/states/error.hpp"
#include "communication/bolt/v1/states/executor.hpp"
#include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/states/init.hpp"
#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp"
#include "communication/bolt/v1/decoder/decoder.hpp"
#include "communication/bolt/v1/encoder/encoder.hpp"
#include "communication/bolt/v1/encoder/result_stream.hpp"
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "io/network/stream_buffer.hpp"
#include "logging/loggable.hpp"
@ -29,10 +33,11 @@ namespace communication::bolt {
*/
template <typename Socket>
class Session : public Loggable {
public:
using Decoder = BoltDecoder;
private:
using OutputStream = ResultStream<Encoder<ChunkedEncoderBuffer<Socket>>>;
using StreamBuffer = io::network::StreamBuffer;
public:
Session(Socket &&socket, Dbms &dbms, QueryEngine<OutputStream> &query_engine)
: Loggable("communication::bolt::Session"),
socket_(std::move(socket)),
@ -40,16 +45,17 @@ class Session : public Loggable {
query_engine_(query_engine),
encoder_buffer_(socket_),
encoder_(encoder_buffer_),
output_stream_(encoder_) {
output_stream_(encoder_),
decoder_buffer_(buffer_),
decoder_(decoder_buffer_),
state_(State::Handshake) {
event_.data.ptr = this;
// start with a handshake state
state_ = HANDSHAKE;
}
/**
* @return is the session in a valid state
*/
bool Alive() const { return state_ != NULLSTATE; }
bool Alive() const { return state_ != State::Close; }
/**
* @return the socket id
@ -57,48 +63,83 @@ class Session : public Loggable {
int Id() const { return socket_.id(); }
/**
* Reads the data from a client and goes through the bolt states in
* order to execute command from the client.
*
* @param data pointer on bytes received from a client
* @param len length of data received from a client
* Executes the session after data has been read into the buffer.
* Goes through the bolt states in order to execute commands from the client.
*/
void Execute(const uint8_t *data, size_t len) {
// mark the end of the message
auto end = data + len;
while (true) {
auto size = end - data;
void Execute() {
// while there is data in the buffers
while (buffer_.size() > 0 || decoder_buffer_.Size() > 0) {
if (LIKELY(connected_)) {
logger.debug("Decoding chunk of size {}", size);
if (!decoder_.decode(data, size)) return;
logger.debug("Decoding chunk of size {}", buffer_.size());
auto chunk_state = decoder_buffer_.GetChunk();
if (chunk_state == ChunkState::Partial) {
logger.trace("Chunk isn't complete!");
return;
} else if (chunk_state == ChunkState::Invalid) {
logger.trace("Chunk is invalid!");
ClientFailureInvalidData();
return;
}
// if chunk_state == ChunkState::Whole then we continue with
// execution of the select below
} else if (buffer_.size() < HANDSHAKE_SIZE) {
logger.debug("Received partial handshake of size {}", buffer_.size());
return;
} else if (buffer_.size() > HANDSHAKE_SIZE) {
logger.debug("Received too large handshake of size {}", buffer_.size());
ClientFailureInvalidData();
return;
} else {
logger.debug("Decoding handshake of size {}", size);
decoder_.handshake(data, size);
logger.debug("Decoding handshake of size {}", buffer_.size());
}
switch (state_) {
case HANDSHAKE:
case State::Handshake:
state_ = StateHandshakeRun<Session<Socket>>(*this);
break;
case INIT:
case State::Init:
state_ = StateInitRun<Session<Socket>>(*this);
break;
case EXECUTOR:
case State::Executor:
state_ = StateExecutorRun<Session<Socket>>(*this);
break;
case ERROR:
case State::Error:
state_ = StateErrorRun<Session<Socket>>(*this);
break;
case NULLSTATE:
case State::Close:
// This state is handled below
break;
}
decoder_.reset();
// State::Close is handled here because we always want to check for
// it after the above select. If any of the states above return a
// State::Close then the connection should be terminated immediately.
if (state_ == State::Close) {
ClientFailureInvalidData();
return;
}
logger.trace("Buffer size: {}", buffer_.size());
logger.trace("Decoder buffer size: {}", decoder_buffer_.Size());
}
}
/**
* Allocates data from the internal buffer.
* Used in the underlying network stack to asynchronously read data
* from the client.
* @returns a StreamBuffer to the allocated internal data buffer
*/
StreamBuffer Allocate() { return buffer_.Allocate(); }
/**
* Notifies the internal buffer of written data.
* Used in the underlying network stack to notify the internal buffer
* how many bytes of data have been written.
* @param len how many data was written to the buffer
*/
void Written(size_t len) { buffer_.Written(len); }
/**
* Closes the session (client socket).
*/
@ -112,12 +153,29 @@ class Session : public Loggable {
Socket socket_;
Dbms &dbms_;
QueryEngine<OutputStream> &query_engine_;
ChunkedEncoderBuffer<Socket> encoder_buffer_;
Encoder<ChunkedEncoderBuffer<Socket>> encoder_;
OutputStream output_stream_;
Decoder decoder_;
Buffer<> buffer_;
ChunkedDecoderBuffer decoder_buffer_;
Decoder<ChunkedDecoderBuffer> decoder_;
io::network::Epoll::Event event_;
bool connected_{false};
State state_;
private:
void ClientFailureInvalidData() {
// set the state to Close
state_ = State::Close;
// don't care about the return status because this is always
// called when we are about to close the connection to the client
encoder_.MessageFailure({{"code", "Memgraph.InvalidData"},
{"message", "The client has sent invalid data!"}});
// close the connection
Close();
}
};
}

View File

@ -3,15 +3,36 @@
namespace communication::bolt {
/**
* TODO (mferencevic): change to a class enum & document (explain states in
* more details)
* This class represents states in execution of the Bolt protocol.
* It is used only internally in the Session. All functions that run
* these states can be found in the states/ subdirectory.
*/
enum State {
HANDSHAKE,
INIT,
EXECUTOR,
ERROR,
NULLSTATE
};
enum class State : uint8_t {
/**
* This state negotiates a handshake with the client.
*/
Handshake,
/**
* This state initializes the Bolt session.
*/
Init,
/**
* This state executes commands from the Bolt protocol.
*/
Executor,
/**
* This state handles errors.
*/
Error,
/**
* This is a 'virtual' state (it doesn't have a run function) which tells
* the session that the client has sent malformed data and that the
* session should be closed.
*/
Close
};
}

View File

@ -1,38 +1,72 @@
#pragma once
#include "communication/bolt/v1/messaging/codes.hpp"
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/state.hpp"
#include "logging/default.hpp"
namespace communication::bolt {
/**
* TODO (mferencevic): finish & document
* Error state run function
* This function handles a Bolt session when it is in an error state.
* The error state is exited upon receiving an ACK_FAILURE or RESET message.
* @param session the session that should be used for the run
*/
template <typename Session>
State StateErrorRun(Session &session) {
static Logger logger = logging::log->logger("State ERROR");
session.decoder_.read_byte();
auto message_type = session.decoder_.read_byte();
logger.trace("Message type byte is: {:02X}", message_type);
if (message_type == MessageCode::PullAll) {
session.encoder_.MessageIgnored();
return ERROR;
} else if (message_type == MessageCode::AckFailure) {
// TODO reset current statement? is it even necessary?
logger.trace("AckFailure received");
session.encoder_.MessageSuccess();
return EXECUTOR;
} else if (message_type == MessageCode::Reset) {
// TODO rollback current transaction
// discard all records waiting to be sent
session.encoder_.MessageSuccess();
return EXECUTOR;
Marker marker;
Signature signature;
if (!session.decoder_.ReadMessageHeader(&signature, &marker)) {
logger.debug("Missing header data!");
return State::Close;
}
logger.trace("Message signature is: 0x{:02X}", underlying_cast(signature));
// clear the data buffer if it has any leftover data
session.encoder_buffer_.Clear();
if (signature == Signature::AckFailure || signature == Signature::Reset) {
if (signature == Signature::AckFailure)
logger.trace("AckFailure received");
else
logger.trace("Reset received");
if (!session.encoder_.MessageSuccess()) {
logger.debug("Couldn't send success message!");
return State::Close;
}
return State::Executor;
} else {
uint8_t value = underlying_cast(marker);
// all bolt client messages have less than 15 parameters
// so if we receive anything than a TinyStruct it's an error
if ((value & 0xF0) != underlying_cast(Marker::TinyStruct)) {
logger.debug("Expected TinyStruct marker, but received 0x{:02X}!", value);
return State::Close;
}
// we need to clean up all parameters from this command
value &= 0x0F; // the length is stored in the lower nibble
query::TypedValue tv;
for (int i = 0; i < value; ++i) {
if (!session.decoder_.ReadTypedValue(&tv)) {
logger.debug("Couldn't clean up parameter {} / {}!", i, value);
return State::Close;
}
}
// ignore this message
if (!session.encoder_.MessageIgnored()) {
logger.debug("Couldn't send ignored message!");
return State::Close;
}
// cleanup done, command ignored, stay in error state
return State::Error;
}
session.encoder_.MessageIgnored();
return ERROR;
}
}

View File

@ -2,131 +2,209 @@
#include <string>
#include "communication/bolt/v1/bolt_exception.hpp"
#include "communication/bolt/v1/messaging/codes.hpp"
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/state.hpp"
#include "logging/default.hpp"
#include "query/exceptions.hpp"
namespace communication::bolt {
struct Query {
Query(std::string &&statement) : statement(statement) {}
std::string statement;
};
template <typename Session>
void StateExecutorFailure(
Session &session, Logger &logger,
const std::map<std::string, query::TypedValue> &metadata) {
try {
session.encoder_.MessageFailure(metadata);
} catch (const BoltException &e) {
logger.debug("MessageFailure failed because: {}", e.what());
session.Close();
}
}
/**
* TODO (mferencevic): finish & document
* Executor state run function
* This function executes an initialized Bolt session.
* It executes: RUN, PULL_ALL, DISCARD_ALL & RESET.
* @param session the session that should be used for the run
*/
template <typename Session>
State StateExecutorRun(Session &session) {
// initialize logger
static Logger logger = logging::log->logger("State EXECUTOR");
// just read one byte that represents the struct type, we can skip the
// information contained in this byte
session.decoder_.read_byte();
auto message_type = session.decoder_.read_byte();
Marker marker;
Signature signature;
if (!session.decoder_.ReadMessageHeader(&signature, &marker)) {
logger.debug("Missing header data!");
return State::Close;
}
if (message_type == MessageCode::Run) {
Query query(session.decoder_.read_string());
if (signature == Signature::Run) {
if (marker != Marker::TinyStruct2) {
logger.debug("Expected TinyStruct2 marker, but received 0x{:02X}!",
underlying_cast(marker));
return State::Close;
}
query::TypedValue query, params;
if (!session.decoder_.ReadTypedValue(&query,
query::TypedValue::Type::String)) {
logger.debug("Couldn't read query string!");
return State::Close;
}
if (!session.decoder_.ReadTypedValue(&params,
query::TypedValue::Type::Map)) {
logger.debug("Couldn't read parameters!");
return State::Close;
}
// TODO (mferencevic): implement proper exception handling
auto db_accessor = session.dbms_.active();
logger.debug("[ActiveDB] '{}'", db_accessor->name());
try {
logger.trace("[Run] '{}'", query.statement);
logger.trace("[Run] '{}'", query.Value<std::string>());
auto is_successfully_executed = session.query_engine_.Run(
query.statement, *db_accessor, session.output_stream_);
query.Value<std::string>(), *db_accessor, session.output_stream_);
if (!is_successfully_executed) {
// abort transaction
db_accessor->abort();
StateExecutorFailure<Session>(
session, logger,
// clear any leftover messages in the buffer
session.encoder_buffer_.Clear();
// send failure message
bool exec_fail_sent = session.encoder_.MessageFailure(
{{"code", "Memgraph.QueryExecutionFail"},
{"message",
"Query execution has failed (probably there is no "
"element or there are some problems with concurrent "
"access -> client has to resolve problems with "
"concurrent access)"}});
return ERROR;
if (!exec_fail_sent) {
logger.debug("Couldn't send failure message!");
return State::Close;
} else {
logger.debug("Query execution failed!");
return State::Error;
}
} else {
db_accessor->commit();
// The query engine has already stored all query data in the buffer.
// We should only send the first chunk now which is the success
// message which contains header data. The rest of this data (records
// and summary) will be sent after a PULL_ALL command from the client.
if (!session.encoder_buffer_.FlushFirstChunk()) {
logger.debug("Couldn't flush header data from the buffer!");
return State::Close;
}
return State::Executor;
}
return EXECUTOR;
// !! QUERY ENGINE -> RUN METHOD -> EXCEPTION HANDLING !!
} catch (const query::SyntaxException &e) {
} catch (const BasicException &e) {
// clear header success message
session.encoder_buffer_.Clear();
db_accessor->abort();
StateExecutorFailure<Session>(
session, logger,
{{"code", "Memgraph.SyntaxException"}, {"message", "Syntax error"}});
return ERROR;
} catch (const query::QueryEngineException &e) {
db_accessor->abort();
StateExecutorFailure<Session>(
session, logger,
{{"code", "Memgraph.QueryEngineException"},
{"message", "Query engine was unable to execute the query"}});
return ERROR;
bool fail_sent = session.encoder_.MessageFailure(
{{"code", "Memgraph.Exception"}, {"message", e.what()}});
logger.debug("Error message: {}", e.what());
if (!fail_sent) {
logger.debug("Couldn't send failure message!");
return State::Close;
}
return State::Error;
} catch (const StacktraceException &e) {
// clear header success message
session.encoder_buffer_.Clear();
db_accessor->abort();
StateExecutorFailure<Session>(session, logger,
{{"code", "Memgraph.StacktraceException"},
{"message", "Unknown exception"}});
return ERROR;
} catch (const BoltException &e) {
db_accessor->abort();
logger.debug("Failed because: {}", e.what());
session.Close();
bool fail_sent = session.encoder_.MessageFailure(
{{"code", "Memgraph.Exception"}, {"message", e.what()}});
logger.debug("Error message: {}", e.what());
logger.debug("Error trace: {}", e.trace());
if (!fail_sent) {
logger.debug("Couldn't send failure message!");
return State::Close;
}
return State::Error;
} catch (std::exception &e) {
// clear header success message
session.encoder_buffer_.Clear();
db_accessor->abort();
StateExecutorFailure<Session>(
session, logger,
{{"code", "Memgraph.Exception"}, {"message", "Unknown exception"}});
return ERROR;
bool fail_sent = session.encoder_.MessageFailure(
{{"code", "Memgraph.Exception"},
{"message",
"An unknown exception occured, please contact your database "
"administrator."}});
logger.debug("Unknown exception!!!");
if (!fail_sent) {
logger.debug("Couldn't send failure message!");
return State::Close;
}
return State::Error;
}
// TODO (mferencevic): finish the error handling, cover all exceptions
// which can be raised from query engine
// * [abort, MessageFailure, return ERROR] should be extracted into
// separate function (or something equivalent)
//
// !! QUERY ENGINE -> RUN METHOD -> EXCEPTION HANDLING !!
} else if (message_type == MessageCode::PullAll) {
} else if (signature == Signature::PullAll) {
logger.trace("[PullAll]");
session.encoder_buffer_.Flush();
} else if (message_type == MessageCode::DiscardAll) {
if (marker != Marker::TinyStruct) {
logger.debug("Expected TinyStruct marker, but received 0x{:02X}!",
underlying_cast(marker));
return State::Close;
}
if (!session.encoder_buffer_.HasData()) {
// the buffer doesn't have data, return a failure message
bool data_fail_sent = session.encoder_.MessageFailure(
{{"code", "Memgraph.Exception"},
{"message",
"There is no data to "
"send, you have to execute a RUN command before a PULL_ALL!"}});
if (!data_fail_sent) {
logger.debug("Couldn't send failure message!");
return State::Close;
}
return State::Error;
}
// flush pending data to the client, the success message is streamed
// from the query engine, it contains statistics from the query run
if (!session.encoder_buffer_.Flush()) {
logger.debug("Couldn't flush data from the buffer!");
return State::Close;
}
return State::Executor;
} else if (signature == Signature::DiscardAll) {
logger.trace("[DiscardAll]");
if (marker != Marker::TinyStruct) {
logger.debug("Expected TinyStruct marker, but received 0x{:02X}!",
underlying_cast(marker));
return State::Close;
}
// clear all pending data and send a success message
session.encoder_buffer_.Clear();
if (!session.encoder_.MessageSuccess()) {
logger.debug("Couldn't send success message!");
return State::Close;
}
return State::Executor;
} else if (signature == Signature::Reset) {
// IMPORTANT: This implementation of the Bolt RESET command isn't fully
// compliant to the protocol definition. In the protocol it is defined
// that this command should immediately stop any running commands and
// reset the session to a clean state. That means that we should always
// make a look-ahead for the RESET command before processing anything.
// Our implementation, for now, does everything in a blocking fashion
// so we cannot simply "kill" a transaction while it is running. So
// now this command only resets the session to a clean state. It
// does not IGNORE running and pending commands as it should.
if (marker != Marker::TinyStruct) {
logger.debug("Expected TinyStruct marker, but received 0x{:02X}!",
underlying_cast(marker));
return State::Close;
}
// clear all pending data and send a success message
session.encoder_buffer_.Clear();
if (!session.encoder_.MessageSuccess()) {
logger.debug("Couldn't send success message!");
return State::Close;
}
return State::Executor;
// TODO: discard state
// TODO: write_success, send
session.encoder_.MessageSuccess();
} else if (message_type == MessageCode::Reset) {
// TODO: rollback current transaction
// discard all records waiting to be sent
session.encoder_.MessageSuccess();
return EXECUTOR;
} else {
logger.error("Unrecognized message recieved");
logger.debug("Invalid message type 0x{:02X}", message_type);
return ERROR;
logger.debug("Unrecognized signature recieved (0x{:02X})!",
underlying_cast(signature));
return State::Close;
}
return EXECUTOR;
}
}

View File

@ -1,31 +1,42 @@
#pragma once
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "logging/default.hpp"
namespace communication::bolt {
static constexpr uint32_t preamble = 0x6060B017;
static constexpr byte protocol[4] = {0x00, 0x00, 0x00, 0x01};
static constexpr uint8_t preamble[4] = {0x60, 0x60, 0xB0, 0x17};
static constexpr uint8_t protocol[4] = {0x00, 0x00, 0x00, 0x01};
/**
* TODO (mferencevic): finish & document
* Handshake state run function
* This function runs everything to make a Bolt handshake with the client.
* @param session the session that should be used for the run
*/
template <typename Session>
State StateHandshakeRun(Session &session) {
static Logger logger = logging::log->logger("State HANDSHAKE");
if (UNLIKELY(session.decoder_.read_uint32() != preamble)) return NULLSTATE;
auto precmp = memcmp(session.buffer_.data(), preamble, sizeof(preamble));
if (UNLIKELY(precmp != 0)) {
logger.debug("Received a wrong preamble!");
return State::Close;
}
// TODO so far we only support version 1 of the protocol so it doesn't
// make sense to check which version the client prefers
// this will change in the future
if (!session.socket_.Write(protocol, sizeof(protocol))) {
logger.debug("Couldn't write handshake response!");
return State::Close;
}
session.connected_ = true;
// TODO: check for success
session.socket_.Write(protocol, sizeof protocol);
return INIT;
// Delete data from buffer. It is guaranteed that there will be exactly
// 20 bytes in the buffer so we can use buffer_.size() here.
session.buffer_.Shift(session.buffer_.size());
return State::Init;
}
}

View File

@ -1,53 +1,62 @@
#pragma once
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/encoder/result_stream.hpp"
#include "communication/bolt/v1/messaging/codes.hpp"
#include "communication/bolt/v1/packing/codes.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "logging/default.hpp"
#include "utils/likely.hpp"
namespace communication::bolt {
/**
* TODO (mferencevic): finish & document
* Init state run function
* This function runs everything to initialize a Bolt session with the client.
* @param session the session that should be used for the run
*/
template <typename Session>
State StateInitRun(Session &session) {
static Logger logger = logging::log->logger("State INIT");
logger.debug("Parsing message");
auto struct_type = session.decoder_.read_byte();
if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) {
logger.debug("{}", struct_type);
logger.debug(
"Expected struct marker of max size 0x{:02} instead of 0x{:02X}",
(unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type);
return NULLSTATE;
Marker marker;
Signature signature;
if (!session.decoder_.ReadMessageHeader(&signature, &marker)) {
logger.debug("Missing header data!");
return State::Close;
}
auto message_type = session.decoder_.read_byte();
if (UNLIKELY(message_type != MessageCode::Init)) {
logger.debug("Expected Init (0x01) instead of (0x{:02X})",
(unsigned)message_type);
return NULLSTATE;
if (UNLIKELY(signature != Signature::Init)) {
logger.debug("Expected Init signature, but received 0x{:02X}!",
underlying_cast(signature));
return State::Close;
}
if (UNLIKELY(marker != Marker::TinyStruct2)) {
logger.debug("Expected TinyStruct2 marker, but received 0x{:02X}!",
underlying_cast(marker));
return State::Close;
}
auto client_name = session.decoder_.read_string();
if (struct_type == pack::Code::StructTwo) {
// TODO process authentication tokens
query::TypedValue client_name;
if (!session.decoder_.ReadTypedValue(&client_name,
query::TypedValue::Type::String)) {
logger.debug("Couldn't read client name!");
return State::Close;
}
logger.debug("Executing state");
logger.debug("Client connected '{}'", client_name);
query::TypedValue metadata;
if (!session.decoder_.ReadTypedValue(&metadata,
query::TypedValue::Type::Map)) {
logger.debug("Couldn't read metadata!");
return State::Close;
}
// TODO: write_success, chunk, send
session.encoder_.MessageSuccess();
logger.debug("Client connected '{}'", client_name.Value<std::string>());
return EXECUTOR;
if (!session.encoder_.MessageSuccess()) {
logger.debug("Couldn't send success message to the client!");
return State::Close;
}
return State::Executor;
}
}

View File

@ -1,108 +0,0 @@
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "communication/bolt/v1/packing/codes.hpp"
#include "logging/default.hpp"
#include "utils/bswap.hpp"
namespace communication::bolt {
void BoltDecoder::handshake(const byte *&data, size_t len) {
buffer.write(data, len);
data += len;
}
bool BoltDecoder::decode(const byte *&data, size_t len) {
return decoder(data, len);
}
bool BoltDecoder::empty() const { return pos == buffer.size(); }
void BoltDecoder::reset() {
buffer.clear();
pos = 0;
}
byte BoltDecoder::peek() const { return buffer[pos]; }
byte BoltDecoder::read_byte() { return buffer[pos++]; }
void BoltDecoder::read_bytes(void *dest, size_t n) {
std::memcpy(dest, buffer.data() + pos, n);
pos += n;
}
template <class T>
T parse(const void *data) {
// reinterpret bytes as the target value
auto value = reinterpret_cast<const T *>(data);
// swap values to little endian
return bswap(*value);
}
template <class T>
T parse(Buffer &buffer, size_t &pos) {
// get a pointer to the data we're converting
auto ptr = buffer.data() + pos;
// skip sizeof bytes that we're going to read
pos += sizeof(T);
// read and convert the value
return parse<T>(ptr);
}
int16_t BoltDecoder::read_int16() { return parse<int16_t>(buffer, pos); }
uint16_t BoltDecoder::read_uint16() { return parse<uint16_t>(buffer, pos); }
int32_t BoltDecoder::read_int32() { return parse<int32_t>(buffer, pos); }
uint32_t BoltDecoder::read_uint32() { return parse<uint32_t>(buffer, pos); }
int64_t BoltDecoder::read_int64() { return parse<int64_t>(buffer, pos); }
uint64_t BoltDecoder::read_uint64() { return parse<uint64_t>(buffer, pos); }
double BoltDecoder::read_float64() {
auto v = parse<int64_t>(buffer, pos);
return *reinterpret_cast<const double *>(&v);
}
std::string BoltDecoder::read_string() {
auto marker = read_byte();
std::string res;
uint32_t size;
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
// size is stored in the lower 4 bits of the marker byte
size = marker & 0x0F;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
size = read_byte();
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
size = read_uint16();
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
size = read_uint32();
} else {
// TODO error?
return res;
}
if (size == 0) return res;
res.append(reinterpret_cast<const char *>(raw()), size);
pos += size;
return res;
}
const byte *BoltDecoder::raw() const { return buffer.data() + pos; }
}

View File

@ -1,41 +0,0 @@
#pragma once
#include "communication/bolt/v1/transport/buffer.hpp"
#include "communication/bolt/v1/transport/chunked_decoder.hpp"
#include "utils/types/byte.hpp"
namespace communication::bolt {
class BoltDecoder {
public:
void handshake(const byte *&data, size_t len);
bool decode(const byte *&data, size_t len);
bool empty() const;
void reset();
byte peek() const;
byte read_byte();
void read_bytes(void *dest, size_t n);
int16_t read_int16();
uint16_t read_uint16();
int32_t read_int32();
uint32_t read_uint32();
int64_t read_int64();
uint64_t read_uint64();
double read_float64();
std::string read_string();
private:
Buffer buffer;
ChunkedDecoder<Buffer> decoder{buffer};
size_t pos{0};
const byte *raw() const;
};
}

View File

@ -1,10 +0,0 @@
#include "communication/bolt/v1/transport/buffer.hpp"
namespace communication::bolt {
void Buffer::write(const byte* data, size_t len) {
buffer.insert(buffer.end(), data, data + len);
}
void Buffer::clear() { buffer.clear(); }
}

View File

@ -1,26 +0,0 @@
#pragma once
#include <cstdint>
#include <cstdlib>
#include <vector>
#include "utils/types/byte.hpp"
namespace communication::bolt {
class Buffer {
public:
void write(const byte* data, size_t len);
void clear();
size_t size() const { return buffer.size(); }
byte operator[](size_t idx) const { return buffer[idx]; }
const byte* data() const { return buffer.data(); }
private:
std::vector<byte> buffer;
};
}

View File

@ -1,63 +0,0 @@
#pragma once
#include <cstring>
#include <functional>
#include "logging/default.hpp"
#include "utils/exceptions/stacktrace_exception.hpp"
#include "utils/likely.hpp"
#include "utils/types/byte.hpp"
namespace communication::bolt {
template <class Stream>
class ChunkedDecoder {
public:
class DecoderError : public StacktraceException {
public:
using StacktraceException::StacktraceException;
};
ChunkedDecoder(Stream &stream) : stream(stream) {}
/* Decode chunked data
*
* Chunk format looks like:
*
* |Header| Data ||Header| Data || ... || End |
* | 2B | size bytes || 2B | size bytes || ... ||00 00|
*/
bool decode(const byte *&chunk, size_t n) {
while (n > 0) {
// get size from first two bytes in the chunk
auto size = get_size(chunk);
if (UNLIKELY(size + 2 > n))
throw DecoderError("Chunk size larger than available data.");
// advance chunk to pass those two bytes
chunk += 2;
n -= 2;
// if chunk size is 0, we're done!
if (size == 0) return true;
stream.get().write(chunk, size);
chunk += size;
n -= size;
}
return false;
}
bool operator()(const byte *&chunk, size_t n) { return decode(chunk, n); }
private:
std::reference_wrapper<Stream> stream;
size_t get_size(const byte *chunk) {
return size_t(chunk[0]) << 8 | chunk[1];
}
};
}

View File

@ -1,11 +0,0 @@
#pragma once
#include "utils/exceptions/stacktrace_exception.hpp"
namespace communication::bolt {
class StreamError : StacktraceException {
public:
using StacktraceException::StacktraceException;
};
}

View File

@ -1,307 +0,0 @@
#pragma once
#include <string>
#include "communication/bolt/v1/packing/codes.hpp"
#include "query/exception/decoder_exception.hpp"
#include "utils/bswap.hpp"
#include "utils/types/byte.hpp"
namespace communication::bolt {
// BoltDecoder for streams. Meant for use in SnapshotDecoder.
// This should be recoded to recieve the current caller so that decoder can
// based on a current type call it.
template <class STREAM>
class StreamedBoltDecoder {
static constexpr int64_t plus_2_to_the_31 = 2147483648L;
static constexpr int64_t plus_2_to_the_15 = 32768L;
static constexpr int64_t plus_2_to_the_7 = 128L;
static constexpr int64_t minus_2_to_the_4 = -16L;
static constexpr int64_t minus_2_to_the_7 = -128L;
static constexpr int64_t minus_2_to_the_15 = -32768L;
static constexpr int64_t minus_2_to_the_31 = -2147483648L;
public:
StreamedBoltDecoder(STREAM &stream) : stream(stream) {}
// Returns mark of a data.
size_t mark() { return peek_byte(); }
// Calls handle with current primitive data. Throws DecoderException if it
// isn't a primitive.
template <class H, class T>
T accept_primitive(H &handle) {
switch (byte()) {
case pack::False: {
return handle.handle(false);
}
case pack::True: {
return handle.handle(true);
}
case pack::Float64: {
return handle.handle(read_double());
}
default: { return handle.handle(integer()); }
};
}
// Reads map header. Throws DecoderException if it isn't map header.
size_t map_header() {
auto marker = byte();
size_t size;
if ((marker & 0xF0) == pack::TinyMap) {
size = marker & 0x0F;
} else if (marker == pack::Map8) {
size = byte();
} else if (marker == pack::Map16) {
size = read<uint16_t>();
} else if (marker == pack::Map32) {
size = read<uint32_t>();
} else {
// Error
throw DecoderException(
"StreamedBoltDecoder: Tryed to read map header but found ", marker);
}
return size;
}
bool is_list() {
auto marker = peek_byte();
if ((marker & 0xF0) == pack::TinyList) {
return true;
} else if (marker == pack::List8) {
return true;
} else if (marker == pack::List16) {
return true;
} else if (marker == pack::List32) {
return true;
} else {
return false;
}
}
// Reads list header. Throws DecoderException if it isn't list header.
size_t list_header() {
auto marker = byte();
if ((marker & 0xF0) == pack::TinyList) {
return marker & 0x0F;
} else if (marker == pack::List8) {
return byte();
} else if (marker == pack::List16) {
return read<uint16_t>();
} else if (marker == pack::List32) {
return read<uint32_t>();
} else {
// Error
throw DecoderException(
"StreamedBoltDecoder: Tryed to read list header but found ", marker);
}
}
bool is_bool() {
auto marker = peek_byte();
if (marker == pack::True) {
return true;
} else if (marker == pack::False) {
return true;
} else {
return false;
}
}
// Reads bool.Throws DecoderException if it isn't bool.
bool read_bool() {
auto marker = byte();
if (marker == pack::True) {
return true;
} else if (marker == pack::False) {
return false;
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read bool header but found ", marker);
}
}
bool is_integer() {
auto marker = peek_byte();
if (marker >= minus_2_to_the_4 && marker < plus_2_to_the_7) {
return true;
} else if (marker == pack::Int8) {
return true;
} else if (marker == pack::Int16) {
return true;
} else if (marker == pack::Int32) {
return true;
} else if (marker == pack::Int64) {
return true;
} else {
return false;
}
}
// Reads integer.Throws DecoderException if it isn't integer.
int64_t integer() {
auto marker = byte();
if (marker >= minus_2_to_the_4 && marker < plus_2_to_the_7) {
return marker;
} else if (marker == pack::Int8) {
return byte();
} else if (marker == pack::Int16) {
return read<int16_t>();
} else if (marker == pack::Int32) {
return read<int32_t>();
} else if (marker == pack::Int64) {
return read<int64_t>();
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read integer but found ", marker);
}
}
bool is_double() {
auto marker = peek_byte();
return marker == pack::Float64;
}
// Reads double.Throws DecoderException if it isn't double.
double read_double() {
auto marker = byte();
if (marker == pack::Float64) {
auto tmp = read<int64_t>();
return *reinterpret_cast<const double *>(&tmp);
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read double but found ", marker);
}
}
bool is_string() {
auto marker = peek_byte();
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
return true;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
return true;
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
return true;
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
return true;
} else {
return false;
}
}
// Reads string into res. Throws DecoderException if it isn't string.
void string(std::string &res) {
if (!string_try(res)) {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read string but found ",
std::to_string(peek_byte()));
}
}
// Try-s to read string. Retunrns true on success. If it didn't succed
// stream remains unchanged
bool string_try(std::string &res) {
auto marker = peek_byte();
uint32_t size;
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
byte();
// size is stored in the lower 4 bits of the marker byte
size = marker & 0x0F;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
byte();
size = byte();
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
byte();
size = read<uint16_t>();
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
byte();
size = read<uint32_t>();
} else {
// Error
return false;
}
if (size > 0) {
res.resize(size);
stream.read(&res.front(), size);
} else {
res.clear();
}
return true;
}
private:
// Reads T from stream. It doens't care for alligment so this is valid only
// for primitives.
template <class T>
T read() {
buffer.resize(sizeof(T));
// Load value
stream.read(&buffer.front(), sizeof(T));
// reinterpret bytes as the target value
auto value = reinterpret_cast<const T *>(&buffer.front());
// swap values to little endian
return bswap(*value);
}
::byte byte() { return stream.get(); }
::byte peek_byte() { return stream.peek(); }
STREAM &stream;
std::string buffer;
};
};

View File

@ -63,17 +63,11 @@ class Worker
void OnWaitTimeout() {}
StreamBuffer OnAlloc(Session &) {
/* logger.trace("[on_alloc] Allocating {}B", sizeof buf); */
return StreamBuffer{buf_, sizeof buf_};
}
void OnRead(Session &session, StreamBuffer &buf) {
logger_.trace("[on_read] Received {}B", buf.len);
void OnRead(Session &session) {
logger_.trace("OnRead");
try {
session.Execute(buf.data, buf.len);
session.Execute();
} catch (const std::exception &e) {
logger_.error("Error occured while executing statement.");
logger_.error("{}", e.what());
@ -96,7 +90,6 @@ class Worker
// TODO: Do something about it
}
uint8_t buf_[65536];
std::thread thread_;
void Start(std::atomic<bool> &alive) {

View File

@ -53,13 +53,13 @@ class StreamReader : public StreamListener<Derived, Stream> {
}
// allocate the buffer to fill the data
auto buf = this->derived().OnAlloc(stream);
auto buf = stream.Allocate();
// read from the buffer at most buf.len bytes
buf.len = stream.socket_.Read(buf.data, buf.len);
int len = stream.socket_.Read(buf.data, buf.len);
// check for read errors
if (buf.len == -1) {
if (len == -1) {
// this means we have read all available data
if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) {
return;
@ -71,13 +71,16 @@ class StreamReader : public StreamListener<Derived, Stream> {
}
// end of file, the client has closed the connection
if (UNLIKELY(buf.len == 0)) {
if (UNLIKELY(len == 0)) {
logger_.trace("Calling OnClose because the socket is closed!");
this->derived().OnClose(stream);
return;
}
this->derived().OnRead(stream, buf);
// notify the stream that it has new data
stream.Written(len);
this->derived().OnRead(stream);
}
private:

View File

@ -6,5 +6,5 @@ class NotYetImplemented : public StacktraceException {
public:
using StacktraceException::StacktraceException;
NotYetImplemented() : StacktraceException("") {}
NotYetImplemented() : StacktraceException("Not yet implemented!") {}
};

View File

@ -11,7 +11,7 @@ class StacktraceException : public std::exception {
public:
StacktraceException(const std::string &message) noexcept : message_(message) {
Stacktrace stacktrace;
message_.append(stacktrace.dump());
stacktrace_ = stacktrace.dump();
}
template <class... Args>
@ -25,6 +25,9 @@ class StacktraceException : public std::exception {
const char *what() const noexcept override { return message_.c_str(); }
const char *trace() const noexcept { return stacktrace_.c_str(); }
private:
std::string message_;
std::string stacktrace_;
};

View File

@ -11,6 +11,7 @@
#include "logging/default.hpp"
#include "logging/streams/stdout.hpp"
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/server.hpp"
#include "dbms/dbms.hpp"
#include "io/network/epoll.hpp"
@ -38,23 +39,26 @@ class TestSession {
int Id() const { return socket_.id(); }
void Execute(const byte* data, size_t len) {
if (size_ == 0) {
size_ = data[0];
size_ <<= 8;
size_ += data[1];
data += 2;
len -= 2;
}
memcpy(buffer_ + have_, data, len);
have_ += len;
if (have_ < size_) return;
void Execute() {
if (buffer_.size() < 2) return;
const uint8_t *data = buffer_.data();
size_t size = data[0];
size <<= 8;
size += data[1];
if (buffer_.size() < size + 2) return;
for (int i = 0; i < REPLY; ++i)
ASSERT_TRUE(this->socket_.Write(buffer_, size_));
ASSERT_TRUE(this->socket_.Write(data + 2, size));
have_ = 0;
size_ = 0;
buffer_.Shift(size + 2);
}
io::network::StreamBuffer Allocate() {
return buffer_.Allocate();
}
void Written(size_t len) {
buffer_.Written(len);
}
void Close() {
@ -62,9 +66,7 @@ class TestSession {
this->socket_.Close();
}
char buffer_[SIZE * 2];
uint32_t have_, size_;
communication::bolt::Buffer<SIZE * 2> buffer_;
Logger logger_;
socket_t socket_;
io::network::Epoll::Event event_;
@ -87,6 +89,7 @@ void client_run(int num, const char* interface, const char* port,
endpoint_t endpoint(interface, port);
socket_t socket;
ASSERT_TRUE(socket.Connect(endpoint));
ASSERT_TRUE(socket.SetTimeout(2, 0));
logger.trace("Socket create: {}", socket.id());
for (int len = lo; len <= hi; len += 100) {
have = 0;

View File

@ -15,6 +15,7 @@
#include "logging/default.hpp"
#include "logging/streams/stdout.hpp"
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/server.hpp"
#include "dbms/dbms.hpp"
#include "io/network/epoll.hpp"
@ -40,8 +41,16 @@ class TestSession {
int Id() const { return socket_.id(); }
void Execute(const byte* data, size_t len) {
this->socket_.Write(data, len);
void Execute() {
this->socket_.Write(buffer_.data(), buffer_.size());
}
io::network::StreamBuffer Allocate() {
return buffer_.Allocate();
}
void Written(size_t len) {
buffer_.Written(len);
}
void Close() {
@ -49,6 +58,7 @@ class TestSession {
}
socket_t socket_;
communication::bolt::Buffer<> buffer_;
io::network::Epoll::Event event_;
};

View File

@ -4,7 +4,7 @@
constexpr const int SIZE = 4096;
uint8_t data[SIZE];
using BufferT = communication::bolt::Buffer;
using BufferT = communication::bolt::Buffer<>;
using StreamBufferT = io::network::StreamBuffer;
TEST(BoltBuffer, AllocateAndWritten) {

View File

@ -1,59 +0,0 @@
#include <array>
#include <cassert>
#include <cstring>
#include <deque>
#include <iostream>
#include <vector>
#include "communication/bolt/v1/transport/chunked_decoder.hpp"
#include "gtest/gtest.h"
/**
* DummyStream which is going to be used to test output data.
*/
struct DummyStream {
/**
* TODO (mferencevic): apply google style guide once decoder will be
* refactored + document
*/
void write(const uint8_t *values, size_t n) {
data.insert(data.end(), values, values + n);
}
std::vector<uint8_t> data;
};
using DecoderT = communication::bolt::ChunkedDecoder<DummyStream>;
TEST(ChunkedDecoderTest, WriteString) {
DummyStream stream;
DecoderT decoder(stream);
std::vector<uint8_t> chunks[] = {
{0x00, 0x08, 'A', ' ', 'q', 'u', 'i', 'c', 'k', ' ', 0x00, 0x06, 'b', 'r',
'o', 'w', 'n', ' '},
{0x00, 0x0A, 'f', 'o', 'x', ' ', 'j', 'u', 'm', 'p', 's', ' '},
{0x00, 0x07, 'o', 'v', 'e', 'r', ' ', 'a', ' '},
{0x00, 0x08, 'l', 'a', 'z', 'y', ' ', 'd', 'o', 'g', 0x00, 0x00}};
static constexpr size_t N = std::extent<decltype(chunks)>::value;
for (size_t i = 0; i < N; ++i) {
auto &chunk = chunks[i];
logging::info("Chunk size: {}", chunk.size());
const uint8_t *start = chunk.data();
auto finished = decoder.decode(start, chunk.size());
// break early if finished
if (finished) break;
}
// check validity
std::string decoded = "A quick brown fox jumps over a lazy dog";
ASSERT_EQ(decoded.size(), stream.data.size());
for (size_t i = 0; i < decoded.size(); ++i)
ASSERT_EQ(decoded[i], stream.data[i]);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -5,9 +5,10 @@
constexpr const int SIZE = 131072;
uint8_t data[SIZE];
using BufferT = communication::bolt::Buffer;
using BufferT = communication::bolt::Buffer<>;
using StreamBufferT = io::network::StreamBuffer;
using DecoderBufferT = communication::bolt::ChunkedDecoderBuffer;
using ChunkStateT = communication::bolt::ChunkState;
TEST(BoltBuffer, CorrectChunk) {
uint8_t tmp[2000];
@ -20,7 +21,7 @@ TEST(BoltBuffer, CorrectChunk) {
sb.data[1002] = 0; sb.data[1003] = 0;
buffer.Written(1004);
ASSERT_EQ(decoder_buffer.GetChunk(), true);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i)
@ -40,7 +41,7 @@ TEST(BoltBuffer, CorrectChunkTrailingData) {
sb.data[1002] = 0; sb.data[1003] = 0;
buffer.Written(2004);
ASSERT_EQ(decoder_buffer.GetChunk(), true);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i)
@ -62,7 +63,7 @@ TEST(BoltBuffer, InvalidChunk) {
sb.data[1002] = 1; sb.data[1003] = 1;
buffer.Written(2004);
ASSERT_EQ(decoder_buffer.GetChunk(), false);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Invalid);
ASSERT_EQ(buffer.size(), 1000);
@ -79,19 +80,19 @@ TEST(BoltBuffer, GraduallyPopulatedChunk) {
sb.data[0] = 0x03; sb.data[1] = 0xe8;
buffer.Written(2);
ASSERT_EQ(decoder_buffer.GetChunk(), false);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
for (int i = 0; i < 5; ++i) {
sb = buffer.Allocate();
memcpy(sb.data, data + 200 * i, 200);
buffer.Written(200);
ASSERT_EQ(decoder_buffer.GetChunk(), false);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
}
sb = buffer.Allocate();
sb.data[0] = 0; sb.data[1] = 0;
buffer.Written(2);
ASSERT_EQ(decoder_buffer.GetChunk(), true);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i)
@ -108,13 +109,13 @@ TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) {
sb.data[0] = 0x03; sb.data[1] = 0xe8;
buffer.Written(2);
ASSERT_EQ(decoder_buffer.GetChunk(), false);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
for (int i = 0; i < 5; ++i) {
sb = buffer.Allocate();
memcpy(sb.data, data + 200 * i, 200);
buffer.Written(200);
ASSERT_EQ(decoder_buffer.GetChunk(), false);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
}
sb = buffer.Allocate();
@ -125,7 +126,7 @@ TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) {
memcpy(sb.data, data, 1000);
buffer.Written(1000);
ASSERT_EQ(decoder_buffer.GetChunk(), true);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i)

View File

@ -29,19 +29,25 @@ class TestSocket {
int id() const { return socket; }
int Write(const std::string &str) { return Write(str.c_str(), str.size()); }
int Write(const char *data, size_t len) {
bool Write(const std::string &str) { return Write(str.c_str(), str.size()); }
bool Write(const char *data, size_t len) {
return Write(reinterpret_cast<const uint8_t *>(data), len);
}
int Write(const uint8_t *data, size_t len) {
bool Write(const uint8_t *data, size_t len) {
if (!write_success_) return false;
for (size_t i = 0; i < len; ++i) output.push_back(data[i]);
return len;
return true;
}
void SetWriteSuccess(bool success) {
write_success_ = success;
}
std::vector<uint8_t> output;
protected:
int socket;
bool write_success_{true};
};
/**
@ -53,7 +59,7 @@ class TestBuffer {
void Write(const uint8_t *data, size_t n) { socket_.Write(data, n); }
void Chunk() {}
void Flush() {}
bool Flush() { return true; }
private:
TestSocket &socket_;

View File

@ -31,7 +31,8 @@ TEST(Bolt, ResultStream) {
for (int i = 0; i < 10; ++i)
headers.push_back(std::string(2, (char)('a' + i)));
result_stream.Header(headers); // this method automatically calls Flush
result_stream.Header(headers);
buffer.FlushFirstChunk();
PrintOutput(output);
CheckOutput(output, header_output, 45);

View File

@ -2,68 +2,604 @@
#include "communication/bolt/v1/encoder/result_stream.hpp"
#include "communication/bolt/v1/session.hpp"
#include "config/config.hpp"
#include "query/engine.hpp"
// Shortcuts for writing variable initializations in tests
#define INIT_VARS Dbms dbms;\
TestSocket socket(10);\
QueryEngine<ResultStreamT> query_engine;\
SessionT session(std::move(socket), dbms, query_engine);\
std::vector<uint8_t> &output = session.socket_.output;
using ResultStreamT =
communication::bolt::ResultStream<communication::bolt::Encoder<
communication::bolt::ChunkedEncoderBuffer<TestSocket>>>;
using SessionT = communication::bolt::Session<TestSocket>;
using StateT = communication::bolt::State;
/**
* TODO (mferencevic): document
*/
const uint8_t handshake_req[] =
"\x60\x60\xb0\x17\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
"\x00\x00";
const uint8_t handshake_resp[] = "\x00\x00\x00\x01";
const uint8_t init_req[] =
"\x00\x3f\xb2\x01\xd0\x15\x6c\x69\x62\x6e\x65\x6f\x34\x6a\x2d\x63\x6c\x69"
"\x65\x6e\x74\x2f\x31\x2e\x32\x2e\x31\xa3\x86\x73\x63\x68\x65\x6d\x65\x85"
"\x62\x61\x73\x69\x63\x89\x70\x72\x69\x6e\x63\x69\x70\x61\x6c\x80\x8b\x63"
"\x72\x65\x64\x65\x6e\x74\x69\x61\x6c\x73\x80\x00\x00";
const uint8_t init_resp[] = "\x00\x03\xb1\x70\xa0\x00\x00";
const uint8_t run_req[] =
"\x00\x26\xb2\x10\xd0\x21\x43\x52\x45\x41\x54\x45\x20\x28\x6e\x20\x7b\x6e"
"\x61\x6d\x65\x3a\x20\x32\x39\x33\x38\x33\x7d\x29\x20\x52\x45\x54\x55\x52"
"\x4e\x20\x6e\xa0\x00\x00";
// Sample testdata that has correct inputs and outputs.
const uint8_t handshake_req[] = {
0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
const uint8_t handshake_resp[] = {0x00, 0x00, 0x00, 0x01};
const uint8_t init_req[] = {
0xb2, 0x01, 0xd0, 0x15, 0x6c, 0x69, 0x62, 0x6e, 0x65, 0x6f, 0x34, 0x6a,
0x2d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2f, 0x31, 0x2e, 0x32, 0x2e,
0x31, 0xa3, 0x86, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x65, 0x85, 0x62, 0x61,
0x73, 0x69, 0x63, 0x89, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61,
0x6c, 0x80, 0x8b, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61,
0x6c, 0x73, 0x80};
const uint8_t init_resp[] = {0x00, 0x03, 0xb1, 0x70, 0xa0, 0x00, 0x00};
const uint8_t run_req_header[] = {0xb2, 0x10, 0xd1};
const uint8_t pullall_req[] = {0xb0, 0x3f};
const uint8_t discardall_req[] = {0xb0, 0x2f};
const uint8_t reset_req[] = {0xb0, 0x0f};
const uint8_t ackfailure_req[] = {0xb0, 0x0e};
const uint8_t success_resp[] = {0x00, 0x03, 0xb1, 0x70, 0xa0, 0x00, 0x00};
const uint8_t ignored_resp[] = {0x00, 0x02, 0xb0, 0x7e, 0x00, 0x00};
TEST(Bolt, Session) {
Dbms dbms;
TestSocket socket(10);
QueryEngine<ResultStreamT> query_engine;
SessionT session(std::move(socket), dbms, query_engine);
std::vector<uint8_t> &output = session.socket_.output;
// execute handshake
session.Execute(handshake_req, 20);
ASSERT_EQ(session.state_, communication::bolt::INIT);
// Write bolt chunk header (length)
void WriteChunkHeader(SessionT &session, uint16_t len) {
len = bswap(len);
auto buff = session.Allocate();
memcpy(buff.data, reinterpret_cast<uint8_t *>(&len), sizeof(len));
session.Written(sizeof(len));
}
// Write bolt chunk tail (two zeros)
void WriteChunkTail(SessionT &session) {
WriteChunkHeader(session, 0);
}
// Check that the server responded with a failure message
void CheckFailureMessage(std::vector<uint8_t> &output) {
ASSERT_GE(output.size(), 6);
// skip the first two bytes because they are the chunk header
ASSERT_EQ(output[2], 0xB1); // tiny struct 1
ASSERT_EQ(output[3], 0x7F); // signature failure
}
// Execute and check a correct handshake
void ExecuteHandshake(SessionT &session, std::vector<uint8_t> &output) {
auto buff = session.Allocate();
memcpy(buff.data, handshake_req, 20);
session.Written(20);
session.Execute();
ASSERT_EQ(session.state_, StateT::Init);
ASSERT_TRUE(session.socket_.IsOpen());
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
}
// execute init
session.Execute(init_req, 67);
ASSERT_EQ(session.state_, communication::bolt::EXECUTOR);
// Write bolt chunk and execute command
void ExecuteCommand(SessionT &session, const uint8_t *data, size_t len, bool chunk = true) {
if (chunk) WriteChunkHeader(session, len);
auto buff = session.Allocate();
memcpy(buff.data, data, len);
session.Written(len);
if (chunk) WriteChunkTail(session);
session.Execute();
}
// Execute and check a correct init
void ExecuteInit(SessionT &session, std::vector<uint8_t> &output) {
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, StateT::Executor);
ASSERT_TRUE(session.socket_.IsOpen());
PrintOutput(output);
CheckOutput(output, init_resp, 7);
// execute run
session.Execute(run_req, 42);
// TODO (mferencevic): query engine doesn't currently work,
// we should test the query output here and the next state
// ASSERT_EQ(session.state, bolt::EXECUTOR);
// PrintOutput(output);
// CheckOutput(output, run_resp, len);
// TODO (mferencevic): add more tests
session.Close();
}
// Write bolt encoded run request
void WriteRunRequest(SessionT &session, const char *str) {
// write chunk header
auto len = strlen(str);
WriteChunkHeader(session, 3 + 2 + len + 1);
// write string header
auto buff = session.Allocate();
memcpy(buff.data, run_req_header, 3);
session.Written(3);
// write string length
WriteChunkHeader(session, len);
// write string
buff = session.Allocate();
memcpy(buff.data, str, len);
session.Written(len);
// write empty map for parameters
buff = session.Allocate();
buff.data[0] = 0xA0; // TinyMap0
session.Written(1);
// write chunk tail
WriteChunkTail(session);
}
TEST(BoltSession, HandshakeWrongPreamble) {
INIT_VARS;
auto buff = session.Allocate();
// copy 0x00000001 four times
for (int i = 0; i < 4; ++i)
memcpy(buff.data + i * 4, handshake_req + 4, 4);
session.Written(20);
session.Execute();
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
PrintOutput(output);
CheckFailureMessage(output);
}
TEST(BoltSession, HandshakeInTwoPackets) {
INIT_VARS;
auto buff = session.Allocate();
memcpy(buff.data, handshake_req, 10);
session.Written(10);
session.Execute();
ASSERT_EQ(session.state_, StateT::Handshake);
ASSERT_TRUE(session.socket_.IsOpen());
memcpy(buff.data + 10, handshake_req + 10, 10);
session.Written(10);
session.Execute();
ASSERT_EQ(session.state_, StateT::Init);
ASSERT_TRUE(session.socket_.IsOpen());
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
}
TEST(BoltSession, HandshakeTooLarge) {
INIT_VARS;
auto buff = session.Allocate();
memcpy(buff.data, handshake_req, 20);
memcpy(buff.data + 20, handshake_req, 20);
session.Written(40);
session.Execute();
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
PrintOutput(output);
CheckFailureMessage(output);
}
TEST(BoltSession, HandshakeWriteFail) {
INIT_VARS;
session.socket_.SetWriteSuccess(false);
ExecuteCommand(session, handshake_req, sizeof(handshake_req), false);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
TEST(BoltSession, HandshakeOK) {
INIT_VARS;
ExecuteHandshake(session, output);
}
TEST(BoltSession, InitWrongSignature) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteCommand(session, run_req_header, sizeof(run_req_header));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
TEST(BoltSession, InitWrongMarker) {
INIT_VARS;
ExecuteHandshake(session, output);
// wrong marker, good signature
uint8_t data[2] = {0x00, init_req[1]};
ExecuteCommand(session, data, 2);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
TEST(BoltSession, InitMissingData) {
// test lengths, they test the following situations:
// missing header data, missing client name, missing metadata
int len[] = {1, 2, 25};
for (int i = 0; i < 3; ++i) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteCommand(session, init_req, len[i]);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
}
TEST(BoltSession, InitWriteFail) {
INIT_VARS;
ExecuteHandshake(session, output);
session.socket_.SetWriteSuccess(false);
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
TEST(BoltSession, InitOK) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
}
TEST(BoltSession, ExecuteRunWrongMarker) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
// wrong marker, good signature
uint8_t data[2] = {0x00, run_req_header[1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
TEST(BoltSession, ExecuteRunMissingData) {
// test lengths, they test the following situations:
// missing header data, missing query data, missing parameters
int len[] = {1, 2, 37};
for (int i = 0; i < 3; ++i) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
ExecuteCommand(session, run_req_header, len[i]);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
}
TEST(BoltSession, ExecuteRunBasicException) {
// first test with socket write success, then with socket write fail
for (int i = 0; i < 2; ++i) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
session.socket_.SetWriteSuccess(i == 0);
WriteRunRequest(session, "MATCH (omnom");
session.Execute();
if (i == 0) {
ASSERT_EQ(session.state_, StateT::Error);
ASSERT_TRUE(session.socket_.IsOpen());
CheckFailureMessage(output);
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
}
TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) {
// This test first tests PULL_ALL then DISCARD_ALL and then RESET
// It tests for missing data in the message header
const uint8_t *dataset[3] = {pullall_req, discardall_req, reset_req};
for (int i = 0; i < 3; ++i) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
// wrong marker, good signature
uint8_t data[2] = {0x00, dataset[i][1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
}
TEST(BoltSession, ExecutePullAllBufferEmpty) {
// first test with socket write success, then with socket write fail
for (int i = 0; i < 2; ++i) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
session.socket_.SetWriteSuccess(i == 0);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
if (i == 0) {
ASSERT_EQ(session.state_, StateT::Error);
ASSERT_TRUE(session.socket_.IsOpen());
CheckFailureMessage(output);
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
}
TEST(BoltSession, ExecutePullAllDiscardAllReset) {
// This test first tests PULL_ALL then DISCARD_ALL and then RESET
// It tests a good message
const uint8_t *dataset[3] = {pullall_req, discardall_req, reset_req};
for (int i = 0; i < 3; ++i) {
// first test with socket write success, then with socket write fail
for (int j = 0; j < 2; ++j) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
WriteRunRequest(session, "CREATE (n) RETURN n");
session.Execute();
if (j == 1) output.clear();
session.socket_.SetWriteSuccess(j == 0);
ExecuteCommand(session, dataset[i], 2);
if (j == 0) {
ASSERT_EQ(session.state_, StateT::Executor);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_FALSE(session.encoder_buffer_.HasData());
PrintOutput(output);
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
}
}
TEST(BoltSession, ExecuteInvalidMessage) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
TEST(BoltSession, ErrorIgnoreMessage) {
// first test with socket write success, then with socket write fail
for (int i = 0; i < 2; ++i) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
WriteRunRequest(session, "MATCH (omnom");
session.Execute();
output.clear();
session.socket_.SetWriteSuccess(i == 0);
ExecuteCommand(session, init_req, sizeof(init_req));
// assert that all data from the init message was cleaned up
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
if (i == 0) {
ASSERT_EQ(session.state_, StateT::Error);
ASSERT_TRUE(session.socket_.IsOpen());
CheckOutput(output, ignored_resp, sizeof(ignored_resp));
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
}
TEST(BoltSession, ErrorCantCleanup) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
WriteRunRequest(session, "MATCH (omnom");
session.Execute();
output.clear();
// there is data missing in the request, cleanup should fail
ExecuteCommand(session, init_req, sizeof(init_req) - 10);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
TEST(BoltSession, ErrorWrongMarker) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
WriteRunRequest(session, "MATCH (omnom");
session.Execute();
output.clear();
// wrong marker, good signature
uint8_t data[2] = {0x00, init_req[1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
TEST(BoltSession, ErrorOK) {
// test ACK_FAILURE and RESET
const uint8_t *dataset[] = {ackfailure_req, reset_req};
for (int i = 0; i < 2; ++i) {
// first test with socket write success, then with socket write fail
for (int j = 0; j < 2; ++j) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
WriteRunRequest(session, "MATCH (omnom");
session.Execute();
output.clear();
session.socket_.SetWriteSuccess(j == 0);
ExecuteCommand(session, dataset[i], 2);
// assert that all data from the init message was cleaned up
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
if (j == 0) {
ASSERT_EQ(session.state_, StateT::Executor);
ASSERT_TRUE(session.socket_.IsOpen());
CheckOutput(output, success_resp, sizeof(success_resp));
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
}
}
TEST(BoltSession, ErrorMissingData) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
WriteRunRequest(session, "MATCH (omnom");
session.Execute();
output.clear();
// some marker, missing signature
uint8_t data[1] = {0x00};
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
TEST(BoltSession, MultipleChunksInOneExecute) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
WriteRunRequest(session, "CREATE (n) RETURN n");
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
ASSERT_EQ(session.state_, StateT::Executor);
ASSERT_TRUE(session.socket_.IsOpen());
PrintOutput(output);
// Count chunks in output
int len, num = 0;
while(output.size() > 0) {
len = (output[0] << 8) + output[1];
output.erase(output.begin(), output.begin() + len + 4);
++num;
}
// there should be 3 chunks in the output
// the first is a success with the query headers
// the second is a record message
// and the last is a success message with query run metadata
ASSERT_EQ(num, 3);
}
TEST(BoltSession, PartialChunk) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
WriteChunkHeader(session, sizeof(discardall_req));
auto buff = session.Allocate();
memcpy(buff.data, discardall_req, sizeof(discardall_req));
session.Written(2);
// missing chunk tail
session.Execute();
ASSERT_EQ(session.state_, StateT::Executor);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
WriteChunkTail(session);
session.Execute();
ASSERT_EQ(session.state_, StateT::Executor);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_GT(output.size(), 0);
PrintOutput(output);
}
TEST(BoltSession, InvalidChunk) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
// this will write 0x00 0x02 0x00 0x02 0x00 0x02
// that is a chunk of good size, but it's invalid because the last
// two bytes are 0x00 0x02 and they should be 0x00 0x00
for (int i = 0; i < 3; ++i) WriteChunkHeader(session, 2);
session.Execute();
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
int main(int argc, char **argv) {
logging::init_sync();
logging::log->pipe(std::make_unique<Stdout>());
// Set the interpret to true to avoid calling the compiler which only
// supports a limited set of queries.
CONFIG(config::INTERPRET) = "true";
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}