clang format has been run on all hpp and cpp files under src and tests

This commit is contained in:
Marko Budiselic 2017-02-18 11:54:37 +01:00
parent e7f5bd4c21
commit d4e3c4cd10
386 changed files with 19085 additions and 20574 deletions

12
format.sh Executable file
View File

@ -0,0 +1,12 @@
#!/bin/bash
# This scrips runs clang-format recursively on all files under specified
# directories. Formatting configuration is defined in .clang-format.
clang_format="clang-format"
for directory in src tests
do
echo "formatting code under $directory/"
find "$directory" \( -name '*.hpp' -or -name '*.cpp' \) -print0 | xargs -0 "${clang_format}" -i
done

View File

@ -1,9 +1,8 @@
#pragma once
#include "io/network/socket.hpp"
#include "communication/bolt/v1/serialization/record_stream.hpp"
#include "io/network/socket.hpp"
namespace communication
{
using OutputStream = bolt::RecordStream<io::Socket>;
namespace communication {
using OutputStream = bolt::RecordStream<io::Socket>;
}

View File

@ -2,26 +2,18 @@
#include "communication/bolt/v1/session.hpp"
namespace bolt
{
namespace bolt {
Bolt::Bolt()
{
}
Session* Bolt::create_session(io::Socket&& socket)
{
// TODO fix session lifecycle handling
// dangling pointers are not cool :)
// TODO attach currently active Db
return new Session(std::forward<io::Socket>(socket), *this);
}
void Bolt::close(Session* session)
{
session->socket.close();
Bolt::Bolt() {}
Session* Bolt::create_session(io::Socket&& socket) {
// TODO fix session lifecycle handling
// dangling pointers are not cool :)
// TODO attach currently active Db
return new Session(std::forward<io::Socket>(socket), *this);
}
void Bolt::close(Session* session) { session->socket.close(); }
}

View File

@ -4,22 +4,20 @@
#include "dbms/dbms.hpp"
#include "io/network/socket.hpp"
namespace bolt
{
namespace bolt {
class Session;
class Bolt
{
friend class Session;
class Bolt {
friend class Session;
public:
Bolt();
public:
Bolt();
Session *create_session(io::Socket &&socket);
void close(Session *session);
Session *create_session(io::Socket &&socket);
void close(Session *session);
States states;
Dbms dbms;
States states;
Dbms dbms;
};
}

View File

@ -2,16 +2,13 @@
#include <cstddef>
namespace bolt
{
namespace bolt {
namespace config
{
/** chunk size */
static constexpr size_t N = 65535;
namespace config {
/** chunk size */
static constexpr size_t N = 65535;
/** end mark */
static constexpr size_t C = N + 2;
/** end mark */
static constexpr size_t C = N + 2;
}
}

View File

@ -3,44 +3,36 @@
#include "utils/types/byte.hpp"
#include "utils/underlying_cast.hpp"
namespace bolt
{
namespace bolt {
enum class MessageCode : byte
{
Init = 0x01,
AckFailure = 0x0E,
Reset = 0x0F,
enum class MessageCode : byte {
Init = 0x01,
AckFailure = 0x0E,
Reset = 0x0F,
Run = 0x10,
DiscardAll = 0x2F,
PullAll = 0x3F,
Run = 0x10,
DiscardAll = 0x2F,
PullAll = 0x3F,
Record = 0x71,
Success = 0x70,
Ignored = 0x7E,
Failure = 0x7F
Record = 0x71,
Success = 0x70,
Ignored = 0x7E,
Failure = 0x7F
};
inline bool operator==(byte value, MessageCode code)
{
return value == underlying_cast(code);
inline bool operator==(byte value, MessageCode code) {
return value == underlying_cast(code);
}
inline bool operator==(MessageCode code, byte value)
{
return operator==(value, code);
inline bool operator==(MessageCode code, byte value) {
return operator==(value, code);
}
inline bool operator!=(byte value, MessageCode code)
{
return !operator==(value, code);
inline bool operator!=(byte value, MessageCode code) {
return !operator==(value, code);
}
inline bool operator!=(MessageCode code, byte value)
{
return operator!=(value, code);
inline bool operator!=(MessageCode code, byte value) {
return operator!=(value, code);
}
}

View File

@ -2,65 +2,57 @@
#include <cstdint>
namespace bolt
{
namespace bolt {
namespace pack
{
namespace pack {
enum Code : uint8_t
{
TinyString = 0x80,
TinyList = 0x90,
TinyMap = 0xA0,
enum Code : uint8_t {
TinyString = 0x80,
TinyList = 0x90,
TinyMap = 0xA0,
TinyStruct = 0xB0,
StructOne = 0xB1,
StructTwo = 0xB2,
TinyStruct = 0xB0,
StructOne = 0xB1,
StructTwo = 0xB2,
Null = 0xC0,
Null = 0xC0,
Float64 = 0xC1,
Float64 = 0xC1,
False = 0xC2,
True = 0xC3,
False = 0xC2,
True = 0xC3,
Int8 = 0xC8,
Int16 = 0xC9,
Int32 = 0xCA,
Int64 = 0xCB,
Int8 = 0xC8,
Int16 = 0xC9,
Int32 = 0xCA,
Int64 = 0xCB,
Bytes8 = 0xCC,
Bytes16 = 0xCD,
Bytes32 = 0xCE,
Bytes8 = 0xCC,
Bytes16 = 0xCD,
Bytes32 = 0xCE,
String8 = 0xD0,
String16 = 0xD1,
String32 = 0xD2,
String8 = 0xD0,
String16 = 0xD1,
String32 = 0xD2,
List8 = 0xD4,
List16 = 0xD5,
List32 = 0xD6,
List8 = 0xD4,
List16 = 0xD5,
List32 = 0xD6,
Map8 = 0xD8,
Map16 = 0xD9,
Map32 = 0xDA,
MapStream = 0xDB,
Map8 = 0xD8,
Map16 = 0xD9,
Map32 = 0xDA,
MapStream = 0xDB,
Node = 0x4E,
Relationship = 0x52,
Path = 0x50,
Node = 0x4E,
Relationship = 0x52,
Path = 0x50,
Struct8 = 0xDC,
Struct16 = 0xDD,
EndOfStream = 0xDF,
};
enum Rule : uint8_t
{
MaxInitStructSize = 0x02
Struct8 = 0xDC,
Struct16 = 0xDD,
EndOfStream = 0xDF,
};
enum Rule : uint8_t { MaxInitStructSize = 0x02 };
}
}

View File

@ -1,42 +1,39 @@
#pragma once
namespace bolt
{
namespace bolt {
enum class PackType
{
/** denotes absence of a value */
Null,
enum class PackType {
/** denotes absence of a value */
Null,
/** denotes a type with two possible values (t/f) */
Boolean,
/** denotes a type with two possible values (t/f) */
Boolean,
/** 64-bit signed integral number */
Integer,
/** 64-bit signed integral number */
Integer,
/** 64-bit floating point number */
Float,
/** 64-bit floating point number */
Float,
/** binary data */
Bytes,
/** binary data */
Bytes,
/** unicode string */
String,
/** unicode string */
String,
/** collection of values */
List,
/** collection of values */
List,
/** collection of zero or more key/value pairs */
Map,
/** collection of zero or more key/value pairs */
Map,
/** zero or more packstream values */
Struct,
/** zero or more packstream values */
Struct,
/** denotes stream value end */
EndOfStream,
/** denotes stream value end */
EndOfStream,
/** reserved for future use */
Reserved
/** reserved for future use */
Reserved
};
}

View File

@ -8,9 +8,8 @@
#include "database/graph_db.hpp"
#include "storage/typed_value_store.hpp"
template<class Stream>
template <class Stream>
void bolt::BoltSerializer<Stream>::write(const VertexAccessor &vertex) {
// write signatures for the node struct and node data type
encoder.write_struct_header(3);
encoder.write(underlying_cast(pack::Node));
@ -19,7 +18,7 @@ void bolt::BoltSerializer<Stream>::write(const VertexAccessor &vertex) {
// use internal IDs, but need to give something to Bolt
// note that OpenCypher has no id(x) function, so the client
// should not be able to do anything with this value anyway
encoder.write_integer(0); // uID
encoder.write_integer(0); // uID
// write the list of labels
auto labels = vertex.labels();
@ -30,13 +29,14 @@ void bolt::BoltSerializer<Stream>::write(const VertexAccessor &vertex) {
// write the properties
const TypedValueStore<GraphDb::Property> &props = vertex.Properties();
encoder.write_map_header(props.size());
props.Accept([this, &vertex](const GraphDb::Property prop, const TypedValue &value) {
this->encoder.write(vertex.db_accessor().property_name(prop));
this->write(value);
});
props.Accept(
[this, &vertex](const GraphDb::Property prop, const TypedValue &value) {
this->encoder.write(vertex.db_accessor().property_name(prop));
this->write(value);
});
}
template<class Stream>
template <class Stream>
void bolt::BoltSerializer<Stream>::write(const EdgeAccessor &edge) {
// write signatures for the edge struct and edge data type
encoder.write_struct_header(5);
@ -54,7 +54,7 @@ void bolt::BoltSerializer<Stream>::write(const EdgeAccessor &edge) {
encoder.write(edge.db_accessor().edge_type_name(edge.edge_type()));
// write the property map
const TypedValueStore<GraphDb::Property>& props = edge.Properties();
const TypedValueStore<GraphDb::Property> &props = edge.Properties();
encoder.write_map_header(props.size());
props.Accept([this, &edge](GraphDb::Property prop, const TypedValue &value) {
this->encoder.write(edge.db_accessor().property_name(prop));
@ -62,9 +62,8 @@ void bolt::BoltSerializer<Stream>::write(const EdgeAccessor &edge) {
});
}
template<class Stream>
void bolt::BoltSerializer<Stream>::write(const TypedValue& value) {
template <class Stream>
void bolt::BoltSerializer<Stream>::write(const TypedValue &value) {
switch (value.type_) {
case TypedValue::Type::Null:
encoder.write_null();
@ -84,8 +83,9 @@ void bolt::BoltSerializer<Stream>::write(const TypedValue& value) {
}
}
template<class Stream>
void bolt::BoltSerializer<Stream>::write_failure(const std::map<std::string, std::string> &data) {
template <class Stream>
void bolt::BoltSerializer<Stream>::write_failure(
const std::map<std::string, std::string> &data) {
encoder.message_failure();
encoder.write_map_header(data.size());
for (auto const &kv : data) {
@ -94,6 +94,5 @@ void bolt::BoltSerializer<Stream>::write_failure(const std::map<std::string, std
}
}
template
class bolt::BoltSerializer<bolt::BoltEncoder<
template class bolt::BoltSerializer<bolt::BoltEncoder<
bolt::ChunkedEncoder<bolt::ChunkedBuffer<bolt::SocketStream<io::Socket>>>>>;

View File

@ -10,48 +10,49 @@
namespace bolt {
template<class Stream>
class BoltSerializer {
template <class Stream>
class BoltSerializer {
public:
BoltSerializer(Stream &stream) : encoder(stream) {}
public:
BoltSerializer(Stream &stream) : encoder(stream) {}
/** Serializes the vertex accessor into the packstream format
*
* struct[size = 3] Vertex [signature = 0x4E] {
* Integer node_id;
* List<String> labels;
* Map<String, Value> properties;
* }
*
*/
void write(const VertexAccessor &vertex);
/** Serializes the vertex accessor into the packstream format
*
* struct[size = 3] Vertex [signature = 0x4E] {
* Integer node_id;
* List<String> labels;
* Map<String, Value> properties;
* }
*
*/
void write(const VertexAccessor &vertex);
/** Serializes the edge accessor into the packstream format
*
* struct[size = 5] Edge [signature = 0x52] {
* Integer edge_id; // IMPORTANT: always 0 since we
* don't do IDs
* Integer start_node_id; // IMPORTANT: always 0 since we
* don't do IDs
* Integer end_node_id; // IMPORTANT: always 0 since we
* don't do IDs
* String type;
* Map<String, Value> properties;
* }
*
*/
void write(const EdgeAccessor &edge);
/** Serializes the edge accessor into the packstream format
*
* struct[size = 5] Edge [signature = 0x52] {
* Integer edge_id; // IMPORTANT: always 0 since we don't do IDs
* Integer start_node_id; // IMPORTANT: always 0 since we don't do IDs
* Integer end_node_id; // IMPORTANT: always 0 since we don't do IDs
* String type;
* Map<String, Value> properties;
* }
*
*/
void write(const EdgeAccessor &edge);
// TODO document
void write_failure(const std::map<std::string, std::string> &data);
// TODO document
void write_failure(const std::map<std::string, std::string> &data);
/**
* Writes a TypedValue (typically a property value in the edge or vertex).
*
* @param value The value to write.
*/
void write(const TypedValue &value);
/**
* Writes a TypedValue (typically a property value in the edge or vertex).
*
* @param value The value to write.
*/
void write(const TypedValue& value);
protected:
Stream &encoder;
};
protected:
Stream &encoder;
};
}

View File

@ -13,152 +13,148 @@ namespace bolt {
* compiled queries have to use this class in order to return results
* query code should not know about bolt protocol
*/
template<class Socket>
class RecordStream {
public:
RecordStream(Socket &socket) : socket(socket) {
logger = logging::log->logger("Record Stream");
template <class Socket>
class RecordStream {
public:
RecordStream(Socket &socket) : socket(socket) {
logger = logging::log->logger("Record Stream");
}
~RecordStream() = default;
// TODO: create apstract methods that are not bolt specific ---------------
void write_success() {
logger.trace("write_success");
bolt_encoder.message_success();
}
void write_success_empty() {
logger.trace("write_success_empty");
bolt_encoder.message_success_empty();
}
void write_ignored() {
logger.trace("write_ignored");
bolt_encoder.message_ignored();
}
void write_empty_fields() {
bolt_encoder.message_success();
bolt_encoder.write_map_header(1);
bolt_encoder.write_string("fields");
write_list_header(0);
chunk();
}
void write_fields(const std::vector<std::string> &fields) {
// TODO: that should be one level below?
bolt_encoder.message_success();
bolt_encoder.write_map_header(1);
bolt_encoder.write_string("fields");
write_list_header(fields.size());
for (auto &name : fields) {
bolt_encoder.write_string(name);
}
~RecordStream() = default;
chunk();
send();
}
// TODO: create apstract methods that are not bolt specific ---------------
void write_success() {
logger.trace("write_success");
bolt_encoder.message_success();
}
void write_field(const std::string &field) {
bolt_encoder.message_success();
bolt_encoder.write_map_header(1);
bolt_encoder.write_string("fields");
write_list_header(1);
bolt_encoder.write_string(field);
chunk();
send();
}
void write_success_empty() {
logger.trace("write_success_empty");
bolt_encoder.message_success_empty();
}
void write_list_header(size_t size) { bolt_encoder.write_list_header(size); }
void write_ignored() {
logger.trace("write_ignored");
bolt_encoder.message_ignored();
}
void write_record() { bolt_encoder.message_record(); }
void write_empty_fields() {
bolt_encoder.message_success();
bolt_encoder.write_map_header(1);
bolt_encoder.write_string("fields");
write_list_header(0);
chunk();
}
// writes metadata at the end of the message
// TODO: write whole implementation (currently, only type is supported)
// { "stats": { "nodes created": 1, "properties set": 1},
// "type": "r" | "rw" | ...
void write_meta(const std::string &type) {
bolt_encoder.message_success();
bolt_encoder.write_map_header(1);
bolt_encoder.write_string("type");
bolt_encoder.write_string(type);
chunk();
}
void write_fields(const std::vector<std::string> &fields) {
// TODO: that should be one level below?
bolt_encoder.message_success();
void write_failure(const std::map<std::string, std::string> &data) {
serializer.write_failure(data);
chunk();
}
bolt_encoder.write_map_header(1);
bolt_encoder.write_string("fields");
write_list_header(fields.size());
void write_count(const size_t count) {
write_record();
write_list_header(1);
write(count);
chunk();
}
for (auto &name : fields) {
bolt_encoder.write_string(name);
}
void write(const VertexAccessor &vertex) { serializer.write(vertex); }
chunk();
send();
}
void write_vertex_record(const VertexAccessor &va) {
write_record();
write_list_header(1);
write(va);
chunk();
}
void write_field(const std::string &field) {
bolt_encoder.message_success();
bolt_encoder.write_map_header(1);
bolt_encoder.write_string("fields");
write_list_header(1);
bolt_encoder.write_string(field);
chunk();
send();
}
void write(const EdgeAccessor &edge) { serializer.write(edge); }
void write_list_header(size_t size) {
bolt_encoder.write_list_header(size);
}
void write_edge_record(const EdgeAccessor &ea) {
write_record();
write_list_header(1);
write(ea);
chunk();
}
void write_record() { bolt_encoder.message_record(); }
void write(const TypedValue &value) { serializer.write(value); }
// writes metadata at the end of the message
// TODO: write whole implementation (currently, only type is supported)
// { "stats": { "nodes created": 1, "properties set": 1},
// "type": "r" | "rw" | ...
void write_meta(const std::string &type) {
bolt_encoder.message_success();
bolt_encoder.write_map_header(1);
bolt_encoder.write_string("type");
bolt_encoder.write_string(type);
chunk();
}
void send() { chunked_buffer.flush(); }
void write_failure(const std::map<std::string, std::string> &data) {
serializer.write_failure(data);
chunk();
}
void chunk() { chunked_encoder.write_chunk(); }
void write_count(const size_t count) {
write_record();
write_list_header(1);
write(count);
chunk();
}
// TODO WTF is this test doing here?
void _write_test() {
logger.trace("write_test");
void write(const VertexAccessor &vertex) { serializer.write(vertex); }
write_fields({{"name"}});
void write_vertex_record(const VertexAccessor &va) {
write_record();
write_list_header(1);
write(va);
chunk();
}
write_record();
write_list_header(1);
bolt_encoder.write("max");
void write(const EdgeAccessor &edge) { serializer.write(edge); }
write_record();
write_list_header(1);
bolt_encoder.write("paul");
void write_edge_record(const EdgeAccessor &ea) {
write_record();
write_list_header(1);
write(ea);
chunk();
}
write_success_empty();
}
void write(const TypedValue& value) {
serializer.write(value);
}
protected:
Logger logger;
void send() { chunked_buffer.flush(); }
private:
using socket_t = SocketStream<Socket>;
using buffer_t = ChunkedBuffer<socket_t>;
using chunked_encoder_t = ChunkedEncoder<buffer_t>;
using bolt_encoder_t = BoltEncoder<chunked_encoder_t>;
using bolt_serializer_t = BoltSerializer<bolt_encoder_t>;
void chunk() { chunked_encoder.write_chunk(); }
// TODO WTF is this test doing here?
void _write_test() {
logger.trace("write_test");
write_fields({{"name"}});
write_record();
write_list_header(1);
bolt_encoder.write("max");
write_record();
write_list_header(1);
bolt_encoder.write("paul");
write_success_empty();
}
protected:
Logger logger;
private:
using socket_t = SocketStream<Socket>;
using buffer_t = ChunkedBuffer<socket_t>;
using chunked_encoder_t = ChunkedEncoder<buffer_t>;
using bolt_encoder_t = BoltEncoder<chunked_encoder_t>;
using bolt_serializer_t = BoltSerializer<bolt_encoder_t>;
socket_t socket;
buffer_t chunked_buffer{socket};
chunked_encoder_t chunked_encoder{chunked_buffer};
bolt_encoder_t bolt_encoder{chunked_encoder};
bolt_serializer_t serializer{bolt_encoder};
};
socket_t socket;
buffer_t chunked_buffer{socket};
chunked_encoder_t chunked_encoder{chunked_buffer};
bolt_encoder_t bolt_encoder{chunked_encoder};
bolt_serializer_t serializer{bolt_encoder};
};
}

View File

@ -1,72 +1,62 @@
#pragma once
#include <vector>
#include <memory>
#include <thread>
#include <atomic>
#include <cassert>
#include <memory>
#include <thread>
#include <vector>
#include "io/network/server.hpp"
#include "communication/bolt/v1/bolt.hpp"
#include "io/network/server.hpp"
#include "logging/default.hpp"
namespace bolt
{
namespace bolt {
template <class Worker>
class Server : public io::Server<Server<Worker>>
{
public:
Server(io::Socket&& socket)
: io::Server<Server<Worker>>(std::forward<io::Socket>(socket)),
logger(logging::log->logger("bolt::Server")) {}
class Server : public io::Server<Server<Worker>> {
public:
Server(io::Socket&& socket)
: io::Server<Server<Worker>>(std::forward<io::Socket>(socket)),
logger(logging::log->logger("bolt::Server")) {}
void start(size_t n)
{
workers.reserve(n);
void start(size_t n) {
workers.reserve(n);
for(size_t i = 0; i < n; ++i)
{
workers.push_back(std::make_shared<Worker>(bolt));
workers.back()->start(alive);
}
while(alive)
{
this->wait_and_process_events();
}
for (size_t i = 0; i < n; ++i) {
workers.push_back(std::make_shared<Worker>(bolt));
workers.back()->start(alive);
}
void shutdown()
{
alive.store(false);
for(auto& worker : workers)
worker->thread.join();
while (alive) {
this->wait_and_process_events();
}
}
void on_connect()
{
assert(idx < workers.size());
void shutdown() {
alive.store(false);
logger.trace("on connect");
for (auto& worker : workers) worker->thread.join();
}
if(UNLIKELY(!workers[idx]->accept(this->socket)))
return;
void on_connect() {
assert(idx < workers.size());
idx = idx == workers.size() - 1 ? 0 : idx + 1;
}
logger.trace("on connect");
void on_wait_timeout() {}
if (UNLIKELY(!workers[idx]->accept(this->socket))) return;
private:
Bolt bolt;
idx = idx == workers.size() - 1 ? 0 : idx + 1;
}
std::vector<typename Worker::sptr> workers;
std::atomic<bool> alive {true};
void on_wait_timeout() {}
int idx {0};
Logger logger;
private:
Bolt bolt;
std::vector<typename Worker::sptr> workers;
std::atomic<bool> alive{true};
int idx{0};
Logger logger;
};
}

View File

@ -12,103 +12,92 @@
#include "io/network/stream_reader.hpp"
#include "logging/default.hpp"
namespace bolt
{
namespace bolt {
template <class Worker>
class Server;
class Worker : public io::StreamReader<Worker, Session>
{
friend class bolt::Server<Worker>;
class Worker : public io::StreamReader<Worker, Session> {
friend class bolt::Server<Worker>;
public:
using sptr = std::shared_ptr<Worker>;
public:
using sptr = std::shared_ptr<Worker>;
Worker(Bolt &bolt) : bolt(bolt)
{
logger = logging::log->logger("bolt::Worker");
}
Worker(Bolt &bolt) : bolt(bolt) {
logger = logging::log->logger("bolt::Worker");
}
Session &on_connect(io::Socket &&socket)
{
logger.trace("Accepting connection on socket {}", socket.id());
Session &on_connect(io::Socket &&socket) {
logger.trace("Accepting connection on socket {}", socket.id());
return *bolt.get().create_session(std::forward<io::Socket>(socket));
}
return *bolt.get().create_session(std::forward<io::Socket>(socket));
}
void on_error(Session &)
{
logger.trace("[on_error] errno = {}", errno);
void on_error(Session &) {
logger.trace("[on_error] errno = {}", errno);
#ifndef NDEBUG
auto err = io::NetworkError("");
logger.debug("{}", err.what());
auto err = io::NetworkError("");
logger.debug("{}", err.what());
#endif
logger.error("Error occured in this session");
}
logger.error("Error occured in this session");
}
void on_wait_timeout() {}
void on_wait_timeout() {}
Buffer on_alloc(Session &)
{
/* logger.trace("[on_alloc] Allocating {}B", sizeof buf); */
Buffer on_alloc(Session &) {
/* logger.trace("[on_alloc] Allocating {}B", sizeof buf); */
return Buffer{buf, sizeof buf};
}
return Buffer{buf, sizeof buf};
}
void on_read(Session &session, Buffer &buf)
{
logger.trace("[on_read] Received {}B", buf.len);
void on_read(Session &session, Buffer &buf) {
logger.trace("[on_read] Received {}B", buf.len);
#ifndef NDEBUG
std::stringstream stream;
std::stringstream stream;
for (size_t i = 0; i < buf.len; ++i)
stream << fmt::format("{:02X} ", static_cast<byte>(buf.ptr[i]));
for (size_t i = 0; i < buf.len; ++i)
stream << fmt::format("{:02X} ", static_cast<byte>(buf.ptr[i]));
logger.trace("[on_read] {}", stream.str());
logger.trace("[on_read] {}", stream.str());
#endif
try {
session.execute(reinterpret_cast<const byte *>(buf.ptr), buf.len);
} catch (const std::exception &e) {
logger.error("Error occured while executing statement.");
logger.error("{}", e.what());
// TODO: report to client
}
try {
session.execute(reinterpret_cast<const byte *>(buf.ptr), buf.len);
} catch (const std::exception &e) {
logger.error("Error occured while executing statement.");
logger.error("{}", e.what());
// TODO: report to client
}
}
void on_close(Session &session)
{
logger.trace("[on_close] Client closed the connection");
session.close();
}
void on_close(Session &session) {
logger.trace("[on_close] Client closed the connection");
session.close();
}
template <class... Args>
void on_exception(Session &session, Args &&... args)
{
logger.error("Error occured in this session");
logger.error(args...);
template <class... Args>
void on_exception(Session &session, Args &&... args) {
logger.error("Error occured in this session");
logger.error(args...);
// TODO: Do something about it
}
// TODO: Do something about it
}
char buf[65536];
char buf[65536];
protected:
std::reference_wrapper<Bolt> bolt;
protected:
std::reference_wrapper<Bolt> bolt;
Logger logger;
std::thread thread;
Logger logger;
std::thread thread;
void start(std::atomic<bool> &alive)
{
thread = std::thread([&, this]() {
while (alive)
wait_and_process_events();
});
}
void start(std::atomic<bool> &alive) {
thread = std::thread([&, this]() {
while (alive) wait_and_process_events();
});
}
};
}

View File

@ -1,46 +1,42 @@
#include "communication/bolt/v1/session.hpp"
namespace bolt
{
namespace bolt {
Session::Session(io::Socket &&socket, Bolt &bolt)
: Stream(std::forward<io::Socket>(socket)), bolt(bolt)
{
logger = logging::log->logger("Session");
: Stream(std::forward<io::Socket>(socket)), bolt(bolt) {
logger = logging::log->logger("Session");
// start with a handshake state
state = bolt.states.handshake.get();
// start with a handshake state
state = bolt.states.handshake.get();
}
bool Session::alive() const { return state != nullptr; }
void Session::execute(const byte *data, size_t len)
{
// mark the end of the message
auto end = data + len;
void Session::execute(const byte *data, size_t len) {
// mark the end of the message
auto end = data + len;
while (true) {
auto size = end - data;
while (true) {
auto size = end - data;
if (LIKELY(connected)) {
logger.debug("Decoding chunk of size {}", size);
auto finished = decoder.decode(data, size);
if (LIKELY(connected)) {
logger.debug("Decoding chunk of size {}", size);
auto finished = decoder.decode(data, size);
if (!finished) return;
} else {
logger.debug("Decoding handshake of size {}", size);
decoder.handshake(data, size);
}
state = state->run(*this);
decoder.reset();
if (!finished) return;
} else {
logger.debug("Decoding handshake of size {}", size);
decoder.handshake(data, size);
}
state = state->run(*this);
decoder.reset();
}
}
void Session::close()
{
logger.debug("Closing session");
bolt.close(this);
void Session::close() {
logger.debug("Closing session");
bolt.close(this);
}
GraphDbAccessor Session::active_db() { return bolt.dbms.active(); }

View File

@ -3,41 +3,41 @@
#include "io/network/socket.hpp"
#include "io/network/tcp/stream.hpp"
#include "communication/bolt/communication.hpp"
#include "communication/bolt/v1/bolt.hpp"
#include "communication/bolt/v1/serialization/record_stream.hpp"
#include "communication/bolt/v1/states/state.hpp"
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "communication/bolt/v1/transport/bolt_encoder.hpp"
#include "communication/bolt/communication.hpp"
#include "logging/default.hpp"
namespace bolt {
class Session : public io::tcp::Stream<io::Socket> {
public:
using Decoder = BoltDecoder;
using OutputStream = communication::OutputStream;
class Session : public io::tcp::Stream<io::Socket> {
public:
using Decoder = BoltDecoder;
using OutputStream = communication::OutputStream;
Session(io::Socket &&socket, Bolt &bolt);
Session(io::Socket &&socket, Bolt &bolt);
bool alive() const;
bool alive() const;
void execute(const byte *data, size_t len);
void execute(const byte *data, size_t len);
void close();
void close();
Bolt &bolt;
Bolt &bolt;
GraphDbAccessor active_db();
GraphDbAccessor active_db();
Decoder decoder;
OutputStream output_stream{socket};
Decoder decoder;
OutputStream output_stream{socket};
bool connected{false};
State *state;
bool connected{false};
State *state;
protected:
Logger logger;
};
protected:
Logger logger;
};
}

View File

@ -1,19 +1,16 @@
#include "communication/bolt/v1/states.hpp"
#include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/states/init.hpp"
#include "communication/bolt/v1/states/error.hpp"
#include "communication/bolt/v1/states/executor.hpp"
#include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/states/init.hpp"
namespace bolt
{
namespace bolt {
States::States()
{
handshake = std::make_unique<Handshake>();
init = std::make_unique<Init>();
executor = std::make_unique<Executor>();
error = std::make_unique<Error>();
States::States() {
handshake = std::make_unique<Handshake>();
init = std::make_unique<Init>();
executor = std::make_unique<Executor>();
error = std::make_unique<Error>();
}
}

View File

@ -3,18 +3,15 @@
#include "communication/bolt/v1/states/state.hpp"
#include "logging/log.hpp"
namespace bolt
{
namespace bolt {
class States
{
public:
States();
class States {
public:
States();
State::uptr handshake;
State::uptr init;
State::uptr executor;
State::uptr error;
State::uptr handshake;
State::uptr init;
State::uptr executor;
State::uptr error;
};
}

View File

@ -1,55 +1,47 @@
#include "communication/bolt/v1/states/error.hpp"
namespace bolt
{
namespace bolt {
Error::Error() : State(logging::log->logger("Error State")) {}
State* Error::run(Session& session)
{
logger.trace("Run");
State* Error::run(Session& session) {
logger.trace("Run");
session.decoder.read_byte();
auto message_type = session.decoder.read_byte();
session.decoder.read_byte();
auto message_type = session.decoder.read_byte();
logger.trace("Message type byte is: {:02X}", message_type);
logger.trace("Message type byte is: {:02X}", message_type);
if (message_type == MessageCode::PullAll)
{
session.output_stream.write_ignored();
session.output_stream.chunk();
session.output_stream.send();
return this;
}
else if(message_type == MessageCode::AckFailure)
{
// TODO reset current statement? is it even necessary?
logger.trace("AckFailure received");
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
return session.bolt.states.executor.get();
}
else if(message_type == MessageCode::Reset)
{
// TODO rollback current transaction
// discard all records waiting to be sent
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
return session.bolt.states.executor.get();
}
// TODO: write this as single call
if (message_type == MessageCode::PullAll) {
session.output_stream.write_ignored();
session.output_stream.chunk();
session.output_stream.send();
return this;
}
} else if (message_type == MessageCode::AckFailure) {
// TODO reset current statement? is it even necessary?
logger.trace("AckFailure received");
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
return session.bolt.states.executor.get();
} else if (message_type == MessageCode::Reset) {
// TODO rollback current transaction
// discard all records waiting to be sent
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
return session.bolt.states.executor.get();
}
// TODO: write this as single call
session.output_stream.write_ignored();
session.output_stream.chunk();
session.output_stream.send();
return this;
}
}

View File

@ -3,15 +3,12 @@
#include "communication/bolt/v1/session.hpp"
#include "communication/bolt/v1/states/state.hpp"
namespace bolt
{
namespace bolt {
class Error : public State
{
public:
Error();
class Error : public State {
public:
Error();
State *run(Session &session) override;
State *run(Session &session) override;
};
}

View File

@ -6,112 +6,93 @@
#include "barrier/barrier.cpp"
#endif
namespace bolt
{
namespace bolt {
Executor::Executor() : State(logging::log->logger("Executor")) {}
State *Executor::run(Session &session)
{
// just read one byte that represents the struct type, we can skip the
// information contained in this byte
session.decoder.read_byte();
State *Executor::run(Session &session) {
// just read one byte that represents the struct type, we can skip the
// information contained in this byte
session.decoder.read_byte();
logger.debug("Run");
logger.debug("Run");
auto message_type = session.decoder.read_byte();
auto message_type = session.decoder.read_byte();
if (message_type == MessageCode::Run)
{
Query q;
if (message_type == MessageCode::Run) {
Query q;
q.statement = session.decoder.read_string();
q.statement = session.decoder.read_string();
try
{
return this->run(session, q);
// TODO: RETURN success MAYBE
}
catch (const QueryEngineException &e)
{
session.output_stream.write_failure(
{{"code", "Memgraph.QueryEngineException"},
{"message", e.what()}});
session.output_stream.send();
return session.bolt.states.error.get();
} catch (std::exception &e) {
session.output_stream.write_failure(
{{"code", "Memgraph.Exception"},
{"message", e.what()}});
session.output_stream.send();
return session.bolt.states.error.get();
}
try {
return this->run(session, q);
// TODO: RETURN success MAYBE
} catch (const QueryEngineException &e) {
session.output_stream.write_failure(
{{"code", "Memgraph.QueryEngineException"}, {"message", e.what()}});
session.output_stream.send();
return session.bolt.states.error.get();
} catch (std::exception &e) {
session.output_stream.write_failure(
{{"code", "Memgraph.Exception"}, {"message", e.what()}});
session.output_stream.send();
return session.bolt.states.error.get();
}
else if (message_type == MessageCode::PullAll)
{
pull_all(session);
}
else if (message_type == MessageCode::DiscardAll)
{
discard_all(session);
}
else if (message_type == MessageCode::Reset)
{
// TODO: rollback current transaction
// discard all records waiting to be sent
return this;
}
else
{
logger.error("Unrecognized message recieved");
logger.debug("Invalid message type 0x{:02X}", message_type);
return session.bolt.states.error.get();
}
} else if (message_type == MessageCode::PullAll) {
pull_all(session);
} else if (message_type == MessageCode::DiscardAll) {
discard_all(session);
} else if (message_type == MessageCode::Reset) {
// TODO: rollback current transaction
// discard all records waiting to be sent
return this;
} else {
logger.error("Unrecognized message recieved");
logger.debug("Invalid message type 0x{:02X}", message_type);
return session.bolt.states.error.get();
}
return this;
}
State *Executor::run(Session &session, Query &query)
{
logger.trace("[Run] '{}'", query.statement);
State *Executor::run(Session &session, Query &query) {
logger.trace("[Run] '{}'", query.statement);
auto db_accessor = session.active_db();
logger.debug("[ActiveDB] '{}'", db_accessor.name());
auto db_accessor = session.active_db();
logger.debug("[ActiveDB] '{}'", db_accessor.name());
auto is_successfully_executed =
query_engine.Run(query.statement, db_accessor, session.output_stream);
if (!is_successfully_executed)
{
session.output_stream.write_failure(
{{"code", "Memgraph.QueryExecutionFail"},
{"message", "Query execution has failed (probably there is no "
"element or there are some problems with concurrent "
"access -> client has to resolve problems with "
"concurrent access)"}});
session.output_stream.send();
return session.bolt.states.error.get();
}
return this;
}
void Executor::pull_all(Session &session)
{
logger.trace("[PullAll]");
auto is_successfully_executed =
query_engine.Run(query.statement, db_accessor, session.output_stream);
if (!is_successfully_executed) {
session.output_stream.write_failure(
{{"code", "Memgraph.QueryExecutionFail"},
{"message",
"Query execution has failed (probably there is no "
"element or there are some problems with concurrent "
"access -> client has to resolve problems with "
"concurrent access)"}});
session.output_stream.send();
return session.bolt.states.error.get();
}
return this;
}
void Executor::discard_all(Session &session)
{
logger.trace("[DiscardAll]");
void Executor::pull_all(Session &session) {
logger.trace("[PullAll]");
// TODO: discard state
session.output_stream.send();
}
session.output_stream.write_success();
session.output_stream.chunk();
session.output_stream.send();
void Executor::discard_all(Session &session) {
logger.trace("[DiscardAll]");
// TODO: discard state
session.output_stream.write_success();
session.output_stream.chunk();
session.output_stream.send();
}
}

View File

@ -1,44 +1,38 @@
#pragma once
#include "communication/bolt/v1/states/state.hpp"
#include "communication/bolt/v1/session.hpp"
#include "communication/bolt/v1/states/state.hpp"
#include "query/engine.hpp"
namespace bolt
{
namespace bolt {
class Executor : public State
{
struct Query
{
std::string statement;
};
class Executor : public State {
struct Query {
std::string statement;
};
public:
Executor();
public:
Executor();
State* run(Session& session) override final;
State* run(Session& session) override final;
protected:
/* Execute an incoming query
*
*/
State* run(Session& session, Query& query);
protected:
/* Execute an incoming query
*
*/
State* run(Session& session, Query& query);
/* Send all remaining results to the client
*
*/
void pull_all(Session& session);
/* Send all remaining results to the client
*
*/
void pull_all(Session& session);
/* Discard all remaining results
*
*/
void discard_all(Session& session);
private:
QueryEngine<communication::OutputStream> query_engine;
/* Discard all remaining results
*
*/
void discard_all(Session& session);
private:
QueryEngine<communication::OutputStream> query_engine;
};
}

View File

@ -2,8 +2,7 @@
#include "communication/bolt/v1/session.hpp"
namespace bolt
{
namespace bolt {
static constexpr uint32_t preamble = 0x6060B017;
@ -11,21 +10,18 @@ static constexpr byte protocol[4] = {0x00, 0x00, 0x00, 0x01};
Handshake::Handshake() : State(logging::log->logger("Handshake")) {}
State* Handshake::run(Session& session)
{
logger.debug("run");
State* Handshake::run(Session& session) {
logger.debug("run");
if(UNLIKELY(session.decoder.read_uint32() != preamble))
return nullptr;
if (UNLIKELY(session.decoder.read_uint32() != preamble)) return nullptr;
// TODO so far we only support version 1 of the protocol so it doesn't
// make sense to check which version the client prefers
// this will change in the future
// TODO so far we only support version 1 of the protocol so it doesn't
// make sense to check which version the client prefers
// this will change in the future
session.connected = true;
session.socket.write(protocol, sizeof protocol);
session.connected = true;
session.socket.write(protocol, sizeof protocol);
return session.bolt.states.init.get();
return session.bolt.states.init.get();
}
}

View File

@ -2,14 +2,11 @@
#include "communication/bolt/v1/states/state.hpp"
namespace bolt
{
namespace bolt {
class Handshake : public State
{
public:
Handshake();
State* run(Session& session) override;
class Handshake : public State {
public:
Handshake();
State* run(Session& session) override;
};
}

View File

@ -5,53 +5,50 @@
#include "utils/likely.hpp"
namespace bolt
{
namespace bolt {
Init::Init() : MessageParser<Init>(logging::log->logger("Init")) {}
State *Init::parse(Session &session, Message &message)
{
logger.debug("bolt::Init.parse()");
State *Init::parse(Session &session, Message &message) {
logger.debug("bolt::Init.parse()");
auto struct_type = session.decoder.read_byte();
auto struct_type = session.decoder.read_byte();
if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) {
logger.debug("{}", struct_type);
if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) {
logger.debug("{}", struct_type);
logger.debug(
"Expected struct marker of max size 0x{:02} instead of 0x{:02X}",
(unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type);
logger.debug(
"Expected struct marker of max size 0x{:02} instead of 0x{:02X}",
(unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type);
return nullptr;
}
return nullptr;
}
auto message_type = session.decoder.read_byte();
auto message_type = session.decoder.read_byte();
if (UNLIKELY(message_type != MessageCode::Init)) {
logger.debug("Expected Init (0x01) instead of (0x{:02X})",
(unsigned)message_type);
if (UNLIKELY(message_type != MessageCode::Init)) {
logger.debug("Expected Init (0x01) instead of (0x{:02X})",
(unsigned)message_type);
return nullptr;
}
return nullptr;
}
message.client_name = session.decoder.read_string();
message.client_name = session.decoder.read_string();
if (struct_type == pack::Code::StructTwo) {
// TODO process authentication tokens
}
if (struct_type == pack::Code::StructTwo) {
// TODO process authentication tokens
}
return this;
return this;
}
State *Init::execute(Session &session, Message &message)
{
logger.debug("Client connected '{}'", message.client_name);
State *Init::execute(Session &session, Message &message) {
logger.debug("Client connected '{}'", message.client_name);
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
return session.bolt.states.executor.get();
return session.bolt.states.executor.get();
}
}

View File

@ -2,21 +2,17 @@
#include "communication/bolt/v1/states/message_parser.hpp"
namespace bolt
{
namespace bolt {
class Init : public MessageParser<Init>
{
public:
struct Message
{
std::string client_name;
};
class Init : public MessageParser<Init> {
public:
struct Message {
std::string client_name;
};
Init();
Init();
State* parse(Session& session, Message& message);
State* execute(Session& session, Message& message);
State* parse(Session& session, Message& message);
State* execute(Session& session, Message& message);
};
}

View File

@ -4,30 +4,27 @@
#include "communication/bolt/v1/states/state.hpp"
#include "utils/crtp.hpp"
namespace bolt
{
namespace bolt {
template <class Derived>
class MessageParser : public State, public Crtp<Derived>
{
public:
MessageParser(Logger &&logger) : logger(std::forward<Logger>(logger)) {}
class MessageParser : public State, public Crtp<Derived> {
public:
MessageParser(Logger &&logger) : logger(std::forward<Logger>(logger)) {}
State *run(Session &session) override final
{
typename Derived::Message message;
State *run(Session &session) override final {
typename Derived::Message message;
logger.debug("Parsing message");
auto next = this->derived().parse(session, message);
logger.debug("Parsing message");
auto next = this->derived().parse(session, message);
// return next state if parsing was unsuccessful (i.e. error state)
if (next != &this->derived()) return next;
// return next state if parsing was unsuccessful (i.e. error state)
if (next != &this->derived()) return next;
logger.debug("Executing state");
return this->derived().execute(session, message);
}
logger.debug("Executing state");
return this->derived().execute(session, message);
}
protected:
Logger logger;
protected:
Logger logger;
};
}

View File

@ -1,30 +1,27 @@
#pragma once
#include <cstdlib>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include "logging/default.hpp"
namespace bolt
{
namespace bolt {
class Session;
class State
{
public:
using uptr = std::unique_ptr<State>;
class State {
public:
using uptr = std::unique_ptr<State>;
State() = default;
State(Logger logger) : logger(logger) {}
State() = default;
State(Logger logger) : logger(logger) {}
virtual ~State() = default;
virtual ~State() = default;
virtual State* run(Session& session) = 0;
virtual State* run(Session& session) = 0;
protected:
Logger logger;
protected:
Logger logger;
};
}

View File

@ -4,59 +4,52 @@
#include "logging/default.hpp"
#include "utils/bswap.hpp"
namespace bolt
{
namespace bolt {
void BoltDecoder::handshake(const byte *&data, size_t len)
{
buffer.write(data, len);
data += len;
void BoltDecoder::handshake(const byte *&data, size_t len) {
buffer.write(data, len);
data += len;
}
bool BoltDecoder::decode(const byte *&data, size_t len)
{
return decoder(data, len);
bool BoltDecoder::decode(const byte *&data, size_t len) {
return decoder(data, len);
}
bool BoltDecoder::empty() const { return pos == buffer.size(); }
void BoltDecoder::reset()
{
buffer.clear();
pos = 0;
void BoltDecoder::reset() {
buffer.clear();
pos = 0;
}
byte BoltDecoder::peek() const { return buffer[pos]; }
byte BoltDecoder::read_byte() { return buffer[pos++]; }
void BoltDecoder::read_bytes(void *dest, size_t n)
{
std::memcpy(dest, buffer.data() + pos, n);
pos += n;
void BoltDecoder::read_bytes(void *dest, size_t n) {
std::memcpy(dest, buffer.data() + pos, n);
pos += n;
}
template <class T>
T parse(const void *data)
{
// reinterpret bytes as the target value
auto value = reinterpret_cast<const T *>(data);
T parse(const void *data) {
// reinterpret bytes as the target value
auto value = reinterpret_cast<const T *>(data);
// swap values to little endian
return bswap(*value);
// swap values to little endian
return bswap(*value);
}
template <class T>
T parse(Buffer &buffer, size_t &pos)
{
// get a pointer to the data we're converting
auto ptr = buffer.data() + pos;
T parse(Buffer &buffer, size_t &pos) {
// get a pointer to the data we're converting
auto ptr = buffer.data() + pos;
// skip sizeof bytes that we're going to read
pos += sizeof(T);
// skip sizeof bytes that we're going to read
pos += sizeof(T);
// read and convert the value
return parse<T>(ptr);
// read and convert the value
return parse<T>(ptr);
}
int16_t BoltDecoder::read_int16() { return parse<int16_t>(buffer, pos); }
@ -71,46 +64,44 @@ int64_t BoltDecoder::read_int64() { return parse<int64_t>(buffer, pos); }
uint64_t BoltDecoder::read_uint64() { return parse<uint64_t>(buffer, pos); }
double BoltDecoder::read_float64()
{
auto v = parse<int64_t>(buffer, pos);
return *reinterpret_cast<const double *>(&v);
double BoltDecoder::read_float64() {
auto v = parse<int64_t>(buffer, pos);
return *reinterpret_cast<const double *>(&v);
}
std::string BoltDecoder::read_string()
{
auto marker = read_byte();
std::string BoltDecoder::read_string() {
auto marker = read_byte();
std::string res;
uint32_t size;
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
// size is stored in the lower 4 bits of the marker byte
size = marker & 0x0F;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
size = read_byte();
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
size = read_uint16();
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
size = read_uint32();
} else {
// TODO error?
return res;
}
if (size == 0) return res;
res.append(reinterpret_cast<const char *>(raw()), size);
pos += size;
std::string res;
uint32_t size;
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
// size is stored in the lower 4 bits of the marker byte
size = marker & 0x0F;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
size = read_byte();
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
size = read_uint16();
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
size = read_uint32();
} else {
// TODO error?
return res;
}
if (size == 0) return res;
res.append(reinterpret_cast<const char *>(raw()), size);
pos += size;
return res;
}
const byte *BoltDecoder::raw() const { return buffer.data() + pos; }

View File

@ -4,40 +4,38 @@
#include "communication/bolt/v1/transport/chunked_decoder.hpp"
#include "utils/types/byte.hpp"
namespace bolt
{
namespace bolt {
class BoltDecoder
{
public:
void handshake(const byte *&data, size_t len);
bool decode(const byte *&data, size_t len);
class BoltDecoder {
public:
void handshake(const byte *&data, size_t len);
bool decode(const byte *&data, size_t len);
bool empty() const;
void reset();
bool empty() const;
void reset();
byte peek() const;
byte read_byte();
void read_bytes(void *dest, size_t n);
byte peek() const;
byte read_byte();
void read_bytes(void *dest, size_t n);
int16_t read_int16();
uint16_t read_uint16();
int16_t read_int16();
uint16_t read_uint16();
int32_t read_int32();
uint32_t read_uint32();
int32_t read_int32();
uint32_t read_uint32();
int64_t read_int64();
uint64_t read_uint64();
int64_t read_int64();
uint64_t read_uint64();
double read_float64();
double read_float64();
std::string read_string();
std::string read_string();
private:
Buffer buffer;
ChunkedDecoder<Buffer> decoder{buffer};
size_t pos{0};
private:
Buffer buffer;
ChunkedDecoder<Buffer> decoder{buffer};
size_t pos{0};
const byte *raw() const;
const byte *raw() const;
};
}

View File

@ -8,212 +8,189 @@
#include "utils/bswap.hpp"
#include "utils/types/byte.hpp"
namespace bolt
{
namespace bolt {
template <class Stream>
class BoltEncoder
{
static constexpr int64_t plus_2_to_the_31 = 2147483648L;
static constexpr int64_t plus_2_to_the_15 = 32768L;
static constexpr int64_t plus_2_to_the_7 = 128L;
static constexpr int64_t minus_2_to_the_4 = -16L;
static constexpr int64_t minus_2_to_the_7 = -128L;
static constexpr int64_t minus_2_to_the_15 = -32768L;
static constexpr int64_t minus_2_to_the_31 = -2147483648L;
class BoltEncoder {
static constexpr int64_t plus_2_to_the_31 = 2147483648L;
static constexpr int64_t plus_2_to_the_15 = 32768L;
static constexpr int64_t plus_2_to_the_7 = 128L;
static constexpr int64_t minus_2_to_the_4 = -16L;
static constexpr int64_t minus_2_to_the_7 = -128L;
static constexpr int64_t minus_2_to_the_15 = -32768L;
static constexpr int64_t minus_2_to_the_31 = -2147483648L;
public:
BoltEncoder(Stream &stream) : stream(stream)
{
logger = logging::log->logger("Bolt Encoder");
public:
BoltEncoder(Stream &stream) : stream(stream) {
logger = logging::log->logger("Bolt Encoder");
}
void write(byte value) { write_byte(value); }
void write_byte(byte value) {
logger.trace("write byte: {}", value);
stream.write(value);
}
void write(const byte *values, size_t n) { stream.write(values, n); }
void write_null() { stream.write(pack::Null); }
void write(bool value) { write_bool(value); }
void write_bool(bool value) {
if (value)
write_true();
else
write_false();
}
void write_true() { stream.write(pack::True); }
void write_false() { stream.write(pack::False); }
template <class T>
void write_value(T value) {
value = bswap(value);
stream.write(reinterpret_cast<const byte *>(&value), sizeof(value));
}
void write_integer(int64_t value) {
if (value >= minus_2_to_the_4 && value < plus_2_to_the_7) {
write(static_cast<byte>(value));
} else if (value >= minus_2_to_the_7 && value < minus_2_to_the_4) {
write(pack::Int8);
write(static_cast<byte>(value));
} else if (value >= minus_2_to_the_15 && value < plus_2_to_the_15) {
write(pack::Int16);
write_value(static_cast<int16_t>(value));
} else if (value >= minus_2_to_the_31 && value < plus_2_to_the_31) {
write(pack::Int32);
write_value(static_cast<int32_t>(value));
} else {
write(pack::Int64);
write_value(value);
}
}
void write(byte value) { write_byte(value); }
void write(double value) { write_double(value); }
void write_byte(byte value)
{
logger.trace("write byte: {}", value);
stream.write(value);
void write_double(double value) {
write(pack::Float64);
write_value(*reinterpret_cast<const int64_t *>(&value));
}
void write_map_header(size_t size) {
if (size < 0x10) {
write(static_cast<byte>(pack::TinyMap | size));
} else if (size <= 0xFF) {
write(pack::Map8);
write(static_cast<byte>(size));
} else if (size <= 0xFFFF) {
write(pack::Map16);
write_value<uint16_t>(size);
} else {
write(pack::Map32);
write_value<uint32_t>(size);
}
}
void write(const byte *values, size_t n) { stream.write(values, n); }
void write_empty_map() { write(pack::TinyMap); }
void write_null() { stream.write(pack::Null); }
void write(bool value) { write_bool(value); }
void write_bool(bool value)
{
if (value)
write_true();
else
write_false();
void write_list_header(size_t size) {
if (size < 0x10) {
write(static_cast<byte>(pack::TinyList | size));
} else if (size <= 0xFF) {
write(pack::List8);
write(static_cast<byte>(size));
} else if (size <= 0xFFFF) {
write(pack::List16);
write_value<uint16_t>(size);
} else {
write(pack::List32);
write_value<uint32_t>(size);
}
}
void write_true() { stream.write(pack::True); }
void write_empty_list() { write(pack::TinyList); }
void write_false() { stream.write(pack::False); }
template <class T>
void write_value(T value)
{
value = bswap(value);
stream.write(reinterpret_cast<const byte *>(&value), sizeof(value));
void write_string_header(size_t size) {
if (size < 0x10) {
write(static_cast<byte>(pack::TinyString | size));
} else if (size <= 0xFF) {
write(pack::String8);
write(static_cast<byte>(size));
} else if (size <= 0xFFFF) {
write(pack::String16);
write_value<uint16_t>(size);
} else {
write(pack::String32);
write_value<uint32_t>(size);
}
}
void write_integer(int64_t value)
{
if (value >= minus_2_to_the_4 && value < plus_2_to_the_7) {
write(static_cast<byte>(value));
} else if (value >= minus_2_to_the_7 && value < minus_2_to_the_4) {
write(pack::Int8);
write(static_cast<byte>(value));
} else if (value >= minus_2_to_the_15 && value < plus_2_to_the_15) {
write(pack::Int16);
write_value(static_cast<int16_t>(value));
} else if (value >= minus_2_to_the_31 && value < plus_2_to_the_31) {
write(pack::Int32);
write_value(static_cast<int32_t>(value));
} else {
write(pack::Int64);
write_value(value);
}
void write(const std::string &str) { write_string(str); }
void write_string(const std::string &str) {
write_string(str.c_str(), str.size());
}
void write_string(const char *str, size_t len) {
write_string_header(len);
write(reinterpret_cast<const byte *>(str), len);
}
void write_struct_header(size_t size) {
if (size < 0x10) {
write(static_cast<byte>(pack::TinyStruct | size));
} else if (size <= 0xFF) {
write(pack::Struct8);
write(static_cast<byte>(size));
} else {
write(pack::Struct16);
write_value<uint16_t>(size);
}
}
void write(double value) { write_double(value); }
void message_success() {
write_struct_header(1);
write(underlying_cast(MessageCode::Success));
}
void write_double(double value)
{
write(pack::Float64);
write_value(*reinterpret_cast<const int64_t *>(&value));
}
void message_success_empty() {
message_success();
write_empty_map();
}
void write_map_header(size_t size)
{
if (size < 0x10) {
write(static_cast<byte>(pack::TinyMap | size));
} else if (size <= 0xFF) {
write(pack::Map8);
write(static_cast<byte>(size));
} else if (size <= 0xFFFF) {
write(pack::Map16);
write_value<uint16_t>(size);
} else {
write(pack::Map32);
write_value<uint32_t>(size);
}
}
void message_record() {
write_struct_header(1);
write(underlying_cast(MessageCode::Record));
}
void write_empty_map() { write(pack::TinyMap); }
void message_record_empty() {
message_record();
write_empty_list();
}
void write_list_header(size_t size)
{
if (size < 0x10) {
write(static_cast<byte>(pack::TinyList | size));
} else if (size <= 0xFF) {
write(pack::List8);
write(static_cast<byte>(size));
} else if (size <= 0xFFFF) {
write(pack::List16);
write_value<uint16_t>(size);
} else {
write(pack::List32);
write_value<uint32_t>(size);
}
}
void message_ignored() {
write_struct_header(0);
write(underlying_cast(MessageCode::Ignored));
}
void write_empty_list() { write(pack::TinyList); }
void message_failure() {
write_struct_header(1);
write(underlying_cast(MessageCode::Failure));
}
void write_string_header(size_t size)
{
if (size < 0x10) {
write(static_cast<byte>(pack::TinyString | size));
} else if (size <= 0xFF) {
write(pack::String8);
write(static_cast<byte>(size));
} else if (size <= 0xFFFF) {
write(pack::String16);
write_value<uint16_t>(size);
} else {
write(pack::String32);
write_value<uint32_t>(size);
}
}
void message_ignored_empty() {
message_ignored();
write_empty_map();
}
void write(const std::string& str) {
write_string(str);
}
protected:
Logger logger;
void write_string(const std::string &str)
{
write_string(str.c_str(), str.size());
}
void write_string(const char *str, size_t len)
{
write_string_header(len);
write(reinterpret_cast<const byte *>(str), len);
}
void write_struct_header(size_t size)
{
if (size < 0x10) {
write(static_cast<byte>(pack::TinyStruct | size));
} else if (size <= 0xFF) {
write(pack::Struct8);
write(static_cast<byte>(size));
} else {
write(pack::Struct16);
write_value<uint16_t>(size);
}
}
void message_success()
{
write_struct_header(1);
write(underlying_cast(MessageCode::Success));
}
void message_success_empty()
{
message_success();
write_empty_map();
}
void message_record()
{
write_struct_header(1);
write(underlying_cast(MessageCode::Record));
}
void message_record_empty()
{
message_record();
write_empty_list();
}
void message_ignored()
{
write_struct_header(0);
write(underlying_cast(MessageCode::Ignored));
}
void message_failure()
{
write_struct_header(1);
write(underlying_cast(MessageCode::Failure));
}
void message_ignored_empty()
{
message_ignored();
write_empty_map();
}
protected:
Logger logger;
private:
Stream &stream;
private:
Stream &stream;
};
}

View File

@ -1,16 +1,10 @@
#include "communication/bolt/v1/transport/buffer.hpp"
namespace bolt
{
namespace bolt {
void Buffer::write(const byte* data, size_t len)
{
buffer.insert(buffer.end(), data, data + len);
}
void Buffer::clear()
{
buffer.clear();
void Buffer::write(const byte* data, size_t len) {
buffer.insert(buffer.end(), data, data + len);
}
void Buffer::clear() { buffer.clear(); }
}

View File

@ -6,34 +6,21 @@
#include "utils/types/byte.hpp"
namespace bolt
{
namespace bolt {
class Buffer
{
public:
void write(const byte* data, size_t len);
class Buffer {
public:
void write(const byte* data, size_t len);
void clear();
void clear();
size_t size() const
{
return buffer.size();
}
size_t size() const { return buffer.size(); }
byte operator[](size_t idx) const
{
return buffer[idx];
}
byte operator[](size_t idx) const { return buffer[idx]; }
const byte* data() const
{
return buffer.data();
}
const byte* data() const { return buffer.data(); }
private:
std::vector<byte> buffer;
private:
std::vector<byte> buffer;
};
}

View File

@ -1,65 +1,57 @@
#pragma once
#include <cstring>
#include <memory>
#include <vector>
#include <cstring>
#include "communication/bolt/v1/config.hpp"
#include "utils/types/byte.hpp"
#include "logging/default.hpp"
#include "utils/types/byte.hpp"
namespace bolt
{
namespace bolt {
template <class Stream>
class ChunkedBuffer
{
static constexpr size_t C = bolt::config::C; /* chunk size */
class ChunkedBuffer {
static constexpr size_t C = bolt::config::C; /* chunk size */
public:
ChunkedBuffer(Stream &stream) : stream(stream)
{
logger = logging::log->logger("Chunked Buffer");
}
public:
ChunkedBuffer(Stream &stream) : stream(stream) {
logger = logging::log->logger("Chunked Buffer");
}
void write(const byte *values, size_t n)
{
logger.trace("Write {} bytes", n);
void write(const byte *values, size_t n) {
logger.trace("Write {} bytes", n);
// total size of the buffer is now bigger for n
size += n;
// total size of the buffer is now bigger for n
size += n;
// reserve enough spece for the new data
buffer.reserve(size);
// reserve enough spece for the new data
buffer.reserve(size);
// copy new data
std::copy(values, values + n, std::back_inserter(buffer));
}
// copy new data
std::copy(values, values + n, std::back_inserter(buffer));
}
void flush()
{
stream.get().write(&buffer.front(), size);
void flush() {
stream.get().write(&buffer.front(), size);
logger.trace("Flushed {} bytes", size);
logger.trace("Flushed {} bytes", size);
// GC
// TODO: impelement a better strategy
buffer.clear();
// GC
// TODO: impelement a better strategy
buffer.clear();
// reset size
size = 0;
}
// reset size
size = 0;
}
~ChunkedBuffer()
{
}
~ChunkedBuffer() {}
private:
Logger logger;
// every new stream.write creates new TCP package
std::reference_wrapper<Stream> stream;
std::vector<byte> buffer;
size_t size {0};
private:
Logger logger;
// every new stream.write creates new TCP package
std::reference_wrapper<Stream> stream;
std::vector<byte> buffer;
size_t size{0};
};
}

View File

@ -9,62 +9,56 @@
#include "utils/likely.hpp"
#include "utils/types/byte.hpp"
namespace bolt
{
namespace bolt {
template <class Stream>
class ChunkedDecoder
{
public:
class DecoderError : public BasicException
{
public:
using BasicException::BasicException;
};
class ChunkedDecoder {
public:
class DecoderError : public BasicException {
public:
using BasicException::BasicException;
};
ChunkedDecoder(Stream& stream) : stream(stream) {}
ChunkedDecoder(Stream &stream) : stream(stream) {}
/* Decode chunked data
*
* Chunk format looks like:
*
* |Header| Data ||Header| Data || ... || End |
* | 2B | size bytes || 2B | size bytes || ... ||00 00|
*/
bool decode(const byte *&chunk, size_t n)
{
while (n > 0)
{
// get size from first two bytes in the chunk
auto size = get_size(chunk);
/* Decode chunked data
*
* Chunk format looks like:
*
* |Header| Data ||Header| Data || ... || End |
* | 2B | size bytes || 2B | size bytes || ... ||00 00|
*/
bool decode(const byte *&chunk, size_t n) {
while (n > 0) {
// get size from first two bytes in the chunk
auto size = get_size(chunk);
if (UNLIKELY(size + 2 > n))
throw DecoderError("Chunk size larger than available data.");
if (UNLIKELY(size + 2 > n))
throw DecoderError("Chunk size larger than available data.");
// advance chunk to pass those two bytes
chunk += 2;
n -= 2;
// advance chunk to pass those two bytes
chunk += 2;
n -= 2;
// if chunk size is 0, we're done!
if (size == 0) return true;
// if chunk size is 0, we're done!
if (size == 0) return true;
stream.get().write(chunk, size);
stream.get().write(chunk, size);
chunk += size;
n -= size;
}
return false;
chunk += size;
n -= size;
}
bool operator()(const byte *&chunk, size_t n) { return decode(chunk, n); }
return false;
}
private:
std::reference_wrapper<Stream> stream;
bool operator()(const byte *&chunk, size_t n) { return decode(chunk, n); }
size_t get_size(const byte *chunk)
{
return size_t(chunk[0]) << 8 | chunk[1];
}
private:
std::reference_wrapper<Stream> stream;
size_t get_size(const byte *chunk) {
return size_t(chunk[0]) << 8 | chunk[1];
}
};
}

View File

@ -8,85 +8,76 @@
#include "logging/default.hpp"
#include "utils/likely.hpp"
namespace bolt
{
namespace bolt {
template <class Stream>
class ChunkedEncoder
{
static constexpr size_t N = bolt::config::N;
static constexpr size_t C = bolt::config::C;
class ChunkedEncoder {
static constexpr size_t N = bolt::config::N;
static constexpr size_t C = bolt::config::C;
public:
using byte = unsigned char;
public:
using byte = unsigned char;
ChunkedEncoder(Stream &stream)
: logger(logging::log->logger("Chunked Encoder")), stream(stream)
{
ChunkedEncoder(Stream &stream)
: logger(logging::log->logger("Chunked Encoder")), stream(stream) {}
static constexpr size_t chunk_size = N - 2;
void write(byte value) {
if (UNLIKELY(pos == N)) write_chunk();
chunk[pos++] = value;
}
void write(const byte *values, size_t n) {
logger.trace("write {} bytes", n);
while (n > 0) {
auto size = n < N - pos ? n : N - pos;
std::memcpy(chunk.data() + pos, values, size);
pos += size;
n -= size;
// TODO: see how bolt splits message over more TCP packets,
// test for more TCP packets
if (pos == N) write_chunk();
}
}
static constexpr size_t chunk_size = N - 2;
void write_chunk() {
write_chunk_header();
void write(byte value)
{
if (UNLIKELY(pos == N)) write_chunk();
// write two zeros to signal message end
chunk[pos++] = 0x00;
chunk[pos++] = 0x00;
chunk[pos++] = value;
}
flush();
}
void write(const byte *values, size_t n)
{
logger.trace("write {} bytes", n);
private:
Logger logger;
std::reference_wrapper<Stream> stream;
while (n > 0) {
auto size = n < N - pos ? n : N - pos;
std::array<byte, C> chunk;
size_t pos{2};
std::memcpy(chunk.data() + pos, values, size);
void write_chunk_header() {
// write the size of the chunk
uint16_t size = pos - 2;
pos += size;
n -= size;
// write the higher byte
chunk[0] = size >> 8;
// TODO: see how bolt splits message over more TCP packets,
// test for more TCP packets
if (pos == N) write_chunk();
}
}
// write the lower byte
chunk[1] = size & 0xFF;
}
void write_chunk()
{
write_chunk_header();
// write two zeros to signal message end
chunk[pos++] = 0x00;
chunk[pos++] = 0x00;
flush();
}
private:
Logger logger;
std::reference_wrapper<Stream> stream;
std::array<byte, C> chunk;
size_t pos{2};
void write_chunk_header()
{
// write the size of the chunk
uint16_t size = pos - 2;
// write the higher byte
chunk[0] = size >> 8;
// write the lower byte
chunk[1] = size & 0xFF;
}
void flush()
{
// write chunk to the stream
stream.get().write(chunk.data(), pos);
pos = 2;
}
void flush() {
// write chunk to the stream
stream.get().write(chunk.data(), pos);
pos = 2;
}
};
}

View File

@ -1,39 +1,33 @@
#pragma once
#include <cstdint>
#include <vector>
#include <cstdio>
#include <vector>
#include "io/network/socket.hpp"
#include "communication/bolt/v1/transport/stream_error.hpp"
#include "io/network/socket.hpp"
namespace bolt
{
namespace bolt {
template <typename Stream>
class SocketStream
{
public:
using byte = uint8_t;
class SocketStream {
public:
using byte = uint8_t;
SocketStream(Stream& socket) : socket(socket) {}
SocketStream(Stream& socket) : socket(socket) {}
void write(const byte* data, size_t n)
{
while(n > 0)
{
auto written = socket.get().write(data, n);
void write(const byte* data, size_t n) {
while (n > 0) {
auto written = socket.get().write(data, n);
if(UNLIKELY(written == -1))
throw StreamError("Can't write to stream");
if (UNLIKELY(written == -1)) throw StreamError("Can't write to stream");
n -= written;
data += written;
}
n -= written;
data += written;
}
}
private:
std::reference_wrapper<Stream> socket;
private:
std::reference_wrapper<Stream> socket;
};
}

View File

@ -2,13 +2,10 @@
#include "utils/exceptions/basic_exception.hpp"
namespace bolt
{
namespace bolt {
class StreamError : BasicException
{
public:
using BasicException::BasicException;
class StreamError : BasicException {
public:
using BasicException::BasicException;
};
}

View File

@ -7,324 +7,301 @@
#include "utils/bswap.hpp"
#include "utils/types/byte.hpp"
namespace bolt
{
namespace bolt {
// BoltDecoder for streams. Meant for use in SnapshotDecoder.
// This should be recoded to recieve the current caller so that decoder can
// based on a current type call it.
template <class STREAM>
class StreamedBoltDecoder
{
static constexpr int64_t plus_2_to_the_31 = 2147483648L;
static constexpr int64_t plus_2_to_the_15 = 32768L;
static constexpr int64_t plus_2_to_the_7 = 128L;
static constexpr int64_t minus_2_to_the_4 = -16L;
static constexpr int64_t minus_2_to_the_7 = -128L;
static constexpr int64_t minus_2_to_the_15 = -32768L;
static constexpr int64_t minus_2_to_the_31 = -2147483648L;
class StreamedBoltDecoder {
static constexpr int64_t plus_2_to_the_31 = 2147483648L;
static constexpr int64_t plus_2_to_the_15 = 32768L;
static constexpr int64_t plus_2_to_the_7 = 128L;
static constexpr int64_t minus_2_to_the_4 = -16L;
static constexpr int64_t minus_2_to_the_7 = -128L;
static constexpr int64_t minus_2_to_the_15 = -32768L;
static constexpr int64_t minus_2_to_the_31 = -2147483648L;
public:
StreamedBoltDecoder(STREAM &stream) : stream(stream) {}
public:
StreamedBoltDecoder(STREAM &stream) : stream(stream) {}
// Returns mark of a data.
size_t mark() { return peek_byte(); }
// Returns mark of a data.
size_t mark() { return peek_byte(); }
// Calls handle with current primitive data. Throws DecoderException if it
// isn't a primitive.
template <class H, class T>
T accept_primitive(H &handle)
{
switch (byte()) {
case pack::False: {
return handle.handle(false);
}
case pack::True: {
return handle.handle(true);
}
case pack::Float64: {
return handle.handle(read_double());
}
default: {
return handle.handle(integer());
}
};
// Calls handle with current primitive data. Throws DecoderException if it
// isn't a primitive.
template <class H, class T>
T accept_primitive(H &handle) {
switch (byte()) {
case pack::False: {
return handle.handle(false);
}
case pack::True: {
return handle.handle(true);
}
case pack::Float64: {
return handle.handle(read_double());
}
default: { return handle.handle(integer()); }
};
}
// Reads map header. Throws DecoderException if it isn't map header.
size_t map_header() {
auto marker = byte();
size_t size;
if ((marker & 0xF0) == pack::TinyMap) {
size = marker & 0x0F;
} else if (marker == pack::Map8) {
size = byte();
} else if (marker == pack::Map16) {
size = read<uint16_t>();
} else if (marker == pack::Map32) {
size = read<uint32_t>();
} else {
// Error
throw DecoderException(
"StreamedBoltDecoder: Tryed to read map header but found ", marker);
}
// Reads map header. Throws DecoderException if it isn't map header.
size_t map_header()
{
auto marker = byte();
return size;
}
size_t size;
bool is_list() {
auto marker = peek_byte();
if ((marker & 0xF0) == pack::TinyMap) {
size = marker & 0x0F;
if ((marker & 0xF0) == pack::TinyList) {
return true;
} else if (marker == pack::Map8) {
size = byte();
} else if (marker == pack::List8) {
return true;
} else if (marker == pack::Map16) {
size = read<uint16_t>();
} else if (marker == pack::List16) {
return true;
} else if (marker == pack::Map32) {
size = read<uint32_t>();
} else if (marker == pack::List32) {
return true;
} else {
return false;
}
}
} else {
// Error
throw DecoderException(
"StreamedBoltDecoder: Tryed to read map header but found ",
marker);
}
// Reads list header. Throws DecoderException if it isn't list header.
size_t list_header() {
auto marker = byte();
return size;
if ((marker & 0xF0) == pack::TinyList) {
return marker & 0x0F;
} else if (marker == pack::List8) {
return byte();
} else if (marker == pack::List16) {
return read<uint16_t>();
} else if (marker == pack::List32) {
return read<uint32_t>();
} else {
// Error
throw DecoderException(
"StreamedBoltDecoder: Tryed to read list header but found ", marker);
}
}
bool is_bool() {
auto marker = peek_byte();
if (marker == pack::True) {
return true;
} else if (marker == pack::False) {
return true;
} else {
return false;
}
}
// Reads bool.Throws DecoderException if it isn't bool.
bool read_bool() {
auto marker = byte();
if (marker == pack::True) {
return true;
} else if (marker == pack::False) {
return false;
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read bool header but found ", marker);
}
}
bool is_integer() {
auto marker = peek_byte();
if (marker >= minus_2_to_the_4 && marker < plus_2_to_the_7) {
return true;
} else if (marker == pack::Int8) {
return true;
} else if (marker == pack::Int16) {
return true;
} else if (marker == pack::Int32) {
return true;
} else if (marker == pack::Int64) {
return true;
} else {
return false;
}
}
// Reads integer.Throws DecoderException if it isn't integer.
int64_t integer() {
auto marker = byte();
if (marker >= minus_2_to_the_4 && marker < plus_2_to_the_7) {
return marker;
} else if (marker == pack::Int8) {
return byte();
} else if (marker == pack::Int16) {
return read<int16_t>();
} else if (marker == pack::Int32) {
return read<int32_t>();
} else if (marker == pack::Int64) {
return read<int64_t>();
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read integer but found ", marker);
}
}
bool is_double() {
auto marker = peek_byte();
return marker == pack::Float64;
}
// Reads double.Throws DecoderException if it isn't double.
double read_double() {
auto marker = byte();
if (marker == pack::Float64) {
auto tmp = read<int64_t>();
return *reinterpret_cast<const double *>(&tmp);
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read double but found ", marker);
}
}
bool is_string() {
auto marker = peek_byte();
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
return true;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
return true;
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
return true;
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
return true;
} else {
return false;
}
}
// Reads string into res. Throws DecoderException if it isn't string.
void string(std::string &res) {
if (!string_try(res)) {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read string but found ",
std::to_string(peek_byte()));
}
}
// Try-s to read string. Retunrns true on success. If it didn't succed
// stream remains unchanged
bool string_try(std::string &res) {
auto marker = peek_byte();
uint32_t size;
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
byte();
// size is stored in the lower 4 bits of the marker byte
size = marker & 0x0F;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
byte();
size = byte();
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
byte();
size = read<uint16_t>();
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
byte();
size = read<uint32_t>();
} else {
// Error
return false;
}
bool is_list()
{
auto marker = peek_byte();
if ((marker & 0xF0) == pack::TinyList) {
return true;
} else if (marker == pack::List8) {
return true;
} else if (marker == pack::List16) {
return true;
} else if (marker == pack::List32) {
return true;
} else {
return false;
}
if (size > 0) {
res.resize(size);
stream.read(&res.front(), size);
} else {
res.clear();
}
// Reads list header. Throws DecoderException if it isn't list header.
size_t list_header()
{
auto marker = byte();
return true;
}
if ((marker & 0xF0) == pack::TinyList) {
return marker & 0x0F;
private:
// Reads T from stream. It doens't care for alligment so this is valid only
// for primitives.
template <class T>
T read() {
buffer.resize(sizeof(T));
} else if (marker == pack::List8) {
return byte();
// Load value
stream.read(&buffer.front(), sizeof(T));
} else if (marker == pack::List16) {
return read<uint16_t>();
// reinterpret bytes as the target value
auto value = reinterpret_cast<const T *>(&buffer.front());
} else if (marker == pack::List32) {
return read<uint32_t>();
// swap values to little endian
return bswap(*value);
}
} else {
// Error
throw DecoderException(
"StreamedBoltDecoder: Tryed to read list header but found ",
marker);
}
}
::byte byte() { return stream.get(); }
::byte peek_byte() { return stream.peek(); }
bool is_bool()
{
auto marker = peek_byte();
if (marker == pack::True) {
return true;
} else if (marker == pack::False) {
return true;
} else {
return false;
}
}
// Reads bool.Throws DecoderException if it isn't bool.
bool read_bool()
{
auto marker = byte();
if (marker == pack::True) {
return true;
} else if (marker == pack::False) {
return false;
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read bool header but found ",
marker);
}
}
bool is_integer()
{
auto marker = peek_byte();
if (marker >= minus_2_to_the_4 && marker < plus_2_to_the_7) {
return true;
} else if (marker == pack::Int8) {
return true;
} else if (marker == pack::Int16) {
return true;
} else if (marker == pack::Int32) {
return true;
} else if (marker == pack::Int64) {
return true;
} else {
return false;
}
}
// Reads integer.Throws DecoderException if it isn't integer.
int64_t integer()
{
auto marker = byte();
if (marker >= minus_2_to_the_4 && marker < plus_2_to_the_7) {
return marker;
} else if (marker == pack::Int8) {
return byte();
} else if (marker == pack::Int16) {
return read<int16_t>();
} else if (marker == pack::Int32) {
return read<int32_t>();
} else if (marker == pack::Int64) {
return read<int64_t>();
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read integer but found ",
marker);
}
}
bool is_double()
{
auto marker = peek_byte();
return marker == pack::Float64;
}
// Reads double.Throws DecoderException if it isn't double.
double read_double()
{
auto marker = byte();
if (marker == pack::Float64) {
auto tmp = read<int64_t>();
return *reinterpret_cast<const double *>(&tmp);
} else {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read double but found ", marker);
}
}
bool is_string()
{
auto marker = peek_byte();
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
return true;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
return true;
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
return true;
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
return true;
} else {
return false;
}
}
// Reads string into res. Throws DecoderException if it isn't string.
void string(std::string &res)
{
if (!string_try(res)) {
throw DecoderException(
"StreamedBoltDecoder: Tryed to read string but found ",
std::to_string(peek_byte()));
}
}
// Try-s to read string. Retunrns true on success. If it didn't succed
// stream remains unchanged
bool string_try(std::string &res)
{
auto marker = peek_byte();
uint32_t size;
// if the first 4 bits equal to 1000 (0x8), this is a tiny string
if ((marker & 0xF0) == pack::TinyString) {
byte();
// size is stored in the lower 4 bits of the marker byte
size = marker & 0x0F;
}
// if the marker is 0xD0, size is an 8-bit unsigned integer
else if (marker == pack::String8) {
byte();
size = byte();
}
// if the marker is 0xD1, size is a 16-bit big-endian unsigned integer
else if (marker == pack::String16) {
byte();
size = read<uint16_t>();
}
// if the marker is 0xD2, size is a 32-bit big-endian unsigned integer
else if (marker == pack::String32) {
byte();
size = read<uint32_t>();
} else {
// Error
return false;
}
if (size > 0) {
res.resize(size);
stream.read(&res.front(), size);
} else {
res.clear();
}
return true;
}
private:
// Reads T from stream. It doens't care for alligment so this is valid only
// for primitives.
template <class T>
T read()
{
buffer.resize(sizeof(T));
// Load value
stream.read(&buffer.front(), sizeof(T));
// reinterpret bytes as the target value
auto value = reinterpret_cast<const T *>(&buffer.front());
// swap values to little endian
return bswap(*value);
}
::byte byte() { return stream.get(); }
::byte peek_byte() { return stream.peek(); }
STREAM &stream;
std::string buffer;
STREAM &stream;
std::string buffer;
};
};

View File

@ -1,7 +1,6 @@
#include "config/config.hpp"
namespace config
{
namespace config {
const char *MemgraphConfig::env_config_key = "MEMGRAPH_CONFIG";
const char *MemgraphConfig::default_file_path = "/etc/memgraph/config.yaml";
@ -11,8 +10,6 @@ const char *MemgraphConfig::default_file_path = "/etc/memgraph/config.yaml";
// Example:
// --cleaning_cycle_sec or -ccs, etc.
std::set<std::string> MemgraphConfig::arguments = {
"cleaning_cycle_sec",
"snapshot_cycle_sec",
"cleaning_cycle_sec", "snapshot_cycle_sec",
};
}

View File

@ -2,23 +2,21 @@
#include "utils/config/config.hpp"
#include <string>
#include <set>
#include <string>
namespace config
{
namespace config {
// this class is used as a Definition class of config::Config class from utils
// number of elements should be small,
// it depends on implementation of config::Config class
// in other words number of fields in Definition class should be related
// to the number of config keys
class MemgraphConfig
{
public:
static const char *env_config_key;
static const char *default_file_path;
static std::set<std::string> arguments;
class MemgraphConfig {
public:
static const char *env_config_key;
static const char *default_file_path;
static std::set<std::string> arguments;
};
// -- all possible Memgraph's keys --

View File

@ -7,149 +7,132 @@
#include "threading/sync/spinlock.hpp"
template <class block_t = uint8_t, size_t chunk_size = 32768>
class DynamicBitset : Lockable<SpinLock>
{
struct Block
{
Block() = default;
class DynamicBitset : Lockable<SpinLock> {
struct Block {
Block() = default;
Block(Block &) = delete;
Block(Block &&) = delete;
Block(Block &) = delete;
Block(Block &&) = delete;
static constexpr size_t size = sizeof(block_t) * 8;
static constexpr size_t size = sizeof(block_t) * 8;
constexpr block_t bitmask(size_t group_size) const
{
return (block_t)(-1) >> (size - group_size);
}
block_t at(size_t k, size_t n, std::memory_order order)
{
assert(k + n - 1 < size);
return (block.load(order) >> k) & bitmask(n);
}
void set(size_t k, size_t n, std::memory_order order)
{
assert(k + n - 1 < size);
block.fetch_or(bitmask(n) << k, order);
}
void clear(size_t k, size_t n, std::memory_order order)
{
assert(k + n - 1 < size);
block.fetch_and(~(bitmask(n) << k), order);
}
std::atomic<block_t> block{0};
};
struct Chunk
{
Chunk() : next(nullptr)
{
static_assert(chunk_size % sizeof(block_t) == 0,
"chunk size not divisible by block size");
}
Chunk(Chunk &) = delete;
Chunk(Chunk &&) = delete;
static constexpr size_t size = chunk_size * Block::size;
static constexpr size_t n_blocks = chunk_size / sizeof(block_t);
block_t at(size_t k, size_t n, std::memory_order order)
{
return blocks[k / Block::size].at(k % Block::size, n, order);
}
void set(size_t k, size_t n, std::memory_order order)
{
blocks[k / Block::size].set(k % Block::size, n, order);
}
void clear(size_t k, size_t n, std::memory_order order)
{
blocks[k / Block::size].clear(k % Block::size, n, order);
}
Block blocks[n_blocks];
std::atomic<Chunk *> next;
};
public:
DynamicBitset() : head(new Chunk()) {}
DynamicBitset(DynamicBitset &) = delete;
DynamicBitset(DynamicBitset &&) = delete;
~DynamicBitset()
{
auto now = head.load();
while (now != nullptr) {
auto next = now->next.load();
delete now;
now = next;
}
constexpr block_t bitmask(size_t group_size) const {
return (block_t)(-1) >> (size - group_size);
}
block_t at(size_t k, size_t n)
{
auto &chunk = find_chunk(k);
return chunk.at(k, n, std::memory_order_seq_cst);
block_t at(size_t k, size_t n, std::memory_order order) {
assert(k + n - 1 < size);
return (block.load(order) >> k) & bitmask(n);
}
bool at(size_t k)
{
auto &chunk = find_chunk(k);
return chunk.at(k, 1, std::memory_order_seq_cst);
void set(size_t k, size_t n, std::memory_order order) {
assert(k + n - 1 < size);
block.fetch_or(bitmask(n) << k, order);
}
void set(size_t k, size_t n = 1)
{
auto &chunk = find_chunk(k);
return chunk.set(k, n, std::memory_order_seq_cst);
void clear(size_t k, size_t n, std::memory_order order) {
assert(k + n - 1 < size);
block.fetch_and(~(bitmask(n) << k), order);
}
void clear(size_t k, size_t n = 1)
{
auto &chunk = find_chunk(k);
return chunk.clear(k, n, std::memory_order_seq_cst);
std::atomic<block_t> block{0};
};
struct Chunk {
Chunk() : next(nullptr) {
static_assert(chunk_size % sizeof(block_t) == 0,
"chunk size not divisible by block size");
}
private:
Chunk &find_chunk(size_t &k)
{
Chunk *chunk = head.load(), *next = nullptr;
Chunk(Chunk &) = delete;
Chunk(Chunk &&) = delete;
// while i'm not in the right chunk
// (my index is bigger than the size of this chunk)
while (k >= Chunk::size) {
next = chunk->next.load();
static constexpr size_t size = chunk_size * Block::size;
static constexpr size_t n_blocks = chunk_size / sizeof(block_t);
// if a next chunk exists, switch to it and decrement my
// pointer by the size of the current chunk
if (next != nullptr) {
chunk = next;
k -= Chunk::size;
continue;
}
// the next chunk does not exist and we need it. take an exclusive
// lock to prevent others that also want to create a new chunk
// from creating it
auto guard = acquire_unique();
// double-check locking. if the chunk exists now, some other thread
// has just created it, continue searching for my chunk
if (chunk->next.load() != nullptr) continue;
chunk->next.store(new Chunk());
}
assert(chunk != nullptr);
return *chunk;
block_t at(size_t k, size_t n, std::memory_order order) {
return blocks[k / Block::size].at(k % Block::size, n, order);
}
std::atomic<Chunk *> head;
void set(size_t k, size_t n, std::memory_order order) {
blocks[k / Block::size].set(k % Block::size, n, order);
}
void clear(size_t k, size_t n, std::memory_order order) {
blocks[k / Block::size].clear(k % Block::size, n, order);
}
Block blocks[n_blocks];
std::atomic<Chunk *> next;
};
public:
DynamicBitset() : head(new Chunk()) {}
DynamicBitset(DynamicBitset &) = delete;
DynamicBitset(DynamicBitset &&) = delete;
~DynamicBitset() {
auto now = head.load();
while (now != nullptr) {
auto next = now->next.load();
delete now;
now = next;
}
}
block_t at(size_t k, size_t n) {
auto &chunk = find_chunk(k);
return chunk.at(k, n, std::memory_order_seq_cst);
}
bool at(size_t k) {
auto &chunk = find_chunk(k);
return chunk.at(k, 1, std::memory_order_seq_cst);
}
void set(size_t k, size_t n = 1) {
auto &chunk = find_chunk(k);
return chunk.set(k, n, std::memory_order_seq_cst);
}
void clear(size_t k, size_t n = 1) {
auto &chunk = find_chunk(k);
return chunk.clear(k, n, std::memory_order_seq_cst);
}
private:
Chunk &find_chunk(size_t &k) {
Chunk *chunk = head.load(), *next = nullptr;
// while i'm not in the right chunk
// (my index is bigger than the size of this chunk)
while (k >= Chunk::size) {
next = chunk->next.load();
// if a next chunk exists, switch to it and decrement my
// pointer by the size of the current chunk
if (next != nullptr) {
chunk = next;
k -= Chunk::size;
continue;
}
// the next chunk does not exist and we need it. take an exclusive
// lock to prevent others that also want to create a new chunk
// from creating it
auto guard = acquire_unique();
// double-check locking. if the chunk exists now, some other thread
// has just created it, continue searching for my chunk
if (chunk->next.load() != nullptr) continue;
chunk->next.store(new Chunk());
}
assert(chunk != nullptr);
return *chunk;
}
std::atomic<Chunk *> head;
};

View File

@ -13,64 +13,56 @@
* Type specifies the type of data stored
*/
template <class Type, int BucketSize = 8>
class BloomFilter
{
private:
using HashFunction = std::function<uint64_t(const Type &)>;
using CompresionFunction = std::function<int(uint64_t)>;
class BloomFilter {
private:
using HashFunction = std::function<uint64_t(const Type &)>;
using CompresionFunction = std::function<int(uint64_t)>;
std::bitset<BucketSize> filter_;
std::vector<HashFunction> hashes_;
CompresionFunction compression_;
std::vector<int> buckets;
std::bitset<BucketSize> filter_;
std::vector<HashFunction> hashes_;
CompresionFunction compression_;
std::vector<int> buckets;
int default_compression(uint64_t hash) { return hash % BucketSize; }
int default_compression(uint64_t hash) { return hash % BucketSize; }
void get_buckets(const Type &data)
{
for (int i = 0; i < hashes_.size(); i++)
buckets[i] = compression_(hashes_[i](data));
void get_buckets(const Type &data) {
for (int i = 0; i < hashes_.size(); i++)
buckets[i] = compression_(hashes_[i](data));
}
void print_buckets(std::vector<uint64_t> &buckets) {
for (int i = 0; i < buckets.size(); i++) {
std::cout << buckets[i] << " ";
}
std::cout << std::endl;
}
void print_buckets(std::vector<uint64_t> &buckets)
{
for (int i = 0; i < buckets.size(); i++)
{
std::cout << buckets[i] << " ";
}
std::cout << std::endl;
}
public:
BloomFilter(std::vector<HashFunction> funcs,
CompresionFunction compression = {})
: hashes_(funcs) {
if (!compression)
compression_ = std::bind(&BloomFilter::default_compression, this,
std::placeholders::_1);
else
compression_ = compression;
public:
BloomFilter(std::vector<HashFunction> funcs,
CompresionFunction compression = {})
: hashes_(funcs)
{
if (!compression)
compression_ = std::bind(&BloomFilter::default_compression, this,
std::placeholders::_1);
else
compression_ = compression;
buckets.resize(hashes_.size());
}
buckets.resize(hashes_.size());
}
bool contains(const Type &data) {
get_buckets(data);
bool contains_element = true;
bool contains(const Type &data)
{
get_buckets(data);
bool contains_element = true;
for (int i = 0; i < buckets.size(); i++)
contains_element &= filter_[buckets[i]];
for (int i = 0; i < buckets.size(); i++)
contains_element &= filter_[buckets[i]];
return contains_element;
}
return contains_element;
}
void insert(const Type &data) {
get_buckets(data);
void insert(const Type &data)
{
get_buckets(data);
for (int i = 0; i < buckets.size(); i++)
filter_[buckets[i]] = true;
}
for (int i = 0; i < buckets.size(); i++) filter_[buckets[i]] = true;
}
};

View File

@ -11,89 +11,79 @@ template <typename K, typename T>
class Item : public TotalOrdering<Item<K, T>>,
public TotalOrdering<K, Item<K, T>>,
public TotalOrdering<Item<K, T>, K>,
public pair<const K, T>
{
public:
using pair<const K, T>::pair;
public pair<const K, T> {
public:
using pair<const K, T>::pair;
friend constexpr bool operator<(const Item &lhs, const Item &rhs)
{
return lhs.first < rhs.first;
}
friend constexpr bool operator<(const Item &lhs, const Item &rhs) {
return lhs.first < rhs.first;
}
friend constexpr bool operator==(const Item &lhs, const Item &rhs)
{
return lhs.first == rhs.first;
}
friend constexpr bool operator==(const Item &lhs, const Item &rhs) {
return lhs.first == rhs.first;
}
friend constexpr bool operator<(const K &lhs, const Item &rhs)
{
return lhs < rhs.first;
}
friend constexpr bool operator<(const K &lhs, const Item &rhs) {
return lhs < rhs.first;
}
friend constexpr bool operator==(const K &lhs, const Item &rhs)
{
return lhs == rhs.first;
}
friend constexpr bool operator==(const K &lhs, const Item &rhs) {
return lhs == rhs.first;
}
friend constexpr bool operator<(const Item &lhs, const K &rhs)
{
return lhs.first < rhs;
}
friend constexpr bool operator<(const Item &lhs, const K &rhs) {
return lhs.first < rhs;
}
friend constexpr bool operator==(const Item &lhs, const K &rhs)
{
return lhs.first == rhs;
}
friend constexpr bool operator==(const Item &lhs, const K &rhs) {
return lhs.first == rhs;
}
};
// Common base for accessor of all derived containers(ConcurrentMap,
// ConcurrentSet, ...) from SkipList.
template <typename T>
class AccessorBase
{
typedef SkipList<T> list;
typedef typename SkipList<T>::Iterator list_it;
typedef typename SkipList<T>::ConstIterator list_it_con;
class AccessorBase {
typedef SkipList<T> list;
typedef typename SkipList<T>::Iterator list_it;
typedef typename SkipList<T>::ConstIterator list_it_con;
protected:
AccessorBase(list *skiplist) : accessor(skiplist->access()) {}
protected:
AccessorBase(list *skiplist) : accessor(skiplist->access()) {}
public:
AccessorBase(const AccessorBase &) = delete;
public:
AccessorBase(const AccessorBase &) = delete;
AccessorBase(AccessorBase &&other) : accessor(std::move(other.accessor)) {}
AccessorBase(AccessorBase &&other) : accessor(std::move(other.accessor)) {}
~AccessorBase() {}
~AccessorBase() {}
size_t size() { return accessor.size(); };
size_t size() { return accessor.size(); };
list_it begin() { return accessor.begin(); }
list_it begin() { return accessor.begin(); }
list_it_con begin() const { return accessor.cbegin(); }
list_it_con begin() const { return accessor.cbegin(); }
list_it_con cbegin() const { return accessor.cbegin(); }
list_it_con cbegin() const { return accessor.cbegin(); }
list_it end() { return accessor.end(); }
list_it end() { return accessor.end(); }
list_it_con end() const { return accessor.cend(); }
list_it_con end() const { return accessor.cend(); }
list_it_con cend() const { return accessor.cend(); }
list_it_con cend() const { return accessor.cend(); }
template <class K>
typename SkipList<T>::template MultiIterator<K> end(const K &data)
{
return accessor.template mend<K>(data);
}
template <class K>
typename SkipList<T>::template MultiIterator<K> end(const K &data) {
return accessor.template mend<K>(data);
}
template <class K>
typename SkipList<T>::template MultiIterator<K> mend(const K &data)
{
return accessor.template mend<K>(data);
}
template <class K>
typename SkipList<T>::template MultiIterator<K> mend(const K &data) {
return accessor.template mend<K>(data);
}
size_t size() const { return accessor.size(); }
size_t size() const { return accessor.size(); }
protected:
typename list::Accessor accessor;
protected:
typename list::Accessor accessor;
};

View File

@ -1,9 +1,8 @@
#pragma once
#include "data_structures/concurrent/common.hpp"
#include "data_structures/concurrent/skiplist.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "data_structures/concurrent/skiplist.hpp"
using std::pair;
@ -12,25 +11,25 @@ class ConcurrentBloomMap {
using item_t = Item<Key, Value>;
using list_it = typename SkipList<item_t>::Iterator;
private:
ConcurrentMap<Key, Value> map_;
BloomFilter filter_;
private:
ConcurrentMap<Key, Value> map_;
BloomFilter filter_;
public:
ConcurrentBloomMap(BloomFilter filter) : filter_(filter) {}
public:
ConcurrentBloomMap(BloomFilter filter) : filter_(filter) {}
std::pair<list_it, bool> insert(const Key &key, const Value &data) {
filter_.insert(key);
std::pair<list_it, bool> insert(const Key &key, const Value &data) {
filter_.insert(key);
auto accessor = std::move(map_.access());
auto accessor = std::move(map_.access());
return accessor.insert(key, data);
}
return accessor.insert(key, data);
}
bool contains(const Key &key) {
if (!filter_.contains(key)) return false;
bool contains(const Key &key) {
if (!filter_.contains(key)) return false;
auto accessor = map_.access();
return accessor.contains(key);
}
auto accessor = map_.access();
return accessor.contains(key);
}
};

View File

@ -8,330 +8,308 @@
// TODO: reimplement this. It's correct but somewhat inefecient and it could be
// done better.
template <class T>
class ConcurrentList
{
private:
template <class V>
static V load(std::atomic<V> &atomic)
{
return atomic.load(std::memory_order_acquire);
class ConcurrentList {
private:
template <class V>
static V load(std::atomic<V> &atomic) {
return atomic.load(std::memory_order_acquire);
}
template <class V>
static void store(std::atomic<V> &atomic,
V desired) { // Maybe could be relaxed
atomic.store(desired, std::memory_order_release);
}
template <class V>
static bool cas(
std::atomic<V> &atomic, V expected,
V desired) { // Could be relaxed but must be at least Release.
return atomic.compare_exchange_strong(expected, desired,
std::memory_order_seq_cst);
}
template <class V>
static V *swap(std::atomic<V *> &atomic, V *desired) { // Could be relaxed
return atomic.exchange(desired, std::memory_order_seq_cst);
}
// Basic element in a ConcurrentList
class Node {
public:
Node(const T &data) : data(data) {}
Node(T &&data) : data(std::move(data)) {}
// Carried data
T data;
// Next element in list or nullptr if end.
std::atomic<Node *> next{nullptr};
// Next removed element in list or nullptr if end.
std::atomic<Node *> next_rem{nullptr};
// True if node has logicaly been removed from list.
std::atomic<bool> removed{false};
};
// Base for Mutable and Immutable iterators. Also serves as accessor to the
// list uses for safe garbage disposall.
template <class It>
class IteratorBase : public Crtp<It> {
friend class ConcurrentList;
protected:
IteratorBase() : list(nullptr), curr(nullptr) {}
IteratorBase(ConcurrentList *list) : list(list) {
assert(list != nullptr);
// Increment number of iterators accessing list.
list->active_threads_no_++;
// Start from the begining of list.
reset();
}
template <class V>
static void store(std::atomic<V> &atomic, V desired)
{ // Maybe could be relaxed
atomic.store(desired, std::memory_order_release);
public:
IteratorBase(const IteratorBase &) = delete;
IteratorBase(IteratorBase &&other)
: list(other.list), curr(other.curr), prev(other.prev) {
other.list = nullptr;
other.curr = nullptr;
other.prev = nullptr;
}
template <class V>
static bool cas(std::atomic<V> &atomic, V expected, V desired)
{ // Could be relaxed but must be at least Release.
return atomic.compare_exchange_strong(expected, desired,
std::memory_order_seq_cst);
~IteratorBase() {
if (list == nullptr) {
return;
}
auto head_rem = load(list->removed);
// Next IF checks if this thread is responisble for disposall of
// collected garbage.
// Fetch could be relaxed
// There exist possibility that no one will delete garbage at this
// time but it will be deleted at some other time.
if (list->active_threads_no_.fetch_sub(1) ==
1 && // I am the last one accessing
head_rem != nullptr && // There is some garbage
cas<Node *>(list->removed, head_rem,
nullptr) // No new garbage was added.
) {
// Delete all removed node following chain of next_rem starting
// from head_rem.
auto now = head_rem;
do {
auto next = load(now->next_rem);
delete now;
now = next;
} while (now != nullptr);
}
}
template <class V>
static V *swap(std::atomic<V *> &atomic, V *desired)
{ // Could be relaxed
return atomic.exchange(desired, std::memory_order_seq_cst);
IteratorBase &operator=(IteratorBase const &other) = delete;
IteratorBase &operator=(IteratorBase &&other) = delete;
T &operator*() const {
assert(valid());
return curr->data;
}
T *operator->() const {
assert(valid());
return &(curr->data);
}
// Basic element in a ConcurrentList
class Node
{
public:
Node(const T &data) : data(data) {}
Node(T &&data) : data(std::move(data)) {}
bool valid() const { return curr != nullptr; }
// Carried data
T data;
// Iterating is wait free.
It &operator++() {
assert(valid());
do {
prev = curr;
curr = load(curr->next);
} while (valid() && is_removed()); // Loop ends if end of list is
// found or if not removed
// element is found.
return this->derived();
}
It &operator++(int) { return operator++(); }
// Next element in list or nullptr if end.
std::atomic<Node *> next{nullptr};
// Next removed element in list or nullptr if end.
std::atomic<Node *> next_rem{nullptr};
// True if node has logicaly been removed from list.
std::atomic<bool> removed{false};
};
// Base for Mutable and Immutable iterators. Also serves as accessor to the
// list uses for safe garbage disposall.
template <class It>
class IteratorBase : public Crtp<It>
{
friend class ConcurrentList;
protected:
IteratorBase() : list(nullptr), curr(nullptr) {}
IteratorBase(ConcurrentList *list) : list(list)
{
assert(list != nullptr);
// Increment number of iterators accessing list.
list->active_threads_no_++;
// Start from the begining of list.
reset();
}
public:
IteratorBase(const IteratorBase &) = delete;
IteratorBase(IteratorBase &&other)
: list(other.list), curr(other.curr), prev(other.prev)
{
other.list = nullptr;
other.curr = nullptr;
other.prev = nullptr;
}
~IteratorBase()
{
if (list == nullptr) {
return;
}
auto head_rem = load(list->removed);
// Next IF checks if this thread is responisble for disposall of
// collected garbage.
// Fetch could be relaxed
// There exist possibility that no one will delete garbage at this
// time but it will be deleted at some other time.
if (list->active_threads_no_.fetch_sub(1) == 1 && // I am the last one accessing
head_rem != nullptr && // There is some garbage
cas<Node *>(list->removed, head_rem,
nullptr) // No new garbage was added.
) {
// Delete all removed node following chain of next_rem starting
// from head_rem.
auto now = head_rem;
do {
auto next = load(now->next_rem);
delete now;
now = next;
} while (now != nullptr);
}
}
IteratorBase &operator=(IteratorBase const &other) = delete;
IteratorBase &operator=(IteratorBase &&other) = delete;
T &operator*() const
{
assert(valid());
return curr->data;
}
T *operator->() const
{
assert(valid());
return &(curr->data);
}
bool valid() const { return curr != nullptr; }
// Iterating is wait free.
It &operator++()
{
assert(valid());
do {
prev = curr;
curr = load(curr->next);
} while (valid() && is_removed()); // Loop ends if end of list is
// found or if not removed
// element is found.
return this->derived();
}
It &operator++(int) { return operator++(); }
bool is_removed()
{
assert(valid());
return load(curr->removed);
}
// Returns IteratorBase to begining
void reset()
{
prev = nullptr;
curr = load(list->head);
if (valid() && is_removed()) {
operator++();
}
}
// Adds to the begining of list
// It is lock free but it isn't wait free.
void push(T &&data)
{
// It could be done with unique_ptr but while this could meen memory
// leak on excpetion, unique_ptr could meean use after free. Memory
// leak is less dangerous.
auto node = new Node(data);
Node *next = nullptr;
// Insert at begining of list. Retrys on failure.
do {
next = load(list->head);
// First connect to next.
store(node->next, next);
// Then try to set as head.
} while (!cas(list->head, next, node));
list->count_.fetch_add(1);
}
// True only if this call removed the element. Only reason for fail is
// if the element is already removed.
// Remove has deadlock if another thread dies between marking node for
// removal and the disconnection.
// This can be improved with combinig the removed flag with prev.next or
// curr.next
bool remove()
{
assert(valid());
// Try to logically remove it.
if (cas(curr->removed, false, true)) {
// I removed it!!!
// Try to disconnect it from list.
if (!disconnect()) {
// Disconnection failed because Node relative location in
// list changed. Whe firstly must find it again and then try
// to disconnect it again.
find_and_disconnect();
}
// Add to list of to be garbage collected.
store(curr->next_rem, swap(list->removed, curr));
list->count_.fetch_sub(1);
return true;
}
return false;
}
friend bool operator==(const It &a, const It &b)
{
return a.curr == b.curr;
}
friend bool operator!=(const It &a, const It &b) { return !(a == b); }
private:
// Fids current element starting from the begining of the list Retrys
// until it succesffuly disconnects it.
void find_and_disconnect()
{
Node *bef = nullptr;
auto now = load(list->head);
auto next = load(curr->next);
while (now != nullptr) {
if (now == curr) {
// Found it.
prev = bef; // Set the correct previous node in list.
if (disconnect()) {
// succesffuly disconnected it.
return;
}
// Let's try again from the begining.
bef = nullptr;
now = load(list->head);
} else if (now == next) { // Comparison with next is
// optimization for early return.
return;
} else {
// Now isn't the one whe are looking for lets try next one.
bef = now;
now = load(now->next);
}
}
}
// Trys to disconnect currrent element from
bool disconnect()
{
auto next = load(curr->next);
if (prev != nullptr) {
store(prev->next, next);
if (load(prev->removed)) {
// previous isn't previous any more.
return false;
}
} else if (!cas(list->head, curr, next)) {
return false;
}
return true;
}
ConcurrentList *list;
Node *prev{nullptr};
Node *curr;
};
public:
class ConstIterator : public IteratorBase<ConstIterator>
{
friend class ConcurrentList;
public:
using IteratorBase<ConstIterator>::IteratorBase;
const T &operator*() const
{
return IteratorBase<ConstIterator>::operator*();
}
const T *operator->() const
{
return IteratorBase<ConstIterator>::operator->();
}
operator const T &() const
{
return IteratorBase<ConstIterator>::operator T &();
}
};
class Iterator : public IteratorBase<Iterator>
{
friend class ConcurrentList;
public:
using IteratorBase<Iterator>::IteratorBase;
};
public:
ConcurrentList() = default;
ConcurrentList(ConcurrentList &) = delete;
ConcurrentList(ConcurrentList &&) = delete;
~ConcurrentList()
{
auto now = head.load();
while (now != nullptr) {
auto next = now->next.load();
delete now;
now = next;
}
bool is_removed() {
assert(valid());
return load(curr->removed);
}
void operator=(ConcurrentList &) = delete;
// Returns IteratorBase to begining
void reset() {
prev = nullptr;
curr = load(list->head);
if (valid() && is_removed()) {
operator++();
}
}
Iterator begin() { return Iterator(this); }
// Adds to the begining of list
// It is lock free but it isn't wait free.
void push(T &&data) {
// It could be done with unique_ptr but while this could meen memory
// leak on excpetion, unique_ptr could meean use after free. Memory
// leak is less dangerous.
auto node = new Node(data);
Node *next = nullptr;
// Insert at begining of list. Retrys on failure.
do {
next = load(list->head);
// First connect to next.
store(node->next, next);
// Then try to set as head.
} while (!cas(list->head, next, node));
ConstIterator cbegin() { return ConstIterator(this); }
list->count_.fetch_add(1);
}
Iterator end() { return Iterator(); }
// True only if this call removed the element. Only reason for fail is
// if the element is already removed.
// Remove has deadlock if another thread dies between marking node for
// removal and the disconnection.
// This can be improved with combinig the removed flag with prev.next or
// curr.next
bool remove() {
assert(valid());
// Try to logically remove it.
if (cas(curr->removed, false, true)) {
// I removed it!!!
// Try to disconnect it from list.
if (!disconnect()) {
// Disconnection failed because Node relative location in
// list changed. Whe firstly must find it again and then try
// to disconnect it again.
find_and_disconnect();
}
// Add to list of to be garbage collected.
store(curr->next_rem, swap(list->removed, curr));
list->count_.fetch_sub(1);
return true;
}
return false;
}
ConstIterator cend() { return ConstIterator(); }
friend bool operator==(const It &a, const It &b) {
return a.curr == b.curr;
}
std::size_t active_threads_no() { return active_threads_no_.load(); }
std::size_t size() { return count_.load(); }
friend bool operator!=(const It &a, const It &b) { return !(a == b); }
private:
// TODO: use lazy GC or something else as a garbage collection strategy
// use the same principle as in skiplist
std::atomic<std::size_t> active_threads_no_{0};
std::atomic<std::size_t> count_{0};
std::atomic<Node *> head{nullptr};
std::atomic<Node *> removed{nullptr};
private:
// Fids current element starting from the begining of the list Retrys
// until it succesffuly disconnects it.
void find_and_disconnect() {
Node *bef = nullptr;
auto now = load(list->head);
auto next = load(curr->next);
while (now != nullptr) {
if (now == curr) {
// Found it.
prev = bef; // Set the correct previous node in list.
if (disconnect()) {
// succesffuly disconnected it.
return;
}
// Let's try again from the begining.
bef = nullptr;
now = load(list->head);
} else if (now == next) { // Comparison with next is
// optimization for early return.
return;
} else {
// Now isn't the one whe are looking for lets try next one.
bef = now;
now = load(now->next);
}
}
}
// Trys to disconnect currrent element from
bool disconnect() {
auto next = load(curr->next);
if (prev != nullptr) {
store(prev->next, next);
if (load(prev->removed)) {
// previous isn't previous any more.
return false;
}
} else if (!cas(list->head, curr, next)) {
return false;
}
return true;
}
ConcurrentList *list;
Node *prev{nullptr};
Node *curr;
};
public:
class ConstIterator : public IteratorBase<ConstIterator> {
friend class ConcurrentList;
public:
using IteratorBase<ConstIterator>::IteratorBase;
const T &operator*() const {
return IteratorBase<ConstIterator>::operator*();
}
const T *operator->() const {
return IteratorBase<ConstIterator>::operator->();
}
operator const T &() const {
return IteratorBase<ConstIterator>::operator T &();
}
};
class Iterator : public IteratorBase<Iterator> {
friend class ConcurrentList;
public:
using IteratorBase<Iterator>::IteratorBase;
};
public:
ConcurrentList() = default;
ConcurrentList(ConcurrentList &) = delete;
ConcurrentList(ConcurrentList &&) = delete;
~ConcurrentList() {
auto now = head.load();
while (now != nullptr) {
auto next = now->next.load();
delete now;
now = next;
}
}
void operator=(ConcurrentList &) = delete;
Iterator begin() { return Iterator(this); }
ConstIterator cbegin() { return ConstIterator(this); }
Iterator end() { return Iterator(); }
ConstIterator cend() { return ConstIterator(); }
std::size_t active_threads_no() { return active_threads_no_.load(); }
std::size_t size() { return count_.load(); }
private:
// TODO: use lazy GC or something else as a garbage collection strategy
// use the same principle as in skiplist
std::atomic<std::size_t> active_threads_no_{0};
std::atomic<std::size_t> count_{0};
std::atomic<Node *> head{nullptr};
std::atomic<Node *> removed{nullptr};
};

View File

@ -12,82 +12,70 @@ using std::pair;
* @tparam T is a type of data.
*/
template <typename K, typename T>
class ConcurrentMap
{
typedef Item<K, T> item_t;
typedef SkipList<item_t> list;
typedef typename SkipList<item_t>::Iterator list_it;
typedef typename SkipList<item_t>::ConstIterator list_it_con;
class ConcurrentMap {
typedef Item<K, T> item_t;
typedef SkipList<item_t> list;
typedef typename SkipList<item_t>::Iterator list_it;
typedef typename SkipList<item_t>::ConstIterator list_it_con;
public:
ConcurrentMap() {}
public:
ConcurrentMap() {}
class Accessor : public AccessorBase<item_t>
{
friend class ConcurrentMap;
class Accessor : public AccessorBase<item_t> {
friend class ConcurrentMap;
using AccessorBase<item_t>::AccessorBase;
using AccessorBase<item_t>::AccessorBase;
private:
using AccessorBase<item_t>::accessor;
private:
using AccessorBase<item_t>::accessor;
public:
std::pair<list_it, bool> insert(const K &key, const T &data)
{
return accessor.insert(item_t(key, data));
}
public:
std::pair<list_it, bool> insert(const K &key, const T &data) {
return accessor.insert(item_t(key, data));
}
std::pair<list_it, bool> insert(const K &key, T &&data)
{
return accessor.insert(item_t(key, std::move(data)));
}
std::pair<list_it, bool> insert(const K &key, T &&data) {
return accessor.insert(item_t(key, std::move(data)));
}
std::pair<list_it, bool> insert(K &&key, T &&data)
{
return accessor.insert(
item_t(std::forward<K>(key), std::forward<T>(data)));
}
std::pair<list_it, bool> insert(K &&key, T &&data) {
return accessor.insert(
item_t(std::forward<K>(key), std::forward<T>(data)));
}
template <class... Args1, class... Args2>
std::pair<list_it, bool> emplace(const K &key,
std::tuple<Args1...> first_args,
std::tuple<Args2...> second_args)
{
return accessor.emplace(
key, std::piecewise_construct,
std::forward<std::tuple<Args1...>>(first_args),
std::forward<std::tuple<Args2...>>(second_args));
}
template <class... Args1, class... Args2>
std::pair<list_it, bool> emplace(const K &key,
std::tuple<Args1...> first_args,
std::tuple<Args2...> second_args) {
return accessor.emplace(key, std::piecewise_construct,
std::forward<std::tuple<Args1...>>(first_args),
std::forward<std::tuple<Args2...>>(second_args));
}
list_it_con find(const K &key) const { return accessor.find(key); }
list_it_con find(const K &key) const { return accessor.find(key); }
list_it find(const K &key) { return accessor.find(key); }
list_it find(const K &key) { return accessor.find(key); }
// Returns iterator to item or first larger if it doesn't exist.
list_it_con find_or_larger(const T &item) const
{
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it_con find_or_larger(const T &item) const {
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it find_or_larger(const T &item)
{
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it find_or_larger(const T &item) {
return accessor.find_or_larger(item);
}
bool contains(const K &key) const
{
return this->find(key) != this->end();
}
bool contains(const K &key) const { return this->find(key) != this->end(); }
bool remove(const K &key) { return accessor.remove(key); }
};
bool remove(const K &key) { return accessor.remove(key); }
};
Accessor access() { return Accessor(&skiplist); }
Accessor access() { return Accessor(&skiplist); }
// TODO:
// const Accessor access() const { return Accessor(&skiplist); }
// TODO:
// const Accessor access() const { return Accessor(&skiplist); }
private:
list skiplist;
private:
list skiplist;
};

View File

@ -12,77 +12,63 @@ using std::pair;
* @tparam T is a type of data.
*/
template <typename K, typename T>
class ConcurrentMultiMap
{
typedef Item<K, T> item_t;
typedef SkipList<item_t> list;
typedef typename SkipList<item_t>::Iterator list_it;
typedef typename SkipList<item_t>::ConstIterator list_it_con;
typedef typename SkipList<item_t>::template MultiIterator<K> list_it_multi;
class ConcurrentMultiMap {
typedef Item<K, T> item_t;
typedef SkipList<item_t> list;
typedef typename SkipList<item_t>::Iterator list_it;
typedef typename SkipList<item_t>::ConstIterator list_it_con;
typedef typename SkipList<item_t>::template MultiIterator<K> list_it_multi;
public:
ConcurrentMultiMap() {}
public:
ConcurrentMultiMap() {}
class Accessor : public AccessorBase<item_t>
{
friend class ConcurrentMultiMap<K, T>;
class Accessor : public AccessorBase<item_t> {
friend class ConcurrentMultiMap<K, T>;
using AccessorBase<item_t>::AccessorBase;
using AccessorBase<item_t>::AccessorBase;
private:
using AccessorBase<item_t>::accessor;
private:
using AccessorBase<item_t>::accessor;
public:
list_it insert(const K &key, const T &data)
{
return accessor.insert_non_unique(item_t(key, data));
}
public:
list_it insert(const K &key, const T &data) {
return accessor.insert_non_unique(item_t(key, data));
}
list_it insert(const K &key, T &&data)
{
return accessor.insert_non_unique(
item_t(key, std::forward<T>(data)));
}
list_it insert(const K &key, T &&data) {
return accessor.insert_non_unique(item_t(key, std::forward<T>(data)));
}
list_it insert(K &&key, T &&data)
{
return accessor.insert_non_unique(
item_t(std::forward<K>(key), std::forward<T>(data)));
}
list_it insert(K &&key, T &&data) {
return accessor.insert_non_unique(
item_t(std::forward<K>(key), std::forward<T>(data)));
}
list_it_multi find_multi(const K &key)
{
return accessor.find_multi(key);
}
list_it_multi find_multi(const K &key) { return accessor.find_multi(key); }
list_it_con find(const K &key) const { return accessor.find(key); }
list_it_con find(const K &key) const { return accessor.find(key); }
list_it find(const K &key) { return accessor.find(key); }
list_it find(const K &key) { return accessor.find(key); }
// Returns iterator to item or first larger if it doesn't exist.
list_it_con find_or_larger(const T &item) const
{
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it_con find_or_larger(const T &item) const {
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it find_or_larger(const T &item)
{
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it find_or_larger(const T &item) {
return accessor.find_or_larger(item);
}
bool contains(const K &key) const
{
return this->find(key) != this->end();
}
bool contains(const K &key) const { return this->find(key) != this->end(); }
bool remove(const K &key) { return accessor.remove(key); }
};
bool remove(const K &key) { return accessor.remove(key); }
};
Accessor access() { return Accessor(&skiplist); }
Accessor access() { return Accessor(&skiplist); }
const Accessor access() const { return Accessor(&skiplist); }
const Accessor access() const { return Accessor(&skiplist); }
private:
list skiplist;
private:
list skiplist;
};

View File

@ -5,63 +5,54 @@
// Multi thread safe multiset based on skiplist.
// T - type of data.
template <class T>
class ConcurrentMultiSet
{
typedef SkipList<T> list;
typedef typename SkipList<T>::Iterator list_it;
typedef typename SkipList<T>::ConstIterator list_it_con;
class ConcurrentMultiSet {
typedef SkipList<T> list;
typedef typename SkipList<T>::Iterator list_it;
typedef typename SkipList<T>::ConstIterator list_it_con;
public:
ConcurrentMultiSet() {}
public:
ConcurrentMultiSet() {}
class Accessor : public AccessorBase<T>
{
friend class ConcurrentMultiSet;
class Accessor : public AccessorBase<T> {
friend class ConcurrentMultiSet;
using AccessorBase<T>::AccessorBase;
using AccessorBase<T>::AccessorBase;
private:
using AccessorBase<T>::accessor;
private:
using AccessorBase<T>::accessor;
public:
list_it insert(const T &item)
{
return accessor.insert_non_unique(item);
}
public:
list_it insert(const T &item) { return accessor.insert_non_unique(item); }
list_it insert(T &&item)
{
return accessor.insert_non_unique(std::forward<T>(item));
}
list_it insert(T &&item) {
return accessor.insert_non_unique(std::forward<T>(item));
}
list_it_con find(const T &item) const { return accessor.find(item); }
list_it_con find(const T &item) const { return accessor.find(item); }
list_it find(const T &item) { return accessor.find(item); }
list_it find(const T &item) { return accessor.find(item); }
// Returns iterator to item or first larger if it doesn't exist.
list_it_con find_or_larger(const T &item) const
{
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it_con find_or_larger(const T &item) const {
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it find_or_larger(const T &item)
{
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
list_it find_or_larger(const T &item) {
return accessor.find_or_larger(item);
}
bool contains(const T &item) const
{
return this->find(item) != this->end();
}
bool contains(const T &item) const {
return this->find(item) != this->end();
}
bool remove(const T &item) { return accessor.remove(item); }
};
bool remove(const T &item) { return accessor.remove(item); }
};
Accessor access() { return Accessor(&skiplist); }
Accessor access() { return Accessor(&skiplist); }
const Accessor access() const { return Accessor(&skiplist); }
const Accessor access() const { return Accessor(&skiplist); }
private:
list skiplist;
private:
list skiplist;
};

View File

@ -5,13 +5,13 @@
// Multi thread safe set based on skiplist.
// T - type of data.
template<class T>
template <class T>
class ConcurrentSet {
typedef SkipList<T> list;
typedef typename SkipList<T>::Iterator list_it;
typedef typename SkipList<T>::ConstIterator list_it_con;
public:
public:
ConcurrentSet() {}
class Accessor : public AccessorBase<T> {
@ -19,10 +19,10 @@ public:
using AccessorBase<T>::AccessorBase;
private:
private:
using AccessorBase<T>::accessor;
public:
public:
std::pair<list_it, bool> insert(const T &item) {
return accessor.insert(item);
}
@ -36,19 +36,19 @@ public:
list_it find(const T &item) { return accessor.find(item); }
// Returns iterator to item or first larger if it doesn't exist.
template<class K>
template <class K>
list_it_con find_or_larger(const K &item) const {
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
template<class K>
template <class K>
list_it find_or_larger(const K &item) {
return accessor.find_or_larger(item);
}
// Returns iterator to item or first larger if it doesn't exist.
template<class K>
template <class K>
list_it_con cfind_or_larger(const K &item) {
return accessor.template find_or_larger<list_it_con, K>(item);
}
@ -64,6 +64,6 @@ public:
const Accessor access() const { return Accessor(&skiplist); }
private:
private:
list skiplist;
};

View File

@ -192,9 +192,7 @@ class SkipList : private Lockable<lock_t> {
this->data.emplace(std::forward<Args>(args)...);
}
Node(const T &data, uint8_t height) : Node(height) {
this->data.set(data);
}
Node(const T &data, uint8_t height) : Node(height) { this->data.set(data); }
Node(T &&data, uint8_t height) : Node(height) {
this->data.set(std::move(data));
@ -391,8 +389,7 @@ class SkipList : private Lockable<lock_t> {
pred = node, node = pred->forward(level);
}
if (level_found == -1 && !less(prev, node))
level_found = level;
if (level_found == -1 && !less(prev, node)) level_found = level;
preds_[level] = pred;
}
@ -955,7 +952,8 @@ class SkipList : private Lockable<lock_t> {
* of node.
*/
// TODO this code is not DRY w.r.t. the other insert function (rvalue ref)
std::pair<Iterator, bool> insert(Node *preds[], Node *succs[], const T &data) {
std::pair<Iterator, bool> insert(Node *preds[], Node *succs[],
const T &data) {
while (true) {
// TODO: before here was data.first
auto level = find_path(this, H - 1, data, preds, succs);
@ -978,9 +976,9 @@ class SkipList : private Lockable<lock_t> {
// has the locks
if (!lock_nodes<true>(height, guards, preds, succs)) continue;
return {insert_here(Node::create(data, height), preds, succs,
height, guards),
true};
return {
insert_here(Node::create(data, height), preds, succs, height, guards),
true};
}
}

View File

@ -3,56 +3,54 @@
// TODO: remove from here and from the project
#include <iostream>
#include "logging/loggable.hpp"
#include "memory/freelist.hpp"
#include "memory/lazy_gc.hpp"
#include "threading/sync/spinlock.hpp"
#include "logging/loggable.hpp"
template <class T, class lock_t = SpinLock>
class SkiplistGC : public LazyGC<SkiplistGC<T, lock_t>, lock_t>, public Loggable
{
public:
SkiplistGC() : Loggable("SkiplistGC") {}
class SkiplistGC : public LazyGC<SkiplistGC<T, lock_t>, lock_t>,
public Loggable {
public:
SkiplistGC() : Loggable("SkiplistGC") {}
// release_ref method should be called by a thread
// when the thread finish it job over object
// which has to be lazy cleaned
// if thread counter becames zero, all objects in the local_freelist
// are going to be deleted
// the only problem with this approach is that
// GC may never be called, but for now we can deal with that
void release_ref()
// release_ref method should be called by a thread
// when the thread finish it job over object
// which has to be lazy cleaned
// if thread counter becames zero, all objects in the local_freelist
// are going to be deleted
// the only problem with this approach is that
// GC may never be called, but for now we can deal with that
void release_ref() {
std::vector<T *> local_freelist;
// take freelist if there is no more threads
{
std::vector<T *> local_freelist;
// take freelist if there is no more threads
{
auto lock = this->acquire_unique();
assert(this->count > 0);
--this->count;
if (this->count == 0) {
freelist.swap(local_freelist);
}
}
if (local_freelist.size() > 0) {
logger.trace("GC started");
logger.trace("Local list size: {}", local_freelist.size());
long long counter = 0;
// destroy all elements from local_freelist
for (auto element : local_freelist) {
if (element->flags.is_marked()) {
T::destroy(element);
counter++;
}
}
logger.trace("Number of destroyed elements: {}", counter);
}
auto lock = this->acquire_unique();
assert(this->count > 0);
--this->count;
if (this->count == 0) {
freelist.swap(local_freelist);
}
}
void collect(T *node) { freelist.add(node); }
if (local_freelist.size() > 0) {
logger.trace("GC started");
logger.trace("Local list size: {}", local_freelist.size());
long long counter = 0;
// destroy all elements from local_freelist
for (auto element : local_freelist) {
if (element->flags.is_marked()) {
T::destroy(element);
counter++;
}
}
logger.trace("Number of destroyed elements: {}", counter);
}
}
private:
FreeList<T> freelist;
void collect(T *node) { freelist.add(node); }
private:
FreeList<T> freelist;
};

View File

@ -1,11 +1,11 @@
#pragma once
#include <vector>
#include <algorithm>
#include <functional>
#include <vector>
#include "math.hpp"
#include "kdnode.hpp"
#include "math.hpp"
namespace kd {
@ -13,55 +13,52 @@ template <class T, class U>
using Nodes = std::vector<KdNode<T, U>*>;
template <class T, class U>
KdNode<T, U>* build(Nodes<T, U>& nodes, byte axis = 0)
{
// if there are no elements left, we've completed building of this branch
if(nodes.empty())
return nullptr;
KdNode<T, U>* build(Nodes<T, U>& nodes, byte axis = 0) {
// if there are no elements left, we've completed building of this branch
if (nodes.empty()) return nullptr;
// comparison function to use for sorting the elements
auto fsort = [axis](KdNode<T, U>* a, KdNode<T, U>* b) -> bool
{ return kd::math::axial_distance(a->coord, b->coord, axis) < 0; };
// comparison function to use for sorting the elements
auto fsort = [axis](KdNode<T, U>* a, KdNode<T, U>* b) -> bool {
return kd::math::axial_distance(a->coord, b->coord, axis) < 0;
};
size_t median = nodes.size() / 2;
size_t median = nodes.size() / 2;
// partial sort nodes vector to compute median and ensure that elements
// less than median are positioned before the median so we can slice it
// nicely
// partial sort nodes vector to compute median and ensure that elements
// less than median are positioned before the median so we can slice it
// nicely
// internal implementation is O(n) worst case
// tl;dr http://en.wikipedia.org/wiki/Introselect
std::nth_element(nodes.begin(), nodes.begin() + median, nodes.end(), fsort);
// internal implementation is O(n) worst case
// tl;dr http://en.wikipedia.org/wiki/Introselect
std::nth_element(nodes.begin(), nodes.begin() + median, nodes.end(), fsort);
// set axis for the node
auto node = nodes.at(median);
node->axis = axis;
// slice the vector into two halves
auto left = Nodes<T, U>(nodes.begin(), nodes.begin() + median);
auto right = Nodes<T, U>(nodes.begin() + median + 1, nodes.end());
// recursively build left and right branches
node->left = build(left, axis ^ 1);
node->right = build(right, axis ^ 1);
// set axis for the node
auto node = nodes.at(median);
node->axis = axis;
return node;
// slice the vector into two halves
auto left = Nodes<T, U>(nodes.begin(), nodes.begin() + median);
auto right = Nodes<T, U>(nodes.begin() + median + 1, nodes.end());
// recursively build left and right branches
node->left = build(left, axis ^ 1);
node->right = build(right, axis ^ 1);
return node;
}
template <class T, class U, class It>
KdNode<T, U>* build(It first, It last)
{
Nodes<T, U> kdnodes;
KdNode<T, U>* build(It first, It last) {
Nodes<T, U> kdnodes;
std::transform(first, last, std::back_inserter(kdnodes),
[&](const std::pair<Point<T>, U>& element) {
auto key = element.first;
auto data = element.second;
return new KdNode<T, U>(key, data);
});
std::transform(first, last, std::back_inserter(kdnodes),
[&](const std::pair<Point<T>, U>& element) {
auto key = element.first;
auto data = element.second;
return new KdNode<T, U>(key, data);
});
// build the tree from the kdnodes and return the root node
return build(kdnodes);
// build the tree from the kdnodes and return the root node
return build(kdnodes);
}
}

View File

@ -7,38 +7,40 @@
namespace kd {
template <class T, class U>
class KdNode
{
public:
KdNode(const U& data)
: axis(0), coord(Point<T>(0, 0)), left(nullptr), right(nullptr), data(data) { }
class KdNode {
public:
KdNode(const U& data)
: axis(0),
coord(Point<T>(0, 0)),
left(nullptr),
right(nullptr),
data(data) {}
KdNode(const Point<T>& coord, const U& data)
: axis(0), coord(coord), left(nullptr), right(nullptr), data(data) { }
KdNode(const Point<T>& coord, const U& data)
: axis(0), coord(coord), left(nullptr), right(nullptr), data(data) {}
KdNode(unsigned char axis, const Point<T>& coord, const U& data)
: axis(axis), coord(coord), left(nullptr), right(nullptr), data(data) { }
KdNode(unsigned char axis, const Point<T>& coord, const U& data)
: axis(axis), coord(coord), left(nullptr), right(nullptr), data(data) {}
KdNode(unsigned char axis, const Point<T>& coord, KdNode<T, U>* left, KdNode<T, U>* right, const U& data)
: axis(axis), coord(coord), left(left), right(right), data(data) { }
KdNode(unsigned char axis, const Point<T>& coord, KdNode<T, U>* left,
KdNode<T, U>* right, const U& data)
: axis(axis), coord(coord), left(left), right(right), data(data) {}
~KdNode();
~KdNode();
unsigned char axis;
unsigned char axis;
Point<T> coord;
Point<T> coord;
KdNode<T, U>* left;
KdNode<T, U>* right;
KdNode<T, U>* left;
KdNode<T, U>* right;
U data;
U data;
};
template <class T, class U>
KdNode<T, U>::~KdNode()
{
delete left;
delete right;
KdNode<T, U>::~KdNode() {
delete left;
delete right;
}
}

View File

@ -5,36 +5,31 @@
#include "build.hpp"
#include "nns.hpp"
namespace kd
{
namespace kd {
template <class T, class U>
class KdTree
{
public:
KdTree() {}
class KdTree {
public:
KdTree() {}
template <class It>
KdTree(It first, It last);
template <class It>
KdTree(It first, It last);
const U& lookup(const Point<T>& pk) const;
const U& lookup(const Point<T>& pk) const;
protected:
std::unique_ptr<KdNode<float, U>> root;
protected:
std::unique_ptr<KdNode<float, U>> root;
};
template <class T, class U>
const U& KdTree<T, U>::lookup(const Point<T>& pk) const
{
// do a nearest neighbour search on the tree
return kd::nns(pk, root.get())->data;
const U& KdTree<T, U>::lookup(const Point<T>& pk) const {
// do a nearest neighbour search on the tree
return kd::nns(pk, root.get())->data;
}
template <class T, class U>
template <class It>
KdTree<T, U>::KdTree(It first, It last)
{
root.reset(kd::build<T, U, It>(first, last));
KdTree<T, U>::KdTree(It first, It last) {
root.reset(kd::build<T, U, It>(first, last));
}
}

View File

@ -1,7 +1,7 @@
#pragma once
#include <limits>
#include <cmath>
#include <limits>
#include "point.hpp"
@ -11,30 +11,24 @@ namespace math {
using byte = unsigned char;
// returns the squared distance between two points
template<class T>
T distance_sq(const Point<T>& a, const Point<T>& b)
{
auto dx = a.longitude - b.longitude;
auto dy = a.latitude - b.latitude;
return dx * dx + dy * dy;
template <class T>
T distance_sq(const Point<T>& a, const Point<T>& b) {
auto dx = a.longitude - b.longitude;
auto dy = a.latitude - b.latitude;
return dx * dx + dy * dy;
}
// returns the distance between two points
template<class T>
T distance(const Point<T>& a, const Point<T>& b)
{
return std::sqrt(distance_sq(a, b));
template <class T>
T distance(const Point<T>& a, const Point<T>& b) {
return std::sqrt(distance_sq(a, b));
}
// returns the distance between two points looking at a specific axis
// \param axis 0 if abscissa else 1 if ordinate
template <class T>
T axial_distance(const Point<T>& a, const Point<T>& b, byte axis)
{
return axis == 0 ?
a.longitude - b.longitude:
a.latitude - b.latitude;
}
T axial_distance(const Point<T>& a, const Point<T>& b, byte axis) {
return axis == 0 ? a.longitude - b.longitude : a.latitude - b.latitude;
}
}
}

View File

@ -1,86 +1,80 @@
#pragma once
#include "kdnode.hpp"
#include "math.hpp"
#include "point.hpp"
#include "kdnode.hpp"
namespace kd {
// helper class for calculating the nearest neighbour in a kdtree
// helper class for calculating the nearest neighbour in a kdtree
template <class T, class U>
struct Result
{
Result()
: node(nullptr), distance_sq(std::numeric_limits<T>::infinity()) {}
struct Result {
Result() : node(nullptr), distance_sq(std::numeric_limits<T>::infinity()) {}
Result(const KdNode<T, U>* node, T distance_sq)
: node(node), distance_sq(distance_sq) {}
Result(const KdNode<T, U>* node, T distance_sq)
: node(node), distance_sq(distance_sq) {}
const KdNode<T, U>* node;
T distance_sq;
const KdNode<T, U>* node;
T distance_sq;
};
// a recursive implementation for the kdtree nearest neighbour search
// \param p the point for which we search for the nearest neighbour
// \param node the root of the subtree during recursive descent
// \param best the place to save the best result so far
template <class T, class U>
void nns(const Point<T>& p, const KdNode<T, U>* const node, Result<T, U>& best)
{
if(node == nullptr)
return;
void nns(const Point<T>& p, const KdNode<T, U>* const node,
Result<T, U>& best) {
if (node == nullptr) return;
T d = math::distance_sq(p, node->coord);
T d = math::distance_sq(p, node->coord);
// keep record of the closest point C found so far
if(d < best.distance_sq)
{
best.node = node;
best.distance_sq = d;
}
// keep record of the closest point C found so far
if (d < best.distance_sq) {
best.node = node;
best.distance_sq = d;
}
// where to traverse next?
// what to prune?
// where to traverse next?
// what to prune?
// |
// possible |
// prune *
// area | - - - - -* P
// |
//
// |----------|
// dx
//
// possible prune
// RIGHT area
//
// --------*------ ---
// | |
// LEFT |
// | | dy
// |
// | |
// * p ---
// |
// possible |
// prune *
// area | - - - - -* P
// |
//
// |----------|
// dx
//
T axd = math::axial_distance(p, node->coord, node->axis);
// possible prune
// RIGHT area
//
// --------*------ ---
// | |
// LEFT |
// | | dy
// |
// | |
// * p ---
// traverse the subtree in order that
// maximizes the probability for pruning
auto near = axd > 0 ? node->right : node->left;
auto far = axd > 0 ? node->left : node->right;
T axd = math::axial_distance(p, node->coord, node->axis);
// try near first
nns(p, near, best);
// traverse the subtree in order that
// maximizes the probability for pruning
auto near = axd > 0 ? node->right : node->left;
auto far = axd > 0 ? node->left : node->right;
// prune subtrees once their bounding boxes say
// that they can't contain any point closer than C
if(axd * axd >= best.distance_sq)
return;
// try near first
nns(p, near, best);
// try other subtree
nns(p, far, best);
// prune subtrees once their bounding boxes say
// that they can't contain any point closer than C
if (axd * axd >= best.distance_sq) return;
// try other subtree
nns(p, far, best);
}
// an implementation for the kdtree nearest neighbour search
@ -88,14 +82,12 @@ void nns(const Point<T>& p, const KdNode<T, U>* const node, Result<T, U>& best)
// \param root the root of the tree
// \return the nearest neighbour for the point p
template <class T, class U>
const KdNode<T, U>* nns(const Point<T>& p, const KdNode<T, U>* root)
{
Result<T, U> best;
const KdNode<T, U>* nns(const Point<T>& p, const KdNode<T, U>* root) {
Result<T, U> best;
// begin recursive search
nns(p, root, best);
// begin recursive search
nns(p, root, best);
return best.node;
return best.node;
}
}

View File

@ -5,26 +5,22 @@
namespace kd {
template <class T>
class Point
{
public:
Point(T latitude, T longitude)
: latitude(latitude), longitude(longitude) {}
class Point {
public:
Point(T latitude, T longitude) : latitude(latitude), longitude(longitude) {}
// latitude
// y
// ^
// |
// 0---> x longitude
// latitude
// y
// ^
// |
// 0---> x longitude
T latitude;
T longitude;
T latitude;
T longitude;
/// nice stream formatting with the standard << operator
friend std::ostream& operator<< (std::ostream& stream, const Point& p) {
return stream << "(lat: " << p.latitude
<< ", lng: " << p.longitude << ')';
}
/// nice stream formatting with the standard << operator
friend std::ostream& operator<<(std::ostream& stream, const Point& p) {
return stream << "(lat: " << p.latitude << ", lng: " << p.longitude << ')';
}
};
}

View File

@ -1,243 +1,216 @@
#pragma once
#include <atomic>
#include <unistd.h>
#include <atomic>
#include "threading/sync/lockable.hpp"
#include "memory/hp.hpp"
#include "threading/sync/lockable.hpp"
namespace lockfree
{
namespace lockfree {
template <class T, size_t sleep_time = 250>
class List : Lockable<SpinLock>
{
public:
List() = default;
class List : Lockable<SpinLock> {
public:
List() = default;
List(List&) = delete;
List(List&&) = delete;
List(List&) = delete;
List(List&&) = delete;
void operator=(List&) = delete;
void operator=(List&) = delete;
class read_iterator
{
public:
// constructor
read_iterator(T* curr) :
curr(curr),
hazard_ref(std::move(memory::HP::get().insert(curr))) {}
class read_iterator {
public:
// constructor
read_iterator(T* curr)
: curr(curr), hazard_ref(std::move(memory::HP::get().insert(curr))) {}
// no copy constructor
read_iterator(read_iterator& other) = delete;
// no copy constructor
read_iterator(read_iterator& other) = delete;
// move constructor
read_iterator(read_iterator&& other) :
curr(other.curr),
hazard_ref(std::move(other.hazard_ref)) {}
T& operator*() { return *curr; }
T* operator->() { return curr; }
// move constructor
read_iterator(read_iterator&& other)
: curr(other.curr), hazard_ref(std::move(other.hazard_ref)) {}
operator T*() { return curr; }
T& operator*() { return *curr; }
T* operator->() { return curr; }
read_iterator& operator++()
{
auto& hp = memory::HP::get();
hazard_ref = std::move(hp.insert(curr->next.load()));
operator T*() { return curr; }
curr = curr->next.load();
return *this;
}
read_iterator& operator++(int)
{
return operator++();
}
read_iterator& operator++() {
auto& hp = memory::HP::get();
hazard_ref = std::move(hp.insert(curr->next.load()));
bool has_next()
{
if (curr->next == nullptr)
return false;
return true;
}
private:
T* curr;
memory::HP::reference hazard_ref;
};
class read_write_iterator
{
friend class List<T, sleep_time>;
public:
read_write_iterator(T* prev, T* curr) :
prev(prev),
curr(curr),
hazard_ref(std::move(memory::HP::get().insert(curr))) {}
// no copy constructor
read_write_iterator(read_write_iterator& other) = delete;
// move constructor
read_write_iterator(read_write_iterator&& other) :
prev(other.prev),
curr(other.curr),
hazard_ref(std::move(other.hazard_ref)) {}
T& operator*() { return *curr; }
T* operator->() { return curr; }
operator T*() { return curr; }
read_write_iterator& operator++()
{
auto& hp = memory::HP::get();
hazard_ref = std::move(hp.insert(curr->next.load()));
prev = curr;
curr = curr->next.load();
return *this;
}
read_write_iterator& operator++(int)
{
return operator++();
}
private:
T* prev;
T* curr;
memory::HP::reference hazard_ref;
};
read_iterator begin()
{
return read_iterator(head.load());
curr = curr->next.load();
return *this;
}
read_write_iterator rw_begin()
{
return read_write_iterator(nullptr, head.load());
read_iterator& operator++(int) { return operator++(); }
bool has_next() {
if (curr->next == nullptr) return false;
return true;
}
void push_front(T* node)
{
// we want to push an item to front of a list like this
// HEAD --> [1] --> [2] --> [3] --> ...
// read the value of head atomically and set the node's next pointer
// to point to the same location as head
private:
T* curr;
memory::HP::reference hazard_ref;
};
// HEAD --------> [1] --> [2] --> [3] --> ...
// |
// |
// NODE ------+
class read_write_iterator {
friend class List<T, sleep_time>;
T* h = node->next = head.load();
public:
read_write_iterator(T* prev, T* curr)
: prev(prev),
curr(curr),
hazard_ref(std::move(memory::HP::get().insert(curr))) {}
// atomically do: if the value of node->next is equal to current value
// of head, make the head to point to the node.
// if this fails (another thread has just made progress), update the
// value of node->next to the current value of head and retry again
// until you succeed
// no copy constructor
read_write_iterator(read_write_iterator& other) = delete;
// HEAD ----|CAS|----------> [1] --> [2] --> [3] --> ...
// | | |
// | v |
// +-------|CAS|---> NODE ---+
// move constructor
read_write_iterator(read_write_iterator&& other)
: prev(other.prev),
curr(other.curr),
hazard_ref(std::move(other.hazard_ref)) {}
while(!head.compare_exchange_weak(h, node))
{
node->next.store(h);
usleep(sleep_time);
}
T& operator*() { return *curr; }
T* operator->() { return curr; }
// the final state of the list after compare-and-swap looks like this
operator T*() { return curr; }
// HEAD [1] --> [2] --> [3] --> ...
// | |
// | |
// +---> NODE ---+
read_write_iterator& operator++() {
auto& hp = memory::HP::get();
hazard_ref = std::move(hp.insert(curr->next.load()));
prev = curr;
curr = curr->next.load();
return *this;
}
bool remove(read_write_iterator& it)
{
// acquire an exclusive guard.
// we only care about push_front and iterator performance so we can
// we only care about push_front and iterator performance so we can
// tradeoff some remove speed for better reads and inserts. remove is
// used exclusively by the GC thread(s) so it can be slower
auto guard = acquire_unique();
read_write_iterator& operator++(int) { return operator++(); }
// even though concurrent removes are synchronized, we need to worry
// about concurrent reads (solved by using atomics) and concurrent
// inserts to head (VERY dangerous, suffers from ABA problem, solved
// by simply not deleting the head node until it gets pushed further
// down the list)
private:
T* prev;
T* curr;
memory::HP::reference hazard_ref;
};
// check if we're deleting the head node. we can't do that because of
// the ABA problem so just return false for now. the logic behind this
// is that this node will move further down the list next time the
// garbage collector traverses this list and therefore it will become
// deletable
if(it.prev == nullptr) {
std::cout << "prev null" << std::endl;
return false;
}
read_iterator begin() { return read_iterator(head.load()); }
// HEAD --> ... --> [i] --> [i + 1] --> [i + 2] --> ...
//
// prev curr next
read_write_iterator rw_begin() {
return read_write_iterator(nullptr, head.load());
}
auto prev = it.prev;
auto curr = it.curr;
auto next = curr->next.load(std::memory_order_acquire);
void push_front(T* node) {
// we want to push an item to front of a list like this
// HEAD --> [1] --> [2] --> [3] --> ...
// effectively remove the curr node from the list
// read the value of head atomically and set the node's next pointer
// to point to the same location as head
// +---------------------+
// | |
// | v
// HEAD --> ... --> [i] [i + 1] --> [i + 2] --> ...
//
// prev curr next
// HEAD --------> [1] --> [2] --> [3] --> ...
// |
// |
// NODE ------+
prev->next.store(next, std::memory_order_release);
T* h = node->next = head.load();
// curr is now removed from the list so no iterators will be able
// to reach it at this point, but we still need to check the hazard
// pointers and wait until everyone who currently holds a reference to
// it has stopped using it before we can physically delete it
// TODO: test more appropriate
auto& hp = memory::HP::get();
// atomically do: if the value of node->next is equal to current value
// of head, make the head to point to the node.
// if this fails (another thread has just made progress), update the
// value of node->next to the current value of head and retry again
// until you succeed
while(hp.find(reinterpret_cast<uintptr_t>(curr)))
sleep(sleep_time);
delete curr;
return true;
// HEAD ----|CAS|----------> [1] --> [2] --> [3] --> ...
// | | |
// | v |
// +-------|CAS|---> NODE ---+
while (!head.compare_exchange_weak(h, node)) {
node->next.store(h);
usleep(sleep_time);
}
private:
std::atomic<T*> head { nullptr };
// the final state of the list after compare-and-swap looks like this
// HEAD [1] --> [2] --> [3] --> ...
// | |
// | |
// +---> NODE ---+
}
bool remove(read_write_iterator& it) {
// acquire an exclusive guard.
// we only care about push_front and iterator performance so we can
// we only care about push_front and iterator performance so we can
// tradeoff some remove speed for better reads and inserts. remove is
// used exclusively by the GC thread(s) so it can be slower
auto guard = acquire_unique();
// even though concurrent removes are synchronized, we need to worry
// about concurrent reads (solved by using atomics) and concurrent
// inserts to head (VERY dangerous, suffers from ABA problem, solved
// by simply not deleting the head node until it gets pushed further
// down the list)
// check if we're deleting the head node. we can't do that because of
// the ABA problem so just return false for now. the logic behind this
// is that this node will move further down the list next time the
// garbage collector traverses this list and therefore it will become
// deletable
if (it.prev == nullptr) {
std::cout << "prev null" << std::endl;
return false;
}
// HEAD --> ... --> [i] --> [i + 1] --> [i + 2] --> ...
//
// prev curr next
auto prev = it.prev;
auto curr = it.curr;
auto next = curr->next.load(std::memory_order_acquire);
// effectively remove the curr node from the list
// +---------------------+
// | |
// | v
// HEAD --> ... --> [i] [i + 1] --> [i + 2] --> ...
//
// prev curr next
prev->next.store(next, std::memory_order_release);
// curr is now removed from the list so no iterators will be able
// to reach it at this point, but we still need to check the hazard
// pointers and wait until everyone who currently holds a reference to
// it has stopped using it before we can physically delete it
// TODO: test more appropriate
auto& hp = memory::HP::get();
while (hp.find(reinterpret_cast<uintptr_t>(curr))) sleep(sleep_time);
delete curr;
return true;
}
private:
std::atomic<T*> head{nullptr};
};
template <class T, size_t sleep_time>
bool operator==(typename List<T, sleep_time>::read_iterator& a,
typename List<T, sleep_time>::read_iterator& b)
{
return a->curr == b->curr;
typename List<T, sleep_time>::read_iterator& b) {
return a->curr == b->curr;
}
template <class T, size_t sleep_time>
bool operator!=(typename List<T, sleep_time>::read_iterator& a,
typename List<T, sleep_time>::read_iterator& b)
{
return !operator==(a, b);
typename List<T, sleep_time>::read_iterator& b) {
return !operator==(a, b);
}
}

View File

@ -11,315 +11,275 @@
// D must have method K& get_key()
// K must be comparable with ==.
template <class K, class D, size_t init_size_pow2 = 2>
class RhBase
{
protected:
class Combined
{
class RhBase {
protected:
class Combined {
public:
Combined() : data(0) {}
public:
Combined() : data(0) {}
Combined(D *data, size_t off) { this->data = ((size_t)data) | off; }
Combined(D *data, size_t off) { this->data = ((size_t)data) | off; }
bool valid() const { return data != 0; }
bool valid() const { return data != 0; }
size_t off() const { return data & 0x7; }
size_t off() const { return data & 0x7; }
void decrement_off_unsafe() { data--; }
void decrement_off_unsafe() { data--; }
bool decrement_off()
{
if (off() > 0) {
data--;
return true;
}
return false;
}
bool increment_off()
{
if (off() < 7) {
data++;
return true;
}
return false;
}
D *ptr() const { return (D *)(data & (~(0x7))); }
bool equal(const K &key, size_t off)
{
return this->off() == off && key == ptr()->get_key();
}
friend bool operator==(const Combined &a, const Combined &b)
{
return a.off() == b.off() &&
a.ptr()->get_key() == b.ptr()->get_key();
}
friend bool operator!=(const Combined &a, const Combined &b)
{
return !(a == b);
}
private:
size_t data;
};
// Base for all iterators. It can start from any point in map.
template <class It>
class IteratorBase : public Crtp<It>
{
protected:
IteratorBase() : map(nullptr) { advanced = index = ~((size_t)0); }
IteratorBase(const RhBase *map)
{
index = 0;
while (index < map->capacity && !map->array[index].valid()) {
index++;
}
if (index >= map->capacity) {
this->map = nullptr;
advanced = index = ~((size_t)0);
} else {
this->map = map;
advanced = index;
}
}
IteratorBase(const RhBase *map, size_t start)
: map(map), index(start), advanced(0)
{
}
const RhBase *map;
// How many times did whe advance.
size_t advanced;
// Current position in array
size_t index;
public:
IteratorBase(const IteratorBase &) = default;
IteratorBase(IteratorBase &&) = default;
D *operator*()
{
assert(index < map->capacity && map->array[index].valid());
return map->array[index].ptr();
}
D *operator->()
{
assert(index < map->capacity && map->array[index].valid());
return map->array[index].ptr();
}
It &operator++()
{
assert(index < map->capacity && map->array[index].valid());
auto mask = map->mask();
do {
advanced++;
if (advanced >= map->capacity) {
// Whe have advanced more than the capacity of map is so whe
// are done.
map = nullptr;
advanced = index = ~((size_t)0);
break;
}
index = (index + 1) & mask;
} while (!map->array[index].valid()); // Check if there is element
// at current position.
return this->derived();
}
It &operator++(int) { return operator++(); }
friend bool operator==(const It &a, const It &b)
{
return a.index == b.index && a.map == b.map;
}
friend bool operator!=(const It &a, const It &b) { return !(a == b); }
};
public:
class ConstIterator : public IteratorBase<ConstIterator>
{
friend class RhBase;
protected:
ConstIterator(const RhBase *map) : IteratorBase<ConstIterator>(map) {}
ConstIterator(const RhBase *map, size_t index)
: IteratorBase<ConstIterator>(map, index)
{
}
public:
ConstIterator() = default;
ConstIterator(const ConstIterator &) = default;
const D *operator->()
{
return IteratorBase<ConstIterator>::operator->();
}
const D *operator*()
{
return IteratorBase<ConstIterator>::operator*();
}
};
class Iterator : public IteratorBase<Iterator>
{
friend class RhBase;
protected:
Iterator(const RhBase *map) : IteratorBase<Iterator>(map) {}
Iterator(const RhBase *map, size_t index)
: IteratorBase<Iterator>(map, index)
{
}
public:
Iterator() = default;
Iterator(const Iterator &) = default;
};
RhBase() {}
RhBase(const RhBase &other) { copy_from(other); }
RhBase(RhBase &&other) { take_from(std::move(other)); }
~RhBase() { this->clear(); }
RhBase &operator=(const RhBase &other)
{
clear();
copy_from(other);
return *this;
}
RhBase &operator=(RhBase &&other)
{
clear();
take_from(std::move(other));
return *this;
}
Iterator begin() { return Iterator(this); }
ConstIterator begin() const { return ConstIterator(this); }
ConstIterator cbegin() const { return ConstIterator(this); }
Iterator end() { return Iterator(); }
ConstIterator end() const { return ConstIterator(); }
ConstIterator cend() const { return ConstIterator(); }
protected:
// Copys RAW BYTE data from other RhBase.
void copy_from(const RhBase &other)
{
capacity = other.capacity;
count = other.count;
if (capacity > 0) {
size_t bytes = sizeof(Combined) * capacity;
array = (Combined *)malloc(bytes);
memcpy(array, other.array, bytes);
} else {
array = nullptr;
}
}
// Takes data from other RhBase.
void take_from(RhBase &&other)
{
capacity = other.capacity;
count = other.count;
array = other.array;
other.array = nullptr;
other.count = 0;
other.capacity = 0;
}
// Initiazes array with given capacity.
void init_array(size_t capacity)
{
size_t bytes = sizeof(Combined) * capacity;
array = (Combined *)malloc(bytes);
std::memset(array, 0, bytes);
this->capacity = capacity;
}
// True if before array has some values.
// Before array must be released in the caller.
bool increase_size()
{
if (capacity == 0) {
// assert(array == nullptr && count == 0);
size_t new_size = 1 << init_size_pow2;
init_array(new_size);
return false;
}
size_t new_size = capacity * 2;
init_array(new_size);
count = 0;
bool decrement_off() {
if (off() > 0) {
data--;
return true;
}
return false;
}
Iterator create_it(size_t index) { return Iterator(this, index); }
ConstIterator create_it(size_t index) const
{
return ConstIterator(this, index);
bool increment_off() {
if (off() < 7) {
data++;
return true;
}
return false;
}
public:
// Cleares all data.
void clear()
{
free(array);
array = nullptr;
capacity = 0;
count = 0;
D *ptr() const { return (D *)(data & (~(0x7))); }
bool equal(const K &key, size_t off) {
return this->off() == off && key == ptr()->get_key();
}
size_t size() const { return count; }
protected:
size_t before_index(size_t now, size_t mask)
{
return (now - 1) & mask; // THIS IS VALID
friend bool operator==(const Combined &a, const Combined &b) {
return a.off() == b.off() && a.ptr()->get_key() == b.ptr()->get_key();
}
size_t index(const K &key, size_t mask) const
{
return hash(std::hash<K>()(key)) & mask;
friend bool operator!=(const Combined &a, const Combined &b) {
return !(a == b);
}
// NOTE: This is rather expensive but offers good distribution.
size_t hash(size_t x) const
{
x = (x ^ (x >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
x = (x ^ (x >> 27)) * UINT64_C(0x94d049bb133111eb);
x = x ^ (x >> 31);
return x;
private:
size_t data;
};
// Base for all iterators. It can start from any point in map.
template <class It>
class IteratorBase : public Crtp<It> {
protected:
IteratorBase() : map(nullptr) { advanced = index = ~((size_t)0); }
IteratorBase(const RhBase *map) {
index = 0;
while (index < map->capacity && !map->array[index].valid()) {
index++;
}
if (index >= map->capacity) {
this->map = nullptr;
advanced = index = ~((size_t)0);
} else {
this->map = map;
advanced = index;
}
}
IteratorBase(const RhBase *map, size_t start)
: map(map), index(start), advanced(0) {}
const RhBase *map;
// How many times did whe advance.
size_t advanced;
// Current position in array
size_t index;
public:
IteratorBase(const IteratorBase &) = default;
IteratorBase(IteratorBase &&) = default;
D *operator*() {
assert(index < map->capacity && map->array[index].valid());
return map->array[index].ptr();
}
size_t mask() const { return capacity - 1; }
D *operator->() {
assert(index < map->capacity && map->array[index].valid());
return map->array[index].ptr();
}
Combined *array = nullptr;
size_t capacity = 0;
size_t count = 0;
It &operator++() {
assert(index < map->capacity && map->array[index].valid());
auto mask = map->mask();
do {
advanced++;
if (advanced >= map->capacity) {
// Whe have advanced more than the capacity of map is so whe
// are done.
map = nullptr;
advanced = index = ~((size_t)0);
break;
}
index = (index + 1) & mask;
} while (!map->array[index].valid()); // Check if there is element
// at current position.
friend class IteratorBase<Iterator>;
friend class IteratorBase<ConstIterator>;
return this->derived();
}
It &operator++(int) { return operator++(); }
friend bool operator==(const It &a, const It &b) {
return a.index == b.index && a.map == b.map;
}
friend bool operator!=(const It &a, const It &b) { return !(a == b); }
};
public:
class ConstIterator : public IteratorBase<ConstIterator> {
friend class RhBase;
protected:
ConstIterator(const RhBase *map) : IteratorBase<ConstIterator>(map) {}
ConstIterator(const RhBase *map, size_t index)
: IteratorBase<ConstIterator>(map, index) {}
public:
ConstIterator() = default;
ConstIterator(const ConstIterator &) = default;
const D *operator->() { return IteratorBase<ConstIterator>::operator->(); }
const D *operator*() { return IteratorBase<ConstIterator>::operator*(); }
};
class Iterator : public IteratorBase<Iterator> {
friend class RhBase;
protected:
Iterator(const RhBase *map) : IteratorBase<Iterator>(map) {}
Iterator(const RhBase *map, size_t index)
: IteratorBase<Iterator>(map, index) {}
public:
Iterator() = default;
Iterator(const Iterator &) = default;
};
RhBase() {}
RhBase(const RhBase &other) { copy_from(other); }
RhBase(RhBase &&other) { take_from(std::move(other)); }
~RhBase() { this->clear(); }
RhBase &operator=(const RhBase &other) {
clear();
copy_from(other);
return *this;
}
RhBase &operator=(RhBase &&other) {
clear();
take_from(std::move(other));
return *this;
}
Iterator begin() { return Iterator(this); }
ConstIterator begin() const { return ConstIterator(this); }
ConstIterator cbegin() const { return ConstIterator(this); }
Iterator end() { return Iterator(); }
ConstIterator end() const { return ConstIterator(); }
ConstIterator cend() const { return ConstIterator(); }
protected:
// Copys RAW BYTE data from other RhBase.
void copy_from(const RhBase &other) {
capacity = other.capacity;
count = other.count;
if (capacity > 0) {
size_t bytes = sizeof(Combined) * capacity;
array = (Combined *)malloc(bytes);
memcpy(array, other.array, bytes);
} else {
array = nullptr;
}
}
// Takes data from other RhBase.
void take_from(RhBase &&other) {
capacity = other.capacity;
count = other.count;
array = other.array;
other.array = nullptr;
other.count = 0;
other.capacity = 0;
}
// Initiazes array with given capacity.
void init_array(size_t capacity) {
size_t bytes = sizeof(Combined) * capacity;
array = (Combined *)malloc(bytes);
std::memset(array, 0, bytes);
this->capacity = capacity;
}
// True if before array has some values.
// Before array must be released in the caller.
bool increase_size() {
if (capacity == 0) {
// assert(array == nullptr && count == 0);
size_t new_size = 1 << init_size_pow2;
init_array(new_size);
return false;
}
size_t new_size = capacity * 2;
init_array(new_size);
count = 0;
return true;
}
Iterator create_it(size_t index) { return Iterator(this, index); }
ConstIterator create_it(size_t index) const {
return ConstIterator(this, index);
}
public:
// Cleares all data.
void clear() {
free(array);
array = nullptr;
capacity = 0;
count = 0;
}
size_t size() const { return count; }
protected:
size_t before_index(size_t now, size_t mask) {
return (now - 1) & mask; // THIS IS VALID
}
size_t index(const K &key, size_t mask) const {
return hash(std::hash<K>()(key)) & mask;
}
// NOTE: This is rather expensive but offers good distribution.
size_t hash(size_t x) const {
x = (x ^ (x >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
x = (x ^ (x >> 27)) * UINT64_C(0x94d049bb133111eb);
x = x ^ (x >> 31);
return x;
}
size_t mask() const { return capacity - 1; }
Combined *array = nullptr;
size_t capacity = 0;
size_t count = 0;
friend class IteratorBase<Iterator>;
friend class IteratorBase<ConstIterator>;
};

View File

@ -12,166 +12,160 @@
// K must be comparable with ==.
// HashMap behaves as if it isn't owner of entrys.
template <class K, class D, size_t init_size_pow2 = 2>
class RhHashMap : public RhBase<K, D, init_size_pow2>
{
typedef RhBase<K, D, init_size_pow2> base;
using base::array;
using base::index;
using base::capacity;
using base::count;
using typename base::Combined;
class RhHashMap : public RhBase<K, D, init_size_pow2> {
typedef RhBase<K, D, init_size_pow2> base;
using base::array;
using base::index;
using base::capacity;
using base::count;
using typename base::Combined;
void increase_size()
{
size_t old_size = capacity;
auto a = array;
if (base::increase_size()) {
for (int i = 0; i < old_size; i++) {
if (a[i].valid()) {
insert(a[i].ptr());
}
}
void increase_size() {
size_t old_size = capacity;
auto a = array;
if (base::increase_size()) {
for (int i = 0; i < old_size; i++) {
if (a[i].valid()) {
insert(a[i].ptr());
}
free(a);
}
}
public:
using base::RhBase;
free(a);
}
bool contains(const K &key) { return find(key).is_present(); }
public:
using base::RhBase;
OptionPtr<D> find(const K key)
{
size_t mask = this->mask();
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
bool contains(const K &key) { return find(key).is_present(); }
while (off < border) {
Combined other = array[now];
if (other.valid()) {
auto other_off = other.off();
if (other_off == off && key == other.ptr()->get_key()) {
// Found data.
return OptionPtr<D>(other.ptr());
OptionPtr<D> find(const K key) {
size_t mask = this->mask();
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
} else if (other_off < off) { // Other is rich
break;
} // Else other has equal or greater offset, so he is poor.
} else {
// Empty slot means that there is no searched data.
break;
while (off < border) {
Combined other = array[now];
if (other.valid()) {
auto other_off = other.off();
if (other_off == off && key == other.ptr()->get_key()) {
// Found data.
return OptionPtr<D>(other.ptr());
} else if (other_off < off) { // Other is rich
break;
} // Else other has equal or greater offset, so he is poor.
} else {
// Empty slot means that there is no searched data.
break;
}
off++;
now = (now + 1) & mask;
}
return OptionPtr<D>();
}
// Inserts element. Returns true if element wasn't in the map.
bool insert(D *data) {
if (count < capacity) {
size_t mask = this->mask();
auto key = std::ref(data->get_key());
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
while (off < border) {
Combined other = array[now];
if (other.valid()) {
auto other_off = other.off();
if (other_off == off && key == other.ptr()->get_key()) {
// Element already exists.
return false;
} else if (other_off < off) { // Other is rich
// Set data.
array[now] = Combined(data, off);
// Move other data to the higher indexes,
while (other.increment_off()) {
now = (now + 1) & mask;
auto tmp = array[now];
array[now] = other;
other = tmp;
if (!other.valid()) {
count++;
return true;
}
}
data = other.ptr();
break; // Cant insert removed element because it would
// be to far from his real place.
} // Else other has equal or greater offset, so he is poor.
} else {
// Data can be placed in this empty slot.
array[now] = Combined(data, off);
count++;
return true;
}
off++;
off++;
now = (now + 1) & mask;
}
}
// There isn't enough space for element pointed by data so whe must
// increase array.
increase_size();
return insert(data);
}
// Removes element. Returns removed element if it existed.
OptionPtr<D> remove(const K &key) {
size_t mask = this->mask();
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
while (off < border) {
Combined other = array[now];
if (other.valid()) {
auto other_off = other.off();
auto other_ptr = other.ptr();
if (other_off == off && key == other_ptr->get_key()) { // Found it
auto before = now;
// Whe must move other elements one slot lower.
do {
// This is alright even for off=0 on found element
// because it wont be seen.
other.decrement_off_unsafe();
array[before] = other;
before = now;
now = (now + 1) & mask;
}
return OptionPtr<D>();
}
// Inserts element. Returns true if element wasn't in the map.
bool insert(D *data)
{
if (count < capacity) {
size_t mask = this->mask();
auto key = std::ref(data->get_key());
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
while (off < border) {
Combined other = array[now];
if (other.valid()) {
auto other_off = other.off();
if (other_off == off && key == other.ptr()->get_key()) {
// Element already exists.
return false;
} else if (other_off < off) { // Other is rich
// Set data.
array[now] = Combined(data, off);
// Move other data to the higher indexes,
while (other.increment_off()) {
now = (now + 1) & mask;
auto tmp = array[now];
array[now] = other;
other = tmp;
if (!other.valid()) {
count++;
return true;
}
}
data = other.ptr();
break; // Cant insert removed element because it would
// be to far from his real place.
} // Else other has equal or greater offset, so he is poor.
} else {
// Data can be placed in this empty slot.
array[now] = Combined(data, off);
count++;
return true;
}
off++;
now = (now + 1) & mask;
}
}
// There isn't enough space for element pointed by data so whe must
// increase array.
increase_size();
return insert(data);
}
// Removes element. Returns removed element if it existed.
OptionPtr<D> remove(const K &key)
{
size_t mask = this->mask();
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
while (off < border) {
Combined other = array[now];
if (other.valid()) {
auto other_off = other.off();
auto other_ptr = other.ptr();
if (other_off == off &&
key == other_ptr->get_key()) { // Found it
auto before = now;
// Whe must move other elements one slot lower.
do {
// This is alright even for off=0 on found element
// because it wont be seen.
other.decrement_off_unsafe();
array[before] = other;
before = now;
now = (now + 1) & mask;
other = array[now];
} while (other.valid() &&
other.off() > 0); // Exit if whe encounter empty
// slot or data which is exactly
// in slot which it want's to be.
array[before] = Combined();
count--;
return OptionPtr<D>(other_ptr);
} else if (other_off < off) { // Other is rich
break;
} // Else other has equal or greater offset, so he is poor.
} else {
// If the element to be removed existed in map it would be here.
break;
}
off++;
now = (now + 1) & mask;
}
return OptionPtr<D>();
other = array[now];
} while (other.valid() &&
other.off() > 0); // Exit if whe encounter empty
// slot or data which is exactly
// in slot which it want's to be.
array[before] = Combined();
count--;
return OptionPtr<D>(other_ptr);
} else if (other_off < off) { // Other is rich
break;
} // Else other has equal or greater offset, so he is poor.
} else {
// If the element to be removed existed in map it would be here.
break;
}
off++;
now = (now + 1) & mask;
}
return OptionPtr<D>();
}
};

View File

@ -37,322 +37,305 @@
// |...|c:a|c|...|c|b|...|b||a|...|a|...| => off(a) = 2
// ...
template <class K, class D, size_t init_size_pow2 = 2>
class RhHashMultiMap : public RhBase<K, D, init_size_pow2>
{
typedef RhBase<K, D, init_size_pow2> base;
using base::array;
using base::index;
using base::capacity;
using base::count;
using typename base::Combined;
using base::before_index;
using base::create_it;
class RhHashMultiMap : public RhBase<K, D, init_size_pow2> {
typedef RhBase<K, D, init_size_pow2> base;
using base::array;
using base::index;
using base::capacity;
using base::count;
using typename base::Combined;
using base::before_index;
using base::create_it;
void increase_size()
{
size_t old_size = capacity;
auto a = array;
if (base::increase_size()) {
for (int i = 0; i < old_size; i++) {
if (a[i].valid()) {
add(a[i].ptr());
}
}
void increase_size() {
size_t old_size = capacity;
auto a = array;
if (base::increase_size()) {
for (int i = 0; i < old_size; i++) {
if (a[i].valid()) {
add(a[i].ptr());
}
free(a);
}
}
public:
using base::RhBase;
using base::end;
using typename base::ConstIterator;
using typename base::Iterator;
free(a);
}
bool contains(const K &key) const { return find_index(key).is_present(); }
public:
using base::RhBase;
using base::end;
using typename base::ConstIterator;
using typename base::Iterator;
Iterator find(const K &key_in)
{
auto index = find_index(key_in);
if (index) {
return create_it(index.get());
bool contains(const K &key) const { return find_index(key).is_present(); }
Iterator find(const K &key_in) {
auto index = find_index(key_in);
if (index) {
return create_it(index.get());
} else {
return end();
}
}
ConstIterator find(const K &key_in) const {
auto index = find_index(key_in);
if (index) {
return create_it(index.get());
} else {
return end();
}
}
private:
Option<size_t> find_index(const K &key_in) const {
if (count > 0) {
auto key = std::ref(key_in);
size_t mask = this->mask();
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
Combined other = array[now];
while (other.valid() && off < border) {
auto other_off = other.off();
if (other_off == off && key == other.ptr()->get_key()) {
return Option<size_t>(now);
} else if (other_off < off) { // Other is rich
break;
} else { // Else other has equal or greater off, so he is poor.
if (UNLIKELY(skip(now, other, other_off, mask))) {
break;
}
off++;
}
}
}
return Option<size_t>();
}
public:
// Inserts element.
void add(D *data) { add(data->get_key(), data); }
// Inserts element with the given key.
void add(const K &key_in, D *data) {
assert(key_in == data->get_key());
if (count < capacity) {
auto key = std::ref(key_in);
size_t mask = this->mask();
size_t now = index(key, mask);
size_t start = now;
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
Combined other = array[now];
while (off < border) {
if (other.valid()) {
const size_t other_off = other.off();
bool multi = false;
if (other_off == off && other.ptr()->get_key() == key) {
// Found the same
// Must skip same keyd values to insert new value at the
// end.
do {
now = (now + 1) & mask;
other = array[now];
if (!other.valid()) {
// Found empty slot in which data ca be added.
set(now, data, off);
return;
}
} while (other.equal(key, off));
// There is no empty slot after same keyed values.
multi = true;
} else if (other_off > off ||
other_poor(other, mask, start,
now)) { // Else other has equal or
// greater off, so he is poor.
skip(now, other, other_off, mask); // TRUE IS IMPOSSIBLE
off++;
continue;
}
// Data will be insrted at current slot and all other data
// will be displaced for one slot.
array[now] = Combined(data, off);
auto start_insert = now;
while (is_off_adjusted(other, mask, start_insert, now, multi) ||
other.increment_off()) {
now = (now + 1) & mask;
auto tmp = array[now];
array[now] = other;
other = tmp;
if (!other.valid()) {
// Found empty slot which means i can finish now.
count++;
return;
}
}
data = other.ptr();
break; // Cant insert removed element
} else {
return end();
// Found empty slot for data.
set(now, data, off);
return;
}
}
}
ConstIterator find(const K &key_in) const
{
auto index = find_index(key_in);
if (index) {
return create_it(index.get());
} else {
return end();
}
}
// There is't enough space for data.
increase_size();
add(data);
}
private:
Option<size_t> find_index(const K &key_in) const
{
if (count > 0) {
auto key = std::ref(key_in);
size_t mask = this->mask();
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
Combined other = array[now];
while (other.valid() && off < border) {
auto other_off = other.off();
if (other_off == off && key == other.ptr()->get_key()) {
return Option<size_t>(now);
// Removes element equal by key and value. Returns true if it existed.
bool remove(D *data) {
if (count > 0) {
auto key = std::ref(data->get_key());
size_t mask = this->mask();
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
Combined other = array[now];
} else if (other_off < off) { // Other is rich
break;
} else { // Else other has equal or greater off, so he is poor.
if (UNLIKELY(skip(now, other, other_off, mask))) {
break;
}
off++;
}
while (other.valid() && off < border) {
const size_t other_off = other.off();
if (other_off == off && key == other.ptr()->get_key()) {
// Found same key data.
auto founded = capacity;
size_t started = now;
bool multi = false;
// Must find slot with searched data.
do {
if (other.ptr() == data) {
// founded it.
founded = now;
}
}
return Option<size_t>();
}
public:
// Inserts element.
void add(D *data) { add(data->get_key(), data); }
// Inserts element with the given key.
void add(const K &key_in, D *data)
{
assert(key_in == data->get_key());
if (count < capacity) {
auto key = std::ref(key_in);
size_t mask = this->mask();
size_t now = index(key, mask);
size_t start = now;
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
Combined other = array[now];
while (off < border) {
if (other.valid()) {
const size_t other_off = other.off();
bool multi = false;
if (other_off == off && other.ptr()->get_key() == key) {
// Found the same
// Must skip same keyd values to insert new value at the
// end.
do {
now = (now + 1) & mask;
other = array[now];
if (!other.valid()) {
// Found empty slot in which data ca be added.
set(now, data, off);
return;
}
} while (other.equal(key, off));
// There is no empty slot after same keyed values.
multi = true;
} else if (other_off > off ||
other_poor(other, mask, start,
now)) { // Else other has equal or
// greater off, so he is poor.
skip(now, other, other_off, mask); // TRUE IS IMPOSSIBLE
off++;
continue;
}
// Data will be insrted at current slot and all other data
// will be displaced for one slot.
array[now] = Combined(data, off);
auto start_insert = now;
while (is_off_adjusted(other, mask, start_insert, now,
multi) ||
other.increment_off()) {
now = (now + 1) & mask;
auto tmp = array[now];
array[now] = other;
other = tmp;
if (!other.valid()) {
// Found empty slot which means i can finish now.
count++;
return;
}
}
data = other.ptr();
break; // Cant insert removed element
} else {
// Found empty slot for data.
set(now, data, off);
return;
}
}
}
// There is't enough space for data.
increase_size();
add(data);
}
// Removes element equal by key and value. Returns true if it existed.
bool remove(D *data)
{
if (count > 0) {
auto key = std::ref(data->get_key());
size_t mask = this->mask();
size_t now = index(key, mask);
size_t off = 0;
size_t border = 8 <= capacity ? 8 : capacity;
Combined other = array[now];
while (other.valid() && off < border) {
const size_t other_off = other.off();
if (other_off == off && key == other.ptr()->get_key()) {
// Found same key data.
auto founded = capacity;
size_t started = now;
bool multi = false;
// Must find slot with searched data.
do {
if (other.ptr() == data) {
// founded it.
founded = now;
}
now = (now + 1) & mask;
other = array[now];
if (!other.valid() || UNLIKELY(started == now)) {
// Reason is possibility of map full of same values.
break;
}
} while (other.equal(key, off) && (multi = true));
if (founded == capacity) {
// Didn't found the data.
return false;
}
// Data will be removed by moving other data by one slot
// before.
auto bef = before_index(now, mask);
array[founded] = array[bef];
auto start_rem = bef;
while (other.valid() &&
(is_off_adjusted_rem(other, mask, start_rem, bef,
now, multi) ||
other.decrement_off())) {
array[bef] = other;
bef = now;
now = (now + 1) & mask;
other = array[now];
}
array[bef] = Combined();
count--;
return true;
} else if (other_off < off) { // Other is rich
break;
} else { // Else other has equal or greater off, so he is poor.
// Must skip values of same keys but different key than
// data.
if (UNLIKELY(skip(now, other, other_off, mask))) {
break;
}
off++;
}
}
}
return false;
}
private:
// Skips same key valus as other. true if whole map is full of same key
// values.
bool skip(size_t &now, Combined &other, size_t other_off, size_t mask) const
{
auto other_key = other.ptr()->get_key();
size_t start = now;
do {
now = (now + 1) & mask;
other = array[now];
if (UNLIKELY(start == now)) { // Reason is possibility of map
// full of same values.
return true;
if (!other.valid() || UNLIKELY(started == now)) {
// Reason is possibility of map full of same values.
break;
}
} while (other.valid() && other.equal(other_key, other_off));
return false;
}
} while (other.equal(key, off) && (multi = true));
void set(size_t now, D *data, size_t off)
{
array[now] = Combined(data, off);
count++;
}
// True if no adjusment is needed, false otherwise.
bool is_off_adjusted(Combined &com, size_t mask, size_t start, size_t now,
bool multi)
{
if (com.off() == 0) { // Must be adjusted
if (founded == capacity) {
// Didn't found the data.
return false;
}
size_t cin = index(com.ptr()->get_key(), mask);
if (outside(start, now, cin)) { // Outside [start,now] interval
return multi;
}
auto a = array[cin];
auto b = array[(cin + 1) & mask];
return a == b;
// Check if different key has eneterd in to
// range of other.
}
}
bool other_poor(Combined other, size_t mask, size_t start, size_t now)
{
// If other index is smaller then he is poorer.
return outside_left_weak(start, now,
index(other.ptr()->get_key(), mask));
}
// Data will be removed by moving other data by one slot
// before.
auto bef = before_index(now, mask);
array[founded] = array[bef];
// True if no adjusment is needed, false otherwise.
bool is_off_adjusted_rem(Combined &com, size_t mask, size_t start,
size_t bef, size_t now, bool multi)
{
if (com.off() == 0) { // Must be adjusted
return false;
}
size_t cin = index(com.ptr()->get_key(), mask);
if (cin == bef) {
return false;
}
if (outside(start, now, cin)) {
return multi;
}
auto a = array[cin];
auto b = array[before_index(cin, mask)];
return b.valid() && a == b;
// Check if different key has eneterd in to
// range of other.
}
auto start_rem = bef;
while (other.valid() && (is_off_adjusted_rem(other, mask, start_rem,
bef, now, multi) ||
other.decrement_off())) {
array[bef] = other;
bef = now;
now = (now + 1) & mask;
other = array[now];
}
// True if p is uutside [start,end] interval
bool outside(size_t start, size_t end, size_t p)
{
return (start <= end && (p < start || p > end)) ||
(end < start && p < start && p > end);
}
array[bef] = Combined();
count--;
return true;
// True if p is outside <start,end] interval
bool outside_left_weak(size_t start, size_t end, size_t p)
{
return (start <= end && (p <= start || p > end)) ||
(end < start && p <= start && p > end);
} else if (other_off < off) { // Other is rich
break;
} else { // Else other has equal or greater off, so he is poor.
// Must skip values of same keys but different key than
// data.
if (UNLIKELY(skip(now, other, other_off, mask))) {
break;
}
off++;
}
}
}
return false;
}
private:
// Skips same key valus as other. true if whole map is full of same key
// values.
bool skip(size_t &now, Combined &other, size_t other_off, size_t mask) const {
auto other_key = other.ptr()->get_key();
size_t start = now;
do {
now = (now + 1) & mask;
other = array[now];
if (UNLIKELY(start == now)) { // Reason is possibility of map
// full of same values.
return true;
}
} while (other.valid() && other.equal(other_key, other_off));
return false;
}
void set(size_t now, D *data, size_t off) {
array[now] = Combined(data, off);
count++;
}
// True if no adjusment is needed, false otherwise.
bool is_off_adjusted(Combined &com, size_t mask, size_t start, size_t now,
bool multi) {
if (com.off() == 0) { // Must be adjusted
return false;
}
size_t cin = index(com.ptr()->get_key(), mask);
if (outside(start, now, cin)) { // Outside [start,now] interval
return multi;
}
auto a = array[cin];
auto b = array[(cin + 1) & mask];
return a == b;
// Check if different key has eneterd in to
// range of other.
}
bool other_poor(Combined other, size_t mask, size_t start, size_t now) {
// If other index is smaller then he is poorer.
return outside_left_weak(start, now, index(other.ptr()->get_key(), mask));
}
// True if no adjusment is needed, false otherwise.
bool is_off_adjusted_rem(Combined &com, size_t mask, size_t start, size_t bef,
size_t now, bool multi) {
if (com.off() == 0) { // Must be adjusted
return false;
}
size_t cin = index(com.ptr()->get_key(), mask);
if (cin == bef) {
return false;
}
if (outside(start, now, cin)) {
return multi;
}
auto a = array[cin];
auto b = array[before_index(cin, mask)];
return b.valid() && a == b;
// Check if different key has eneterd in to
// range of other.
}
// True if p is uutside [start,end] interval
bool outside(size_t start, size_t end, size_t p) {
return (start <= end && (p < start || p > end)) ||
(end < start && p < start && p > end);
}
// True if p is outside <start,end] interval
bool outside_left_weak(size_t start, size_t end, size_t p) {
return (start <= end && (p <= start || p > end)) ||
(end < start && p <= start && p > end);
}
};

View File

@ -5,50 +5,45 @@
#include "utils/numerics/log2.hpp"
template <typename PtrT>
struct PointerPackTraits
{
// here is a place to embed something like platform specific things
// TODO: cover more cases
constexpr static int free_bits = utils::log2(alignof(PtrT));
struct PointerPackTraits {
// here is a place to embed something like platform specific things
// TODO: cover more cases
constexpr static int free_bits = utils::log2(alignof(PtrT));
static auto get_ptr(uintptr_t value) { return (PtrT)(value); }
static auto get_ptr(uintptr_t value) { return (PtrT)(value); }
};
template <typename PtrT, int IntBits, typename IntT = unsigned,
typename PtrTraits = PointerPackTraits<PtrT>>
class PtrInt
{
private:
constexpr static int int_shift = PtrTraits::free_bits - IntBits;
constexpr static uintptr_t ptr_mask =
~(uintptr_t)(((intptr_t)1 << PtrTraits::free_bits) - 1);
constexpr static uintptr_t int_mask =
(uintptr_t)(((intptr_t)1 << IntBits) - 1);
class PtrInt {
private:
constexpr static int int_shift = PtrTraits::free_bits - IntBits;
constexpr static uintptr_t ptr_mask =
~(uintptr_t)(((intptr_t)1 << PtrTraits::free_bits) - 1);
constexpr static uintptr_t int_mask =
(uintptr_t)(((intptr_t)1 << IntBits) - 1);
uintptr_t value {0};
uintptr_t value{0};
public:
PtrInt(PtrT pointer, IntT integer)
{
set_ptr(pointer);
set_int(integer);
}
public:
PtrInt(PtrT pointer, IntT integer) {
set_ptr(pointer);
set_int(integer);
}
auto set_ptr(PtrT pointer)
{
auto integer = static_cast<uintptr_t>(get_int());
auto ptr = reinterpret_cast<uintptr_t>(pointer);
value = (ptr_mask & ptr) | (integer << int_shift);
}
auto set_ptr(PtrT pointer) {
auto integer = static_cast<uintptr_t>(get_int());
auto ptr = reinterpret_cast<uintptr_t>(pointer);
value = (ptr_mask & ptr) | (integer << int_shift);
}
auto set_int(IntT integer)
{
auto ptr = reinterpret_cast<uintptr_t>(get_ptr());
auto int_shifted = static_cast<uintptr_t>(integer << int_shift);
value = (int_mask & int_shifted) | ptr;
}
auto set_int(IntT integer) {
auto ptr = reinterpret_cast<uintptr_t>(get_ptr());
auto int_shifted = static_cast<uintptr_t>(integer << int_shift);
value = (int_mask & int_shifted) | ptr;
}
auto get_ptr() const { return PtrTraits::get_ptr(value & ptr_mask); }
auto get_ptr() const { return PtrTraits::get_ptr(value & ptr_mask); }
auto get_int() const { return (IntT)((value >> int_shift) & int_mask); }
auto get_int() const { return (IntT)((value >> int_shift) & int_mask); }
};

View File

@ -3,85 +3,75 @@
#include <atomic>
#include <memory>
namespace lockfree
{
namespace lockfree {
template <class T, size_t N>
class BoundedSpscQueue
{
public:
static constexpr size_t size = N;
class BoundedSpscQueue {
public:
static constexpr size_t size = N;
BoundedSpscQueue() = default;
BoundedSpscQueue() = default;
BoundedSpscQueue(const BoundedSpscQueue&) = delete;
BoundedSpscQueue(BoundedSpscQueue&&) = delete;
BoundedSpscQueue(const BoundedSpscQueue&) = delete;
BoundedSpscQueue(BoundedSpscQueue&&) = delete;
BoundedSpscQueue& operator=(const BoundedSpscQueue&) = delete;
BoundedSpscQueue& operator=(const BoundedSpscQueue&) = delete;
bool push(const T& item)
{
// load the current tail
// [] [] [1] [2] [3] [4] [5] [$] []
// H T
auto t = tail.load(std::memory_order_relaxed);
bool push(const T& item) {
// load the current tail
// [] [] [1] [2] [3] [4] [5] [$] []
// H T
auto t = tail.load(std::memory_order_relaxed);
// what will next tail be after we push
// [] [] [1] [2] [3] [4] [5] [$] [ ]
// H T T'
auto next = increment(t);
// what will next tail be after we push
// [] [] [1] [2] [3] [4] [5] [$] [ ]
// H T T'
auto next = increment(t);
// check if queue is full and do nothing if it is
// [3] [4] [5] [6] [7] [8] [$] [ 1 ] [2]
// T T'H
if(next == head.load(std::memory_order_acquire))
return false;
// check if queue is full and do nothing if it is
// [3] [4] [5] [6] [7] [8] [$] [ 1 ] [2]
// T T'H
if (next == head.load(std::memory_order_acquire)) return false;
// insert the item into the empty spot
// [] [] [1] [2] [3] [4] [5] [ ] []
// H T T'
items[t] = item;
// insert the item into the empty spot
// [] [] [1] [2] [3] [4] [5] [ ] []
// H T T'
items[t] = item;
// release the tail to the consumer (serialization point)
// [] [] [1] [2] [3] [4] [5] [ $ ] []
// H T T'
tail.store(next, std::memory_order_release);
// release the tail to the consumer (serialization point)
// [] [] [1] [2] [3] [4] [5] [ $ ] []
// H T T'
tail.store(next, std::memory_order_release);
return true;
}
return true;
}
bool pop(T& item)
{
// [] [] [1] [2] [3] [4] [5] [$] []
// H T
auto h = head.load(std::memory_order_relaxed);
bool pop(T& item) {
// [] [] [1] [2] [3] [4] [5] [$] []
// H T
auto h = head.load(std::memory_order_relaxed);
// [] [] [] [] [ $ ] [] [] [] []
// H T
if(h == tail.load(std::memory_order_acquire))
return false;
// [] [] [] [] [ $ ] [] [] [] []
// H T
if (h == tail.load(std::memory_order_acquire)) return false;
// move an item from the queue
item = std::move(items[h]);
// move an item from the queue
item = std::move(items[h]);
// serialization point wrt producer
// [] [] [] [2] [3] [4] [5] [$] []
// H T
head.store(increment(h), std::memory_order_release);
// serialization point wrt producer
// [] [] [] [2] [3] [4] [5] [$] []
// H T
head.store(increment(h), std::memory_order_release);
return true;
}
return true;
}
private:
static constexpr size_t capacity = N + 1;
private:
static constexpr size_t capacity = N + 1;
std::array<T, capacity> items;
std::atomic<size_t> head {0}, tail {0};
std::array<T, capacity> items;
std::atomic<size_t> head{0}, tail{0};
size_t increment(size_t idx) const
{
return (idx + 1) % capacity;
}
size_t increment(size_t idx) const { return (idx + 1) % capacity; }
};
}

View File

@ -3,12 +3,11 @@
#include <atomic>
#include <memory>
namespace lockfree
{
namespace lockfree {
/** @brief Multiple-Producer Single-Consumer Queue
* A wait-free (*) multiple-producer single-consumer queue.
*
*
* features:
* - wait-free
* - fast producers (only one atomic XCHG and and one atomic store with
@ -22,7 +21,8 @@ namespace lockfree
*
* (*) there is a small window of inconsistency from the lock free design
* see the url below for details
* URL: http://www.1024cores.net/home/lock-free-algorithms/queues/intrusive-mpsc-node-based-queue
* URL:
* http://www.1024cores.net/home/lock-free-algorithms/queues/intrusive-mpsc-node-based-queue
*
* mine is not intrusive for better modularity, but with slightly worse
* performance because it needs to do two memory allocations instead of
@ -31,119 +31,111 @@ namespace lockfree
* @tparam T Type of the items to store in the queue
*/
template <class T>
class MpscQueue
{
struct Node
{
Node(Node* next, std::unique_ptr<T>&& item)
: next(next), item(std::forward<std::unique_ptr<T>>(item)) {}
class MpscQueue {
struct Node {
Node(Node* next, std::unique_ptr<T>&& item)
: next(next), item(std::forward<std::unique_ptr<T>>(item)) {}
std::atomic<Node*> next;
std::unique_ptr<T> item;
};
std::atomic<Node*> next;
std::unique_ptr<T> item;
};
public:
MpscQueue()
{
auto stub = new Node(nullptr, nullptr);
head.store(stub);
tail = stub;
public:
MpscQueue() {
auto stub = new Node(nullptr, nullptr);
head.store(stub);
tail = stub;
}
~MpscQueue() {
// purge all elements from the queue
while (pop()) {
}
~MpscQueue()
{
// purge all elements from the queue
while(pop()) {}
// we are left with a stub, delete that
delete tail;
}
// we are left with a stub, delete that
delete tail;
MpscQueue(MpscQueue&) = delete;
MpscQueue(MpscQueue&&) = delete;
/** @brief Pushes an item into the queue.
*
* Pushes an item into the front of the queue.
*
* @param item std::unique_ptr<T> An item to push into the queue
* @return void
*/
void push(std::unique_ptr<T>&& item) {
push(new Node(nullptr, std::forward<std::unique_ptr<T>>(item)));
}
/** @brief Pops a node from the queue.
*
* Pops and returns a node from the back of the queue.
*
* @return std::unique_ptr<T> A pointer to the node popped from the
* queue, nullptr if nothing was popped
*/
std::unique_ptr<T> pop() {
auto tail = this->tail;
// serialization point wrt producers
auto next = tail->next.load(std::memory_order_acquire);
if (next) {
// remove the last stub from the queue
// make [2] the next stub and return it's data
//
// H --> [n] <- ... <- [2] <--+--[STUB] +-- T
// | |
// +-----------+
this->tail = next;
// delete the stub node
// H --> [n] <- ... <- [STUB] <-- T
delete tail;
return std::move(next->item);
}
MpscQueue(MpscQueue&) = delete;
MpscQueue(MpscQueue&&) = delete;
return nullptr;
}
/** @brief Pushes an item into the queue.
*
* Pushes an item into the front of the queue.
*
* @param item std::unique_ptr<T> An item to push into the queue
* @return void
*/
void push(std::unique_ptr<T>&& item)
{
push(new Node(nullptr, std::forward<std::unique_ptr<T>>(item)));
}
private:
std::atomic<Node*> head;
Node* tail;
/** @brief Pops a node from the queue.
*
* Pops and returns a node from the back of the queue.
*
* @return std::unique_ptr<T> A pointer to the node popped from the
* queue, nullptr if nothing was popped
*/
std::unique_ptr<T> pop()
{
auto tail = this->tail;
// serialization point wrt producers
auto next = tail->next.load(std::memory_order_acquire);
/** @brief Pushes a new node into the queue.
*
* Pushes a new node containing the item into the front of the queue.
*
* @param node Node* A pointer to node you want to push into the queue
* @return void
*/
void push(Node* node) {
// initial state
// H --> [3] <- [2] <- [STUB] <-- T
if(next)
{
// remove the last stub from the queue
// make [2] the next stub and return it's data
//
// H --> [n] <- ... <- [2] <--+--[STUB] +-- T
// | |
// +-----------+
this->tail = next;
// serialization point wrt producers, acquire-release
auto old = head.exchange(node, std::memory_order_acq_rel);
// delete the stub node
// H --> [n] <- ... <- [STUB] <-- T
delete tail;
// after exchange
// H --> [4] [3] <- [2] <- [STUB] <-- T
return std::move(next->item);
}
// this is the window of inconsistency, if the producer is blocked
// here, the consumer is also blocked. but this window is extremely
// small, it's followed by a store operation which is a
// serialization point wrt consumer
return nullptr;
}
// old holds a pointer to node [3] and we need to link the [3] to a
// newly created node [4] using release semantics
private:
std::atomic<Node*> head;
Node* tail;
// serialization point wrt consumer, release
old->next.store(node, std::memory_order_release);
/** @brief Pushes a new node into the queue.
*
* Pushes a new node containing the item into the front of the queue.
*
* @param node Node* A pointer to node you want to push into the queue
* @return void
*/
void push(Node* node)
{
// initial state
// H --> [3] <- [2] <- [STUB] <-- T
// serialization point wrt producers, acquire-release
auto old = head.exchange(node, std::memory_order_acq_rel);
// after exchange
// H --> [4] [3] <- [2] <- [STUB] <-- T
// this is the window of inconsistency, if the producer is blocked
// here, the consumer is also blocked. but this window is extremely
// small, it's followed by a store operation which is a
// serialization point wrt consumer
// old holds a pointer to node [3] and we need to link the [3] to a
// newly created node [4] using release semantics
// serialization point wrt consumer, release
old->next.store(node, std::memory_order_release);
// finally, we have a queue like this
// H --> [4] <- [3] <- [2] <- [1] <-- T
}
// finally, we have a queue like this
// H --> [4] <- [3] <- [2] <- [1] <-- T
}
};
}

View File

@ -6,58 +6,48 @@
#include "threading/sync/spinlock.hpp"
template <class T>
class SlQueue : Lockable<SpinLock>
{
public:
class SlQueue : Lockable<SpinLock> {
public:
template <class... Args>
void emplace(Args&&... args) {
auto guard = acquire_unique();
queue.emplace(args...);
}
template <class... Args>
void emplace(Args&&... args)
{
auto guard = acquire_unique();
queue.emplace(args...);
}
void push(const T& item) {
auto guard = acquire_unique();
queue.push(item);
}
void push(const T& item)
{
auto guard = acquire_unique();
queue.push(item);
}
T front() {
auto guard = acquire_unique();
return queue.front();
}
T front()
{
auto guard = acquire_unique();
return queue.front();
}
void pop() {
auto guard = acquire_unique();
queue.pop();
}
void pop()
{
auto guard = acquire_unique();
queue.pop();
}
bool pop(T& item) {
auto guard = acquire_unique();
if (queue.empty()) return false;
bool pop(T& item)
{
auto guard = acquire_unique();
if(queue.empty())
return false;
item = std::move(queue.front());
queue.pop();
return true;
}
item = std::move(queue.front());
queue.pop();
return true;
}
bool empty() {
auto guard = acquire_unique();
return queue.empty();
}
bool empty()
{
auto guard = acquire_unique();
return queue.empty();
}
size_t size() {
auto guard = acquire_unique();
return queue.size();
}
size_t size()
{
auto guard = acquire_unique();
return queue.size();
}
private:
std::queue<T> queue;
private:
std::queue<T> queue;
};

View File

@ -5,10 +5,8 @@
#include "threading/sync/spinlock.hpp"
template <class K, class T>
class SlRbTree : Lockable<SpinLock>
{
public:
private:
std::map<K, T> tree;
class SlRbTree : Lockable<SpinLock> {
public:
private:
std::map<K, T> tree;
};

View File

@ -2,31 +2,27 @@
#include <stack>
#include "threading/sync/spinlock.hpp"
#include "threading/sync/lockable.hpp"
#include "threading/sync/spinlock.hpp"
template <class T>
class SpinLockStack : Lockable<SpinLock>
{
public:
class SpinLockStack : Lockable<SpinLock> {
public:
T pop() {
auto guard = acquire();
T pop()
{
auto guard = acquire();
T elem = stack.top();
stack.pop();
T elem = stack.top();
stack.pop();
return elem;
}
return elem;
}
void push(const T& elem) {
auto guard = acquire();
void push(const T& elem)
{
auto guard = acquire();
stack.push(elem);
}
stack.push(elem);
}
private:
std::stack<T> stack;
private:
std::stack<T> stack;
};

View File

@ -1,8 +1,6 @@
#pragma once
template <class T>
class ArrayStack
{
private:
class ArrayStack {
private:
};

View File

@ -5,61 +5,54 @@
// data structure namespace short ds
// TODO: document strategy related to namespace naming
// (namespace names should be short but eazy to memorize)
namespace ds
{
namespace ds {
// static array is data structure which size (capacity) can be known at compile
// time
// this data structure isn't concurrent
template <typename T, size_t N>
class static_array
{
public:
// default constructor
static_array() {}
class static_array {
public:
// default constructor
static_array() {}
// explicit constructor which populates the data array with
// initial values, array structure after initialization
// is N * [initial_value]
explicit static_array(const T &initial_value)
{
for (size_t i = 0; i < size(); ++i) {
data[i] = initial_value;
}
// explicit constructor which populates the data array with
// initial values, array structure after initialization
// is N * [initial_value]
explicit static_array(const T &initial_value) {
for (size_t i = 0; i < size(); ++i) {
data[i] = initial_value;
}
}
// returns array size
size_t size() const { return N; }
// returns array size
size_t size() const { return N; }
// returns element reference on specific index
T &operator[](size_t index)
{
runtime_assert(index < N, "Index " << index << " must be less than "
<< N);
return data[index];
}
// returns element reference on specific index
T &operator[](size_t index) {
runtime_assert(index < N, "Index " << index << " must be less than " << N);
return data[index];
}
// returns const element reference on specific index
const T &operator[](size_t index) const
{
runtime_assert(index < N, "Index " << index << " must be less than "
<< N);
return data[index];
}
// returns const element reference on specific index
const T &operator[](size_t index) const {
runtime_assert(index < N, "Index " << index << " must be less than " << N);
return data[index];
}
// returns begin iterator
T *begin() { return &data[0]; }
// returns begin iterator
T *begin() { return &data[0]; }
// returns const begin iterator
const T *begin() const { return &data[0]; }
// returns const begin iterator
const T *begin() const { return &data[0]; }
// returns end iterator
T *end() { return &data[N]; }
// returns end iterator
T *end() { return &data[N]; }
// returns const end iterator
const T *end() const { return &data[N]; }
// returns const end iterator
const T *end() const { return &data[N]; }
private:
T data[N];
private:
T data[N];
};
}

View File

@ -1,7 +1,7 @@
#pragma once
#include <vector>
#include <memory>
#include <vector>
template <class uintXX_t = uint32_t>
/**
@ -9,93 +9,79 @@ template <class uintXX_t = uint32_t>
* setting and checking in logarithmic complexity. Memory
* complexity is linear.
*/
class UnionFind
{
public:
/**
* Constructor, creates a UnionFind structure of fixed size.
*
* @param n Number of elements in the data structure.
*/
UnionFind(uintXX_t n) : set_count(n), count(n), parent(n)
{
for(auto i = 0; i < n; ++i)
count[i] = 1, parent[i] = i;
}
class UnionFind {
public:
/**
* Constructor, creates a UnionFind structure of fixed size.
*
* @param n Number of elements in the data structure.
*/
UnionFind(uintXX_t n) : set_count(n), count(n), parent(n) {
for (auto i = 0; i < n; ++i) count[i] = 1, parent[i] = i;
}
/**
* Connects two elements (and thereby the sets they belong
* to). If they are already connected the function has no effect.
*
* Has O(alpha(n)) time complexity.
*
* @param p First element.
* @param q Second element.
*/
void connect(uintXX_t p, uintXX_t q)
{
auto rp = root(p);
auto rq = root(q);
/**
* Connects two elements (and thereby the sets they belong
* to). If they are already connected the function has no effect.
*
* Has O(alpha(n)) time complexity.
*
* @param p First element.
* @param q Second element.
*/
void connect(uintXX_t p, uintXX_t q) {
auto rp = root(p);
auto rq = root(q);
// if roots are equal, we don't have to do anything
if(rp == rq)
return;
// if roots are equal, we don't have to do anything
if (rp == rq) return;
// merge the smaller subtree to the root of the larger subtree
if(count[rp] < count[rq])
parent[rp] = rq, count[rp] += count[rp];
else
parent[rq] = rp, count[rp] += count[rq];
// merge the smaller subtree to the root of the larger subtree
if (count[rp] < count[rq])
parent[rp] = rq, count[rp] += count[rp];
else
parent[rq] = rp, count[rp] += count[rq];
// update the number of groups
set_count--;
}
// update the number of groups
set_count--;
}
/**
* Indicates if two elements are connected. Has O(alpha(n)) time
* complexity.
*
* @param p First element.
* @param q Second element.
* @return See above.
*/
bool find(uintXX_t p, uintXX_t q)
{
return root(p) == root(q);
}
/**
* Indicates if two elements are connected. Has O(alpha(n)) time
* complexity.
*
* @param p First element.
* @param q Second element.
* @return See above.
*/
bool find(uintXX_t p, uintXX_t q) { return root(p) == root(q); }
/**
* Returns the number of disjoint sets in this UnionFind.
*
* @return See above.
*/
uintXX_t size() const
{
return set_count;
}
/**
* Returns the number of disjoint sets in this UnionFind.
*
* @return See above.
*/
uintXX_t size() const { return set_count; }
private:
uintXX_t set_count;
private:
uintXX_t set_count;
// array of subtree counts
std::vector<uintXX_t> count;
// array of subtree counts
std::vector<uintXX_t> count;
// array of tree indices
std::vector<uintXX_t> parent;
uintXX_t root(uintXX_t p)
{
auto r = p;
auto newp = p;
// array of tree indices
std::vector<uintXX_t> parent;
// find the node connected to itself, that's the root
while(parent[r] != r)
r = parent[r];
uintXX_t root(uintXX_t p) {
auto r = p;
auto newp = p;
// do some path compression to enable faster searches
while(p != r)
newp = parent[p], parent[p] = r, p = newp;
// find the node connected to itself, that's the root
while (parent[r] != r) r = parent[r];
return r;
}
// do some path compression to enable faster searches
while (p != r) newp = parent[p], parent[p] = r, p = newp;
return r;
}
};

View File

@ -12,7 +12,6 @@
* be created. Typically due to database overload.
*/
class CreationException : public BasicException {
public:
public:
using BasicException::BasicException;
};

View File

@ -1,10 +1,10 @@
#include "database/graph_db.hpp"
#include <storage/edge.hpp>
#include "database/creation_exception.hpp"
#include "database/graph_db.hpp"
//#include "snapshot/snapshoter.hpp"
GraphDb::GraphDb(const std::string &name, bool import_snapshot) : name_(name) {
// if (import_snapshot)
// snap_engine.import();
// if (import_snapshot)
// snap_engine.import();
}

View File

@ -1,11 +1,11 @@
#pragma once
#include "data_structures/concurrent/skiplist.hpp"
#include "transactions/engine.hpp"
#include "mvcc/version_list.hpp"
#include "utils/pass_key.hpp"
#include "data_structures/concurrent/concurrent_set.hpp"
#include "data_structures/concurrent/skiplist.hpp"
#include "mvcc/version_list.hpp"
#include "storage/unique_object_store.hpp"
#include "transactions/engine.hpp"
#include "utils/pass_key.hpp"
// forward declaring Edge and Vertex because they use
// GraphDb::Label etc., and therefore include this header
@ -26,13 +26,11 @@ class EdgeAccessor;
* exposed to client functions. The GraphDbAccessor is used for that.
*/
class GraphDb {
public:
public:
// definitions for what data types are used for a Label, Property, EdgeType
using Label = std::string*;
using EdgeType = std::string*;
using Property = std::string*;
using Label = std::string *;
using EdgeType = std::string *;
using Property = std::string *;
/**
* Construct database with a custom name.
@ -52,19 +50,19 @@ public:
tx::Engine tx_engine;
/** garbage collector related to this database*/
// TODO bring back garbage collection
// Garbage garbage = {tx_engine};
// TODO bring back garbage collection
// Garbage garbage = {tx_engine};
// TODO bring back shapshot engine
// SnapshotEngine snap_engine = {*this};
// TODO bring back shapshot engine
// SnapshotEngine snap_engine = {*this};
// database name
// TODO consider if this is even necessary
const std::string name_;
// main storage for the graph
SkipList<mvcc::VersionList<Vertex>*> vertices_;
SkipList<mvcc::VersionList<Edge>*> edges_;
SkipList<mvcc::VersionList<Vertex> *> vertices_;
SkipList<mvcc::VersionList<Edge> *> edges_;
// unique object stores
ConcurrentSet<std::string> labels_;

View File

@ -1,37 +1,33 @@
#include "database/creation_exception.hpp"
#include "database/graph_db_accessor.hpp"
#include "database/creation_exception.hpp"
#include "storage/vertex.hpp"
#include "storage/vertex_accessor.hpp"
#include "storage/edge.hpp"
#include "storage/edge_accessor.hpp"
#include "storage/vertex.hpp"
#include "storage/vertex_accessor.hpp"
GraphDbAccessor::GraphDbAccessor(GraphDb& db)
: db_(db), transaction_(std::move(db.tx_engine.begin())) {}
GraphDbAccessor::GraphDbAccessor(GraphDb& db) : db_(db), transaction_(std::move(db.tx_engine.begin())) {}
const std::string& GraphDbAccessor::name() const {
return db_.name_;
}
const std::string& GraphDbAccessor::name() const { return db_.name_; }
VertexAccessor GraphDbAccessor::insert_vertex() {
// create a vertex
auto vertex_vlist = new mvcc::VersionList<Vertex>();
Vertex *vertex = vertex_vlist->insert(transaction_);
Vertex* vertex = vertex_vlist->insert(transaction_);
// insert the newly created record into the main storage
// TODO make the number of tries configurable
for (int i = 0; i < 5; ++i) {
bool success = db_.vertices_.access().insert(vertex_vlist).second;
if (success)
return VertexAccessor(*vertex_vlist, *vertex, *this);
if (success) return VertexAccessor(*vertex_vlist, *vertex, *this);
// TODO sleep for some configurable amount of time
}
throw CreationException("Unable to create a Vertex after 5 attempts");
}
bool GraphDbAccessor::remove_vertex(VertexAccessor &vertex_accessor) {
bool GraphDbAccessor::remove_vertex(VertexAccessor& vertex_accessor) {
// TODO consider if this works well with MVCC
if (vertex_accessor.out_degree() > 0 || vertex_accessor.in_degree() > 0)
return false;
@ -40,15 +36,13 @@ bool GraphDbAccessor::remove_vertex(VertexAccessor &vertex_accessor) {
return true;
}
void GraphDbAccessor::detach_remove_vertex(VertexAccessor &vertex_accessor) {
void GraphDbAccessor::detach_remove_vertex(VertexAccessor& vertex_accessor) {
// removing edges via accessors is both safe
// and it should remove all the pointers in the relevant
// vertices (including this one)
for (auto edge_accessor : vertex_accessor.in())
remove_edge(edge_accessor);
for (auto edge_accessor : vertex_accessor.in()) remove_edge(edge_accessor);
for (auto edge_accessor : vertex_accessor.out())
remove_edge(edge_accessor);
for (auto edge_accessor : vertex_accessor.out()) remove_edge(edge_accessor);
// mvcc removal of the vertex
vertex_accessor.vlist_.remove(&vertex_accessor.update(), transaction_);
@ -60,24 +54,22 @@ std::vector<VertexAccessor> GraphDbAccessor::vertices() {
std::vector<VertexAccessor> accessors;
accessors.reserve(sl_accessor.size());
for (auto vlist : sl_accessor){
for (auto vlist : sl_accessor) {
auto record = vlist->find(transaction_);
if (record == nullptr)
continue;
if (record == nullptr) continue;
accessors.emplace_back(*vlist, *record, *this);
}
}
return accessors;
}
EdgeAccessor GraphDbAccessor::insert_edge(
VertexAccessor& from,
VertexAccessor& to,
GraphDb::EdgeType edge_type) {
EdgeAccessor GraphDbAccessor::insert_edge(VertexAccessor& from,
VertexAccessor& to,
GraphDb::EdgeType edge_type) {
// create an edge
auto edge_vlist = new mvcc::VersionList<Edge>();
Edge* edge = edge_vlist->insert(transaction_, from.vlist_, to.vlist_, edge_type);
Edge* edge =
edge_vlist->insert(transaction_, from.vlist_, to.vlist_, edge_type);
// set the vertex connections to this edge
from.update().out_.emplace_back(edge_vlist);
@ -87,8 +79,7 @@ EdgeAccessor GraphDbAccessor::insert_edge(
// TODO make the number of tries configurable
for (int i = 0; i < 5; ++i) {
bool success = db_.edges_.access().insert(edge_vlist).second;
if (success)
return EdgeAccessor(*edge_vlist, *edge, *this);
if (success) return EdgeAccessor(*edge_vlist, *edge, *this);
// TODO sleep for some amount of time
}
@ -99,7 +90,8 @@ EdgeAccessor GraphDbAccessor::insert_edge(
* Removes the given edge pointer from a vector of pointers.
* Does NOT maintain edge pointer ordering (for efficiency).
*/
void swap_out_edge(std::vector<mvcc::VersionList<Edge>*> &edges, mvcc::VersionList<Edge> *edge) {
void swap_out_edge(std::vector<mvcc::VersionList<Edge>*>& edges,
mvcc::VersionList<Edge>* edge) {
auto found = std::find(edges.begin(), edges.end(), edge);
assert(found != edges.end());
std::swap(*found, edges.back());
@ -118,10 +110,9 @@ std::vector<EdgeAccessor> GraphDbAccessor::edges() {
std::vector<EdgeAccessor> accessors;
accessors.reserve(sl_accessor.size());
for (auto vlist : sl_accessor){
for (auto vlist : sl_accessor) {
auto record = vlist->find(transaction_);
if (record == nullptr)
continue;
if (record == nullptr) continue;
accessors.emplace_back(*vlist, *record, *this);
}
@ -136,11 +127,13 @@ std::string& GraphDbAccessor::label_name(const GraphDb::Label label) const {
return *label;
}
GraphDb::EdgeType GraphDbAccessor::edge_type(const std::string& edge_type_name){
GraphDb::EdgeType GraphDbAccessor::edge_type(
const std::string& edge_type_name) {
return &(*db_.edge_types_.access().insert(edge_type_name).first);
}
std::string& GraphDbAccessor::edge_type_name(const GraphDb::EdgeType edge_type) const {
std::string& GraphDbAccessor::edge_type_name(
const GraphDb::EdgeType edge_type) const {
return *edge_type;
}
@ -148,6 +141,7 @@ GraphDb::Property GraphDbAccessor::property(const std::string& property_name) {
return &(*db_.properties_.access().insert(property_name).first);
}
std::string& GraphDbAccessor::property_name(const GraphDb::Property property) const {
std::string& GraphDbAccessor::property_name(
const GraphDb::Property property) const {
return *property;
}

View File

@ -8,7 +8,6 @@
#include "graph_db.hpp"
#include "transactions/transaction.hpp"
/**
* An accessor for the database object: exposes functions
* for operating on the database. All the functions in
@ -18,9 +17,7 @@
* the creation.
*/
class GraphDbAccessor {
public:
public:
/**
* Creates an accessor for the given database.
*
@ -48,7 +45,7 @@ public:
* @param vertex_accessor Accessor to vertex.
* @return If or not the vertex was deleted.
*/
bool remove_vertex(VertexAccessor &vertex_accessor);
bool remove_vertex(VertexAccessor& vertex_accessor);
/**
* Removes the vertex of the given accessor along with all it's outgoing
@ -56,7 +53,7 @@ public:
*
* @param vertex_accessor Accessor to a vertex.
*/
void detach_remove_vertex(VertexAccessor &vertex_accessor);
void detach_remove_vertex(VertexAccessor& vertex_accessor);
/**
* Returns accessors to all the vertices in the graph.
@ -73,7 +70,8 @@ public:
* @param type Edge type.
* @return An accessor to the edge.
*/
EdgeAccessor insert_edge(VertexAccessor& from, VertexAccessor& to, GraphDb::EdgeType type);
EdgeAccessor insert_edge(VertexAccessor& from, VertexAccessor& to,
GraphDb::EdgeType type);
/**
* Removes an edge from the graph.
@ -134,6 +132,6 @@ public:
/** The current transaction */
tx::Transaction transaction_;
private:
private:
GraphDb& db_;
};

View File

@ -9,65 +9,61 @@
#include "logging/default.hpp"
Cleaning::Cleaning(ConcurrentMap<std::string, GraphDb> &dbs, size_t cleaning_cycle)
: dbms(dbs), cleaning_cycle(cleaning_cycle)
{
// Start the cleaning thread
cleaners.push_back(
std::make_unique<Thread>([&, cleaning_cycle = cleaning_cycle ]() {
Logger logger = logging::log->logger("Cleaner");
logger.info("Started with cleaning cycle of {} sec",
cleaning_cycle);
Cleaning::Cleaning(ConcurrentMap<std::string, GraphDb> &dbs,
size_t cleaning_cycle)
: dbms(dbs), cleaning_cycle(cleaning_cycle) {
// Start the cleaning thread
cleaners.push_back(
std::make_unique<Thread>([&, cleaning_cycle = cleaning_cycle ]() {
Logger logger = logging::log->logger("Cleaner");
logger.info("Started with cleaning cycle of {} sec", cleaning_cycle);
std::time_t last_clean = std::time(nullptr);
while (cleaning.load(std::memory_order_acquire)) {
std::time_t now = std::time(nullptr);
std::time_t last_clean = std::time(nullptr);
while (cleaning.load(std::memory_order_acquire)) {
std::time_t now = std::time(nullptr);
// Maybe it's cleaning time.
if (now >= last_clean + cleaning_cycle) {
logger.info("Started cleaning cyle");
// Maybe it's cleaning time.
if (now >= last_clean + cleaning_cycle) {
logger.info("Started cleaning cyle");
// Clean all databases
for (auto &db : dbs.access()) {
logger.info("Cleaning database \"{}\"", db.first);
DbTransaction t(db.second);
try {
logger.info("Cleaning edges");
t.clean_edge_section();
// Clean all databases
for (auto &db : dbs.access()) {
logger.info("Cleaning database \"{}\"", db.first);
DbTransaction t(db.second);
try {
logger.info("Cleaning edges");
t.clean_edge_section();
logger.info("Cleaning vertices");
t.clean_vertex_section();
logger.info("Cleaning vertices");
t.clean_vertex_section();
logger.info("Cleaning garbage");
db.second.garbage.clean();
logger.info("Cleaning garbage");
db.second.garbage.clean();
} catch (const std::exception &e) {
logger.error(
"Error occured while cleaning database \"{}\"",
db.first);
logger.error("{}", e.what());
}
// NOTE: Whe should commit even if error occured.
t.trans.commit();
}
last_clean = now;
logger.info("Finished cleaning cyle");
} else {
// Cleaning isn't scheduled for now so i should sleep.
std::this_thread::sleep_for(std::chrono::seconds(1));
}
} catch (const std::exception &e) {
logger.error("Error occured while cleaning database \"{}\"",
db.first);
logger.error("{}", e.what());
}
// NOTE: Whe should commit even if error occured.
t.trans.commit();
}
}));
last_clean = now;
logger.info("Finished cleaning cyle");
} else {
// Cleaning isn't scheduled for now so i should sleep.
std::this_thread::sleep_for(std::chrono::seconds(1));
}
}
}));
}
Cleaning::~Cleaning()
{
// Stop cleaning
cleaning.store(false, std::memory_order_release);
for (auto &t : cleaners) {
// Join with cleaners
t.get()->join();
}
Cleaning::~Cleaning() {
// Stop cleaning
cleaning.store(false, std::memory_order_release);
for (auto &t : cleaners) {
// Join with cleaners
t.get()->join();
}
}

View File

@ -5,24 +5,22 @@
class Thread;
class Cleaning
{
class Cleaning {
public:
// How much sec is a cleaning_cycle in which cleaner will clean at most
// once. Starts cleaner thread.
Cleaning(ConcurrentMap<std::string, GraphDb> &dbs, size_t cleaning_cycle);
public:
// How much sec is a cleaning_cycle in which cleaner will clean at most
// once. Starts cleaner thread.
Cleaning(ConcurrentMap<std::string, GraphDb> &dbs, size_t cleaning_cycle);
// Destroys this object after this thread joins cleaning thread.
~Cleaning();
// Destroys this object after this thread joins cleaning thread.
~Cleaning();
private:
ConcurrentMap<std::string, GraphDb> &dbms;
private:
ConcurrentMap<std::string, GraphDb> &dbms;
const size_t cleaning_cycle;
const size_t cleaning_cycle;
std::vector<std::unique_ptr<Thread>> cleaners;
std::vector<std::unique_ptr<Thread>> cleaners;
// Should i continue cleaning.
std::atomic<bool> cleaning = {true};
// Should i continue cleaning.
std::atomic<bool> cleaning = {true};
};

View File

@ -1,21 +1,21 @@
#include "dbms/dbms.hpp"
GraphDbAccessor Dbms::active() {
return GraphDbAccessor(*active_db.load(std::memory_order_acquire));
return GraphDbAccessor(*active_db.load(std::memory_order_acquire));
}
GraphDbAccessor Dbms::active(const std::string &name) {
auto acc = dbs.access();
// create db if it doesn't exist
auto it = acc.find(name);
if (it == acc.end()) {
it = acc.emplace(name, std::forward_as_tuple(name),
std::forward_as_tuple(name))
.first;
}
auto acc = dbs.access();
// create db if it doesn't exist
auto it = acc.find(name);
if (it == acc.end()) {
it = acc.emplace(name, std::forward_as_tuple(name),
std::forward_as_tuple(name))
.first;
}
// set and return active db
auto &db = it->second;
active_db.store(&db, std::memory_order_release);
return GraphDbAccessor(db);
// set and return active db
auto &db = it->second;
active_db.store(&db, std::memory_order_release);
return GraphDbAccessor(db);
}

View File

@ -8,42 +8,42 @@
//#include "dbms/cleaner.hpp"
//#include "snapshot/snapshoter.hpp"
class Dbms
{
public:
Dbms() {
// create the default database and set is a active
active("default");
}
class Dbms {
public:
Dbms() {
// create the default database and set is a active
active("default");
}
/**
* Returns an accessor to the active database.
*/
GraphDbAccessor active();
/**
* Returns an accessor to the active database.
*/
GraphDbAccessor active();
/**
* Set the database with the given name to be active.
* If there is no database with the given name,
* it's created.
*
* @return an accessor to the database with the given name.
*/
GraphDbAccessor active(const std::string &name);
/**
* Set the database with the given name to be active.
* If there is no database with the given name,
* it's created.
*
* @return an accessor to the database with the given name.
*/
GraphDbAccessor active(const std::string &name);
// TODO: DELETE action
// TODO: DELETE action
private:
// dbs container
ConcurrentMap<std::string, GraphDb> dbs;
private:
// dbs container
ConcurrentMap<std::string, GraphDb> dbs;
// currently active database
std::atomic<GraphDb *> active_db;
// currently active database
std::atomic<GraphDb *> active_db;
// // Cleaning thread.
// TODO re-enable cleaning
// Cleaning cleaning = {dbs, CONFIG_INTEGER(config::CLEANING_CYCLE_SEC)};
//
// // Snapshoting thread.
// TODO re-enable cleaning
// Snapshoter snapshoter = {dbs, CONFIG_INTEGER(config::SNAPSHOT_CYCLE_SEC)};
// // Cleaning thread.
// TODO re-enable cleaning
// Cleaning cleaning = {dbs, CONFIG_INTEGER(config::CLEANING_CYCLE_SEC)};
//
// // Snapshoting thread.
// TODO re-enable cleaning
// Snapshoter snapshoter = {dbs,
// CONFIG_INTEGER(config::SNAPSHOT_CYCLE_SEC)};
};

View File

@ -24,158 +24,149 @@ using namespace std;
static Option<VertexAccessor> empty_op_vacc;
// Base importer with common facilities.
class BaseImporter
{
class BaseImporter {
public:
BaseImporter(DbAccessor &db, Logger &&logger)
: db(db), logger(std::move(logger)) {}
public:
BaseImporter(DbAccessor &db, Logger &&logger)
: db(db), logger(std::move(logger))
{
}
char *cstr(string &str) { return &str[0]; }
char *cstr(string &str) { return &str[0]; }
bool split(string &str, char mark, vector<char *> &sub_str) {
return split(cstr(str), mark, sub_str);
}
bool split(string &str, char mark, vector<char *> &sub_str)
{
return split(cstr(str), mark, sub_str);
}
// Occurances of mark are changed with '\0'. sub_str is filled with
// pointers to parts of str splited by mark in ascending order. Empty
// sub_str are included. Doesn't split inside quotations and
// open_bracket,closed_bracket.
// Returns true if it was succesfully parsed.
bool split(char *str, char mark, vector<char *> &sub_str) {
int head = 0;
bool in_text = false;
bool in_array = false;
// Occurances of mark are changed with '\0'. sub_str is filled with
// pointers to parts of str splited by mark in ascending order. Empty
// sub_str are included. Doesn't split inside quotations and
// open_bracket,closed_bracket.
// Returns true if it was succesfully parsed.
bool split(char *str, char mark, vector<char *> &sub_str)
{
for (int i = 0; str[i] != '\0'; i++) {
char &c = str[i];
int head = 0;
bool in_text = false;
bool in_array = false;
for (int i = 0; str[i] != '\0'; i++) {
char &c = str[i];
// IN TEXT check
if (c == quotations_mark) {
in_text = !in_text;
if (in_text && head == i) {
c = '\0';
head = i + 1;
} else if (!in_text && !in_array) {
c = '\0';
}
continue;
} else if (in_text) {
continue;
}
// IN ARRAY check
if (c == open_bracket) {
if (in_array) {
logger.error("Nested arrays aren't supported.");
return false;
}
in_array = true;
continue;
}
if (in_array) {
if (c == closed_bracket) {
in_array = false;
}
continue;
}
// SPLIT CHECK
if (c == mark) {
c = '\0';
sub_str.push_back(&str[head]);
head = i + 1;
}
// IN TEXT check
if (c == quotations_mark) {
in_text = !in_text;
if (in_text && head == i) {
c = '\0';
head = i + 1;
} else if (!in_text && !in_array) {
c = '\0';
}
continue;
} else if (in_text) {
continue;
}
sub_str.push_back(&str[head]);
return true;
}
// Extracts parts of str while stripping parts of array chars and qutation
// marks. Parts are separated with delimiter.
void extract(char *str, const char delimiter, vector<char *> &sub_str)
{
int head = 0;
bool in_text = false;
for (int i = 0; str[i] != '\0'; i++) {
char &c = str[i];
// IN TEXT check
if (c == quotations_mark) {
in_text = !in_text;
if (in_text) {
} else {
c = '\0';
sub_str.push_back(&str[head]);
head = i + 1;
}
head = i + 1;
continue;
} else if (in_text) {
continue;
}
// IN ARRAY check
if (c == open_bracket) {
head = i + 1;
continue;
} else if (c == closed_bracket) {
c = '\0';
if (i > head) {
sub_str.push_back(&str[head]);
}
head = i + 1;
continue;
}
// SPLIT CHECK
if (c == delimiter) {
c = '\0';
if (i > head) {
sub_str.push_back(&str[head]);
}
head = i + 1;
} else if (c == ' ' && i == head) {
head++;
}
// IN ARRAY check
if (c == open_bracket) {
if (in_array) {
logger.error("Nested arrays aren't supported.");
return false;
}
in_array = true;
continue;
}
if (in_array) {
if (c == closed_bracket) {
in_array = false;
}
continue;
}
// SPLIT CHECK
if (c == mark) {
c = '\0';
sub_str.push_back(&str[head]);
head = i + 1;
}
}
// Optionaly return vertex with given import local id if it exists.
Option<VertexAccessor> const &get_vertex(size_t id)
{
if (vertices.size() > id) {
return vertices[id];
sub_str.push_back(&str[head]);
return true;
}
// Extracts parts of str while stripping parts of array chars and qutation
// marks. Parts are separated with delimiter.
void extract(char *str, const char delimiter, vector<char *> &sub_str) {
int head = 0;
bool in_text = false;
for (int i = 0; str[i] != '\0'; i++) {
char &c = str[i];
// IN TEXT check
if (c == quotations_mark) {
in_text = !in_text;
if (in_text) {
} else {
cout << vertices.size() << " -> " << id << endl;
return empty_op_vacc;
c = '\0';
sub_str.push_back(&str[head]);
head = i + 1;
}
head = i + 1;
continue;
} else if (in_text) {
continue;
}
// IN ARRAY check
if (c == open_bracket) {
head = i + 1;
continue;
} else if (c == closed_bracket) {
c = '\0';
if (i > head) {
sub_str.push_back(&str[head]);
}
head = i + 1;
continue;
}
// SPLIT CHECK
if (c == delimiter) {
c = '\0';
if (i > head) {
sub_str.push_back(&str[head]);
}
head = i + 1;
} else if (c == ' ' && i == head) {
head++;
}
}
public:
DbAccessor &db;
Logger logger;
sub_str.push_back(&str[head]);
}
// Varius marks and delimiters. They can be freely changed here and
// everything will work.
char parts_mark = ',';
char parts_array_mark = ',';
char type_mark = ':';
char quotations_mark = '"';
char open_bracket = '[';
char closed_bracket = ']';
// Optionaly return vertex with given import local id if it exists.
Option<VertexAccessor> const &get_vertex(size_t id) {
if (vertices.size() > id) {
return vertices[id];
} else {
cout << vertices.size() << " -> " << id << endl;
return empty_op_vacc;
}
}
protected:
// All created vertices which have import local id.
vector<Option<VertexAccessor>> vertices;
public:
DbAccessor &db;
Logger logger;
// Varius marks and delimiters. They can be freely changed here and
// everything will work.
char parts_mark = ',';
char parts_array_mark = ',';
char type_mark = ':';
char quotations_mark = '"';
char open_bracket = '[';
char closed_bracket = ']';
protected:
// All created vertices which have import local id.
vector<Option<VertexAccessor>> vertices;
};

View File

@ -59,311 +59,293 @@ bool equal_str(const char *a, const char *b) { return strcasecmp(a, b) == 0; }
// If name is missing the column data wont be saved into the elements.
// if the type is missing the column will be interperted as type string. If
// neither name nor type are present column will be skipped.
class CSVImporter : public BaseImporter
{
class CSVImporter : public BaseImporter {
public:
CSVImporter(DbAccessor &db)
: BaseImporter(db, logging::log->logger("CSV_import")) {}
public:
CSVImporter(DbAccessor &db)
: BaseImporter(db, logging::log->logger("CSV_import"))
{
// Loads data from stream and returns number of loaded vertexes.
size_t import_vertices(std::fstream &file) {
return import<TypeGroupVertex>(file, create_vertex, true);
}
// Loads data from stream and returns number of loaded edges.
size_t import_edges(std::fstream &file) {
return import<TypeGroupEdge>(file, create_edge, false);
}
private:
// Loads data from file and returns number of loaded name.
// TG - TypeGroup
// F - function which will create element from filled element skelleton.
template <class TG, class F>
size_t import(std::fstream &file, F f, bool vertex) {
string line;
vector<char *> sub_str;
vector<unique_ptr<Filler>> fillers;
vector<char *> tmp;
// HEADERS
if (!getline(file, line)) {
logger.error("No lines");
return 0;
}
// Loads data from stream and returns number of loaded vertexes.
size_t import_vertices(std::fstream &file)
{
return import<TypeGroupVertex>(file, create_vertex, true);
if (!split(line, parts_mark, sub_str)) {
logger.error("Illegal headers");
return 0;
}
// Loads data from stream and returns number of loaded edges.
size_t import_edges(std::fstream &file)
{
return import<TypeGroupEdge>(file, create_edge, false);
for (auto p : sub_str) {
auto o = get_filler<TG>(p, tmp, vertex);
if (o.is_present()) {
fillers.push_back(o.take());
} else {
return 0;
}
}
sub_str.clear();
// LOAD DATA LINES
size_t count = 0;
size_t line_no = 1;
ElementSkeleton es(db);
while (std::getline(file, line)) {
sub_str.clear();
es.clear();
if (split(line, parts_mark, sub_str)) {
check_for_part_count(sub_str.size() - fillers.size(), line_no);
int n = min(sub_str.size(), fillers.size());
for (int i = 0; i < n; i++) {
auto er = fillers[i]->fill(es, sub_str[i]);
if (er.is_present()) {
logger.error("{} on line: {}", er.get(), line_no);
}
}
if (f(this, es, line_no)) {
count++;
}
}
line_no++;
}
private:
// Loads data from file and returns number of loaded name.
// TG - TypeGroup
// F - function which will create element from filled element skelleton.
template <class TG, class F>
size_t import(std::fstream &file, F f, bool vertex)
{
string line;
vector<char *> sub_str;
vector<unique_ptr<Filler>> fillers;
vector<char *> tmp;
// HEADERS
if (!getline(file, line)) {
logger.error("No lines");
return 0;
}
if (!split(line, parts_mark, sub_str)) {
logger.error("Illegal headers");
return 0;
}
for (auto p : sub_str) {
auto o = get_filler<TG>(p, tmp, vertex);
if (o.is_present()) {
fillers.push_back(o.take());
} else {
return 0;
}
}
sub_str.clear();
// LOAD DATA LINES
size_t count = 0;
size_t line_no = 1;
ElementSkeleton es(db);
while (std::getline(file, line)) {
sub_str.clear();
es.clear();
if (split(line, parts_mark, sub_str)) {
check_for_part_count(sub_str.size() - fillers.size(), line_no);
int n = min(sub_str.size(), fillers.size());
for (int i = 0; i < n; i++) {
auto er = fillers[i]->fill(es, sub_str[i]);
if (er.is_present()) {
logger.error("{} on line: {}", er.get(), line_no);
}
}
if (f(this, es, line_no)) {
count++;
}
}
line_no++;
}
return count;
}
static bool create_vertex(CSVImporter *im, ElementSkeleton &es,
size_t line_no)
{
auto va = es.add_vertex();
auto id = es.element_id();
if (id.is_present()) {
if (im->vertices.size() <= id.get()) {
Option<VertexAccessor> empty = make_option<VertexAccessor>();
im->vertices.insert(im->vertices.end(),
id.get() - im->vertices.size() + 1, empty);
}
if (im->vertices[id.get()].is_present()) {
im->logger.error("Vertex on line: {} has same id with another "
"previously loaded vertex",
line_no);
return false;
} else {
im->vertices[id.get()] = make_option(std::move(va));
return true;
}
} else {
im->logger.warn("Missing import local vertex id for vertex on "
"line: {}",
line_no);
}
return count;
}
static bool create_vertex(CSVImporter *im, ElementSkeleton &es,
size_t line_no) {
auto va = es.add_vertex();
auto id = es.element_id();
if (id.is_present()) {
if (im->vertices.size() <= id.get()) {
Option<VertexAccessor> empty = make_option<VertexAccessor>();
im->vertices.insert(im->vertices.end(),
id.get() - im->vertices.size() + 1, empty);
}
if (im->vertices[id.get()].is_present()) {
im->logger.error(
"Vertex on line: {} has same id with another "
"previously loaded vertex",
line_no);
return false;
} else {
im->vertices[id.get()] = make_option(std::move(va));
return true;
}
} else {
im->logger.warn(
"Missing import local vertex id for vertex on "
"line: {}",
line_no);
}
static bool create_edge(CSVImporter *im, ElementSkeleton &es,
size_t line_no)
{
auto o = es.add_edge();
if (!o.is_present()) {
return true;
} else {
im->logger.error("{} on line: {}", o.get(), line_no);
return false;
}
return true;
}
static bool create_edge(CSVImporter *im, ElementSkeleton &es,
size_t line_no) {
auto o = es.add_edge();
if (!o.is_present()) {
return true;
} else {
im->logger.error("{} on line: {}", o.get(), line_no);
return false;
}
}
template <class TG>
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey property_key(
const char *name, Flags type) {
assert(false);
}
// Returns filler for name:type in header_part. None if error occured.
template <class TG>
Option<unique_ptr<Filler>> get_filler(char *header_part,
vector<char *> &tmp_vec, bool vertex) {
tmp_vec.clear();
split(header_part, type_mark, tmp_vec);
const char *name = tmp_vec[0];
const char *type = tmp_vec[1];
if (tmp_vec.size() > 2) {
logger.error("To much sub parts in header part");
return make_option<unique_ptr<Filler>>();
} else if (tmp_vec.size() < 2) {
if (tmp_vec.size() == 1) {
logger.warn(
"Column: {} doesn't have specified type so string "
"type will be used",
tmp_vec[0]);
name = tmp_vec[0];
type = _string;
} else {
logger.warn("Empty colum definition, skiping column.");
std::unique_ptr<Filler> f(new SkipFiller());
return make_option(std::move(f));
}
} else {
name = tmp_vec[0];
type = tmp_vec[1];
}
template <class TG>
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey
property_key(const char *name, Flags type)
{
assert(false);
// Create adequat filler
if (equal_str(type, "id")) {
std::unique_ptr<Filler> f(
name[0] == '\0' ? new IdFiller<TG>()
: new IdFiller<TG>(make_option(
property_key<TG>(name, Flags::Int64))));
return make_option(std::move(f));
} else if (equal_str(type, "start_id") || equal_str(type, "from_id") ||
equal_str(type, "from") || equal_str(type, "source")) {
std::unique_ptr<Filler> f(new FromFiller(*this));
return make_option(std::move(f));
} else if (equal_str(type, "label")) {
std::unique_ptr<Filler> f(new LabelFiller(*this));
return make_option(std::move(f));
} else if (equal_str(type, "end_id") || equal_str(type, "to_id") ||
equal_str(type, "to") || equal_str(type, "target")) {
std::unique_ptr<Filler> f(new ToFiller(*this));
return make_option(std::move(f));
} else if (equal_str(type, "type")) {
std::unique_ptr<Filler> f(new TypeFiller(*this));
return make_option(std::move(f));
} else if (name[0] == '\0') { // OTHER FILLERS REQUIRE NAME
logger.warn("Unnamed column of type: {} will be skipped.", type);
std::unique_ptr<Filler> f(new SkipFiller());
return make_option(std::move(f));
// *********************** PROPERTIES
} else if (equal_str(type, "bool")) {
std::unique_ptr<Filler> f(
new BoolFiller<TG>(property_key<TG>(name, Flags::Bool)));
return make_option(std::move(f));
} else if (equal_str(type, "double") ||
(UPLIFT_PRIMITIVES && equal_str(type, "float"))) {
std::unique_ptr<Filler> f(
new DoubleFiller<TG>(property_key<TG>(name, Flags::Double)));
return make_option(std::move(f));
} else if (equal_str(type, "float")) {
std::unique_ptr<Filler> f(
new FloatFiller<TG>(property_key<TG>(name, Flags::Float)));
return make_option(std::move(f));
} else if (equal_str(type, "long") ||
(UPLIFT_PRIMITIVES && equal_str(type, "int"))) {
std::unique_ptr<Filler> f(
new Int64Filler<TG>(property_key<TG>(name, Flags::Int64)));
return make_option(std::move(f));
} else if (equal_str(type, "int")) {
std::unique_ptr<Filler> f(
new Int32Filler<TG>(property_key<TG>(name, Flags::Int32)));
return make_option(std::move(f));
} else if (equal_str(type, "string")) {
std::unique_ptr<Filler> f(
new StringFiller<TG>(property_key<TG>(name, Flags::String)));
return make_option(std::move(f));
} else if (equal_str(type, "bool[]")) {
std::unique_ptr<Filler> f(make_array_filler<TG, bool, ArrayBool>(
*this, property_key<TG>(name, Flags::ArrayBool), to_bool));
return make_option(std::move(f));
} else if (equal_str(type, "double[]") ||
(UPLIFT_PRIMITIVES && equal_str(type, "float[]"))) {
std::unique_ptr<Filler> f(make_array_filler<TG, double, ArrayDouble>(
*this, property_key<TG>(name, Flags::ArrayDouble), to_double));
return make_option(std::move(f));
} else if (equal_str(type, "float[]")) {
std::unique_ptr<Filler> f(make_array_filler<TG, float, ArrayFloat>(
*this, property_key<TG>(name, Flags::ArrayFloat), to_float));
return make_option(std::move(f));
} else if (equal_str(type, "long[]") ||
(UPLIFT_PRIMITIVES && equal_str(type, "int[]"))) {
std::unique_ptr<Filler> f(make_array_filler<TG, int64_t, ArrayInt64>(
*this, property_key<TG>(name, Flags::ArrayInt64), to_int64));
return make_option(std::move(f));
} else if (equal_str(type, "int[]")) {
std::unique_ptr<Filler> f(make_array_filler<TG, int32_t, ArrayInt32>(
*this, property_key<TG>(name, Flags::ArrayInt32), to_int32));
return make_option(std::move(f));
} else if (equal_str(type, "string[]")) {
std::unique_ptr<Filler> f(make_array_filler<TG, std::string, ArrayString>(
*this, property_key<TG>(name, Flags::ArrayString), to_string));
return make_option(std::move(f));
} else {
logger.error("Unknown type: {}", type);
return make_option<unique_ptr<Filler>>();
}
}
// Returns filler for name:type in header_part. None if error occured.
template <class TG>
Option<unique_ptr<Filler>> get_filler(char *header_part,
vector<char *> &tmp_vec, bool vertex)
{
tmp_vec.clear();
split(header_part, type_mark, tmp_vec);
const char *name = tmp_vec[0];
const char *type = tmp_vec[1];
if (tmp_vec.size() > 2) {
logger.error("To much sub parts in header part");
return make_option<unique_ptr<Filler>>();
} else if (tmp_vec.size() < 2) {
if (tmp_vec.size() == 1) {
logger.warn("Column: {} doesn't have specified type so string "
"type will be used",
tmp_vec[0]);
name = tmp_vec[0];
type = _string;
} else {
logger.warn("Empty colum definition, skiping column.");
std::unique_ptr<Filler> f(new SkipFiller());
return make_option(std::move(f));
}
} else {
name = tmp_vec[0];
type = tmp_vec[1];
}
// Create adequat filler
if (equal_str(type, "id")) {
std::unique_ptr<Filler> f(
name[0] == '\0' ? new IdFiller<TG>()
: new IdFiller<TG>(make_option(
property_key<TG>(name, Flags::Int64))));
return make_option(std::move(f));
} else if (equal_str(type, "start_id") || equal_str(type, "from_id") ||
equal_str(type, "from") || equal_str(type, "source")) {
std::unique_ptr<Filler> f(new FromFiller(*this));
return make_option(std::move(f));
} else if (equal_str(type, "label")) {
std::unique_ptr<Filler> f(new LabelFiller(*this));
return make_option(std::move(f));
} else if (equal_str(type, "end_id") || equal_str(type, "to_id") ||
equal_str(type, "to") || equal_str(type, "target")) {
std::unique_ptr<Filler> f(new ToFiller(*this));
return make_option(std::move(f));
} else if (equal_str(type, "type")) {
std::unique_ptr<Filler> f(new TypeFiller(*this));
return make_option(std::move(f));
} else if (name[0] == '\0') { // OTHER FILLERS REQUIRE NAME
logger.warn("Unnamed column of type: {} will be skipped.", type);
std::unique_ptr<Filler> f(new SkipFiller());
return make_option(std::move(f));
// *********************** PROPERTIES
} else if (equal_str(type, "bool")) {
std::unique_ptr<Filler> f(
new BoolFiller<TG>(property_key<TG>(name, Flags::Bool)));
return make_option(std::move(f));
} else if (equal_str(type, "double") ||
(UPLIFT_PRIMITIVES && equal_str(type, "float"))) {
std::unique_ptr<Filler> f(
new DoubleFiller<TG>(property_key<TG>(name, Flags::Double)));
return make_option(std::move(f));
} else if (equal_str(type, "float")) {
std::unique_ptr<Filler> f(
new FloatFiller<TG>(property_key<TG>(name, Flags::Float)));
return make_option(std::move(f));
} else if (equal_str(type, "long") ||
(UPLIFT_PRIMITIVES && equal_str(type, "int"))) {
std::unique_ptr<Filler> f(
new Int64Filler<TG>(property_key<TG>(name, Flags::Int64)));
return make_option(std::move(f));
} else if (equal_str(type, "int")) {
std::unique_ptr<Filler> f(
new Int32Filler<TG>(property_key<TG>(name, Flags::Int32)));
return make_option(std::move(f));
} else if (equal_str(type, "string")) {
std::unique_ptr<Filler> f(
new StringFiller<TG>(property_key<TG>(name, Flags::String)));
return make_option(std::move(f));
} else if (equal_str(type, "bool[]")) {
std::unique_ptr<Filler> f(make_array_filler<TG, bool, ArrayBool>(
*this, property_key<TG>(name, Flags::ArrayBool), to_bool));
return make_option(std::move(f));
} else if (equal_str(type, "double[]") ||
(UPLIFT_PRIMITIVES && equal_str(type, "float[]"))) {
std::unique_ptr<Filler> f(
make_array_filler<TG, double, ArrayDouble>(
*this, property_key<TG>(name, Flags::ArrayDouble),
to_double));
return make_option(std::move(f));
} else if (equal_str(type, "float[]")) {
std::unique_ptr<Filler> f(make_array_filler<TG, float, ArrayFloat>(
*this, property_key<TG>(name, Flags::ArrayFloat), to_float));
return make_option(std::move(f));
} else if (equal_str(type, "long[]") ||
(UPLIFT_PRIMITIVES && equal_str(type, "int[]"))) {
std::unique_ptr<Filler> f(
make_array_filler<TG, int64_t, ArrayInt64>(
*this, property_key<TG>(name, Flags::ArrayInt64),
to_int64));
return make_option(std::move(f));
} else if (equal_str(type, "int[]")) {
std::unique_ptr<Filler> f(
make_array_filler<TG, int32_t, ArrayInt32>(
*this, property_key<TG>(name, Flags::ArrayInt32),
to_int32));
return make_option(std::move(f));
} else if (equal_str(type, "string[]")) {
std::unique_ptr<Filler> f(
make_array_filler<TG, std::string, ArrayString>(
*this, property_key<TG>(name, Flags::ArrayString),
to_string));
return make_option(std::move(f));
} else {
logger.error("Unknown type: {}", type);
return make_option<unique_ptr<Filler>>();
}
}
void check_for_part_count(long diff, long line_no)
{
if (diff != 0) {
if (diff < 0) {
logger.warn("Line no: {} has less parts then specified in "
"header. Missing: {} parts",
line_no, diff);
} else {
logger.warn("Line no: {} has more parts then specified in "
"header. Extra: {} parts",
line_no, diff);
}
}
void check_for_part_count(long diff, long line_no) {
if (diff != 0) {
if (diff < 0) {
logger.warn(
"Line no: {} has less parts then specified in "
"header. Missing: {} parts",
line_no, diff);
} else {
logger.warn(
"Line no: {} has more parts then specified in "
"header. Extra: {} parts",
line_no, diff);
}
}
}
};
template <>
PropertyFamily<TypeGroupVertex>::PropertyType::PropertyFamilyKey
CSVImporter::property_key<TypeGroupVertex>(const char *name, Flags type)
{
return db.vertex_property_key(name, Type(type));
CSVImporter::property_key<TypeGroupVertex>(const char *name, Flags type) {
return db.vertex_property_key(name, Type(type));
}
template <>
PropertyFamily<TypeGroupEdge>::PropertyType::PropertyFamilyKey
CSVImporter::property_key<TypeGroupEdge>(const char *name, Flags type)
{
return db.edge_property_key(name, Type(type));
CSVImporter::property_key<TypeGroupEdge>(const char *name, Flags type) {
return db.edge_property_key(name, Type(type));
}
// Imports all -v "vertex_file_path.csv" vertices and -e "edge_file_path.csv"
@ -371,48 +353,47 @@ CSVImporter::property_key<TypeGroupEdge>(const char *name, Flags type)
// -d delimiter => sets delimiter for parsing .csv files. Default is ,
// -ad delimiter => sets delimiter for parsing arrays in .csv. Default is
// Returns (no loaded vertices,no loaded edges)
std::pair<size_t, size_t>
import_csv_from_arguments(Db &db, std::vector<std::string> &para)
{
DbAccessor t(db);
CSVImporter imp(t);
std::pair<size_t, size_t> import_csv_from_arguments(
Db &db, std::vector<std::string> &para) {
DbAccessor t(db);
CSVImporter imp(t);
imp.parts_mark = get_argument(para, "-d", ",")[0];
imp.parts_array_mark = get_argument(para, "-ad", ",")[0];
imp.parts_mark = get_argument(para, "-d", ",")[0];
imp.parts_array_mark = get_argument(para, "-ad", ",")[0];
// IMPORT VERTICES
size_t l_v = 0;
auto o = take_argument(para, "-v");
while (o.is_present()) {
std::fstream file(o.get());
// IMPORT VERTICES
size_t l_v = 0;
auto o = take_argument(para, "-v");
while (o.is_present()) {
std::fstream file(o.get());
imp.logger.info("Importing vertices from file: {}", o.get());
imp.logger.info("Importing vertices from file: {}", o.get());
auto n = imp.import_vertices(file);
l_v = +n;
auto n = imp.import_vertices(file);
l_v = +n;
imp.logger.info("Loaded: {} vertices from {}", n, o.get());
imp.logger.info("Loaded: {} vertices from {}", n, o.get());
o = take_argument(para, "-v");
}
o = take_argument(para, "-v");
}
// IMPORT EDGES
size_t l_e = 0;
o = take_argument(para, "-e");
while (o.is_present()) {
std::fstream file(o.get());
imp.logger.info("Importing edges from file: {}", o.get());
auto n = imp.import_edges(file);
l_e = +n;
imp.logger.info("Loaded: {} edges from {}", n, o.get());
// IMPORT EDGES
size_t l_e = 0;
o = take_argument(para, "-e");
while (o.is_present()) {
std::fstream file(o.get());
}
imp.logger.info("Importing edges from file: {}", o.get());
t.commit();
auto n = imp.import_edges(file);
l_e = +n;
imp.logger.info("Loaded: {} edges from {}", n, o.get());
o = take_argument(para, "-e");
}
t.commit();
return std::make_pair(l_v, l_e);
return std::make_pair(l_v, l_e);
}

View File

@ -3,16 +3,15 @@
#include <cassert>
#include "database/db_accessor.hpp"
#include "storage/vertex_accessor.hpp"
#include "storage/model/typed_value.hpp"
#include "storage/model/typed_value_store.hpp"
#include "storage/vertex_accessor.hpp"
// Holder for element data which he can then insert as a vertex or edge into the
// database depending on the available data and called add_* method.
class ElementSkeleton {
public:
ElementSkeleton(DbAccessor &db) : db(db) {};
public:
ElementSkeleton(DbAccessor &db) : db(db){};
void add_property(StoredProperty<TypeGroupVertex> &&prop) {
properties_v.push_back(std::move(prop));
@ -22,9 +21,7 @@ public:
properties_e.push_back(std::move(prop));
}
void set_element_id(size_t id) {
el_id = make_option<size_t>(std::move(id));
}
void set_element_id(size_t id) { el_id = make_option<size_t>(std::move(id)); }
void add_label(Label const &label) { labels.push_back(&label); }
@ -91,7 +88,7 @@ public:
// Returns import local id.
Option<size_t> element_id() { return el_id; }
private:
private:
DbAccessor &db;
Option<size_t> el_id;

View File

@ -9,26 +9,21 @@
// Parses boolean.
// TG - Type group
template <class TG>
class BoolFiller : public Filler
{
class BoolFiller : public Filler {
public:
BoolFiller(typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key)
: key(key) {}
public:
BoolFiller(typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key)
: key(key)
{
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
if (str[0] != '\0') {
data.add_property(StoredProperty<TG>(Bool(to_bool(str)), key));
}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
if (str[0] != '\0') {
data.add_property(StoredProperty<TG>(Bool(to_bool(str)), key));
}
return make_option<std::string>();
}
return make_option<std::string>();
}
private:
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key;
private:
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key;
};

View File

@ -1,15 +1,14 @@
#pragma once
#include <strings.h>
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <string>
#include <strings.h>
#include "storage/model/properties/all.hpp"
bool string2bool(const char *v)
{
return strcasecmp(v, "true") == 0 || atoi(v) != 0;
bool string2bool(const char *v) {
return strcasecmp(v, "true") == 0 || atoi(v) != 0;
}
bool to_bool(const char *str) { return string2bool(str); }

View File

@ -5,10 +5,9 @@
// Common class for varius classes which accept one part from data line in
// import, parses it and adds it into element skelleton.
class Filler
{
public:
// Fills skeleton with data from str. Returns error description if
// error occurs.
virtual Option<std::string> fill(ElementSkeleton &data, char *str) = 0;
class Filler {
public:
// Fills skeleton with data from str. Returns error description if
// error occurs.
virtual Option<std::string> fill(ElementSkeleton &data, char *str) = 0;
};

View File

@ -9,26 +9,20 @@
// Parses float.
// TG - Type group
template <class TG>
class FloatFiller : public Filler
{
public:
FloatFiller(
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key)
: key(key)
{
}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
if (str[0] != '\0') {
data.add_property(StoredProperty<TG>(Float(to_float(str)), key));
}
return make_option<std::string>();
class FloatFiller : public Filler {
public:
FloatFiller(typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key)
: key(key) {}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
if (str[0] != '\0') {
data.add_property(StoredProperty<TG>(Float(to_float(str)), key));
}
private:
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key;
return make_option<std::string>();
}
private:
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key;
};

View File

@ -7,32 +7,28 @@
#include "storage/model/properties/property_family.hpp"
// Parses from id of vertex for edge.
class FromFiller : public Filler
{
class FromFiller : public Filler {
public:
FromFiller(BaseImporter &db) : bim(db) {}
public:
FromFiller(BaseImporter &db) : bim(db) {}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
if (str[0] != '\0') {
auto id = atol(str);
Option<VertexAccessor> const &oav = bim.get_vertex(id);
if (oav.is_present()) {
data.set_from(VertexAccessor(oav.get()));
return make_option<std::string>();
} else {
return make_option(
std::string("Unknown vertex in from field with id: ") +
str);
}
} else {
return make_option(std::string("From field must be spceified"));
}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
if (str[0] != '\0') {
auto id = atol(str);
Option<VertexAccessor> const &oav = bim.get_vertex(id);
if (oav.is_present()) {
data.set_from(VertexAccessor(oav.get()));
return make_option<std::string>();
} else {
return make_option(
std::string("Unknown vertex in from field with id: ") + str);
}
} else {
return make_option(std::string("From field must be spceified"));
}
}
private:
BaseImporter &bim;
private:
BaseImporter &bim;
};

View File

@ -5,40 +5,31 @@
// Parses import local Id.
// TG - Type group
template <class TG>
class IdFiller : public Filler
{
class IdFiller : public Filler {
public:
IdFiller()
: key(make_option<
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey>()) {}
public:
IdFiller()
: key(make_option<
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey>())
{
IdFiller(
Option<typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey> key)
: key(key) {
assert(!key.is_present() || key.get().prop_type() == Type(Flags::Int64));
}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
if (str[0] != '\0') {
data.set_element_id(atol(str));
if (key.is_present()) {
data.add_property(StoredProperty<TG>(Int64(to_int64(str)), key.get()));
}
}
IdFiller(
Option<typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey>
key)
: key(key)
{
assert(!key.is_present() ||
key.get().prop_type() == Type(Flags::Int64));
}
return make_option<std::string>();
}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
if (str[0] != '\0') {
data.set_element_id(atol(str));
if (key.is_present()) {
data.add_property(
StoredProperty<TG>(Int64(to_int64(str)), key.get()));
}
}
return make_option<std::string>();
}
private:
Option<typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey> key;
private:
Option<typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey> key;
};

View File

@ -9,26 +9,20 @@
// Parses int64.
// TG - Type group
template <class TG>
class Int64Filler : public Filler
{
public:
Int64Filler(
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key)
: key(key)
{
}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
if (str[0] != '\0') {
data.add_property(StoredProperty<TG>(Int64(to_int64(str)), key));
}
return make_option<std::string>();
class Int64Filler : public Filler {
public:
Int64Filler(typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key)
: key(key) {}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
if (str[0] != '\0') {
data.add_property(StoredProperty<TG>(Int64(to_int64(str)), key));
}
private:
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key;
return make_option<std::string>();
}
private:
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key;
};

View File

@ -4,27 +4,24 @@
#include "import/fillings/filler.hpp"
// Parses array of labels.
class LabelFiller : public Filler
{
class LabelFiller : public Filler {
public:
LabelFiller(BaseImporter &db) : bim(db) {}
public:
LabelFiller(BaseImporter &db) : bim(db) {}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
sub_str.clear();
bim.extract(str, bim.parts_array_mark, sub_str);
for (auto s : sub_str) {
if (s[0] != '\0') {
data.add_label(bim.db.label_find_or_create(s));
}
}
return make_option<std::string>();
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
sub_str.clear();
bim.extract(str, bim.parts_array_mark, sub_str);
for (auto s : sub_str) {
if (s[0] != '\0') {
data.add_label(bim.db.label_find_or_create(s));
}
}
return make_option<std::string>();
}
private:
BaseImporter &bim;
vector<char *> sub_str;
private:
BaseImporter &bim;
vector<char *> sub_str;
};

View File

@ -7,14 +7,11 @@
#include "storage/model/properties/property_family.hpp"
// Skips column.
class SkipFiller : public Filler
{
public:
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
return make_option<std::string>();
}
class SkipFiller : public Filler {
public:
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
return make_option<std::string>();
}
};

View File

@ -9,26 +9,20 @@
// Parses string.
// TG - Type group
template <class TG>
class StringFiller : public Filler
{
public:
StringFiller(
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key)
: key(key)
{
}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
if (str[0] != '\0') {
data.add_property(StoredProperty<TG>(String(to_string(str)), key));
}
return make_option<std::string>();
class StringFiller : public Filler {
public:
StringFiller(typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key)
: key(key) {}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
if (str[0] != '\0') {
data.add_property(StoredProperty<TG>(String(to_string(str)), key));
}
private:
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key;
return make_option<std::string>();
}
private:
typename PropertyFamily<TG>::PropertyType::PropertyFamilyKey key;
};

View File

@ -7,31 +7,28 @@
#include "storage/model/properties/property_family.hpp"
// Parses to import local id of vertex for edge.
class ToFiller : public Filler
{
class ToFiller : public Filler {
public:
ToFiller(BaseImporter &db) : bim(db) {}
public:
ToFiller(BaseImporter &db) : bim(db) {}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
if (str[0] != '\0') {
auto id = atol(str);
Option<VertexAccessor> const &oav = bim.get_vertex(id);
if (oav.is_present()) {
data.set_to(VertexAccessor(oav.get()));
return make_option<std::string>();
} else {
return make_option(
std::string("Unknown vertex in to field with id: ") + str);
}
} else {
return make_option(std::string("To field must be spceified"));
}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
if (str[0] != '\0') {
auto id = atol(str);
Option<VertexAccessor> const &oav = bim.get_vertex(id);
if (oav.is_present()) {
data.set_to(VertexAccessor(oav.get()));
return make_option<std::string>();
} else {
return make_option(std::string("Unknown vertex in to field with id: ") +
str);
}
} else {
return make_option(std::string("To field must be spceified"));
}
}
private:
BaseImporter &bim;
private:
BaseImporter &bim;
};

View File

@ -4,23 +4,20 @@
#include "import/fillings/filler.hpp"
// Parses type of edge.
class TypeFiller : public Filler
{
class TypeFiller : public Filler {
public:
TypeFiller(BaseImporter &db) : bim(db) {}
public:
TypeFiller(BaseImporter &db) : bim(db) {}
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final
{
if (str[0] != '\0') {
data.set_type(bim.db.type_find_or_create(str));
}
return make_option<std::string>();
// Fills skeleton with data from str. Returns error description if
// error occurs.
Option<std::string> fill(ElementSkeleton &data, char *str) final {
if (str[0] != '\0') {
data.set_type(bim.db.type_find_or_create(str));
}
private:
BaseImporter &bim;
return make_option<std::string>();
}
private:
BaseImporter &bim;
};

View File

@ -1,46 +1,38 @@
#pragma once
#include <cstring>
#include <netdb.h>
#include <cstring>
#include "io/network/network_error.hpp"
#include "utils/underlying_cast.hpp"
namespace io
{
namespace io {
class AddrInfo
{
AddrInfo(struct addrinfo* info) : info(info) {}
class AddrInfo {
AddrInfo(struct addrinfo* info) : info(info) {}
public:
~AddrInfo()
{
freeaddrinfo(info);
}
public:
~AddrInfo() { freeaddrinfo(info); }
static AddrInfo get(const char* addr, const char* port)
{
struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo));
static AddrInfo get(const char* addr, const char* port) {
struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_UNSPEC; // IPv4 and IPv6
hints.ai_socktype = SOCK_STREAM; // TCP socket
hints.ai_flags = AI_PASSIVE;
hints.ai_family = AF_UNSPEC; // IPv4 and IPv6
hints.ai_socktype = SOCK_STREAM; // TCP socket
hints.ai_flags = AI_PASSIVE;
struct addrinfo* result;
auto status = getaddrinfo(addr, port, &hints, &result);
struct addrinfo* result;
auto status = getaddrinfo(addr, port, &hints, &result);
if(status != 0)
throw NetworkError(gai_strerror(status));
if (status != 0) throw NetworkError(gai_strerror(status));
return AddrInfo(result);
}
return AddrInfo(result);
}
operator struct addrinfo*() { return info; }
operator struct addrinfo*() { return info; }
private:
struct addrinfo* info;
private:
struct addrinfo* info;
};
}

View File

@ -2,34 +2,28 @@
#include "io/network/stream_reader.hpp"
namespace io
{
namespace io {
template <class Derived, class Stream>
class Client : public StreamReader<Derived, Stream>
{
public:
bool connect(const std::string& host, const std::string& port)
{
return connect(host.c_str(), port.c_str());
}
class Client : public StreamReader<Derived, Stream> {
public:
bool connect(const std::string& host, const std::string& port) {
return connect(host.c_str(), port.c_str());
}
bool connect(const char* host, const char* port)
{
auto socket = io::Socket::connect(host, port);
bool connect(const char* host, const char* port) {
auto socket = io::Socket::connect(host, port);
if(!socket.is_open())
return false;
if (!socket.is_open()) return false;
socket.set_non_blocking();
socket.set_non_blocking();
auto& stream = this->derived().on_connect(std::move(socket));
auto& stream = this->derived().on_connect(std::move(socket));
stream.event.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
this->add(stream);
stream.event.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
this->add(stream);
return true;
}
return true;
}
};
}

View File

@ -4,54 +4,43 @@
#include <sys/epoll.h>
#include "io/network/socket.hpp"
#include "utils/likely.hpp"
#include "logging/default.hpp"
#include "utils/likely.hpp"
namespace io
{
namespace io {
class EpollError : BasicException
{
public:
using BasicException::BasicException;
class EpollError : BasicException {
public:
using BasicException::BasicException;
};
class Epoll
{
public:
using Event = struct epoll_event;
class Epoll {
public:
using Event = struct epoll_event;
Epoll(int flags) :
logger(logging::log->logger("io::Epoll"))
{
epoll_fd = epoll_create1(flags);
Epoll(int flags) : logger(logging::log->logger("io::Epoll")) {
epoll_fd = epoll_create1(flags);
if(UNLIKELY(epoll_fd == -1))
throw EpollError("Can't create epoll file descriptor");
}
if (UNLIKELY(epoll_fd == -1))
throw EpollError("Can't create epoll file descriptor");
}
template <class Stream>
void add(Stream& stream, Event* event)
{
auto status = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, stream, event);
template <class Stream>
void add(Stream& stream, Event* event) {
auto status = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, stream, event);
if(UNLIKELY(status))
throw EpollError("Can't add an event to epoll listener.");
}
if (UNLIKELY(status))
throw EpollError("Can't add an event to epoll listener.");
}
int wait(Event* events, int max_events, int timeout)
{
return epoll_wait(epoll_fd, events, max_events, timeout);
}
int wait(Event* events, int max_events, int timeout) {
return epoll_wait(epoll_fd, events, max_events, timeout);
}
int id() const
{
return epoll_fd;
}
int id() const { return epoll_fd; }
private:
int epoll_fd;
Logger logger;
private:
int epoll_fd;
Logger logger;
};
}

View File

@ -1,90 +1,82 @@
#pragma once
#include "io/network/epoll.hpp"
#include "utils/crtp.hpp"
#include "logging/default.hpp"
#include "utils/crtp.hpp"
namespace io
{
namespace io {
template <class Derived, size_t max_events = 64, int wait_timeout = -1>
class EventListener : public Crtp<Derived>
{
public:
using Crtp<Derived>::derived;
class EventListener : public Crtp<Derived> {
public:
using Crtp<Derived>::derived;
EventListener(uint32_t flags = 0) :
listener(flags),
logger(logging::log->logger("io::EventListener"))
{
}
EventListener(uint32_t flags = 0)
: listener(flags), logger(logging::log->logger("io::EventListener")) {}
void wait_and_process_events()
{
// TODO hardcoded a wait timeout because of thread joining
// when you shutdown the server. This should be wait_timeout of the
// template parameter and should almost never change from that.
// thread joining should be resolved using a signal that interrupts
// the system call.
void wait_and_process_events() {
// TODO hardcoded a wait timeout because of thread joining
// when you shutdown the server. This should be wait_timeout of the
// template parameter and should almost never change from that.
// thread joining should be resolved using a signal that interrupts
// the system call.
// waits for an event/multiple events and returns a maximum of
// max_events and stores them in the events array. it waits for
// wait_timeout milliseconds. if wait_timeout is achieved, returns 0
auto n = listener.wait(events, max_events, 200);
// waits for an event/multiple events and returns a maximum of
// max_events and stores them in the events array. it waits for
// wait_timeout milliseconds. if wait_timeout is achieved, returns 0
auto n = listener.wait(events, max_events, 200);
#ifndef NDEBUG
#ifndef LOG_NO_TRACE
if (n > 0)
logger.trace("number of events: {}", n);
if (n > 0) logger.trace("number of events: {}", n);
#endif
#endif
// go through all events and process them in order
for (int i = 0; i < n; ++i) {
auto &event = events[i];
// go through all events and process them in order
for (int i = 0; i < n; ++i) {
auto &event = events[i];
try {
// hangup event
if (UNLIKELY(event.events & EPOLLRDHUP)) {
this->derived().on_close_event(event);
continue;
}
// there was an error on the server side
if (UNLIKELY(!(event.events & EPOLLIN) ||
event.events & (EPOLLHUP | EPOLLERR))) {
this->derived().on_error_event(event);
continue;
}
// we have some data waiting to be read
this->derived().on_data_event(event);
} catch (const std::exception &e) {
this->derived().on_exception_event(
event, "Error occured while processing event \n{}",
e.what());
}
try {
// hangup event
if (UNLIKELY(event.events & EPOLLRDHUP)) {
this->derived().on_close_event(event);
continue;
}
// this will be optimized out :D
if (wait_timeout < 0) return;
// there was an error on the server side
if (UNLIKELY(!(event.events & EPOLLIN) ||
event.events & (EPOLLHUP | EPOLLERR))) {
this->derived().on_error_event(event);
continue;
}
// if there was events, continue to wait on new events
if (n != 0) return;
// wait timeout occurred and there were no events. if wait_timeout
// is -1 there will never be any timeouts so client should provide
// an empty function. in that case the conditional above and the
// function call will be optimized out by the compiler
this->derived().on_wait_timeout();
// we have some data waiting to be read
this->derived().on_data_event(event);
} catch (const std::exception &e) {
this->derived().on_exception_event(
event, "Error occured while processing event \n{}", e.what());
}
}
protected:
Epoll listener;
Epoll::Event events[max_events];
// this will be optimized out :D
if (wait_timeout < 0) return;
private:
Logger logger;
// if there was events, continue to wait on new events
if (n != 0) return;
// wait timeout occurred and there were no events. if wait_timeout
// is -1 there will never be any timeouts so client should provide
// an empty function. in that case the conditional above and the
// function call will be optimized out by the compiler
this->derived().on_wait_timeout();
}
protected:
Epoll listener;
Epoll::Event events[max_events];
private:
Logger logger;
};
}

View File

@ -4,13 +4,10 @@
#include "utils/exceptions/basic_exception.hpp"
namespace io
{
namespace io {
class NetworkError : public BasicException
{
public:
using BasicException::BasicException;
class NetworkError : public BasicException {
public:
using BasicException::BasicException;
};
}

View File

@ -1,93 +1,63 @@
#pragma once
#include "tls.hpp"
#include "io/network/socket.hpp"
#include "tls.hpp"
#include "tls_error.hpp"
#include "utils/types/byte.hpp"
#include <iostream>
namespace io
{
namespace io {
class SecureSocket
{
public:
SecureSocket(Socket&& socket, const Tls::Context& tls)
: socket(std::forward<Socket>(socket))
{
ssl = SSL_new(tls);
SSL_set_fd(ssl, this->socket);
class SecureSocket {
public:
SecureSocket(Socket&& socket, const Tls::Context& tls)
: socket(std::forward<Socket>(socket)) {
ssl = SSL_new(tls);
SSL_set_fd(ssl, this->socket);
SSL_set_accept_state(ssl);
SSL_set_accept_state(ssl);
if(SSL_accept(ssl) <= 0)
ERR_print_errors_fp(stderr);
}
if (SSL_accept(ssl) <= 0) ERR_print_errors_fp(stderr);
}
SecureSocket(SecureSocket&& other)
{
*this = std::forward<SecureSocket>(other);
}
SecureSocket(SecureSocket&& other) {
*this = std::forward<SecureSocket>(other);
}
SecureSocket& operator=(SecureSocket&& other)
{
socket = std::move(other.socket);
SecureSocket& operator=(SecureSocket&& other) {
socket = std::move(other.socket);
ssl = other.ssl;
other.ssl = nullptr;
ssl = other.ssl;
other.ssl = nullptr;
return *this;
}
return *this;
}
~SecureSocket()
{
if(ssl == nullptr)
return;
~SecureSocket() {
if (ssl == nullptr) return;
std::cout << "DELETING SSL" << std::endl;
std::cout << "DELETING SSL" << std::endl;
SSL_free(ssl);
}
SSL_free(ssl);
}
int error(int status)
{
return SSL_get_error(ssl, status);
}
int error(int status) { return SSL_get_error(ssl, status); }
int write(const std::string& str)
{
return write(str.c_str(), str.size());
}
int write(const std::string& str) { return write(str.c_str(), str.size()); }
int write(const byte* data, size_t len)
{
return SSL_write(ssl, data, len);
}
int write(const byte* data, size_t len) { return SSL_write(ssl, data, len); }
int write(const char* data, size_t len)
{
return SSL_write(ssl, data, len);
}
int write(const char* data, size_t len) { return SSL_write(ssl, data, len); }
int read(char* buffer, size_t len)
{
return SSL_read(ssl, buffer, len);
}
int read(char* buffer, size_t len) { return SSL_read(ssl, buffer, len); }
operator int()
{
return socket;
}
operator int() { return socket; }
operator Socket&()
{
return socket;
}
operator Socket&() { return socket; }
private:
Socket socket;
SSL* ssl {nullptr};
private:
Socket socket;
SSL* ssl{nullptr};
};
}

Some files were not shown because too many files have changed in this diff Show More