diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7ff9a85e7..cef839da0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(communication) add_subdirectory(stats) add_subdirectory(auth) add_subdirectory(rpc) +add_subdirectory(slk) # ---------------------------------------------------------------------------- # Memgraph Single Node diff --git a/src/communication/CMakeLists.txt b/src/communication/CMakeLists.txt index aeffefcab..94aaec60c 100644 --- a/src/communication/CMakeLists.txt +++ b/src/communication/CMakeLists.txt @@ -25,5 +25,6 @@ add_custom_target(generate_communication_rpc_capnp DEPENDS ${communication_rpc_c add_library(mg-comm-rpc STATIC ${communication_rpc_src_files}) target_link_libraries(mg-comm-rpc Threads::Threads mg-communication mg-utils mg-io mg-rpc fmt glog gflags) target_link_libraries(mg-comm-rpc capnp kj) +target_link_libraries(mg-comm-rpc mg-slk) add_dependencies(mg-comm-rpc generate_communication_rpc_capnp) diff --git a/src/communication/rpc/streams.hpp b/src/communication/rpc/streams.hpp deleted file mode 100644 index 07199c076..000000000 --- a/src/communication/rpc/streams.hpp +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once - -#include <cstdint> - -#include <glog/logging.h> - -namespace slk { - -// TODO (mferencevic): Implementations of the `Builder` and `Reader` are just -// mock implementations for now. They will be finished when they will be -// integrated into the RPC layer. - -class Builder { - public: - void Save(const uint8_t *data, uint64_t size) { - CHECK(size_ + size <= 262144); - memcpy(data_ + size_, data, size); - size_ += size; - } - - uint8_t *data() { return data_; } - uint64_t size() { return size_; } - - private: - uint8_t data_[262144]; - uint64_t size_{0}; -}; - -class Reader { - public: - Reader(const uint8_t *data, uint64_t size) : data_(data), size_(size) {} - - void Load(uint8_t *data, uint64_t size) { - CHECK(offset_ <= size_); - memcpy(data, data_ + offset_, size); - offset_ += size; - } - - private: - const uint8_t *data_; - uint64_t size_; - uint64_t offset_{0}; -}; - -} // namespace slk diff --git a/src/database/distributed/counters_rpc_messages.lcp b/src/database/distributed/counters_rpc_messages.lcp index 20a2466d2..3493a89ff 100644 --- a/src/database/distributed/counters_rpc_messages.lcp +++ b/src/database/distributed/counters_rpc_messages.lcp @@ -4,8 +4,8 @@ #include <string> #include "communication/rpc/messages.hpp" -#include "communication/rpc/serialization.hpp" #include "database/distributed/counters_rpc_messages.capnp.h" +#include "slk/serialization.hpp" cpp<# (lcp:namespace database) diff --git a/src/distributed/dynamic_worker_rpc_messages.lcp b/src/distributed/dynamic_worker_rpc_messages.lcp index 0242fb12d..00094955f 100644 --- a/src/distributed/dynamic_worker_rpc_messages.lcp +++ b/src/distributed/dynamic_worker_rpc_messages.lcp @@ -5,8 +5,8 @@ #include <string> #include "communication/rpc/messages.hpp" -#include "communication/rpc/serialization.hpp" #include "distributed/dynamic_worker_rpc_messages.capnp.h" +#include "slk/serialization.hpp" cpp<# (lcp:namespace distributed) diff --git a/src/distributed/storage_gc_rpc_messages.lcp b/src/distributed/storage_gc_rpc_messages.lcp index 1353f2311..877f7a97e 100644 --- a/src/distributed/storage_gc_rpc_messages.lcp +++ b/src/distributed/storage_gc_rpc_messages.lcp @@ -2,9 +2,9 @@ #pragma once #include "communication/rpc/messages.hpp" -#include "communication/rpc/serialization.hpp" #include "distributed/storage_gc_rpc_messages.capnp.h" #include "io/network/endpoint.hpp" +#include "slk/serialization.hpp" #include "transactions/transaction.hpp" cpp<# diff --git a/src/distributed/token_sharing_rpc_messages.lcp b/src/distributed/token_sharing_rpc_messages.lcp index a2ab4d042..3e88cab2e 100644 --- a/src/distributed/token_sharing_rpc_messages.lcp +++ b/src/distributed/token_sharing_rpc_messages.lcp @@ -5,8 +5,8 @@ #include <string> #include "communication/rpc/messages.hpp" -#include "communication/rpc/serialization.hpp" #include "distributed/token_sharing_rpc_messages.capnp.h" +#include "slk/serialization.hpp" cpp<# (lcp:namespace distributed) diff --git a/src/durability/distributed/serialization.hpp b/src/durability/distributed/serialization.hpp index 177741777..0365b82ca 100644 --- a/src/durability/distributed/serialization.hpp +++ b/src/durability/distributed/serialization.hpp @@ -1,9 +1,9 @@ #pragma once -#include "communication/rpc/serialization.hpp" #include "durability/distributed/recovery.hpp" #include "durability/distributed/serialization.capnp.h" #include "rpc/serialization.hpp" +#include "slk/serialization.hpp" namespace durability { diff --git a/src/io/network/serialization.hpp b/src/io/network/serialization.hpp index 43b492175..222c9bfcc 100644 --- a/src/io/network/serialization.hpp +++ b/src/io/network/serialization.hpp @@ -1,9 +1,8 @@ #pragma once -// TODO: SLK serialization should be its own thing -#include "communication/rpc/serialization.hpp" #include "io/network/endpoint.capnp.h" #include "io/network/endpoint.hpp" +#include "slk/serialization.hpp" namespace io::network { diff --git a/src/lisp/lcp.lisp b/src/lisp/lcp.lisp index 88540ec20..bce2eee7d 100644 --- a/src/lisp/lcp.lisp +++ b/src/lisp/lcp.lisp @@ -1559,7 +1559,7 @@ file." (with-open-file (out hpp-file :direction :output :if-exists :append) (terpri out) (write-line "// SLK serialization declarations" out) - (write-line "#include \"communication/rpc/serialization.hpp\"" out) + (write-line "#include \"slk/serialization.hpp\"" out) (with-namespaced-output (out open-namespace) (open-namespace '("slk")) (dolist (type-for-slk types-for-slk) diff --git a/src/raft/storage_info_rpc_messages.lcp b/src/raft/storage_info_rpc_messages.lcp index 9a20f0e32..96cd8a8ce 100644 --- a/src/raft/storage_info_rpc_messages.lcp +++ b/src/raft/storage_info_rpc_messages.lcp @@ -5,9 +5,9 @@ #include <string> #include "communication/rpc/messages.hpp" -#include "communication/rpc/serialization.hpp" #include "raft/storage_info_rpc_messages.capnp.h" #include "rpc/serialization.hpp" +#include "slk/serialization.hpp" cpp<# (lcp:namespace raft) diff --git a/src/slk/CMakeLists.txt b/src/slk/CMakeLists.txt new file mode 100644 index 000000000..14b526677 --- /dev/null +++ b/src/slk/CMakeLists.txt @@ -0,0 +1,6 @@ +set(slk_src_files + streams.cpp) + +add_library(mg-slk STATIC ${slk_src_files}) +target_link_libraries(mg-slk glog gflags) +target_link_libraries(mg-slk mg-utils) diff --git a/src/communication/rpc/serialization.hpp b/src/slk/serialization.hpp similarity index 92% rename from src/communication/rpc/serialization.hpp rename to src/slk/serialization.hpp index 644369ac6..eefd1a8ba 100644 --- a/src/communication/rpc/serialization.hpp +++ b/src/slk/serialization.hpp @@ -15,7 +15,7 @@ #include <utility> #include <vector> -#include "communication/rpc/streams.hpp" +#include "slk/streams.hpp" #include "utils/exceptions.hpp" // The namespace name stands for SaveLoadKit. It should be not mistaken for the @@ -133,30 +133,15 @@ MAKE_PRIMITIVE_LOAD(double) inline void Save(const std::string &obj, Builder *builder) { uint64_t size = obj.size(); - builder->Save(reinterpret_cast<const uint8_t *>(&size), sizeof(uint64_t)); + Save(size, builder); builder->Save(reinterpret_cast<const uint8_t *>(obj.data()), size); } inline void Load(std::string *obj, Reader *reader) { - const int kMaxStackBuffer = 8192; uint64_t size = 0; - reader->Load(reinterpret_cast<uint8_t *>(&size), sizeof(uint64_t)); - if (size < kMaxStackBuffer) { - // Here we use a temporary buffer on the stack to prevent temporary - // allocations. Most of strings that are decoded are small so it makes no - // sense to allocate a temporary buffer every time we decode a string. This - // way we allocate a temporary buffer only when the string is large. This - // wouldn't be necessary if we had full C++17 support. In C++17 we could - // preallocate the `buff[size]` in the destination string `*obj = - // std::string('\0', size)` and just call `reader->Load(obj->data())`. - char buff[kMaxStackBuffer]; - reader->Load(reinterpret_cast<uint8_t *>(buff), size); - *obj = std::string(buff, size); - } else { - auto buff = std::unique_ptr<char[]>(new char[size]); - reader->Load(reinterpret_cast<uint8_t *>(buff.get()), size); - *obj = std::string(buff.get(), size); - } + Load(&size, reader); + *obj = std::string(size, '\0'); + reader->Load(reinterpret_cast<uint8_t *>(obj->data()), size); } template <typename T> diff --git a/src/slk/streams.cpp b/src/slk/streams.cpp new file mode 100644 index 000000000..ed7cb822f --- /dev/null +++ b/src/slk/streams.cpp @@ -0,0 +1,135 @@ +#include "slk/streams.hpp" + +#include <cstring> + +#include <glog/logging.h> + +namespace slk { + +Builder::Builder(std::function<void(const uint8_t *, size_t, bool)> write_func) + : write_func_(write_func) {} + +void Builder::Save(const uint8_t *data, uint64_t size) { + size_t offset = 0; + while (size > 0) { + FlushSegment(false); + + size_t to_write = size; + if (to_write > kSegmentMaxDataSize - pos_) { + to_write = kSegmentMaxDataSize - pos_; + } + + memcpy(segment_ + sizeof(SegmentSize) + pos_, data + offset, to_write); + + size -= to_write; + pos_ += to_write; + + offset += to_write; + } +} + +void Builder::Finalize() { FlushSegment(true); } + +void Builder::FlushSegment(bool final_segment) { + if (!final_segment && pos_ < kSegmentMaxDataSize) return; + CHECK(pos_ > 0) << "Trying to flush out a segment that has no data in it!"; + + size_t total_size = sizeof(SegmentSize) + pos_; + + SegmentSize size = pos_; + memcpy(segment_, &size, sizeof(SegmentSize)); + + if (final_segment) { + SegmentSize footer = 0; + memcpy(segment_ + total_size, &footer, sizeof(SegmentSize)); + total_size += sizeof(SegmentSize); + } + + write_func_(segment_, total_size, !final_segment); + + pos_ = 0; +} + +Reader::Reader(const uint8_t *data, size_t size) : data_(data), size_(size) {} + +void Reader::Load(uint8_t *data, uint64_t size) { + size_t offset = 0; + while (size > 0) { + GetSegment(); + size_t to_read = size; + if (to_read > have_) { + to_read = have_; + } + memcpy(data + offset, data_ + pos_, to_read); + pos_ += to_read; + have_ -= to_read; + offset += to_read; + size -= to_read; + } +} + +void Reader::Finalize() { GetSegment(true); } + +void Reader::GetSegment(bool should_be_final) { + if (have_ != 0) { + if (should_be_final) { + throw SlkReaderException( + "There is still leftover data in the SLK stream!"); + } + return; + } + + // Load new segment. + SegmentSize len = 0; + if (pos_ + sizeof(SegmentSize) > size_) { + throw SlkReaderException("Size data missing in SLK stream!"); + } + memcpy(&len, data_ + pos_, sizeof(SegmentSize)); + pos_ += sizeof(SegmentSize); + + if (should_be_final && len != 0) { + throw SlkReaderException( + "Got a non-empty SLK segment when expecting the final segment!"); + } + if (!should_be_final && len == 0) { + throw SlkReaderException( + "Got an empty SLK segment when expecting a non-empty segment!"); + } + + if (pos_ + len > size_) { + throw SlkReaderException("There isn't enough data in the SLK stream!"); + } + have_ = len; +} + +StreamInfo CheckStreamComplete(const uint8_t *data, size_t size) { + size_t found_segments = 0; + size_t data_size = 0; + + size_t pos = 0; + while (true) { + SegmentSize len = 0; + if (pos + sizeof(SegmentSize) > size) { + return {StreamStatus::PARTIAL, pos + kSegmentMaxTotalSize, data_size}; + } + memcpy(&len, data + pos, sizeof(SegmentSize)); + pos += sizeof(SegmentSize); + if (len == 0) { + break; + } + + if (pos + len > size) { + return {StreamStatus::PARTIAL, pos + kSegmentMaxTotalSize, data_size}; + } + pos += len; + + ++found_segments; + data_size += len; + } + if (found_segments < 1) { + return {StreamStatus::INVALID, 0, 0}; + } + return {StreamStatus::COMPLETE, pos, data_size}; +} + +} // namespace slk diff --git a/src/slk/streams.hpp b/src/slk/streams.hpp new file mode 100644 index 000000000..a9154ccd6 --- /dev/null +++ b/src/slk/streams.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include <cstdint> +#include <functional> +#include <limits> + +#include "utils/exceptions.hpp" + +namespace slk { + +using SegmentSize = uint32_t; + +// The maximum allowed size of a segment is set to `kSegmentMaxDataSize`. The +// value of 256 KiB was chosen so that the segment buffer will always fit on the +// stack (it mustn't be too large) and that it isn't too small so that most SLK +// messages fit into a single segment. +const uint64_t kSegmentMaxDataSize = 262144; +const uint64_t kSegmentMaxTotalSize = + kSegmentMaxDataSize + sizeof(SegmentSize) + sizeof(SegmentSize); + +static_assert( + kSegmentMaxDataSize <= std::numeric_limits<SegmentSize>::max(), + "The SLK segment can't be larger than the type used to store its size!"); + +/// SLK splits binary data into segments. Segments are used to avoid the need to +/// have all of the encoded data in memory at once during the building process. +/// That enables streaming during the building process and makes the whole +/// process make zero memory allocations because only one static buffer is used. +/// During the reading process you must have all of the data in memory. +/// +/// SLK segments are just chunks of binary data. They start with a `size` field +/// and then are followed by the binary data itself. The segments have a maximum +/// size of `kSegmentMaxDataSize`. The `size` field itself has a size of +/// `sizeof(SegmentSize)`. A segment of size 0 indicates that we have reached +/// the end of a stream and that there is no more data to be read/written. + +/// Builder used to create a SLK segment stream. +class Builder { + public: + Builder(std::function<void(const uint8_t *, size_t, bool)> write_func); + + /// Function used internally by SLK to serialize the data. + void Save(const uint8_t *data, uint64_t size); + + /// Function that should be called after all `slk::Save` operations are done. + void Finalize(); + + private: + void FlushSegment(bool final_segment); + + std::function<void(const uint8_t *, size_t, bool)> write_func_; + size_t pos_{0}; + uint8_t segment_[kSegmentMaxTotalSize]; +}; + +/// Exception that will be thrown if segments can't be decoded from the byte +/// stream. +class SlkReaderException : public utils::BasicException { + public: + using utils::BasicException::BasicException; +}; + +/// Reader used to read data from a SLK segment stream. +class Reader { + public: + Reader(const uint8_t *data, size_t size); + + /// Function used internally by SLK to deserialize the data. + void Load(uint8_t *data, uint64_t size); + + /// Function that should be called after all `slk::Load` operations are done. + void Finalize(); + + private: + void GetSegment(bool should_be_final = false); + + const uint8_t *data_; + size_t size_; + + size_t pos_{0}; + size_t have_{0}; +}; + +/// Stream status that is returned by the `CheckStreamComplete` function. +enum class StreamStatus { + PARTIAL, + COMPLETE, + INVALID, +}; + +/// Stream information retuned by the `CheckStreamComplete` function. +struct StreamInfo { + StreamStatus status; + size_t stream_size; + size_t encoded_data_size; +}; + +/// This function checks the binary stream to see whether it contains a fully +/// received SLK segment stream. The function returns a `StreamInfo` structure +/// that has three members. The `status` member indicates in which state is the +/// received data (partially received, completely received or invalid), the +/// `stream_size` member indicates the size of the data stream (see NOTE) and +/// the `encoded_data_size` member indicates the size of the SLK encoded data in +/// the stream (so far). +/// NOTE: If the stream is partial, the size of the data stream returned will +/// not be the exact size of the received data. It will be a maximum expected +/// size of the data stream. It is used to indicate to the network stack how +/// much data it should receive before it makes sense to retry decoding of the +/// segment data. +StreamInfo CheckStreamComplete(const uint8_t *data, size_t size); + +} // namespace slk diff --git a/src/stats/stats_rpc_messages.lcp b/src/stats/stats_rpc_messages.lcp index 7c774b0b5..759aae035 100644 --- a/src/stats/stats_rpc_messages.lcp +++ b/src/stats/stats_rpc_messages.lcp @@ -2,7 +2,7 @@ #pragma once #include "communication/rpc/messages.hpp" -#include "communication/rpc/serialization.hpp" +#include "slk/serialization.hpp" #include "stats/stats_rpc_messages.capnp.h" #include "rpc/serialization.hpp" #include "utils/timestamp.hpp" diff --git a/src/storage/common/types/slk.hpp b/src/storage/common/types/slk.hpp index 0745dddbe..f74f33af8 100644 --- a/src/storage/common/types/slk.hpp +++ b/src/storage/common/types/slk.hpp @@ -1,7 +1,6 @@ #pragma once -#include "communication/rpc/serialization.hpp" -#include "communication/rpc/streams.hpp" +#include "slk/serialization.hpp" #include "storage/common/types/property_value.hpp" #include "storage/common/types/property_value_store.hpp" #include "storage/common/types/types.hpp" diff --git a/src/storage/distributed/rpc/serialization.hpp b/src/storage/distributed/rpc/serialization.hpp index f4fc32f03..09a9cc589 100644 --- a/src/storage/distributed/rpc/serialization.hpp +++ b/src/storage/distributed/rpc/serialization.hpp @@ -1,6 +1,6 @@ #pragma once -#include "communication/rpc/serialization.hpp" +#include "slk/serialization.hpp" #include "storage/common/types/property_value.hpp" #include "storage/common/types/property_value_store.hpp" #include "storage/common/types/slk.hpp" diff --git a/src/storage/single_node_ha/rpc/serialization.hpp b/src/storage/single_node_ha/rpc/serialization.hpp index 80969ae3f..e88473624 100644 --- a/src/storage/single_node_ha/rpc/serialization.hpp +++ b/src/storage/single_node_ha/rpc/serialization.hpp @@ -1,6 +1,6 @@ #pragma once -#include "communication/rpc/serialization.hpp" +#include "slk/serialization.hpp" #include "storage/common/types/property_value.hpp" #include "storage/common/types/slk.hpp" #include "storage/common/types/types.hpp" diff --git a/src/transactions/distributed/engine_rpc_messages.lcp b/src/transactions/distributed/engine_rpc_messages.lcp index e439bcc96..12208096c 100644 --- a/src/transactions/distributed/engine_rpc_messages.lcp +++ b/src/transactions/distributed/engine_rpc_messages.lcp @@ -2,7 +2,7 @@ #pragma once #include "communication/rpc/messages.hpp" -#include "communication/rpc/serialization.hpp" +#include "slk/serialization.hpp" #include "transactions/commit_log.hpp" #include "transactions/distributed/engine_rpc_messages.capnp.h" #include "transactions/snapshot.hpp" diff --git a/tests/benchmark/serialization.cpp b/tests/benchmark/serialization.cpp index 02c39a4ec..2bf57e825 100644 --- a/tests/benchmark/serialization.cpp +++ b/tests/benchmark/serialization.cpp @@ -6,9 +6,9 @@ #include <capnp/serialize.h> #include <kj/std/iostream.h> -#include "communication/rpc/serialization.hpp" -#include "query/frontend/semantic/symbol.hpp" #include "query/distributed/frontend/semantic/symbol_serialization.hpp" +#include "query/frontend/semantic/symbol.hpp" +#include "slk/serialization.hpp" class SymbolVectorFixture : public benchmark::Fixture { protected: @@ -96,16 +96,16 @@ BENCHMARK_REGISTER_F(SymbolVectorFixture, CapnpDeserial) ->Range(4, 1 << 12) ->Unit(benchmark::kNanosecond); -void SymbolVectorToCustom(const std::vector<query::Symbol> &symbols, - slk::Builder *builder) { +void SymbolVectorToSlk(const std::vector<query::Symbol> &symbols, + slk::Builder *builder) { slk::Save(symbols.size(), builder); for (int i = 0; i < symbols.size(); ++i) { slk::Save(symbols[i], builder); } } -void CustomToSymbolVector(std::vector<query::Symbol> *symbols, - slk::Reader *reader) { +void SlkToSymbolVector(std::vector<query::Symbol> *symbols, + slk::Reader *reader) { uint64_t size = 0; slk::Load(&size, reader); symbols->resize(size); @@ -114,32 +114,39 @@ void CustomToSymbolVector(std::vector<query::Symbol> *symbols, } } -BENCHMARK_DEFINE_F(SymbolVectorFixture, CustomSerial)(benchmark::State &state) { +BENCHMARK_DEFINE_F(SymbolVectorFixture, SlkSerial)(benchmark::State &state) { while (state.KeepRunning()) { - slk::Builder builder; - SymbolVectorToCustom(symbols_, &builder); + slk::Builder builder([](const uint8_t *, size_t, bool) {}); + SymbolVectorToSlk(symbols_, &builder); + builder.Finalize(); } state.SetItemsProcessed(state.iterations()); } -BENCHMARK_DEFINE_F(SymbolVectorFixture, CustomDeserial) +BENCHMARK_DEFINE_F(SymbolVectorFixture, SlkDeserial) (benchmark::State &state) { - slk::Builder builder; - SymbolVectorToCustom(symbols_, &builder); + std::vector<uint8_t> encoded; + slk::Builder builder( + [&encoded](const uint8_t *data, size_t size, bool have_more) { + for (size_t i = 0; i < size; ++i) encoded.push_back(data[i]); + }); + SymbolVectorToSlk(symbols_, &builder); + builder.Finalize(); + while (state.KeepRunning()) { - slk::Reader reader(builder.data(), builder.size()); + slk::Reader reader(encoded.data(), encoded.size()); std::vector<query::Symbol> symbols; - CustomToSymbolVector(&symbols, &reader); + SlkToSymbolVector(&symbols, &reader); } state.SetItemsProcessed(state.iterations()); } -BENCHMARK_REGISTER_F(SymbolVectorFixture, CustomSerial) +BENCHMARK_REGISTER_F(SymbolVectorFixture, SlkSerial) ->RangeMultiplier(4) ->Range(4, 1 << 12) ->Unit(benchmark::kNanosecond); -BENCHMARK_REGISTER_F(SymbolVectorFixture, CustomDeserial) +BENCHMARK_REGISTER_F(SymbolVectorFixture, SlkDeserial) ->RangeMultiplier(4) ->Range(4, 1 << 12) ->Unit(benchmark::kNanosecond); diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 87e89d9ec..f8e8df5f8 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -250,9 +250,11 @@ target_link_libraries(${test_prefix}skiplist_suffix mg-single-node kvstore_dummy add_unit_test(slk_advanced.cpp) target_link_libraries(${test_prefix}slk_advanced mg-distributed kvstore_dummy_lib) -# TODO (mferencevic): remove glog, gflags and mg-single-node add_unit_test(slk_core.cpp) -target_link_libraries(${test_prefix}slk_core glog gflags) +target_link_libraries(${test_prefix}slk_core mg-slk glog gflags fmt) + +add_unit_test(slk_streams.cpp) +target_link_libraries(${test_prefix}slk_streams mg-slk glog gflags fmt) add_unit_test(small_vector.cpp) target_link_libraries(${test_prefix}small_vector mg-utils) diff --git a/tests/unit/ast_serialization.cpp b/tests/unit/ast_serialization.cpp index 811d9cee8..69b83137d 100644 --- a/tests/unit/ast_serialization.cpp +++ b/tests/unit/ast_serialization.cpp @@ -10,13 +10,15 @@ #include <gmock/gmock.h> #include <gtest/gtest.h> -#include "communication/rpc/serialization.hpp" -#include "query/frontend/ast/ast.hpp" #include "query/distributed/frontend/ast/ast_serialization.hpp" +#include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp" #include "query/frontend/opencypher/parser.hpp" #include "query/frontend/stripped.hpp" #include "query/typed_value.hpp" +#include "slk/serialization.hpp" + +#include "slk_common.hpp" namespace { @@ -117,12 +119,13 @@ class SlkAstGenerator : public Base { CypherMainVisitor visitor(context_, &tmp_storage); visitor.visit(parser.tree()); - slk::Builder builder; - { SaveAstPointer(visitor.query(), &builder); } + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + { SaveAstPointer(visitor.query(), builder); } { - slk::Reader reader(builder.data(), builder.size()); - return LoadAstPointer<Query>(&storage_, &reader); + auto reader = loopback.GetReader(); + return LoadAstPointer<Query>(&storage_, reader); } } diff --git a/tests/unit/slk_advanced.cpp b/tests/unit/slk_advanced.cpp index 2234d0cbe..502588a50 100644 --- a/tests/unit/slk_advanced.cpp +++ b/tests/unit/slk_advanced.cpp @@ -2,6 +2,8 @@ #include "storage/common/types/slk.hpp" +#include "slk_common.hpp" + TEST(SlkAdvanced, PropertyValueList) { std::vector<PropertyValue> original{"hello world!", 5, 1.123423, true, PropertyValue()}; @@ -11,12 +13,13 @@ TEST(SlkAdvanced, PropertyValueList) { ASSERT_EQ(original[3].type(), PropertyValue::Type::Bool); ASSERT_EQ(original[4].type(), PropertyValue::Type::Null); - slk::Builder builder; - slk::Save(original, &builder); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::vector<PropertyValue> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); } @@ -33,12 +36,13 @@ TEST(SlkAdvanced, PropertyValueMap) { ASSERT_EQ(original["truth"].type(), PropertyValue::Type::Bool); ASSERT_EQ(original["nothing"].type(), PropertyValue::Type::Null); - slk::Builder builder; - slk::Save(original, &builder); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::map<std::string, PropertyValue> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); } @@ -66,12 +70,13 @@ TEST(SlkAdvanced, PropertyValueComplex) { PropertyValue original({vec_v, map_v}); ASSERT_EQ(original.type(), PropertyValue::Type::List); - slk::Builder builder; - slk::Save(original, &builder); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); PropertyValue decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); } diff --git a/tests/unit/slk_common.hpp b/tests/unit/slk_common.hpp new file mode 100644 index 000000000..3c0d568d1 --- /dev/null +++ b/tests/unit/slk_common.hpp @@ -0,0 +1,78 @@ +#pragma once + +#include <cstdint> +#include <iostream> +#include <memory> +#include <vector> + +#include <fmt/format.h> +#include <glog/logging.h> +#include <gtest/gtest.h> + +#include "slk/streams.hpp" + +namespace slk { + +/// Class used for SLK tests. It creates a `slk::Builder` that can be written +/// to. After you have written the data to the builder, you can get a +/// `slk::Reader` and try to decode the encoded data. +class Loopback { + public: + ~Loopback() { + CHECK(builder_) << "You haven't created a builder!"; + CHECK(reader_) << "You haven't created a reader!"; + reader_->Finalize(); + } + + slk::Builder *GetBuilder() { + CHECK(!builder_) << "You have already allocated a builder!"; + builder_ = std::make_unique<slk::Builder>( + [this](const uint8_t *data, size_t size, bool have_more) { + Write(data, size, have_more); + }); + return builder_.get(); + } + + slk::Reader *GetReader() { + CHECK(builder_) << "You must first get a builder before getting a reader!"; + CHECK(!reader_) << "You have already allocated a reader!"; + builder_->Finalize(); + auto ret = slk::CheckStreamComplete(data_.data(), data_.size()); + CHECK(ret.status == slk::StreamStatus::COMPLETE); + CHECK(ret.stream_size == data_.size()); + size_ = ret.encoded_data_size; + Dump(); + reader_ = std::make_unique<slk::Reader>(data_.data(), data_.size()); + return reader_.get(); + } + + size_t size() { return size_; } + + private: + void Write(const uint8_t *data, size_t size, bool have_more) { + for (size_t i = 0; i < size; ++i) { + data_.push_back(data[i]); + } + } + + void Dump() { + std::string dump; + for (size_t i = 0; i < data_.size(); ++i) { + dump += fmt::format("{:02x}", data_[i]); + if (i != data_.size() - 1) { + dump += " "; + } + } + // This stores the encoded SLK stream into the test XML output. To get the + // data you have to specify to the test (during runtime) that it should + // create an XML output. + ::testing::Test::RecordProperty("slk_stream", dump); + } + + std::vector<uint8_t> data_; + std::unique_ptr<slk::Builder> builder_; + std::unique_ptr<slk::Reader> reader_; + size_t size_{0}; +}; + +} // namespace slk diff --git a/tests/unit/slk_core.cpp b/tests/unit/slk_core.cpp index 73dd8807e..e8fa5d9c0 100644 --- a/tests/unit/slk_core.cpp +++ b/tests/unit/slk_core.cpp @@ -1,18 +1,21 @@ #include <gtest/gtest.h> -#include "communication/rpc/serialization.hpp" +#include "slk/serialization.hpp" + +#include "slk_common.hpp" #define CREATE_PRIMITIVE_TEST(primitive_type, original_value, decoded_value) \ { \ ASSERT_NE(original_value, decoded_value); \ primitive_type original = original_value; \ - slk::Builder builder; \ - slk::Save(original, &builder); \ - ASSERT_EQ(builder.size(), sizeof(primitive_type)); \ + slk::Loopback loopback; \ + auto builder = loopback.GetBuilder(); \ + slk::Save(original, builder); \ primitive_type decoded = decoded_value; \ - slk::Reader reader(builder.data(), builder.size()); \ - slk::Load(&decoded, &reader); \ + auto reader = loopback.GetReader(); \ + slk::Load(&decoded, reader); \ ASSERT_EQ(original, decoded); \ + ASSERT_EQ(loopback.size(), sizeof(primitive_type)); \ } TEST(SlkCore, Primitive) { @@ -31,248 +34,267 @@ TEST(SlkCore, Primitive) { TEST(SlkCore, String) { std::string original = "hello world"; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), sizeof(uint64_t) + original.size()); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::string decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), sizeof(uint64_t) + original.size()); } TEST(SlkCore, VectorPrimitive) { std::vector<int> original{1, 2, 3, 4, 5}; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), sizeof(uint64_t) + original.size() * sizeof(int)); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::vector<int> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), sizeof(uint64_t) + original.size() * sizeof(int)); } TEST(SlkCore, VectorString) { std::vector<std::string> original{"hai hai hai", "nandare!"}; - slk::Builder builder; - slk::Save(original, &builder); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); uint64_t size = sizeof(uint64_t); for (const auto &item : original) { size += sizeof(uint64_t) + item.size(); } - ASSERT_EQ(builder.size(), size); std::vector<std::string> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), size); } TEST(SlkCore, SetPrimitive) { std::set<int> original{1, 2, 3, 4, 5}; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), sizeof(uint64_t) + original.size() * sizeof(int)); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::set<int> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), sizeof(uint64_t) + original.size() * sizeof(int)); } TEST(SlkCore, SetString) { std::set<std::string> original{"hai hai hai", "nandare!"}; - slk::Builder builder; - slk::Save(original, &builder); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); uint64_t size = sizeof(uint64_t); for (const auto &item : original) { size += sizeof(uint64_t) + item.size(); } - ASSERT_EQ(builder.size(), size); std::set<std::string> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), size); } TEST(SlkCore, MapPrimitive) { std::map<int, int> original{{1, 2}, {3, 4}, {5, 6}}; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), - sizeof(uint64_t) + original.size() * sizeof(int) * 2); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::map<int, int> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), + sizeof(uint64_t) + original.size() * sizeof(int) * 2); } TEST(SlkCore, MapString) { std::map<std::string, std::string> original{{"hai hai hai", "nandare!"}}; - slk::Builder builder; - slk::Save(original, &builder); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); uint64_t size = sizeof(uint64_t); for (const auto &item : original) { size += sizeof(uint64_t) + item.first.size(); size += sizeof(uint64_t) + item.second.size(); } - ASSERT_EQ(builder.size(), size); std::map<std::string, std::string> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), size); } TEST(SlkCore, UnorderedMapPrimitive) { std::unordered_map<int, int> original{{1, 2}, {3, 4}, {5, 6}}; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), - sizeof(uint64_t) + original.size() * sizeof(int) * 2); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::unordered_map<int, int> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), + sizeof(uint64_t) + original.size() * sizeof(int) * 2); } TEST(SlkCore, UnorderedMapString) { std::unordered_map<std::string, std::string> original{ {"hai hai hai", "nandare!"}}; - slk::Builder builder; - slk::Save(original, &builder); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); uint64_t size = sizeof(uint64_t); for (const auto &item : original) { size += sizeof(uint64_t) + item.first.size(); size += sizeof(uint64_t) + item.second.size(); } - ASSERT_EQ(builder.size(), size); std::unordered_map<std::string, std::string> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), size); } TEST(SlkCore, UniquePtrEmpty) { std::unique_ptr<std::string> original; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), sizeof(bool)); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::unique_ptr<std::string> decoded = std::make_unique<std::string>("nandare!"); ASSERT_NE(decoded.get(), nullptr); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(decoded.get(), nullptr); + ASSERT_EQ(loopback.size(), sizeof(bool)); } TEST(SlkCore, UniquePtrFull) { std::unique_ptr<std::string> original = std::make_unique<std::string>("nandare!"); - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), - sizeof(bool) + sizeof(uint64_t) + original.get()->size()); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::unique_ptr<std::string> decoded; ASSERT_EQ(decoded.get(), nullptr); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_NE(decoded.get(), nullptr); ASSERT_EQ(*original.get(), *decoded.get()); + ASSERT_EQ(loopback.size(), + sizeof(bool) + sizeof(uint64_t) + original.get()->size()); } TEST(SlkCore, OptionalPrimitiveEmpty) { std::optional<int> original; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), sizeof(bool)); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::optional<int> decoded = 5; ASSERT_NE(decoded, std::nullopt); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(decoded, std::nullopt); + ASSERT_EQ(loopback.size(), sizeof(bool)); } TEST(SlkCore, OptionalPrimitiveFull) { std::optional<int> original = 5; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), sizeof(bool) + sizeof(int)); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::optional<int> decoded; ASSERT_EQ(decoded, std::nullopt); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_NE(decoded, std::nullopt); ASSERT_EQ(*original, *decoded); + ASSERT_EQ(loopback.size(), sizeof(bool) + sizeof(int)); } TEST(SlkCore, OptionalStringEmpty) { std::optional<std::string> original; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), sizeof(bool)); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::optional<std::string> decoded = "nandare!"; ASSERT_NE(decoded, std::nullopt); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(decoded, std::nullopt); + ASSERT_EQ(loopback.size(), sizeof(bool)); } TEST(SlkCore, OptionalStringFull) { std::optional<std::string> original = "nandare!"; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), sizeof(bool) + sizeof(uint64_t) + original->size()); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::optional<std::string> decoded; ASSERT_EQ(decoded, std::nullopt); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_NE(decoded, std::nullopt); ASSERT_EQ(*original, *decoded); + ASSERT_EQ(loopback.size(), + sizeof(bool) + sizeof(uint64_t) + original->size()); } TEST(SlkCore, Pair) { std::pair<std::string, int> original{"nandare!", 5}; - slk::Builder builder; - slk::Save(original, &builder); - ASSERT_EQ(builder.size(), - sizeof(uint64_t) + original.first.size() + sizeof(int)); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::pair<std::string, int> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + ASSERT_EQ(loopback.size(), + sizeof(uint64_t) + original.first.size() + sizeof(int)); } TEST(SlkCore, SharedPtrEmpty) { std::shared_ptr<std::string> original; std::vector<std::string *> saved; - slk::Builder builder; - slk::Save(original, &builder, &saved); - ASSERT_EQ(builder.size(), sizeof(bool)); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder, &saved); std::shared_ptr<std::string> decoded = std::make_shared<std::string>("nandare!"); std::vector<std::shared_ptr<std::string>> loaded; ASSERT_NE(decoded.get(), nullptr); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader, &loaded); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader, &loaded); ASSERT_EQ(decoded.get(), nullptr); ASSERT_EQ(saved.size(), 0); ASSERT_EQ(loaded.size(), 0); + ASSERT_EQ(loopback.size(), sizeof(bool)); } TEST(SlkCore, SharedPtrFull) { std::shared_ptr<std::string> original = std::make_shared<std::string>("nandare!"); std::vector<std::string *> saved; - slk::Builder builder; - slk::Save(original, &builder, &saved); - ASSERT_EQ(builder.size(), - sizeof(bool) * 2 + sizeof(uint64_t) + original.get()->size()); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder, &saved); std::shared_ptr<std::string> decoded; std::vector<std::shared_ptr<std::string>> loaded; ASSERT_EQ(decoded.get(), nullptr); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader, &loaded); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader, &loaded); ASSERT_NE(decoded.get(), nullptr); ASSERT_EQ(*original.get(), *decoded.get()); ASSERT_EQ(saved.size(), 1); ASSERT_EQ(loaded.size(), 1); + ASSERT_EQ(loopback.size(), + sizeof(bool) * 2 + sizeof(uint64_t) + original.get()->size()); } TEST(SlkCore, SharedPtrMultiple) { @@ -282,31 +304,23 @@ TEST(SlkCore, SharedPtrMultiple) { std::make_shared<std::string>("hai hai hai"); std::vector<std::string *> saved; - slk::Builder builder; - slk::Save(ptr1, &builder, &saved); - slk::Save(ptr2, &builder, &saved); - slk::Save(ptr3, &builder, &saved); - slk::Save(ptr1, &builder, &saved); - slk::Save(ptr3, &builder, &saved); - - // clang-format off - ASSERT_EQ(builder.size(), - sizeof(bool) * 2 + sizeof(uint64_t) + ptr1.get()->size() + - sizeof(bool) + - sizeof(bool) * 2 + sizeof(uint64_t) + ptr3.get()->size() + - sizeof(bool) * 2 + sizeof(uint64_t) + - sizeof(bool) * 2 + sizeof(uint64_t)); - // clang-format on + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(ptr1, builder, &saved); + slk::Save(ptr2, builder, &saved); + slk::Save(ptr3, builder, &saved); + slk::Save(ptr1, builder, &saved); + slk::Save(ptr3, builder, &saved); std::shared_ptr<std::string> dec1, dec2, dec3, dec4, dec5; std::vector<std::shared_ptr<std::string>> loaded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&dec1, &reader, &loaded); - slk::Load(&dec2, &reader, &loaded); - slk::Load(&dec3, &reader, &loaded); - slk::Load(&dec4, &reader, &loaded); - slk::Load(&dec5, &reader, &loaded); + auto reader = loopback.GetReader(); + slk::Load(&dec1, reader, &loaded); + slk::Load(&dec2, reader, &loaded); + slk::Load(&dec3, reader, &loaded); + slk::Load(&dec4, reader, &loaded); + slk::Load(&dec5, reader, &loaded); ASSERT_EQ(saved.size(), 2); ASSERT_EQ(loaded.size(), 2); @@ -326,6 +340,37 @@ TEST(SlkCore, SharedPtrMultiple) { ASSERT_EQ(dec4.get(), dec1.get()); ASSERT_EQ(dec5.get(), dec3.get()); + + // clang-format off + ASSERT_EQ(loopback.size(), + sizeof(bool) * 2 + sizeof(uint64_t) + ptr1.get()->size() + + sizeof(bool) + + sizeof(bool) * 2 + sizeof(uint64_t) + ptr3.get()->size() + + sizeof(bool) * 2 + sizeof(uint64_t) + + sizeof(bool) * 2 + sizeof(uint64_t)); + // clang-format on +} + +TEST(SlkCore, SharedPtrInvalid) { + std::shared_ptr<std::string> ptr = std::make_shared<std::string>("nandare!"); + std::vector<std::string *> saved; + + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(ptr, builder, &saved); + // Here we mess with the `saved` vector to cause an invalid index to be + // written to the SLK stream so that we can check the error handling in the + // `Load` function later. + saved.insert(saved.begin(), nullptr); + // Save the pointer again with an invalid index. + slk::Save(ptr, builder, &saved); + + std::shared_ptr<std::string> dec1, dec2; + std::vector<std::shared_ptr<std::string>> loaded; + + auto reader = loopback.GetReader(); + slk::Load(&dec1, reader, &loaded); + ASSERT_THROW(slk::Load(&dec2, reader, &loaded), slk::SlkDecodeException); } TEST(SlkCore, Complex) { @@ -335,24 +380,25 @@ TEST(SlkCore, Complex) { original.get()->push_back(std::nullopt); original.get()->push_back("hai hai hai"); - slk::Builder builder; - slk::Save(original, &builder); + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); + + std::unique_ptr<std::vector<std::optional<std::string>>> decoded; + ASSERT_EQ(decoded.get(), nullptr); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); + ASSERT_NE(decoded.get(), nullptr); + ASSERT_EQ(*original.get(), *decoded.get()); // clang-format off - ASSERT_EQ(builder.size(), + ASSERT_EQ(loopback.size(), sizeof(bool) + sizeof(uint64_t) + sizeof(bool) + sizeof(uint64_t) + (*original.get())[0]->size() + sizeof(bool) + sizeof(bool) + sizeof(uint64_t) + (*original.get())[2]->size()); // clang-format on - - std::unique_ptr<std::vector<std::optional<std::string>>> decoded; - ASSERT_EQ(decoded.get(), nullptr); - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); - ASSERT_NE(decoded.get(), nullptr); - ASSERT_EQ(*original.get(), *decoded.get()); } struct Foo { @@ -381,20 +427,21 @@ TEST(SlkCore, VectorStruct) { original.push_back({"hai hai hai", 5}); original.push_back({"nandare!", std::nullopt}); - slk::Builder builder; - slk::Save(original, &builder); - - // clang-format off - ASSERT_EQ(builder.size(), - sizeof(uint64_t) + - sizeof(uint64_t) + original[0].name.size() + sizeof(bool) + sizeof(int) + - sizeof(uint64_t) + original[1].name.size() + sizeof(bool)); - // clang-format on + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save(original, builder); std::vector<Foo> decoded; - slk::Reader reader(builder.data(), builder.size()); - slk::Load(&decoded, &reader); + auto reader = loopback.GetReader(); + slk::Load(&decoded, reader); ASSERT_EQ(original, decoded); + + // clang-format off + ASSERT_EQ(loopback.size(), + sizeof(uint64_t) + + sizeof(uint64_t) + original[0].name.size() + sizeof(bool) + +sizeof(int) + sizeof(uint64_t) + original[1].name.size() + sizeof(bool)); + // clang-format on } TEST(SlkCore, VectorSharedPtr) { @@ -412,18 +459,19 @@ TEST(SlkCore, VectorSharedPtr) { original.push_back(ptr1); original.push_back(ptr3); - slk::Builder builder; + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); slk::Save<std::shared_ptr<std::string>>( - original, &builder, [&saved](const auto &item, auto *builder) { + original, builder, [&saved](const auto &item, auto *builder) { Save(item, builder, &saved); }); std::vector<std::shared_ptr<std::string>> decoded; std::vector<std::shared_ptr<std::string>> loaded; - slk::Reader reader(builder.data(), builder.size()); + auto reader = loopback.GetReader(); slk::Load<std::shared_ptr<std::string>>( - &decoded, &reader, + &decoded, reader, [&loaded](auto *item, auto *reader) { Load(item, reader, &loaded); }); ASSERT_EQ(decoded.size(), original.size()); @@ -448,18 +496,19 @@ TEST(SlkCore, OptionalSharedPtr) { std::make_shared<std::string>("nandare!"); std::vector<std::string *> saved; - slk::Builder builder; + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); slk::Save<std::shared_ptr<std::string>>( - original, &builder, [&saved](const auto &item, auto *builder) { + original, builder, [&saved](const auto &item, auto *builder) { Save(item, builder, &saved); }); std::optional<std::shared_ptr<std::string>> decoded; std::vector<std::shared_ptr<std::string>> loaded; - slk::Reader reader(builder.data(), builder.size()); + auto reader = loopback.GetReader(); slk::Load<std::shared_ptr<std::string>>( - &decoded, &reader, + &decoded, reader, [&loaded](auto *item, auto *reader) { Load(item, reader, &loaded); }); ASSERT_NE(decoded, std::nullopt); @@ -469,3 +518,28 @@ TEST(SlkCore, OptionalSharedPtr) { ASSERT_EQ(*decoded->get(), *original->get()); } + +TEST(SlkCore, OptionalSharedPtrEmpty) { + std::optional<std::shared_ptr<std::string>> original; + std::vector<std::string *> saved; + + slk::Loopback loopback; + auto builder = loopback.GetBuilder(); + slk::Save<std::shared_ptr<std::string>>( + original, builder, [&saved](const auto &item, auto *builder) { + Save(item, builder, &saved); + }); + + std::optional<std::shared_ptr<std::string>> decoded; + std::vector<std::shared_ptr<std::string>> loaded; + + auto reader = loopback.GetReader(); + slk::Load<std::shared_ptr<std::string>>( + &decoded, reader, + [&loaded](auto *item, auto *reader) { Load(item, reader, &loaded); }); + + ASSERT_EQ(decoded, std::nullopt); + + ASSERT_EQ(saved.size(), 0); + ASSERT_EQ(loaded.size(), 0); +} diff --git a/tests/unit/slk_streams.cpp b/tests/unit/slk_streams.cpp new file mode 100644 index 000000000..74eb98b7d --- /dev/null +++ b/tests/unit/slk_streams.cpp @@ -0,0 +1,446 @@ +#include <gtest/gtest.h> + +#include <cstring> +#include <memory> +#include <random> +#include <vector> + +#include "slk/streams.hpp" + +class BinaryData { + public: + BinaryData(const uint8_t *data, size_t size) + : data_(new uint8_t[size]), size_(size) { + memcpy(data_.get(), data, size); + } + + BinaryData(std::unique_ptr<uint8_t[]> data, size_t size) + : data_(std::move(data)), size_(size) {} + + const uint8_t *data() const { return data_.get(); } + size_t size() const { return size_; } + + bool operator==(const BinaryData &other) const { + if (size_ != other.size_) return false; + for (size_t i = 0; i < size_; ++i) { + if (data_[i] != other.data_[i]) return false; + } + return true; + } + + private: + std::unique_ptr<uint8_t[]> data_; + size_t size_; +}; + +BinaryData operator+(const BinaryData &a, const BinaryData &b) { + std::unique_ptr<uint8_t[]> data(new uint8_t[a.size() + b.size()]); + memcpy(data.get(), a.data(), a.size()); + memcpy(data.get() + a.size(), b.data(), b.size()); + return BinaryData(std::move(data), a.size() + b.size()); +} + +BinaryData GetRandomData(size_t size) { + std::mt19937 gen(std::random_device{}()); + std::uniform_int_distribution<uint8_t> dis(0, 255); + std::unique_ptr<uint8_t[]> ret(new uint8_t[size]); + auto data = ret.get(); + for (size_t i = 0; i < size; ++i) { + data[i] = dis(gen); + } + return BinaryData(std::move(ret), size); +} + +std::vector<BinaryData> BufferToBinaryData(const uint8_t *data, size_t size, + std::vector<size_t> sizes) { + std::vector<BinaryData> ret; + ret.reserve(sizes.size()); + size_t pos = 0; + for (size_t i = 0; i < sizes.size(); ++i) { + EXPECT_GE(size, pos + sizes[i]); + ret.push_back({data + pos, sizes[i]}); + pos += sizes[i]; + } + return ret; +} + +BinaryData SizeToBinaryData(slk::SegmentSize size) { + return BinaryData(reinterpret_cast<const uint8_t *>(&size), + sizeof(slk::SegmentSize)); +} + +TEST(Builder, SingleSegment) { + std::vector<uint8_t> buffer; + slk::Builder builder( + [&buffer](const uint8_t *data, size_t size, bool have_more) { + for (size_t i = 0; i < size; ++i) buffer.push_back(data[i]); + }); + + auto input = GetRandomData(5); + builder.Save(input.data(), input.size()); + builder.Finalize(); + + ASSERT_EQ(buffer.size(), input.size() + 2 * sizeof(slk::SegmentSize)); + + auto splits = BufferToBinaryData( + buffer.data(), buffer.size(), + {sizeof(slk::SegmentSize), input.size(), sizeof(slk::SegmentSize)}); + + auto header_expected = SizeToBinaryData(input.size()); + ASSERT_EQ(splits[0], header_expected); + + ASSERT_EQ(splits[1], input); + + auto footer_expected = SizeToBinaryData(0); + ASSERT_EQ(splits[2], footer_expected); +} + +TEST(Builder, MultipleSegments) { + std::vector<uint8_t> buffer; + slk::Builder builder( + [&buffer](const uint8_t *data, size_t size, bool have_more) { + for (size_t i = 0; i < size; ++i) buffer.push_back(data[i]); + }); + + auto input = GetRandomData(slk::kSegmentMaxDataSize + 100); + builder.Save(input.data(), input.size()); + builder.Finalize(); + + ASSERT_EQ(buffer.size(), input.size() + 3 * sizeof(slk::SegmentSize)); + + auto splits = BufferToBinaryData( + buffer.data(), buffer.size(), + {sizeof(slk::SegmentSize), slk::kSegmentMaxDataSize, + sizeof(slk::SegmentSize), input.size() - slk::kSegmentMaxDataSize, + sizeof(slk::SegmentSize)}); + + auto datas = BufferToBinaryData( + input.data(), input.size(), + {slk::kSegmentMaxDataSize, input.size() - slk::kSegmentMaxDataSize}); + + auto header1_expected = SizeToBinaryData(slk::kSegmentMaxDataSize); + ASSERT_EQ(splits[0], header1_expected); + + ASSERT_EQ(splits[1], datas[0]); + + auto header2_expected = + SizeToBinaryData(input.size() - slk::kSegmentMaxDataSize); + ASSERT_EQ(splits[2], header2_expected); + + ASSERT_EQ(splits[3], datas[1]); + + auto footer_expected = SizeToBinaryData(0); + ASSERT_EQ(splits[4], footer_expected); +} + +TEST(Reader, SingleSegment) { + std::vector<uint8_t> buffer; + slk::Builder builder( + [&buffer](const uint8_t *data, size_t size, bool have_more) { + for (size_t i = 0; i < size; ++i) buffer.push_back(data[i]); + }); + + auto input = GetRandomData(5); + builder.Save(input.data(), input.size()); + builder.Finalize(); + + // test with missing data + for (size_t i = 0; i < buffer.size(); ++i) { + slk::Reader reader(buffer.data(), i); + uint8_t block[slk::kSegmentMaxDataSize]; + ASSERT_THROW( + { + reader.Load(block, input.size()); + reader.Finalize(); + }, + slk::SlkReaderException); + } + + // test with complete data + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize]; + reader.Load(block, input.size()); + reader.Finalize(); + auto output = BinaryData(block, input.size()); + ASSERT_EQ(output, input); + } + + // test with leftover data + { + auto extended_buffer = + BinaryData(buffer.data(), buffer.size()) + GetRandomData(5); + slk::Reader reader(extended_buffer.data(), extended_buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize]; + reader.Load(block, input.size()); + reader.Finalize(); + auto output = BinaryData(block, input.size()); + ASSERT_EQ(output, input); + } + + // read more data than there is in the stream + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize]; + ASSERT_THROW(reader.Load(block, slk::kSegmentMaxDataSize), + slk::SlkReaderException); + } + + // don't consume all data from the stream + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize]; + reader.Load(block, input.size() / 2); + ASSERT_THROW(reader.Finalize(), slk::SlkReaderException); + } + + // read data with several loads + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize]; + for (size_t i = 0; i < input.size(); ++i) { + reader.Load(block + i, 1); + } + reader.Finalize(); + auto output = BinaryData(block, input.size()); + ASSERT_EQ(output, input); + } + + // modify the end mark + buffer[buffer.size() - 1] = 1; + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize]; + reader.Load(block, input.size()); + ASSERT_THROW(reader.Finalize(), slk::SlkReaderException); + } +} + +TEST(Reader, MultipleSegments) { + std::vector<uint8_t> buffer; + slk::Builder builder( + [&buffer](const uint8_t *data, size_t size, bool have_more) { + for (size_t i = 0; i < size; ++i) buffer.push_back(data[i]); + }); + + auto input = GetRandomData(slk::kSegmentMaxDataSize + 100); + builder.Save(input.data(), input.size()); + builder.Finalize(); + + // test with missing data + for (size_t i = 0; i < buffer.size(); ++i) { + slk::Reader reader(buffer.data(), i); + uint8_t block[slk::kSegmentMaxDataSize * 2]; + ASSERT_THROW( + { + reader.Load(block, input.size()); + reader.Finalize(); + }, + slk::SlkReaderException); + } + + // test with complete data + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize * 2]; + reader.Load(block, input.size()); + reader.Finalize(); + auto output = BinaryData(block, input.size()); + ASSERT_EQ(output, input); + } + + // test with leftover data + { + auto extended_buffer = + BinaryData(buffer.data(), buffer.size()) + GetRandomData(5); + slk::Reader reader(extended_buffer.data(), extended_buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize * 2]; + reader.Load(block, input.size()); + reader.Finalize(); + auto output = BinaryData(block, input.size()); + ASSERT_EQ(output, input); + } + + // read more data than there is in the stream + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize * 2]; + ASSERT_THROW(reader.Load(block, slk::kSegmentMaxDataSize * 2), + slk::SlkReaderException); + } + + // don't consume all data from the stream + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize * 2]; + reader.Load(block, input.size() / 2); + ASSERT_THROW(reader.Finalize(), slk::SlkReaderException); + } + + // read data with several loads + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize * 2]; + for (size_t i = 0; i < input.size(); ++i) { + reader.Load(block + i, 1); + } + reader.Finalize(); + auto output = BinaryData(block, input.size()); + ASSERT_EQ(output, input); + } + + // modify the end mark + buffer[buffer.size() - 1] = 1; + { + slk::Reader reader(buffer.data(), buffer.size()); + uint8_t block[slk::kSegmentMaxDataSize * 2]; + reader.Load(block, input.size()); + ASSERT_THROW(reader.Finalize(), slk::SlkReaderException); + } +} + +TEST(CheckStreamComplete, SingleSegment) { + std::vector<uint8_t> buffer; + slk::Builder builder( + [&buffer](const uint8_t *data, size_t size, bool have_more) { + for (size_t i = 0; i < size; ++i) buffer.push_back(data[i]); + }); + + auto input = GetRandomData(5); + builder.Save(input.data(), input.size()); + builder.Finalize(); + + // test with missing data + for (size_t i = 0; i < sizeof(slk::SegmentSize); ++i) { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), i); + ASSERT_EQ(status, slk::StreamStatus::PARTIAL); + ASSERT_EQ(stream_size, slk::kSegmentMaxTotalSize); + ASSERT_EQ(data_size, 0); + } + for (size_t i = sizeof(slk::SegmentSize); + i < sizeof(slk::SegmentSize) + input.size(); ++i) { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), i); + ASSERT_EQ(status, slk::StreamStatus::PARTIAL); + ASSERT_EQ(stream_size, + slk::kSegmentMaxTotalSize + sizeof(slk::SegmentSize)); + ASSERT_EQ(data_size, 0); + } + for (size_t i = sizeof(slk::SegmentSize) + input.size(); i < buffer.size(); + ++i) { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), i); + ASSERT_EQ(status, slk::StreamStatus::PARTIAL); + ASSERT_EQ(stream_size, slk::kSegmentMaxTotalSize + + sizeof(slk::SegmentSize) + input.size()); + ASSERT_EQ(data_size, input.size()); + } + + // test with complete data + { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), buffer.size()); + ASSERT_EQ(status, slk::StreamStatus::COMPLETE); + ASSERT_EQ(stream_size, buffer.size()); + ASSERT_EQ(data_size, input.size()); + } + + // test with leftover data + { + auto extended_buffer = + BinaryData(buffer.data(), buffer.size()) + GetRandomData(5); + auto [status, stream_size, data_size] = slk::CheckStreamComplete( + extended_buffer.data(), extended_buffer.size()); + ASSERT_EQ(status, slk::StreamStatus::COMPLETE); + ASSERT_EQ(stream_size, buffer.size()); + ASSERT_EQ(data_size, input.size()); + } +} + +TEST(CheckStreamComplete, MultipleSegments) { + std::vector<uint8_t> buffer; + slk::Builder builder( + [&buffer](const uint8_t *data, size_t size, bool have_more) { + for (size_t i = 0; i < size; ++i) buffer.push_back(data[i]); + }); + + auto input = GetRandomData(slk::kSegmentMaxDataSize + 100); + builder.Save(input.data(), input.size()); + builder.Finalize(); + + // test with missing data + for (size_t i = 0; i < sizeof(slk::SegmentSize); ++i) { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), i); + ASSERT_EQ(status, slk::StreamStatus::PARTIAL); + ASSERT_EQ(stream_size, slk::kSegmentMaxTotalSize); + ASSERT_EQ(data_size, 0); + } + for (size_t i = sizeof(slk::SegmentSize); + i < sizeof(slk::SegmentSize) + slk::kSegmentMaxDataSize; ++i) { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), i); + ASSERT_EQ(status, slk::StreamStatus::PARTIAL); + ASSERT_EQ(stream_size, + slk::kSegmentMaxTotalSize + sizeof(slk::SegmentSize)); + ASSERT_EQ(data_size, 0); + } + for (size_t i = sizeof(slk::SegmentSize) + slk::kSegmentMaxDataSize; + i < sizeof(slk::SegmentSize) * 2 + slk::kSegmentMaxDataSize; ++i) { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), i); + ASSERT_EQ(status, slk::StreamStatus::PARTIAL); + ASSERT_EQ(stream_size, sizeof(slk::SegmentSize) + slk::kSegmentMaxDataSize + + slk::kSegmentMaxTotalSize); + ASSERT_EQ(data_size, slk::kSegmentMaxDataSize); + } + for (size_t i = sizeof(slk::SegmentSize) * 2 + slk::kSegmentMaxDataSize; + i < sizeof(slk::SegmentSize) * 2 + input.size(); ++i) { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), i); + ASSERT_EQ(status, slk::StreamStatus::PARTIAL); + ASSERT_EQ(stream_size, sizeof(slk::SegmentSize) * 2 + + slk::kSegmentMaxDataSize + + slk::kSegmentMaxTotalSize); + ASSERT_EQ(data_size, slk::kSegmentMaxDataSize); + } + for (size_t i = sizeof(slk::SegmentSize) * 2 + input.size(); + i < buffer.size(); ++i) { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), i); + ASSERT_EQ(status, slk::StreamStatus::PARTIAL); + ASSERT_EQ(stream_size, slk::kSegmentMaxTotalSize + + sizeof(slk::SegmentSize) * 2 + input.size()); + ASSERT_EQ(data_size, input.size()); + } + + // test with complete data + { + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(buffer.data(), buffer.size()); + ASSERT_EQ(status, slk::StreamStatus::COMPLETE); + ASSERT_EQ(stream_size, buffer.size()); + ASSERT_EQ(data_size, input.size()); + } + + // test with leftover data + { + auto extended_buffer = + BinaryData(buffer.data(), buffer.size()) + GetRandomData(5); + auto [status, stream_size, data_size] = slk::CheckStreamComplete( + extended_buffer.data(), extended_buffer.size()); + ASSERT_EQ(status, slk::StreamStatus::COMPLETE); + ASSERT_EQ(stream_size, buffer.size()); + ASSERT_EQ(data_size, input.size()); + } +} + +TEST(CheckStreamComplete, InvalidSegment) { + auto input = SizeToBinaryData(0); + auto [status, stream_size, data_size] = + slk::CheckStreamComplete(input.data(), input.size()); + ASSERT_EQ(status, slk::StreamStatus::INVALID); + ASSERT_EQ(stream_size, 0); + ASSERT_EQ(data_size, 0); +}