diff --git a/src/communication/bolt/v1/encoder/chunked_buffer.hpp b/src/communication/bolt/v1/encoder/chunked_buffer.hpp index 85f663fe6..0f24a329e 100644 --- a/src/communication/bolt/v1/encoder/chunked_buffer.hpp +++ b/src/communication/bolt/v1/encoder/chunked_buffer.hpp @@ -1,79 +1,157 @@ #pragma once +#include <algorithm> #include <cstring> #include <memory> #include <vector> -#include <algorithm> -#include "logging/default.hpp" -#include "utils/types/byte.hpp" +#include "logging/loggable.hpp" #include "utils/bswap.hpp" namespace communication::bolt { -// maximum chunk size = 65536 bytes data -static constexpr size_t CHUNK_SIZE = 65536; +// TODO: implement a better flushing strategy + optimize memory allocations! +// TODO: see how bolt splits message over more TCP packets +// -> test for more TCP packets! /** - * Bolt chunked buffer. - * Has methods for writing and flushing data. + * Sizes related to the chunk defined in Bolt protocol. + */ +static constexpr size_t CHUNK_HEADER_SIZE = 2; +static constexpr size_t MAX_CHUNK_SIZE = 65535; +static constexpr size_t CHUNK_END_MARKER_SIZE = 2; +static constexpr size_t WHOLE_CHUNK_SIZE = + CHUNK_HEADER_SIZE + MAX_CHUNK_SIZE + CHUNK_END_MARKER_SIZE; + +/** + * @brief ChunkedBuffer + * + * Has methods for writing and flushing data into the message buffer. + * * Writing data stores data in the internal buffer and flushing data sends - * the currently stored data to the Socket with prepended data length and - * appended chunk tail (0x00 0x00). + * the currently stored data to the Socket. Chunking prepends data length and + * appends chunk end marker (0x00 0x00). + * + * | chunk header | --- chunk --- | end marker | ---------- another chunk ... | + * | ------------- whole chunk ----------------| ---------- another chunk ... | + * + * | ------------------------ message --------------------------------------- | + * | ------------------------ buffer --------------------------------------- | + * + * The current implementation stores the whole message into a single buffer + * which is std::vector. * * @tparam Socket the output socket that should be used */ template <class Socket> -class ChunkedBuffer { +class ChunkedBuffer : public Loggable { public: - ChunkedBuffer(Socket &socket) : socket_(socket), logger_(logging::log->logger("Chunked Buffer")) {} + ChunkedBuffer(Socket &socket) : Loggable("Chunked Buffer"), socket_(socket) {} - void Write(const uint8_t* values, size_t n) { - logger_.trace("Write {} bytes", n); + /** + * Writes n values into the buffer. If n is bigger than whole chunk size + * values are automatically chunked. + * + * @param values data array of bytes + * @param n is the number of bytes + */ + void Write(const uint8_t *values, size_t n) { + while (n > 0) { + // Define number of bytes which will be copied into chunk because + // chunk is a fixed lenght array. + auto size = n < MAX_CHUNK_SIZE + CHUNK_HEADER_SIZE - pos_ + ? n + : MAX_CHUNK_SIZE + CHUNK_HEADER_SIZE - pos_; - // total size of the buffer is now bigger for n - size_ += n; + // Copy size values to chunk array. + std::memcpy(chunk_.data() + pos_, values, size); - // reserve enough space for the new data - buffer_.reserve(size_); + // Update positions. Position pointer and incomming size have to be + // updated because all incomming values have to be processed. + pos_ += size; + n -= size; - // copy new data - std::copy(values, values + n, std::back_inserter(buffer_)); + // If chunk is full copy it into the message buffer and make space for + // other incomming values that are left in the values array. + if (pos_ == CHUNK_HEADER_SIZE + MAX_CHUNK_SIZE) Chunk(); + } } + /** + * Wrap the data from chunk array (append header and end marker) and put + * the whole chunk into the buffer. + */ + void Chunk() { + // 1. Write the size of the chunk (CHUNK HEADER). + uint16_t size = pos_ - CHUNK_HEADER_SIZE; + // Write the higher byte. + chunk_[0] = size >> 8; + // Write the lower byte. + chunk_[1] = size & 0xFF; + + // 2. Write last two bytes in the whole chunk (CHUNK END MARKER). + // The last two bytes are always 0x00 0x00. + chunk_[pos_++] = 0x00; + chunk_[pos_++] = 0x00; + + debug_assert(pos_ <= WHOLE_CHUNK_SIZE, + "Internal variable pos_ is bigger than the whole chunk size."); + + // 3. Copy whole chunk into the buffer. + size_ += pos_; + buffer_.reserve(size_); + std::copy(chunk_.begin(), chunk_.begin() + pos_, + std::back_inserter(buffer_)); + + // 4. Cleanup. + // * pos_ has to be reset to the size of chunk header (reserved + // space for the chunk size) + pos_ = CHUNK_HEADER_SIZE; + } + + /** + * Sends the whole buffer(message) to the client. + */ void Flush() { - size_t size = buffer_.size(), n = 0, pos = 0; - uint16_t head; + // Call chunk if is hasn't been called. + if (pos_ > CHUNK_HEADER_SIZE) Chunk(); - while (size > 0) { - head = n = std::min(CHUNK_SIZE, size); - head = bswap(head); + // Early return if buffer is empty because there is nothing to write. + if (size_ == 0) return; - logger_.trace("Flushing chunk of {} bytes", n); + // Flush the whole buffer. + socket_.Write(buffer_.data(), size_); + logger.trace("Flushed {} bytes.", size_); - // TODO: implement better flushing strategy! - socket_.Write(reinterpret_cast<const uint8_t *>(&head), sizeof(head)); - socket_.Write(buffer_.data() + pos, n); - - head = 0; - socket_.Write(reinterpret_cast<const uint8_t *>(&head), sizeof(head)); - - size -= n; - pos += n; - } - - // GC - // TODO: impelement a better strategy + // Cleanup. buffer_.clear(); - - // clear size size_ = 0; } private: - Socket& socket_; - Logger logger_; + /** + * A client socket. + */ + Socket &socket_; + + /** + * Buffer for a single chunk. + */ + std::array<uint8_t, WHOLE_CHUNK_SIZE> chunk_; + + /** + * Buffer for the message which will be sent to a client. + */ std::vector<uint8_t> buffer_; + + /** + * Size of the message. + */ size_t size_{0}; + + /** + * Current position in chunk array. + */ + size_t pos_{CHUNK_HEADER_SIZE}; }; } diff --git a/src/communication/bolt/v1/encoder/encoder.hpp b/src/communication/bolt/v1/encoder/encoder.hpp index cc4620386..98a7d6fc7 100644 --- a/src/communication/bolt/v1/encoder/encoder.hpp +++ b/src/communication/bolt/v1/encoder/encoder.hpp @@ -2,6 +2,7 @@ #include "database/graph_db_accessor.hpp" #include "logging/default.hpp" +#include "logging/logger.hpp" #include "query/backend/cpp/typed_value.hpp" #include "utils/bswap.hpp" @@ -10,11 +11,10 @@ namespace communication::bolt { static constexpr uint8_t TSTRING = 0, TLIST = 1, TMAP = 2; -static constexpr uint8_t type_tiny_marker[3] = { 0x80, 0x90, 0xA0 }; -static constexpr uint8_t type_8_marker[3] = { 0xD0, 0xD4, 0xD8 }; -static constexpr uint8_t type_16_marker[3] = { 0xD1, 0xD5, 0xD9 }; -static constexpr uint8_t type_32_marker[3] = { 0xD2, 0xD6, 0xDA }; - +static constexpr uint8_t type_tiny_marker[3] = {0x80, 0x90, 0xA0}; +static constexpr uint8_t type_8_marker[3] = {0xD0, 0xD4, 0xD8}; +static constexpr uint8_t type_16_marker[3] = {0xD1, 0xD5, 0xD9}; +static constexpr uint8_t type_32_marker[3] = {0xD2, 0xD6, 0xDA}; /** * Bolt Encoder. @@ -22,16 +22,18 @@ static constexpr uint8_t type_32_marker[3] = { 0xD2, 0xD6, 0xDA }; * Supported messages are: Record, Success, Failure and Ignored. * * @tparam Buffer the output buffer that should be used - * @tparam Socket the output socket that should be used */ -template <typename Buffer, typename Socket> -class Encoder { - +template <typename Buffer> +class Encoder : public Loggable { public: - Encoder(Socket& socket) : socket_(socket), buffer_(socket), logger_(logging::log->logger("communication::bolt::Encoder")) {} + Encoder(Buffer &buffer) + : Loggable("communication::bolt::Encoder"), buffer_(buffer) {} /** - * Sends a Record message. + * Writes a Record message. This method only stores data in the Buffer. + * It doesn't send the values out to the Socket (Chunk is called at the + * end of this method). To send the values Flush method has to be called + * after this method. * * From the Bolt v1 documentation: * RecordMessage (signature=0x71) { @@ -40,11 +42,11 @@ class Encoder { * * @param values the fields list object that should be sent */ - void MessageRecord(const std::vector<TypedValue>& values) { + void MessageRecord(const std::vector<TypedValue> &values) { // 0xB1 = struct 1; 0x71 = record signature WriteRAW("\xB1\x71", 2); WriteList(values); - buffer_.Flush(); + buffer_.Chunk(); } /** @@ -56,12 +58,17 @@ class Encoder { * } * * @param metadata the metadata map object that should be sent + * @param flush should method flush the socket */ - void MessageSuccess(const std::map<std::string, TypedValue>& metadata) { + void MessageSuccess(const std::map<std::string, TypedValue> &metadata, + bool flush = true) { // 0xB1 = struct 1; 0x70 = success signature WriteRAW("\xB1\x70", 2); WriteMap(metadata); - buffer_.Flush(); + if (flush) + buffer_.Flush(); + else + buffer_.Chunk(); } /** @@ -84,7 +91,7 @@ class Encoder { * * @param metadata the metadata map object that should be sent */ - void MessageFailure(const std::map<std::string, TypedValue>& metadata) { + void MessageFailure(const std::map<std::string, TypedValue> &metadata) { // 0xB1 = struct 1; 0x7F = failure signature WriteRAW("\xB1\x7F", 2); WriteMap(metadata); @@ -101,7 +108,7 @@ class Encoder { * * @param metadata the metadata map object that should be sent */ - void MessageIgnored(const std::map<std::string, TypedValue>& metadata) { + void MessageIgnored(const std::map<std::string, TypedValue> &metadata) { // 0xB1 = struct 1; 0x7E = ignored signature WriteRAW("\xB1\x7E", 2); WriteMap(metadata); @@ -119,24 +126,16 @@ class Encoder { buffer_.Flush(); } - private: - Socket& socket_; - Buffer buffer_; - Logger logger_; + Buffer &buffer_; + void WriteRAW(const uint8_t *data, uint64_t len) { buffer_.Write(data, len); } - void WriteRAW(const uint8_t* data, uint64_t len) { - buffer_.Write(data, len); + void WriteRAW(const char *data, uint64_t len) { + WriteRAW((const uint8_t *)data, len); } - void WriteRAW(const char* data, uint64_t len) { - WriteRAW((const uint8_t*) data, len); - } - - void WriteRAW(const uint8_t data) { - WriteRAW(&data, 1); - } + void WriteRAW(const uint8_t data) { WriteRAW(&data, 1); } template <class T> void WriteValue(T value) { @@ -149,7 +148,7 @@ class Encoder { WriteRAW(0xC0); } - void WriteBool(const bool& value) { + void WriteBool(const bool &value) { if (value) { // 0xC3 = true marker WriteRAW(0xC3); @@ -159,7 +158,7 @@ class Encoder { } } - void WriteInt(const int64_t& value) { + void WriteInt(const int64_t &value) { if (value >= -16L && value < 128L) { WriteRAW(static_cast<uint8_t>(value)); } else if (value >= -128L && value < -16L) { @@ -181,7 +180,7 @@ class Encoder { } } - void WriteDouble(const double& value) { + void WriteDouble(const double &value) { // 0xC1 = float64 marker WriteRAW(0xC1); WriteValue(*reinterpret_cast<const int64_t *>(&value)); @@ -211,25 +210,25 @@ class Encoder { } } - void WriteString(const std::string& value) { + void WriteString(const std::string &value) { WriteTypeSize(value.size(), TSTRING); WriteRAW(value.c_str(), value.size()); } - void WriteList(const std::vector<TypedValue>& value) { + void WriteList(const std::vector<TypedValue> &value) { WriteTypeSize(value.size(), TLIST); - for (auto& x: value) WriteTypedValue(x); + for (auto &x : value) WriteTypedValue(x); } - void WriteMap(const std::map<std::string, TypedValue>& value) { + void WriteMap(const std::map<std::string, TypedValue> &value) { WriteTypeSize(value.size(), TMAP); - for (auto& x: value) { + for (auto &x : value) { WriteString(x.first); WriteTypedValue(x.second); } } - void WriteVertex(const VertexAccessor& vertex) { + void WriteVertex(const VertexAccessor &vertex) { // 0xB3 = struct 3; 0x4E = vertex signature WriteRAW("\xB3\x4E", 2); @@ -240,51 +239,50 @@ class Encoder { WriteInt(0); // write labels - const auto& labels = vertex.labels(); + const auto &labels = vertex.labels(); WriteTypeSize(labels.size(), TLIST); - for (const auto& label : labels) + for (const auto &label : labels) WriteString(vertex.db_accessor().label_name(label)); // write properties - const auto& props = vertex.Properties(); + const auto &props = vertex.Properties(); WriteTypeSize(props.size(), TMAP); - for (const auto& prop : props) { + for (const auto &prop : props) { WriteString(vertex.db_accessor().property_name(prop.first)); WriteTypedValue(prop.second); } } - - void WriteEdge(const EdgeAccessor& edge) { + void WriteEdge(const EdgeAccessor &edge) { // 0xB5 = struct 5; 0x52 = edge signature WriteRAW("\xB5\x52", 2); // IMPORTANT: here we write a hardcoded 0 because we don't - // use internal IDs, but need to give something to Bolt - // note that OpenCypher has no id(x) function, so the client - // should not be able to do anything with this value anyway - WriteInt(0); - WriteInt(0); - WriteInt(0); + // use internal IDs, but need to give something to Bolt + // note that OpenCypher has no id(x) function, so the client + // should not be able to do anything with this value anyway + WriteInt(0); + WriteInt(0); + WriteInt(0); - // write type - WriteString(edge.db_accessor().edge_type_name(edge.edge_type())); + // write type + WriteString(edge.db_accessor().edge_type_name(edge.edge_type())); - // write properties - const auto& props = edge.Properties(); - WriteTypeSize(props.size(), TMAP); - for (const auto& prop : props) { + // write properties + const auto &props = edge.Properties(); + WriteTypeSize(props.size(), TMAP); + for (const auto &prop : props) { WriteString(edge.db_accessor().property_name(prop.first)); WriteTypedValue(prop.second); - } + } } void WritePath() { // TODO: this isn't implemented in the backend! } - void WriteTypedValue(const TypedValue& value) { - switch(value.type()) { + void WriteTypedValue(const TypedValue &value) { + switch (value.type()) { case TypedValue::Type::Null: WriteNull(); break; diff --git a/src/communication/bolt/v1/encoder/result_stream.hpp b/src/communication/bolt/v1/encoder/result_stream.hpp index 6a0ca4377..f48c0faaa 100644 --- a/src/communication/bolt/v1/encoder/result_stream.hpp +++ b/src/communication/bolt/v1/encoder/result_stream.hpp @@ -1,7 +1,7 @@ #pragma once -#include "communication/bolt/v1/encoder/encoder.hpp" #include "communication/bolt/v1/encoder/chunked_buffer.hpp" +#include "communication/bolt/v1/encoder/encoder.hpp" #include "query/backend/cpp/typed_value.hpp" #include "logging/default.hpp" @@ -13,34 +13,25 @@ namespace communication::bolt { * functionalities used by the compiler and query plans (which * should not use any lower level API). * - * @tparam Socket Socket used. + * @tparam Encoder Encoder used. */ -template <typename Socket> +template <typename Encoder> class ResultStream { - private: - using encoder_t = Encoder<ChunkedBuffer<Socket>, Socket>; public: - - // TODO add logging to this class - ResultStream(encoder_t &encoder) : - encoder_(encoder) {} + ResultStream(Encoder &encoder) : encoder_(encoder) {} /** * Writes a header. Typically a header is something like: - * [ - * "Header1", - * "Header2", - * "Header3" - * ] + * [ "Header1", "Header2", "Header3" ] * * @param fields the header fields that should be sent. */ void Header(const std::vector<std::string> &fields) { std::vector<TypedValue> vec; std::map<std::string, TypedValue> data; - for (auto& i : fields) - vec.push_back(TypedValue(i)); + for (auto &i : fields) vec.push_back(TypedValue(i)); data.insert(std::make_pair(std::string("fields"), TypedValue(vec))); + // this call will automaticaly send the data to the client encoder_.MessageSuccess(data); } @@ -73,10 +64,12 @@ class ResultStream { * @param summary the summary map object that should be sent */ void Summary(const std::map<std::string, TypedValue> &summary) { - encoder_.MessageSuccess(summary); + // at this point message should not flush the socket so + // here is false because chunk has to be called instead of flush + encoder_.MessageSuccess(summary, false); } -private: - encoder_t& encoder_; + private: + Encoder &encoder_; }; } diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index a2fffc015..903598efe 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -7,107 +7,117 @@ #include "query/engine.hpp" #include "communication/bolt/v1/state.hpp" +#include "communication/bolt/v1/states/error.hpp" +#include "communication/bolt/v1/states/executor.hpp" #include "communication/bolt/v1/states/handshake.hpp" #include "communication/bolt/v1/states/init.hpp" -#include "communication/bolt/v1/states/executor.hpp" -#include "communication/bolt/v1/states/error.hpp" #include "communication/bolt/v1/encoder/encoder.hpp" #include "communication/bolt/v1/encoder/result_stream.hpp" #include "communication/bolt/v1/transport/bolt_decoder.hpp" -#include "logging/default.hpp" +#include "logging/loggable.hpp" namespace communication::bolt { -template<typename Socket> -class Session { +/** + * Bolt Session + * + * This class is responsible for handling a single client connection. + * + * @tparam Socket type of socket (could be a network socket or test socket) + */ +template <typename Socket> +class Session : public Loggable { public: using Decoder = BoltDecoder; - using OutputStream = ResultStream<Socket>; + using OutputStream = ResultStream<Encoder<ChunkedBuffer<Socket>>>; Session(Socket &&socket, Dbms &dbms, QueryEngine<OutputStream> &query_engine) - : socket(std::move(socket)), - dbms(dbms), query_engine(query_engine), - encoder(this->socket), output_stream(encoder), - logger(logging::log->logger("Session")) { - event.data.ptr = this; + : Loggable("communication::bolt::Session"), + socket_(std::move(socket)), + dbms_(dbms), + query_engine_(query_engine), + encoder_buffer_(socket_), + encoder_(encoder_buffer_), + output_stream_(encoder_) { + event_.data.ptr = this; // start with a handshake state - state = HANDSHAKE; + state_ = HANDSHAKE; } - bool alive() const { return state != NULLSTATE; } + /** + * @return is the session in a valid state + */ + bool Alive() const { return state_ != NULLSTATE; } - int id() const { return socket.id(); } + /** + * @return the socket id + */ + int Id() const { return socket_.id(); } - void execute(const byte *data, size_t len) { + /** + * Reads the data from a client and goes through the bolt states in + * order to execute command from the client. + * + * @param data pointer on bytes received from a client + * @param len length of data received from a client + */ + void Execute(const byte *data, size_t len) { // mark the end of the message auto end = data + len; while (true) { auto size = end - data; - if (LIKELY(connected)) { + if (LIKELY(connected_)) { logger.debug("Decoding chunk of size {}", size); - auto finished = decoder.decode(data, size); - - if (!finished) return; + if (!decoder_.decode(data, size)) return; } else { logger.debug("Decoding handshake of size {}", size); - decoder.handshake(data, size); + decoder_.handshake(data, size); } - switch(state) { + switch (state_) { case HANDSHAKE: - logger.debug("Current state: DEBUG"); - state = state_handshake_run<Socket>(decoder, this->socket, &connected); + state_ = StateHandshakeRun<Session<Socket>>(*this); break; case INIT: - logger.debug("Current state: INIT"); - // TODO: swap around parameters so that inputs are first and outputs are last! - state = state_init_run<Socket>(encoder, decoder); + state_ = StateInitRun<Session<Socket>>(*this); break; case EXECUTOR: - logger.debug("Current state: EXECUTOR"); - // TODO: swap around parameters so that inputs are first and outputs are last! - state = state_executor_run<Socket>(output_stream, encoder, decoder, dbms, query_engine); + state_ = StateExecutorRun<Session<Socket>>(*this); break; case ERROR: - logger.debug("Current state: ERROR"); - // TODO: swap around parameters so that inputs are first and outputs are last! - state = state_error_run<Socket>(output_stream, encoder, decoder); + state_ = StateErrorRun<Session<Socket>>(*this); break; case NULLSTATE: break; } - decoder.reset(); + decoder_.reset(); } } - void close() { + /** + * Closes the session (client socket). + */ + void Close() { logger.debug("Closing session"); - this->socket.Close(); + this->socket_.Close(); } - // TODO: these members should be private + GraphDbAccessor ActiveDb() { return dbms_.active(); } - Socket socket; - io::network::Epoll::Event event; - - Dbms &dbms; - QueryEngine<OutputStream> &query_engine; - - GraphDbAccessor active_db() { return dbms.active(); } - - Decoder decoder; - Encoder<ChunkedBuffer<Socket>, Socket> encoder; - OutputStream output_stream; - - bool connected{false}; - State state; - - protected: - Logger logger; + Socket socket_; + Dbms &dbms_; + QueryEngine<OutputStream> &query_engine_; + ChunkedBuffer<Socket> encoder_buffer_; + Encoder<ChunkedBuffer<Socket>> encoder_; + OutputStream output_stream_; + Decoder decoder_; + io::network::Epoll::Event event_; + bool connected_{false}; + State state_; }; } diff --git a/src/communication/bolt/v1/state.hpp b/src/communication/bolt/v1/state.hpp index 51a78d827..aed4bd281 100644 --- a/src/communication/bolt/v1/state.hpp +++ b/src/communication/bolt/v1/state.hpp @@ -2,6 +2,10 @@ namespace communication::bolt { +/** + * TODO (mferencevic): change to a class enum & document (explain states in + * more details) + */ enum State { HANDSHAKE, INIT, diff --git a/src/communication/bolt/v1/states/error.hpp b/src/communication/bolt/v1/states/error.hpp index 34b52ded2..5879d3400 100644 --- a/src/communication/bolt/v1/states/error.hpp +++ b/src/communication/bolt/v1/states/error.hpp @@ -1,48 +1,38 @@ #pragma once +#include "communication/bolt/v1/messaging/codes.hpp" #include "communication/bolt/v1/state.hpp" -#include "communication/bolt/v1/transport/bolt_decoder.hpp" -#include "communication/bolt/v1/encoder/result_stream.hpp" - #include "logging/default.hpp" namespace communication::bolt { -template<typename Socket> -State state_error_run(ResultStream<Socket> &output_stream, Encoder<ChunkedBuffer<Socket>, Socket>& encoder, BoltDecoder &decoder) { - Logger logger = logging::log->logger("State ERROR"); - logger.trace("Run"); +/** + * TODO (mferencevic): finish & document + */ +template <typename Session> +State StateErrorRun(Session &session) { + static Logger logger = logging::log->logger("State ERROR"); - decoder.read_byte(); - auto message_type = decoder.read_byte(); + session.decoder_.read_byte(); + auto message_type = session.decoder_.read_byte(); logger.trace("Message type byte is: {:02X}", message_type); if (message_type == MessageCode::PullAll) { - // TODO: write_ignored, chunk, send - encoder.MessageIgnored(); + session.encoder_.MessageIgnored(); return ERROR; } else if (message_type == MessageCode::AckFailure) { // TODO reset current statement? is it even necessary? logger.trace("AckFailure received"); - - // TODO: write_success, chunk, send - encoder.MessageSuccess(); - + session.encoder_.MessageSuccess(); return EXECUTOR; } else if (message_type == MessageCode::Reset) { // TODO rollback current transaction // discard all records waiting to be sent - - // TODO: write_success, chunk, send - encoder.MessageSuccess(); - + session.encoder_.MessageSuccess(); return EXECUTOR; } - - // TODO: write_ignored, chunk, send - encoder.MessageIgnored(); - + session.encoder_.MessageIgnored(); return ERROR; } } diff --git a/src/communication/bolt/v1/states/executor.hpp b/src/communication/bolt/v1/states/executor.hpp index ceba1ed33..b315b0d68 100644 --- a/src/communication/bolt/v1/states/executor.hpp +++ b/src/communication/bolt/v1/states/executor.hpp @@ -2,49 +2,46 @@ #include <string> -#include "dbms/dbms.hpp" -#include "query/engine.hpp" -#include "communication/bolt/v1/states/executor.hpp" #include "communication/bolt/v1/messaging/codes.hpp" #include "communication/bolt/v1/state.hpp" - #include "logging/default.hpp" +#include "query/exceptions.hpp" namespace communication::bolt { struct Query { + Query(std::string &&statement) : statement(statement) {} std::string statement; }; -template<typename Socket> -State state_executor_run(ResultStream<Socket> &output_stream, Encoder<ChunkedBuffer<Socket>, Socket>& encoder, BoltDecoder &decoder, Dbms &dmbs, QueryEngine<ResultStream<Socket>> &query_engine){ - Logger logger = logging::log->logger("State EXECUTOR"); +/** + * TODO (mferencevic): finish & document + */ +template <typename Session> +State StateExecutorRun(Session &session) { + // initialize logger + static Logger logger = logging::log->logger("State EXECUTOR"); + // just read one byte that represents the struct type, we can skip the // information contained in this byte - decoder.read_byte(); - - logger.debug("Run"); - - auto message_type = decoder.read_byte(); + session.decoder_.read_byte(); + auto message_type = session.decoder_.read_byte(); if (message_type == MessageCode::Run) { - Query query; + Query query(session.decoder_.read_string()); - query.statement = decoder.read_string(); + // TODO (mferencevic): implement proper exception handling + auto db_accessor = session.dbms_.active(); + logger.debug("[ActiveDB] '{}'", db_accessor->name()); - // TODO (mferencevic): refactor bolt exception handling try { logger.trace("[Run] '{}'", query.statement); - - auto db_accessor = dmbs.active(); - logger.debug("[ActiveDB] '{}'", db_accessor->name()); - - auto is_successfully_executed = - query_engine.Run(query.statement, *db_accessor, output_stream); + auto is_successfully_executed = session.query_engine_.Run( + query.statement, *db_accessor, session.output_stream_); if (!is_successfully_executed) { - // TODO: write_failure, send - encoder.MessageFailure( + db_accessor->abort(); + session.encoder_.MessageFailure( {{"code", "Memgraph.QueryExecutionFail"}, {"message", "Query execution has failed (probably there is no " @@ -52,49 +49,50 @@ State state_executor_run(ResultStream<Socket> &output_stream, Encoder<ChunkedBuf "access -> client has to resolve problems with " "concurrent access)"}}); return ERROR; + } else { + db_accessor->commit(); } return EXECUTOR; - // TODO: RETURN success MAYBE + // !! QUERY ENGINE -> RUN METHOD -> EXCEPTION HANDLING !! } catch (const query::SyntaxException &e) { - // TODO: write_failure, send - encoder.MessageFailure( + db_accessor->abort(); + session.encoder_.MessageFailure( {{"code", "Memgraph.SyntaxException"}, {"message", "Syntax error"}}); return ERROR; - // } catch (const backend::cpp::GeneratorException &e) { - // output_stream.write_failure( - // {{"code", "Memgraph.GeneratorException"}, - // {"message", "Unsupported query"}}); - // output_stream.send(); - // return ERROR; } catch (const query::QueryEngineException &e) { - // TODO: write_failure, send - encoder.MessageFailure( + db_accessor->abort(); + session.encoder_.MessageFailure( {{"code", "Memgraph.QueryEngineException"}, {"message", "Query engine was unable to execute the query"}}); return ERROR; } catch (const StacktraceException &e) { - // TODO: write_failure, send - encoder.MessageFailure( - {{"code", "Memgraph.StacktraceException"}, - {"message", "Unknown exception"}}); + db_accessor->abort(); + session.encoder_.MessageFailure({{"code", "Memgraph.StacktraceException"}, + {"message", "Unknown exception"}}); return ERROR; } catch (std::exception &e) { - // TODO: write_failure, send - encoder.MessageFailure( + db_accessor->abort(); + session.encoder_.MessageFailure( {{"code", "Memgraph.Exception"}, {"message", "Unknown exception"}}); return ERROR; } + // TODO (mferencevic): finish the error handling, cover all exceptions + // which can be raised from query engine + // * [abort, MessageFailure, return ERROR] should be extracted into + // separate function (or something equivalent) + // + // !! QUERY ENGINE -> RUN METHOD -> EXCEPTION HANDLING !! + } else if (message_type == MessageCode::PullAll) { logger.trace("[PullAll]"); - // TODO: all query output should not be immediately flushed from the buffer, it should wait the PullAll command to start flushing!! - //output_stream.send(); + session.encoder_buffer_.Flush(); } else if (message_type == MessageCode::DiscardAll) { logger.trace("[DiscardAll]"); // TODO: discard state // TODO: write_success, send - encoder.MessageSuccess(); + session.encoder_.MessageSuccess(); } else if (message_type == MessageCode::Reset) { // TODO: rollback current transaction // discard all records waiting to be sent @@ -108,5 +106,4 @@ State state_executor_run(ResultStream<Socket> &output_stream, Encoder<ChunkedBuf return EXECUTOR; } - } diff --git a/src/communication/bolt/v1/states/handshake.hpp b/src/communication/bolt/v1/states/handshake.hpp index 0007ea127..5667f1af9 100644 --- a/src/communication/bolt/v1/states/handshake.hpp +++ b/src/communication/bolt/v1/states/handshake.hpp @@ -1,31 +1,30 @@ #pragma once -#include "dbms/dbms.hpp" #include "communication/bolt/v1/state.hpp" #include "communication/bolt/v1/transport/bolt_decoder.hpp" - #include "logging/default.hpp" namespace communication::bolt { static constexpr uint32_t preamble = 0x6060B017; - static constexpr byte protocol[4] = {0x00, 0x00, 0x00, 0x01}; -template<typename Socket> -State state_handshake_run(BoltDecoder &decoder, Socket &socket_, bool *connected) { - Logger logger = logging::log->logger("State HANDSHAKE"); - logger.debug("run"); +/** + * TODO (mferencevic): finish & document + */ +template <typename Session> +State StateHandshakeRun(Session &session) { + static Logger logger = logging::log->logger("State HANDSHAKE"); - if (UNLIKELY(decoder.read_uint32() != preamble)) return NULLSTATE; + if (UNLIKELY(session.decoder_.read_uint32() != preamble)) return NULLSTATE; // TODO so far we only support version 1 of the protocol so it doesn't // make sense to check which version the client prefers // this will change in the future - *connected = true; + session.connected_ = true; // TODO: check for success - socket_.Write(protocol, sizeof protocol); + session.socket_.Write(protocol, sizeof protocol); return INIT; } diff --git a/src/communication/bolt/v1/states/init.hpp b/src/communication/bolt/v1/states/init.hpp index 68fa50a78..6025c3d0a 100644 --- a/src/communication/bolt/v1/states/init.hpp +++ b/src/communication/bolt/v1/states/init.hpp @@ -1,43 +1,42 @@ #pragma once +#include "communication/bolt/v1/encoder/result_stream.hpp" +#include "communication/bolt/v1/messaging/codes.hpp" #include "communication/bolt/v1/packing/codes.hpp" #include "communication/bolt/v1/state.hpp" #include "communication/bolt/v1/transport/bolt_decoder.hpp" -#include "communication/bolt/v1/encoder/result_stream.hpp" -#include "communication/bolt/v1/messaging/codes.hpp" - #include "logging/default.hpp" #include "utils/likely.hpp" namespace communication::bolt { -template<typename Socket> -State state_init_run(Encoder<ChunkedBuffer<Socket>, Socket> &encoder, BoltDecoder &decoder) { - Logger logger = logging::log->logger("State INIT"); +/** + * TODO (mferencevic): finish & document + */ +template <typename Session> +State StateInitRun(Session &session) { + static Logger logger = logging::log->logger("State INIT"); logger.debug("Parsing message"); - auto struct_type = decoder.read_byte(); + auto struct_type = session.decoder_.read_byte(); if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) { logger.debug("{}", struct_type); - logger.debug( "Expected struct marker of max size 0x{:02} instead of 0x{:02X}", (unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type); - return NULLSTATE; } - auto message_type = decoder.read_byte(); + auto message_type = session.decoder_.read_byte(); if (UNLIKELY(message_type != MessageCode::Init)) { logger.debug("Expected Init (0x01) instead of (0x{:02X})", (unsigned)message_type); - return NULLSTATE; } - auto client_name = decoder.read_string(); + auto client_name = session.decoder_.read_string(); if (struct_type == pack::Code::StructTwo) { // TODO process authentication tokens @@ -47,9 +46,8 @@ State state_init_run(Encoder<ChunkedBuffer<Socket>, Socket> &encoder, BoltDecode logger.debug("Client connected '{}'", client_name); // TODO: write_success, chunk, send - encoder.MessageSuccess(); + session.encoder_.MessageSuccess(); return EXECUTOR; } - } diff --git a/src/communication/result_stream_faker.hpp b/src/communication/result_stream_faker.hpp index 2bb709f80..45526aaa5 100644 --- a/src/communication/result_stream_faker.hpp +++ b/src/communication/result_stream_faker.hpp @@ -12,15 +12,16 @@ */ class ResultStreamFaker { public: - void Header(const std::vector<std::string> &fields) { - debug_assert(current_state_ == State::Start, "Headers can only be written in the beginning"); + debug_assert(current_state_ == State::Start, + "Headers can only be written in the beginning"); header_ = fields; current_state_ = State::WritingResults; } void Result(const std::vector<TypedValue> &values) { - debug_assert(current_state_ == State::WritingResults, "Can't accept results before header nor after summary"); + debug_assert(current_state_ == State::WritingResults, + "Can't accept results before header nor after summary"); results_.push_back(values); } @@ -43,16 +44,11 @@ class ResultStreamFaker { } private: - /** * Possible states of the Mocker. Used for checking if calls to * the Mocker as in acceptable order. */ - enum class State { - Start, - WritingResults, - Done - }; + enum class State { Start, WritingResults, Done }; // the current state State current_state_ = State::Start; diff --git a/src/communication/server.hpp b/src/communication/server.hpp index 3ce87b0b8..0d7b5a709 100644 --- a/src/communication/server.hpp +++ b/src/communication/server.hpp @@ -15,6 +15,10 @@ namespace communication { +/** + * TODO (mferencevic): document methods + */ + /** * Communication server. * Listens for incomming connections on the server port and assings them in a @@ -37,9 +41,9 @@ class Server public: Server(Socket &&socket, Dbms &dbms, QueryEngine<OutputStream> &query_engine) - : dbms_(dbms), + : socket_(std::forward<Socket>(socket)), + dbms_(dbms), query_engine_(query_engine), - socket_(std::forward<Socket>(socket)), logger_(logging::log->logger("communication::Server")) { event_.data.fd = socket_; diff --git a/src/communication/worker.hpp b/src/communication/worker.hpp index 5ed410756..585a8342b 100644 --- a/src/communication/worker.hpp +++ b/src/communication/worker.hpp @@ -17,6 +17,10 @@ namespace communication { +/** + * TODO (mferencevic): document methods + */ + /** * Communication worker. * Listens for incomming data on connections and accepts new connections. @@ -69,7 +73,7 @@ class Worker logger_.trace("[on_read] Received {}B", buf.len); try { - session.execute(reinterpret_cast<const byte *>(buf.ptr), buf.len); + session.Execute(reinterpret_cast<const byte *>(buf.ptr), buf.len); } catch (const std::exception &e) { logger_.error("Error occured while executing statement."); logger_.error("{}", e.what()); @@ -80,7 +84,7 @@ class Worker void OnClose(Session &session) { logger_.trace("Client closed the connection"); // TODO: remove socket from epoll object - session.close(); + session.Close(); delete &session; } diff --git a/src/io/network/stream_listener.hpp b/src/io/network/stream_listener.hpp index 1bf521e7e..53449cc63 100644 --- a/src/io/network/stream_listener.hpp +++ b/src/io/network/stream_listener.hpp @@ -12,7 +12,7 @@ class StreamListener : public EventListener<Derived, max_events, wait_timeout> { void Add(Stream &stream) { // add the stream to the event listener - this->listener_.Add(stream.socket, &stream.event); + this->listener_.Add(stream.socket_, &stream.event_); } void OnCloseEvent(Epoll::Event &event) { diff --git a/src/io/network/stream_reader.hpp b/src/io/network/stream_reader.hpp index 5c62f2593..375cef70a 100644 --- a/src/io/network/stream_reader.hpp +++ b/src/io/network/stream_reader.hpp @@ -40,7 +40,7 @@ class StreamReader : public StreamListener<Derived, Stream> { // we want to listen to an incoming event which is edge triggered and // we also want to listen on the hangup event - stream.event.events = EPOLLIN | EPOLLRDHUP; + stream.event_.events = EPOLLIN | EPOLLRDHUP; // add the connection to the event listener this->Add(stream); @@ -52,7 +52,7 @@ class StreamReader : public StreamListener<Derived, Stream> { logger_.trace("On data"); while (true) { - if (UNLIKELY(!stream.alive())) { + if (UNLIKELY(!stream.Alive())) { logger_.trace("Calling OnClose because the stream isn't alive!"); this->derived().OnClose(stream); break; @@ -62,7 +62,7 @@ class StreamReader : public StreamListener<Derived, Stream> { auto buf = this->derived().OnAlloc(stream); // read from the buffer at most buf.len bytes - buf.len = stream.socket.Read(buf.ptr, buf.len); + buf.len = stream.socket_.Read(buf.ptr, buf.len); // check for read errors if (buf.len == -1) { diff --git a/src/memgraph_bolt.cpp b/src/memgraph_bolt.cpp index fa0cbdbcf..63b6a9d43 100644 --- a/src/memgraph_bolt.cpp +++ b/src/memgraph_bolt.cpp @@ -22,8 +22,10 @@ using endpoint_t = io::network::NetworkEndpoint; using socket_t = io::network::Socket; using session_t = communication::bolt::Session<socket_t>; -using result_stream_t = communication::bolt::ResultStream<socket_t>; -using bolt_server_t = communication::Server<session_t, result_stream_t, socket_t>; +using result_stream_t = communication::bolt::ResultStream< + communication::bolt::Encoder<communication::bolt::ChunkedBuffer<socket_t>>>; +using bolt_server_t = + communication::Server<session_t, result_stream_t, socket_t>; static bolt_server_t *serverptr; diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index f1060aaf4..86fd25463 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -41,12 +41,12 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor, // generate frame based on symbol table max_position Frame frame(symbol_table.max_position()); + std::vector<std::string> header; if (auto produce = dynamic_cast<plan::Produce *>(logical_plan.get())) { // top level node in the operator tree is a produce (return) // so stream out results // generate header - std::vector<std::string> header; for (auto named_expression : produce->named_expressions()) header.push_back(named_expression->name_); stream.Header(header); @@ -66,6 +66,7 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor, } else if (dynamic_cast<plan::CreateNode *>(logical_plan.get()) || dynamic_cast<plan::CreateExpand *>(logical_plan.get()) || dynamic_cast<Delete *>(logical_plan.get())) { + stream.Header(header); auto cursor = logical_plan.get()->MakeCursor(db_accessor); while (cursor->Pull(frame, symbol_table)) continue; diff --git a/tests/concurrent/network_common.hpp b/tests/concurrent/network_common.hpp index d4d7bcfc1..ac399fff4 100644 --- a/tests/concurrent/network_common.hpp +++ b/tests/concurrent/network_common.hpp @@ -30,15 +30,15 @@ class TestSession { TestSession(socket_t&& socket, Dbms& dbms, QueryEngine<TestOutputStream>& query_engine) : logger_(logging::log->logger("TestSession")), - socket(std::move(socket)) { - event.data.ptr = this; + socket_(std::move(socket)) { + event_.data.ptr = this; } - bool alive() { return socket.IsOpen(); } + bool Alive() { return socket_.IsOpen(); } - int id() const { return socket.id(); } + int Id() const { return socket_.id(); } - void execute(const byte* data, size_t len) { + void Execute(const byte* data, size_t len) { if (size_ == 0) { size_ = data[0]; size_ <<= 8; @@ -51,23 +51,23 @@ class TestSession { if (have_ < size_) return; for (int i = 0; i < REPLY; ++i) - ASSERT_TRUE(this->socket.Write(buffer_, size_)); + ASSERT_TRUE(this->socket_.Write(buffer_, size_)); have_ = 0; size_ = 0; } - void close() { + void Close() { logger_.trace("Close session!"); - this->socket.Close(); + this->socket_.Close(); } char buffer_[SIZE * 2]; uint32_t have_, size_; Logger logger_; - socket_t socket; - io::network::Epoll::Event event; + socket_t socket_; + io::network::Epoll::Event event_; }; using test_server_t = diff --git a/tests/integration/hardcoded_query/using.hpp b/tests/integration/hardcoded_query/using.hpp index f42ed18c9..610377f11 100644 --- a/tests/integration/hardcoded_query/using.hpp +++ b/tests/integration/hardcoded_query/using.hpp @@ -5,7 +5,8 @@ #ifndef HARDCODED_OUTPUT_STREAM #include "communication/bolt/v1/encoder/result_stream.hpp" #include "io/network/socket.hpp" -using Stream = communication::bolt::ResultStream<io::network::Socket>; +using Stream = communication::bolt::ResultStream<communication::bolt::Encoder< + communication::bolt::ChunkedBuffer<io::network::Socket>>>; #else #include "../stream/print_record_stream.hpp" using Stream = PrintRecordStream; diff --git a/tests/unit/bolt_chunked_buffer.cpp b/tests/unit/bolt_chunked_buffer.cpp index eeb0c3da1..73ea11dcb 100644 --- a/tests/unit/bolt_chunked_buffer.cpp +++ b/tests/unit/bolt_chunked_buffer.cpp @@ -1,44 +1,115 @@ -#define NDEBUG #include "bolt_common.hpp" - #include "communication/bolt/v1/encoder/chunked_buffer.hpp" +// aliases +using SocketT = TestSocket; +using BufferT = communication::bolt::ChunkedBuffer<SocketT>; -constexpr const int SIZE = 131072; -uint8_t data[SIZE]; +// "alias" constants +static constexpr auto CHS = communication::bolt::CHUNK_HEADER_SIZE; +static constexpr auto CEMS = communication::bolt::CHUNK_END_MARKER_SIZE; +static constexpr auto MCS = communication::bolt::MAX_CHUNK_SIZE; +static constexpr auto WCS = communication::bolt::WHOLE_CHUNK_SIZE; +/** + * Verifies a single chunk. The chunk should be constructed from header + * (chunk size), data and end marker. The header is two bytes long number + * written in big endian format. Data is array of elements which max size is + * 0xFFFF. The end marker is always two bytes long array of two zeros. + * + * @param data pointer on data array (array of bytes) + * @param size of data array + * @param element expected element in all positions of chunk data array + * (all data bytes in tested chunk should be equal to element) + */ +void VerifyChunkOfOnes(uint8_t *data, int size, uint8_t element) { + // first two bytes are size (big endian) + uint8_t lower_byte = size & 0xFF; + uint8_t higher_byte = (size & 0xFF00) >> 8; + ASSERT_EQ(*data, higher_byte); + ASSERT_EQ(*(data + 1), lower_byte); -void verify_output(std::vector<uint8_t>& output, const uint8_t* data, uint64_t size) { - uint64_t len = 0, pos = 0; - uint8_t tail[2] = { 0, 0 }; - uint16_t head; - while (size > 0) { - head = len = std::min(size, communication::bolt::CHUNK_SIZE); - head = bswap(head); - check_output(output, reinterpret_cast<uint8_t *>(&head), sizeof(head), false); - check_output(output, data + pos, len, false); - check_output(output, tail, 2, false); - size -= len; - pos += len; + // in the data array should be size number of ones + // the header is skipped + for (auto i = CHS; i < size + CHS; ++i) { + ASSERT_EQ(*(data + i), element); } - check_output(output, nullptr, 0, true); + + // last two bytes should be zeros + // next to header and data + ASSERT_EQ(*(data + CHS + size), 0x00); + ASSERT_EQ(*(data + CHS + size + 1), 0x00); } -TEST(Bolt, ChunkedBuffer) { - TestSocket socket(10); - communication::bolt::ChunkedBuffer<TestSocket> chunked_buffer(socket); - std::vector<uint8_t>& output = socket.output; +TEST(BoltChunkedBuffer, OneSmallChunk) { + // initialize array of 100 ones (small chunk) + int size = 100; + uint8_t element = '1'; + std::vector<uint8_t> data(100, element); - for (int i = 0; i <= SIZE; i += 16) { - chunked_buffer.Write(data, i); - chunked_buffer.Flush(); - verify_output(output, data, i); - } + // initialize tested buffer + SocketT socket(10); + BufferT buffer(socket); + + // write into buffer + buffer.Write(data.data(), size); + buffer.Flush(); + + // check the output array + // the array should look like: [0, 100, 1, 1, ... , 1, 0, 0] + VerifyChunkOfOnes(socket.output.data(), size, element); } +TEST(BoltChunkedBuffer, TwoSmallChunks) { + // initialize the small arrays + int size1 = 100; + uint8_t element1 = '1'; + std::vector<uint8_t> data1(size1, element1); + int size2 = 200; + uint8_t element2 = '2'; + std::vector<uint8_t> data2(size2, element2); -int main(int argc, char** argv) { - initialize_data(data, SIZE); + // initialize tested buffer + SocketT socket(10); + BufferT buffer(socket); + + // write into buffer + buffer.Write(data1.data(), size1); + buffer.Chunk(); + buffer.Write(data2.data(), size2); + buffer.Flush(); + + // check the output array + // the output array should look like this: [0, 100, 1, 1, ... , 1, 0, 0] + + // [0, 100, 2, 2, ...... , 2, 0, 0] + auto data = socket.output.data(); + VerifyChunkOfOnes(data, size1, element1); + VerifyChunkOfOnes(data + CHS + size1 + CEMS, size2, element2); +} + +TEST(BoltChunkedBuffer, OneAndAHalfOfMaxChunk) { + // initialize a big chunk + int size = 100000; + uint8_t element = '1'; + std::vector<uint8_t> data(size, element); + + // initialize tested buffer + SocketT socket(10); + BufferT buffer(socket); + + // write into buffer + buffer.Write(data.data(), size); + buffer.Flush(); + + // check the output array + // the output array should look like this: + // [0xFF, 0xFF, 1, 1, ... , 1, 0, 0, 0x86, 0xA1, 1, 1, ... , 1, 0, 0] + auto output = socket.output.data(); + VerifyChunkOfOnes(output, MCS, element); + VerifyChunkOfOnes(output + WCS, size - MCS, element); +} + +int main(int argc, char **argv) { logging::init_sync(); logging::log->pipe(std::make_unique<Stdout>()); ::testing::InitGoogleTest(&argc, argv); diff --git a/tests/unit/bolt_chunked_decoder.cpp b/tests/unit/bolt_chunked_decoder.cpp new file mode 100644 index 000000000..ec9b903ac --- /dev/null +++ b/tests/unit/bolt_chunked_decoder.cpp @@ -0,0 +1,59 @@ +#include <array> +#include <cassert> +#include <cstring> +#include <deque> +#include <iostream> +#include <vector> + +#include "communication/bolt/v1/transport/chunked_decoder.hpp" +#include "gtest/gtest.h" + +/** + * DummyStream which is going to be used to test output data. + */ +struct DummyStream { + /** + * TODO (mferencevic): apply google style guide once decoder will be + * refactored + document + */ + void write(const uint8_t *values, size_t n) { + data.insert(data.end(), values, values + n); + } + std::vector<uint8_t> data; +}; +using DecoderT = communication::bolt::ChunkedDecoder<DummyStream>; + +TEST(ChunkedDecoderTest, WriteString) { + DummyStream stream; + DecoderT decoder(stream); + + std::vector<uint8_t> chunks[] = { + {0x00, 0x08, 'A', ' ', 'q', 'u', 'i', 'c', 'k', ' ', 0x00, 0x06, 'b', 'r', + 'o', 'w', 'n', ' '}, + {0x00, 0x0A, 'f', 'o', 'x', ' ', 'j', 'u', 'm', 'p', 's', ' '}, + {0x00, 0x07, 'o', 'v', 'e', 'r', ' ', 'a', ' '}, + {0x00, 0x08, 'l', 'a', 'z', 'y', ' ', 'd', 'o', 'g', 0x00, 0x00}}; + static constexpr size_t N = std::extent<decltype(chunks)>::value; + + for (size_t i = 0; i < N; ++i) { + auto &chunk = chunks[i]; + logging::info("Chunk size: {}", chunk.size()); + + const uint8_t *start = chunk.data(); + auto finished = decoder.decode(start, chunk.size()); + + // break early if finished + if (finished) break; + } + + // check validity + std::string decoded = "A quick brown fox jumps over a lazy dog"; + ASSERT_EQ(decoded.size(), stream.data.size()); + for (size_t i = 0; i < decoded.size(); ++i) + ASSERT_EQ(decoded[i], stream.data[i]); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tests/unit/bolt_common.hpp b/tests/unit/bolt_common.hpp index 9bbce0c38..2d400e833 100644 --- a/tests/unit/bolt_common.hpp +++ b/tests/unit/bolt_common.hpp @@ -4,21 +4,21 @@ #include <iostream> #include <vector> +#include "dbms/dbms.hpp" #include "gtest/gtest.h" - #include "logging/default.hpp" #include "logging/streams/stdout.hpp" -#include "dbms/dbms.hpp" - - +/** + * TODO (mferencevic): document + */ class TestSocket { public: TestSocket(int socket) : socket(socket) {} - TestSocket(const TestSocket& s) : socket(s.id()){}; - TestSocket(TestSocket&& other) { *this = std::forward<TestSocket>(other); } + TestSocket(const TestSocket &s) : socket(s.id()){}; + TestSocket(TestSocket &&other) { *this = std::forward<TestSocket>(other); } - TestSocket& operator=(TestSocket&& other) { + TestSocket &operator=(TestSocket &&other) { this->socket = other.socket; other.socket = -1; return *this; @@ -29,12 +29,12 @@ class TestSocket { int id() const { return socket; } - int Write(const std::string& str) { return Write(str.c_str(), str.size()); } - int Write(const char* data, size_t len) { - return Write(reinterpret_cast<const uint8_t*>(data), len); + int Write(const std::string &str) { return Write(str.c_str(), str.size()); } + int Write(const char *data, size_t len) { + return Write(reinterpret_cast<const uint8_t *>(data), len); } - int Write(const uint8_t* data, size_t len) { - for (int i = 0; i < len; ++i) output.push_back(data[i]); + int Write(const uint8_t *data, size_t len) { + for (size_t i = 0; i < len; ++i) output.push_back(data[i]); return len; } @@ -44,29 +44,56 @@ class TestSocket { int socket; }; -void print_output(std::vector<uint8_t>& output) { +/** + * TODO (mferencevic): document + */ +class TestBuffer { + public: + TestBuffer(TestSocket &socket) : socket_(socket) {} + + void Write(const uint8_t *data, size_t n) { socket_.Write(data, n); } + void Chunk() {} + void Flush() {} + + private: + TestSocket &socket_; +}; + +/** + * TODO (mferencevic): document + */ +void PrintOutput(std::vector<uint8_t> &output) { fprintf(stderr, "output: "); - for (int i = 0; i < output.size(); ++i) { + for (size_t i = 0; i < output.size(); ++i) { fprintf(stderr, "%02X ", output[i]); } fprintf(stderr, "\n"); } -void check_output(std::vector<uint8_t>& output, const uint8_t* data, - uint64_t len, bool clear = true) { - if (clear) ASSERT_EQ(len, output.size()); - else ASSERT_LE(len, output.size()); - for (int i = 0; i < len; ++i) - EXPECT_EQ(output[i], data[i]); - if (clear) output.clear(); - else output.erase(output.begin(), output.begin() + len); +/** + * TODO (mferencevic): document + */ +void CheckOutput(std::vector<uint8_t> &output, const uint8_t *data, + uint64_t len, bool clear = true) { + if (clear) + ASSERT_EQ(len, output.size()); + else + ASSERT_LE(len, output.size()); + for (size_t i = 0; i < len; ++i) EXPECT_EQ(output[i], data[i]); + if (clear) + output.clear(); + else + output.erase(output.begin(), output.begin() + len); } -void initialize_data(uint8_t* data, size_t size) { +/** + * TODO (mferencevic): document + */ +void InitializeData(uint8_t *data, size_t size) { std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution<> dis(0, 255); - for (int i = 0; i < size; ++i) { + for (size_t i = 0; i < size; ++i) { data[i] = dis(gen); } } diff --git a/tests/unit/bolt_encoder.cpp b/tests/unit/bolt_encoder.cpp index 1d263c252..c1669fccb 100644 --- a/tests/unit/bolt_encoder.cpp +++ b/tests/unit/bolt_encoder.cpp @@ -5,79 +5,90 @@ #include "database/graph_db_accessor.hpp" #include "query/backend/cpp/typed_value.hpp" +/** + * TODO (mferencevic): document + */ -class TestBuffer { - public: - TestBuffer(TestSocket& socket) : socket_(socket) {} +// clang-format off +const int64_t int_input[] = { + 0, -1, -8, -16, 1, 63, 127, -128, -20, -17, -32768, -12345, -129, 128, + 12345, 32767, -2147483648L, -12345678L, -32769L, 32768L, 12345678L, + 2147483647L, -9223372036854775807L, -12345678912345L, -2147483649L, + 2147483648L, 12345678912345L, 9223372036854775807}; - void Write(const uint8_t* data, size_t n) { - socket_.Write(data, n); - } +const uint8_t int_output[][10] = { + "\x00", "\xFF", "\xF8", "\xF0", "\x01", "\x3F", "\x7F", "\xC8\x80", + "\xC8\xEC", "\xC8\xEF", "\xC9\x80\x00", "\xC9\xCF\xC7", "\xC9\xFF\x7F", + "\xC9\x00\x80", "\xC9\x30\x39", "\xC9\x7F\xFF", "\xCA\x80\x00\x00\x00", + "\xCA\xFF\x43\x9E\xB2", "\xCA\xFF\xFF\x7F\xFF", "\xCA\x00\x00\x80\x00", + "\xCA\x00\xBC\x61\x4E", "\xCA\x7F\xFF\xFF\xFF", + "\xCB\x80\x00\x00\x00\x00\x00\x00\x01", + "\xCB\xFF\xFF\xF4\xC5\x8C\x31\xA4\xA7", + "\xCB\xFF\xFF\xFF\xFF\x7F\xFF\xFF\xFF", + "\xCB\x00\x00\x00\x00\x80\x00\x00\x00", + "\xCB\x00\x00\x0B\x3A\x73\xCE\x5B\x59", + "\xCB\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF"}; +// clang-format on +const uint32_t int_output_len[] = {1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, + 3, 3, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9}; - void Flush() {} - - private: - TestSocket& socket_; -}; - - -const int64_t int_input[] = { 0, -1, -8, -16, 1, 63, 127, -128, -20, -17, -32768, -12345, -129, 128, 12345, 32767, -2147483648L, -12345678L, -32769L, 32768L, 12345678L, 2147483647L, -9223372036854775807L, -12345678912345L, -2147483649L, 2147483648L, 12345678912345L, 9223372036854775807 }; -const uint8_t int_output[][10] = { "\x00", "\xFF", "\xF8", "\xF0", "\x01", "\x3F", "\x7F", "\xC8\x80", "\xC8\xEC", "\xC8\xEF", "\xC9\x80\x00", "\xC9\xCF\xC7", "\xC9\xFF\x7F", "\xC9\x00\x80", "\xC9\x30\x39", "\xC9\x7F\xFF", "\xCA\x80\x00\x00\x00", "\xCA\xFF\x43\x9E\xB2", "\xCA\xFF\xFF\x7F\xFF", "\xCA\x00\x00\x80\x00", "\xCA\x00\xBC\x61\x4E", "\xCA\x7F\xFF\xFF\xFF", "\xCB\x80\x00\x00\x00\x00\x00\x00\x01", "\xCB\xFF\xFF\xF4\xC5\x8C\x31\xA4\xA7", "\xCB\xFF\xFF\xFF\xFF\x7F\xFF\xFF\xFF", "\xCB\x00\x00\x00\x00\x80\x00\x00\x00", "\xCB\x00\x00\x0B\x3A\x73\xCE\x5B\x59", "\xCB\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF" }; -const uint32_t int_output_len[] = { 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9 }; - - -const double double_input[] = { 5.834, 108.199, 43677.9882, 254524.5851 }; -const uint8_t double_output[][10] = { "\xC1\x40\x17\x56\x04\x18\x93\x74\xBC", "\xC1\x40\x5B\x0C\xBC\x6A\x7E\xF9\xDB", "\xC1\x40\xE5\x53\xBF\x9F\x55\x9B\x3D", "\xC1\x41\x0F\x11\xE4\xAE\x48\xE8\xA7"}; - - -const uint8_t vertexedge_output[] = "\xB1\x71\x93\xB3\x4E\x00\x92\x86\x6C\x61\x62\x65\x6C\x31\x86\x6C\x61\x62\x65\x6C\x32\xA2\x85\x70\x72\x6F\x70\x31\x0C\x85\x70\x72\x6F\x70\x32\xC9\x00\xC8\xB3\x4E\x00\x90\xA0\xB5\x52\x00\x00\x00\x88\x65\x64\x67\x65\x74\x79\x70\x65\xA2\x85\x70\x72\x6F\x70\x33\x2A\x85\x70\x72\x6F\x70\x34\xC9\x04\xD2"; +const double double_input[] = {5.834, 108.199, 43677.9882, 254524.5851}; +const uint8_t double_output[][10] = {"\xC1\x40\x17\x56\x04\x18\x93\x74\xBC", + "\xC1\x40\x5B\x0C\xBC\x6A\x7E\xF9\xDB", + "\xC1\x40\xE5\x53\xBF\x9F\x55\x9B\x3D", + "\xC1\x41\x0F\x11\xE4\xAE\x48\xE8\xA7"}; +const uint8_t vertexedge_output[] = + "\xB1\x71\x93\xB3\x4E\x00\x92\x86\x6C\x61\x62\x65\x6C\x31\x86\x6C\x61\x62" + "\x65\x6C\x32\xA2\x85\x70\x72\x6F\x70\x31\x0C\x85\x70\x72\x6F\x70\x32\xC9" + "\x00\xC8\xB3\x4E\x00\x90\xA0\xB5\x52\x00\x00\x00\x88\x65\x64\x67\x65\x74" + "\x79\x70\x65\xA2\x85\x70\x72\x6F\x70\x33\x2A\x85\x70\x72\x6F\x70\x34\xC9" + "\x04\xD2"; constexpr const int SIZE = 131072; uint8_t data[SIZE]; -const uint64_t sizes[] = { 0, 1, 5, 15, 16, 120, 255, 256, 12345, 65535, 65536, 100000 }; +const uint64_t sizes[] = {0, 1, 5, 15, 16, 120, + 255, 256, 12345, 65535, 65536, 100000}; const uint64_t sizes_num = 12; - constexpr const int STRING = 0, LIST = 1, MAP = 2; -const uint8_t type_tiny_magic[] = { 0x80, 0x90, 0xA0 }; -const uint8_t type_8_magic[] = { 0xD0, 0xD4, 0xD8 }; -const uint8_t type_16_magic[] = { 0xD1, 0xD5, 0xD9 }; -const uint8_t type_32_magic[] = { 0xD2, 0xD6, 0xDA }; +const uint8_t type_tiny_magic[] = {0x80, 0x90, 0xA0}; +const uint8_t type_8_magic[] = {0xD0, 0xD4, 0xD8}; +const uint8_t type_16_magic[] = {0xD1, 0xD5, 0xD9}; +const uint8_t type_32_magic[] = {0xD2, 0xD6, 0xDA}; -void check_type_size(std::vector<uint8_t>& v, int typ, uint64_t size) { +void CheckTypeSize(std::vector<uint8_t> &v, int typ, uint64_t size) { if (size <= 15) { uint8_t len = size; len &= 0x0F; len += type_tiny_magic[typ]; - check_output(v, &len, 1, false); + CheckOutput(v, &len, 1, false); } else if (size <= 255) { uint8_t len = size; - check_output(v, &type_8_magic[typ], 1, false); - check_output(v, &len, 1, false); + CheckOutput(v, &type_8_magic[typ], 1, false); + CheckOutput(v, &len, 1, false); } else if (size <= 65536) { uint16_t len = size; len = bswap(len); - check_output(v, &type_16_magic[typ], 1, false); - check_output(v, reinterpret_cast<const uint8_t*> (&len), 2, false); + CheckOutput(v, &type_16_magic[typ], 1, false); + CheckOutput(v, reinterpret_cast<const uint8_t *>(&len), 2, false); } else { uint32_t len = size; len = bswap(len); - check_output(v, &type_32_magic[typ], 1, false); - check_output(v, reinterpret_cast<const uint8_t*> (&len), 4, false); + CheckOutput(v, &type_32_magic[typ], 1, false); + CheckOutput(v, reinterpret_cast<const uint8_t *>(&len), 4, false); } } -void check_record_header(std::vector<uint8_t>& v, uint64_t size) { - check_output(v, (const uint8_t*) "\xB1\x71", 2, false); - check_type_size(v, LIST, size); +void CheckRecordHeader(std::vector<uint8_t> &v, uint64_t size) { + CheckOutput(v, (const uint8_t *)"\xB1\x71", 2, false); + CheckTypeSize(v, LIST, size); } - TestSocket socket(10); -communication::bolt::Encoder<TestBuffer, TestSocket> bolt_encoder(socket); -std::vector<uint8_t>& output = socket.output; - +TestBuffer encoder_buffer(socket); +communication::bolt::Encoder<TestBuffer> bolt_encoder(encoder_buffer); +std::vector<uint8_t> &output = socket.output; TEST(BoltEncoder, NullAndBool) { std::vector<TypedValue> vals; @@ -85,65 +96,62 @@ TEST(BoltEncoder, NullAndBool) { vals.push_back(TypedValue(true)); vals.push_back(TypedValue(false)); bolt_encoder.MessageRecord(vals); - check_record_header(output, 3); - check_output(output, (const uint8_t*) "\xC0\xC3\xC2", 3); + CheckRecordHeader(output, 3); + CheckOutput(output, (const uint8_t *)"\xC0\xC3\xC2", 3); } TEST(BoltEncoder, Int) { int N = 28; std::vector<TypedValue> vals; - for (int i = 0; i < N; ++i) - vals.push_back(TypedValue(int_input[i])); + for (int i = 0; i < N; ++i) vals.push_back(TypedValue(int_input[i])); bolt_encoder.MessageRecord(vals); - check_record_header(output, N); + CheckRecordHeader(output, N); for (int i = 0; i < N; ++i) - check_output(output, int_output[i], int_output_len[i], false); - check_output(output, nullptr, 0); + CheckOutput(output, int_output[i], int_output_len[i], false); + CheckOutput(output, nullptr, 0); } TEST(BoltEncoder, Double) { int N = 4; std::vector<TypedValue> vals; - for (int i = 0; i < N; ++i) - vals.push_back(TypedValue(double_input[i])); + for (int i = 0; i < N; ++i) vals.push_back(TypedValue(double_input[i])); bolt_encoder.MessageRecord(vals); - check_record_header(output, N); - for (int i = 0; i < N; ++i) - check_output(output, double_output[i], 9, false); - check_output(output, nullptr, 0); + CheckRecordHeader(output, N); + for (int i = 0; i < N; ++i) CheckOutput(output, double_output[i], 9, false); + CheckOutput(output, nullptr, 0); } TEST(BoltEncoder, String) { std::vector<TypedValue> vals; - for (int i = 0; i < sizes_num; ++i) - vals.push_back(TypedValue(std::string((const char*) data, sizes[i]))); + for (uint64_t i = 0; i < sizes_num; ++i) + vals.push_back(TypedValue(std::string((const char *)data, sizes[i]))); bolt_encoder.MessageRecord(vals); - check_record_header(output, vals.size()); - for (int i = 0; i < sizes_num; ++i) { - check_type_size(output, STRING, sizes[i]); - check_output(output, data, sizes[i], false); + CheckRecordHeader(output, vals.size()); + for (uint64_t i = 0; i < sizes_num; ++i) { + CheckTypeSize(output, STRING, sizes[i]); + CheckOutput(output, data, sizes[i], false); } - check_output(output, nullptr, 0); + CheckOutput(output, nullptr, 0); } TEST(BoltEncoder, List) { std::vector<TypedValue> vals; - for (int i = 0; i < sizes_num; ++i) { + for (uint64_t i = 0; i < sizes_num; ++i) { std::vector<TypedValue> val; - for (int j = 0; j < sizes[i]; ++j) - val.push_back(TypedValue(std::string((const char*) &data[j], 1))); + for (uint64_t j = 0; j < sizes[i]; ++j) + val.push_back(TypedValue(std::string((const char *)&data[j], 1))); vals.push_back(TypedValue(val)); } bolt_encoder.MessageRecord(vals); - check_record_header(output, vals.size()); - for (int i = 0; i < sizes_num; ++i) { - check_type_size(output, LIST, sizes[i]); - for (int j = 0; j < sizes[i]; ++j) { - check_type_size(output, STRING, 1); - check_output(output, &data[j], 1, false); + CheckRecordHeader(output, vals.size()); + for (uint64_t i = 0; i < sizes_num; ++i) { + CheckTypeSize(output, LIST, sizes[i]); + for (uint64_t j = 0; j < sizes[i]; ++j) { + CheckTypeSize(output, STRING, 1); + CheckOutput(output, &data[j], 1, false); } } - check_output(output, nullptr, 0); + CheckOutput(output, nullptr, 0); } TEST(BoltEncoder, Map) { @@ -152,25 +160,25 @@ TEST(BoltEncoder, Map) { for (int i = 0; i < sizes_num; ++i) { std::map<std::string, TypedValue> val; for (int j = 0; j < sizes[i]; ++j) { - sprintf((char*) buff, "%05X", j); - std::string tmp((char*) buff, 5); + sprintf((char *)buff, "%05X", j); + std::string tmp((char *)buff, 5); val.insert(std::make_pair(tmp, TypedValue(tmp))); } vals.push_back(TypedValue(val)); } bolt_encoder.MessageRecord(vals); - check_record_header(output, vals.size()); + CheckRecordHeader(output, vals.size()); for (int i = 0; i < sizes_num; ++i) { - check_type_size(output, MAP, sizes[i]); + CheckTypeSize(output, MAP, sizes[i]); for (int j = 0; j < sizes[i]; ++j) { - sprintf((char*) buff, "%05X", j); - check_type_size(output, STRING, 5); - check_output(output, buff, 5, false); - check_type_size(output, STRING, 5); - check_output(output, buff, 5, false); + sprintf((char *)buff, "%05X", j); + CheckTypeSize(output, STRING, 5); + CheckOutput(output, buff, 5, false); + CheckTypeSize(output, STRING, 5); + CheckOutput(output, buff, 5, false); } } - check_output(output, nullptr, 0); + CheckOutput(output, nullptr, 0); } TEST(BoltEncoder, VertexAndEdge) { @@ -201,7 +209,7 @@ TEST(BoltEncoder, VertexAndEdge) { vals.push_back(TypedValue(va2)); vals.push_back(TypedValue(ea)); bolt_encoder.MessageRecord(vals); - check_output(output, vertexedge_output, 74); + CheckOutput(output, vertexedge_output, 74); } TEST(BoltEncoder, BoltV1ExampleMessages) { @@ -211,7 +219,7 @@ TEST(BoltEncoder, BoltV1ExampleMessages) { std::vector<TypedValue> rvals; for (int i = 1; i < 4; ++i) rvals.push_back(TypedValue(i)); bolt_encoder.MessageRecord(rvals); - check_output(output, (const uint8_t*) "\xB1\x71\x93\x01\x02\x03", 6); + CheckOutput(output, (const uint8_t *)"\xB1\x71\x93\x01\x02\x03", 6); // success message std::string sv1("name"), sv2("age"), sk("fields"); @@ -222,26 +230,30 @@ TEST(BoltEncoder, BoltV1ExampleMessages) { std::map<std::string, TypedValue> svals; svals.insert(std::make_pair(sk, slist)); bolt_encoder.MessageSuccess(svals); - check_output(output, (const uint8_t*) "\xB1\x70\xA1\x86\x66\x69\x65\x6C\x64\x73\x92\x84\x6E\x61\x6D\x65\x83\x61\x67\x65", 20); + CheckOutput(output, + (const uint8_t *) "\xB1\x70\xA1\x86\x66\x69\x65\x6C\x64\x73\x92\x84\x6E\x61\x6D\x65\x83\x61\x67\x65", + 20); // failure message - std::string fv1("Neo.ClientError.Statement.SyntaxError"), fv2("Invalid syntax."); + std::string fv1("Neo.ClientError.Statement.SyntaxError"), + fv2("Invalid syntax."); std::string fk1("code"), fk2("message"); TypedValue ftv1(fv1), ftv2(fv2); std::map<std::string, TypedValue> fvals; fvals.insert(std::make_pair(fk1, ftv1)); fvals.insert(std::make_pair(fk2, ftv2)); bolt_encoder.MessageFailure(fvals); - check_output(output, (const uint8_t*) "\xB1\x7F\xA2\x84\x63\x6F\x64\x65\xD0\x25\x4E\x65\x6F\x2E\x43\x6C\x69\x65\x6E\x74\x45\x72\x72\x6F\x72\x2E\x53\x74\x61\x74\x65\x6D\x65\x6E\x74\x2E\x53\x79\x6E\x74\x61\x78\x45\x72\x72\x6F\x72\x87\x6D\x65\x73\x73\x61\x67\x65\x8F\x49\x6E\x76\x61\x6C\x69\x64\x20\x73\x79\x6E\x74\x61\x78\x2E", 71); + CheckOutput(output, + (const uint8_t *) "\xB1\x7F\xA2\x84\x63\x6F\x64\x65\xD0\x25\x4E\x65\x6F\x2E\x43\x6C\x69\x65\x6E\x74\x45\x72\x72\x6F\x72\x2E\x53\x74\x61\x74\x65\x6D\x65\x6E\x74\x2E\x53\x79\x6E\x74\x61\x78\x45\x72\x72\x6F\x72\x87\x6D\x65\x73\x73\x61\x67\x65\x8F\x49\x6E\x76\x61\x6C\x69\x64\x20\x73\x79\x6E\x74\x61\x78\x2E", + 71); // ignored message bolt_encoder.MessageIgnored(); - check_output(output, (const uint8_t*) "\xB0\x7E", 2); + CheckOutput(output, (const uint8_t *)"\xB0\x7E", 2); } - -int main(int argc, char** argv) { - initialize_data(data, SIZE); +int main(int argc, char **argv) { + InitializeData(data, SIZE); logging::init_sync(); logging::log->pipe(std::make_unique<Stdout>()); ::testing::InitGoogleTest(&argc, argv); diff --git a/tests/unit/bolt_result_stream.cpp b/tests/unit/bolt_result_stream.cpp index ed18c274c..002fa3c32 100644 --- a/tests/unit/bolt_result_stream.cpp +++ b/tests/unit/bolt_result_stream.cpp @@ -5,44 +5,54 @@ #include "communication/bolt/v1/encoder/result_stream.hpp" #include "query/backend/cpp/typed_value.hpp" +using BufferT = communication::bolt::ChunkedBuffer<TestSocket>; +using EncoderT = communication::bolt::Encoder<BufferT>; +using ResultStreamT = communication::bolt::ResultStream<EncoderT>; -using buffer_t = communication::bolt::ChunkedBuffer<TestSocket>; -using encoder_t = communication::bolt::Encoder<buffer_t, TestSocket>; -using result_stream_t = communication::bolt::ResultStream<TestSocket>; - - -const uint8_t header_output[] = "\x00\x29\xB1\x70\xA1\x86\x66\x69\x65\x6C\x64\x73\x9A\x82\x61\x61\x82\x62\x62\x82\x63\x63\x82\x64\x64\x82\x65\x65\x82\x66\x66\x82\x67\x67\x82\x68\x68\x82\x69\x69\x82\x6A\x6A\x00\x00"; -const uint8_t result_output[] = "\x00\x0A\xB1\x71\x92\x05\x85\x68\x65\x6C\x6C\x6F\x00\x00"; -const uint8_t summary_output[] = "\x00\x0C\xB1\x70\xA1\x87\x63\x68\x61\x6E\x67\x65\x64\x0A\x00\x00"; +/** + * TODO (mferencevic): document + */ +const uint8_t header_output[] = + "\x00\x29\xB1\x70\xA1\x86\x66\x69\x65\x6C\x64\x73\x9A\x82\x61\x61\x82\x62" + "\x62\x82\x63\x63\x82\x64\x64\x82\x65\x65\x82\x66\x66\x82\x67\x67\x82\x68" + "\x68\x82\x69\x69\x82\x6A\x6A\x00\x00"; +const uint8_t result_output[] = + "\x00\x0A\xB1\x71\x92\x05\x85\x68\x65\x6C\x6C\x6F\x00\x00"; +const uint8_t summary_output[] = + "\x00\x0C\xB1\x70\xA1\x87\x63\x68\x61\x6E\x67\x65\x64\x0A\x00\x00"; TEST(Bolt, ResultStream) { TestSocket socket(10); - encoder_t encoder(socket); - result_stream_t result_stream(encoder); - std::vector<uint8_t>& output = socket.output; + BufferT buffer(socket); + EncoderT encoder(buffer); + ResultStreamT result_stream(encoder); + std::vector<uint8_t> &output = socket.output; std::vector<std::string> headers; - for (int i = 0; i < 10; ++i) headers.push_back(std::string(2, (char)('a' + i))); + for (int i = 0; i < 10; ++i) + headers.push_back(std::string(2, (char)('a' + i))); - result_stream.Header(headers); - print_output(output); - check_output(output, header_output, 45); + result_stream.Header(headers); // this method automatically calls Flush + PrintOutput(output); + CheckOutput(output, header_output, 45); - std::vector<TypedValue> result{TypedValue(5), TypedValue(std::string("hello"))}; + std::vector<TypedValue> result{TypedValue(5), + TypedValue(std::string("hello"))}; result_stream.Result(result); - print_output(output); - check_output(output, result_output, 14); + buffer.Flush(); + PrintOutput(output); + CheckOutput(output, result_output, 14); std::map<std::string, TypedValue> summary; summary.insert(std::make_pair(std::string("changed"), TypedValue(10))); result_stream.Summary(summary); - print_output(output); - check_output(output, summary_output, 16); + buffer.Flush(); + PrintOutput(output); + CheckOutput(output, summary_output, 16); } - -int main(int argc, char** argv) { +int main(int argc, char **argv) { logging::init_sync(); logging::log->pipe(std::make_unique<Stdout>()); ::testing::InitGoogleTest(&argc, argv); diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 772b5fa5f..b98c3e635 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -4,9 +4,14 @@ #include "communication/bolt/v1/session.hpp" #include "query/engine.hpp" -using result_stream_t = communication::bolt::ResultStream<TestSocket>; -using session_t = communication::bolt::Session<TestSocket>; +using ResultStreamT = + communication::bolt::ResultStream<communication::bolt::Encoder< + communication::bolt::ChunkedBuffer<TestSocket>>>; +using SessionT = communication::bolt::Session<TestSocket>; +/** + * TODO (mferencevic): document + */ const uint8_t handshake_req[] = "\x60\x60\xb0\x17\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" @@ -23,40 +28,40 @@ const uint8_t run_req[] = "\x61\x6d\x65\x3a\x20\x32\x39\x33\x38\x33\x7d\x29\x20\x52\x45\x54\x55\x52" "\x4e\x20\x6e\xa0\x00\x00"; - TEST(Bolt, Session) { Dbms dbms; TestSocket socket(10); - QueryEngine<result_stream_t> query_engine; - session_t session(std::move(socket), dbms, query_engine); - std::vector<uint8_t>& output = session.socket.output; + QueryEngine<ResultStreamT> query_engine; + SessionT session(std::move(socket), dbms, query_engine); + std::vector<uint8_t> &output = session.socket_.output; // execute handshake - session.execute(handshake_req, 20); - ASSERT_EQ(session.state, communication::bolt::INIT); - print_output(output); - check_output(output, handshake_resp, 4); + session.Execute(handshake_req, 20); + ASSERT_EQ(session.state_, communication::bolt::INIT); + PrintOutput(output); + CheckOutput(output, handshake_resp, 4); // execute init - session.execute(init_req, 67); - ASSERT_EQ(session.state, communication::bolt::EXECUTOR); - print_output(output); - check_output(output, init_resp, 7); + session.Execute(init_req, 67); + ASSERT_EQ(session.state_, communication::bolt::EXECUTOR); + PrintOutput(output); + CheckOutput(output, init_resp, 7); // execute run - session.execute(run_req, 42); - // TODO: query engine doesn't currently work, + session.Execute(run_req, 42); + + // TODO (mferencevic): query engine doesn't currently work, // we should test the query output here and the next state // ASSERT_EQ(session.state, bolt::EXECUTOR); - // print_output(output); - // check_output(output, run_resp, len); + // PrintOutput(output); + // CheckOutput(output, run_resp, len); - // TODO: add more tests + // TODO (mferencevic): add more tests - session.close(); + session.Close(); } -int main(int argc, char** argv) { +int main(int argc, char **argv) { logging::init_sync(); logging::log->pipe(std::make_unique<Stdout>()); ::testing::InitGoogleTest(&argc, argv); diff --git a/tests/unit/chunked_decoder.cpp b/tests/unit/chunked_decoder.cpp deleted file mode 100644 index 3bb66a666..000000000 --- a/tests/unit/chunked_decoder.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include <array> -#include <cassert> -#include <cstring> -#include <deque> -#include <iostream> -#include <vector> - -#include "gtest/gtest.h" - -#include "communication/bolt/v1/transport/chunked_decoder.hpp" - -using byte = unsigned char; - -void print_hex(byte x) { printf("%02X ", static_cast<byte>(x)); } - -struct DummyStream { - void write(const byte *values, size_t n) { - data.insert(data.end(), values, values + n); - } - - std::vector<byte> data; -}; - -using Decoder = communication::bolt::ChunkedDecoder<DummyStream>; - -std::vector<byte> chunks[] = { - {0x00, 0x08, 'A', ' ', 'q', 'u', 'i', 'c', 'k', ' ', 0x00, 0x06, 'b', 'r', - 'o', 'w', 'n', ' '}, - {0x00, 0x0A, 'f', 'o', 'x', ' ', 'j', 'u', 'm', 'p', 's', ' '}, - {0x00, 0x07, 'o', 'v', 'e', 'r', ' ', 'a', ' '}, - {0x00, 0x08, 'l', 'a', 'z', 'y', ' ', 'd', 'o', 'g', 0x00, 0x00}}; - -static constexpr size_t N = std::extent<decltype(chunks)>::value; - -std::string decoded = "A quick brown fox jumps over a lazy dog"; - -TEST(ChunkedDecoderTest, WriteString) { - DummyStream stream; - Decoder decoder(stream); - - for (size_t i = 0; i < N; ++i) { - auto &chunk = chunks[i]; - logging::info("Chunk size: {}", chunk.size()); - - const byte *start = chunk.data(); - auto finished = decoder.decode(start, chunk.size()); - - // break early if finished - if (finished) break; - } - - // check validity - ASSERT_EQ(decoded.size(), stream.data.size()); - for (size_t i = 0; i < decoded.size(); ++i) - ASSERT_EQ(decoded[i], stream.data[i]); -} - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -}