Started network refactorization.

Summary:
Moved server and worker from bolt to communication. Started
templatization.

Started removal of Bolt class.

Removed unnecessary files from network.

Converted states to template functions.

Bolt::Session is now a template.

Merge remote-tracking branch 'origin/dev' into mg_refactor_network

Merged bolt_serializer.cpp into hpp.

Removed obsolete include.

Initial version of bolt session unit test.

Uncommented leftover log commands.

Reimplemented io::Socket.

Added client-stress.sh script.

Reviewers: dgleich, buda

Reviewed By: dgleich, buda

Subscribers: pullbot, mferencevic, buda

Differential Revision: https://phabricator.memgraph.io/D64
This commit is contained in:
Matej Ferencevic 2017-03-06 13:37:51 +01:00
parent 38c3c513fa
commit 813a3b9eed
56 changed files with 1677 additions and 1349 deletions

View File

@ -298,16 +298,11 @@ set(memgraph_src_files
${src_dir}/utils/string/join.cpp
${src_dir}/utils/string/file.cpp
${src_dir}/utils/numerics/saturate.cpp
${src_dir}/communication/bolt/v1/bolt.cpp
${src_dir}/communication/bolt/v1/states.cpp
${src_dir}/communication/bolt/v1/session.cpp
${src_dir}/communication/bolt/v1/states/error.cpp
${src_dir}/communication/bolt/v1/states/executor.cpp
${src_dir}/communication/bolt/v1/states/init.cpp
${src_dir}/communication/bolt/v1/states/handshake.cpp
${src_dir}/communication/bolt/v1/transport/bolt_decoder.cpp
${src_dir}/communication/bolt/v1/transport/buffer.cpp
${src_dir}/communication/bolt/v1/serialization/bolt_serializer.cpp
${src_dir}/io/network/addrinfo.cpp
${src_dir}/io/network/network_endpoint.cpp
${src_dir}/io/network/socket.cpp
${src_dir}/threading/thread.cpp
${src_dir}/mvcc/id.cpp
# ${src_dir}/snapshot/snapshot_engine.cpp

View File

@ -1,8 +0,0 @@
#pragma once
#include "communication/bolt/v1/serialization/record_stream.hpp"
#include "io/network/socket.hpp"
namespace communication {
using OutputStream = bolt::RecordStream<io::Socket>;
}

View File

@ -1,19 +0,0 @@
#include "communication/bolt/v1/bolt.hpp"
#include "communication/bolt/v1/session.hpp"
namespace bolt {
Bolt::Bolt() {}
Session* Bolt::create_session(io::Socket&& socket) {
// TODO fix session lifecycle handling
// dangling pointers are not cool :)
// TODO attach currently active Db
return new Session(std::forward<io::Socket>(socket), *this);
}
void Bolt::close(Session* session) { session->socket.close(); }
}

View File

@ -1,23 +0,0 @@
#pragma once
#include "communication/bolt/v1/states.hpp"
#include "dbms/dbms.hpp"
#include "io/network/socket.hpp"
namespace bolt {
class Session;
class Bolt {
friend class Session;
public:
Bolt();
Session *create_session(io::Socket &&socket);
void close(Session *session);
States states;
Dbms dbms;
};
}

View File

@ -1,104 +0,0 @@
#include "communication/bolt/v1/serialization/bolt_serializer.hpp"
#include "communication/bolt/v1/transport/chunked_buffer.hpp"
#include "communication/bolt/v1/transport/chunked_encoder.hpp"
#include "communication/bolt/v1/transport/socket_stream.hpp"
#include "io/network/socket.hpp"
#include "database/graph_db.hpp"
#include "database/graph_db_accessor.hpp"
#include "storage/property_value_store.hpp"
#include <cassert>
template <class Stream>
void bolt::BoltSerializer<Stream>::write(const VertexAccessor &vertex) {
// write signatures for the node struct and node data type
encoder.write_struct_header(3);
encoder.write(underlying_cast(pack::Code::Node));
// IMPORTANT: here we write a hardcoded 0 because we don't
// use internal IDs, but need to give something to Bolt
// note that OpenCypher has no id(x) function, so the client
// should not be able to do anything with this value anyway
encoder.write_integer(0); // uID
// write the list of labels
auto labels = vertex.labels();
encoder.write_list_header(labels.size());
for (auto label : labels)
encoder.write_string(vertex.db_accessor().label_name(label));
// write the properties
const PropertyValueStore<GraphDb::Property> &props = vertex.Properties();
encoder.write_map_header(props.size());
props.Accept([this, &vertex](const GraphDb::Property prop,
const PropertyValue &value) {
this->encoder.write(vertex.db_accessor().property_name(prop));
this->write(value);
});
}
template <class Stream>
void bolt::BoltSerializer<Stream>::write(const EdgeAccessor &edge) {
// write signatures for the edge struct and edge data type
encoder.write_struct_header(5);
encoder.write(underlying_cast(pack::Code::Relationship));
// IMPORTANT: here we write a hardcoded 0 because we don't
// use internal IDs, but need to give something to Bolt
// note that OpenCypher has no id(x) function, so the client
// should not be able to do anything with this value anyway
encoder.write_integer(0);
encoder.write_integer(0);
encoder.write_integer(0);
// write the type of the edge
encoder.write(edge.db_accessor().edge_type_name(edge.edge_type()));
// write the property map
const PropertyValueStore<GraphDb::Property> &props = edge.Properties();
encoder.write_map_header(props.size());
props.Accept(
[this, &edge](GraphDb::Property prop, const PropertyValue &value) {
this->encoder.write(edge.db_accessor().property_name(prop));
this->write(value);
});
}
template <class Stream>
void bolt::BoltSerializer<Stream>::write(const PropertyValue &value) {
switch (value.type()) {
case PropertyValue::Type::Null:
encoder.write_null();
return;
case PropertyValue::Type::Bool:
encoder.write_bool(value.Value<bool>());
return;
case PropertyValue::Type::String:
encoder.write_string(value.Value<std::string>());
return;
case PropertyValue::Type::Int:
encoder.write_integer(value.Value<int>());
return;
case PropertyValue::Type::Double:
encoder.write_double(value.Value<double>());
return;
case PropertyValue::Type::List:
// Not implemented
assert(false);
}
}
template <class Stream>
void bolt::BoltSerializer<Stream>::write_failure(
const std::map<std::string, std::string> &data) {
encoder.message_failure();
encoder.write_map_header(data.size());
for (auto const &kv : data) {
write(kv.first);
write(kv.second);
}
}
template class bolt::BoltSerializer<bolt::BoltEncoder<
bolt::ChunkedEncoder<bolt::ChunkedBuffer<bolt::SocketStream<io::Socket>>>>>;

View File

@ -2,11 +2,12 @@
#include "communication/bolt/v1/packing/codes.hpp"
#include "communication/bolt/v1/transport/bolt_encoder.hpp"
#include "communication/bolt/v1/transport/chunked_buffer.hpp"
#include "communication/bolt/v1/transport/chunked_encoder.hpp"
#include "storage/edge_accessor.hpp"
#include "storage/vertex_accessor.hpp"
#include "storage/property_value.hpp"
#include "database/graph_db.hpp"
#include "database/graph_db_accessor.hpp"
#include "storage/property_value_store.hpp"
namespace bolt {
@ -24,7 +25,32 @@ class BoltSerializer {
* }
*
*/
void write(const VertexAccessor &vertex);
void write(const VertexAccessor &vertex) {
// write signatures for the node struct and node data type
encoder.write_struct_header(3);
encoder.write(underlying_cast(pack::Code::Node));
// IMPORTANT: here we write a hardcoded 0 because we don't
// use internal IDs, but need to give something to Bolt
// note that OpenCypher has no id(x) function, so the client
// should not be able to do anything with this value anyway
encoder.write_integer(0); // uID
// write the list of labels
auto labels = vertex.labels();
encoder.write_list_header(labels.size());
for (auto label : labels)
encoder.write_string(vertex.db_accessor().label_name(label));
// write the properties
const PropertyValueStore<GraphDb::Property> &props = vertex.Properties();
encoder.write_map_header(props.size());
props.Accept([this, &vertex](const GraphDb::Property prop,
const PropertyValue &value) {
this->encoder.write(vertex.db_accessor().property_name(prop));
this->write(value);
});
}
/** Serializes the edge accessor into the packstream format
*
@ -40,17 +66,69 @@ class BoltSerializer {
* }
*
*/
void write(const EdgeAccessor &edge);
void write(const EdgeAccessor &edge) {
// write signatures for the edge struct and edge data type
encoder.write_struct_header(5);
encoder.write(underlying_cast(pack::Code::Relationship));
// IMPORTANT: here we write a hardcoded 0 because we don't
// use internal IDs, but need to give something to Bolt
// note that OpenCypher has no id(x) function, so the client
// should not be able to do anything with this value anyway
encoder.write_integer(0);
encoder.write_integer(0);
encoder.write_integer(0);
// write the type of the edge
encoder.write(edge.db_accessor().edge_type_name(edge.edge_type()));
// write the property map
const PropertyValueStore<GraphDb::Property> &props = edge.Properties();
encoder.write_map_header(props.size());
props.Accept(
[this, &edge](GraphDb::Property prop, const PropertyValue &value) {
this->encoder.write(edge.db_accessor().property_name(prop));
this->write(value);
});
}
// TODO document
void write_failure(const std::map<std::string, std::string> &data);
void write_failure(const std::map<std::string, std::string> &data) {
encoder.message_failure();
encoder.write_map_header(data.size());
for (auto const &kv : data) {
write(kv.first);
write(kv.second);
}
}
/**
* Writes a PropertyValue (typically a property value in the edge or vertex).
*
* @param value The value to write.
*/
void write(const PropertyValue &value);
void write(const PropertyValue &value) {
switch (value.type()) {
case PropertyValue::Type::Null:
encoder.write_null();
return;
case PropertyValue::Type::Bool:
encoder.write_bool(value.Value<bool>());
return;
case PropertyValue::Type::String:
encoder.write_string(value.Value<std::string>());
return;
case PropertyValue::Type::Int:
encoder.write_integer(value.Value<int>());
return;
case PropertyValue::Type::Double:
encoder.write_double(value.Value<double>());
return;
case PropertyValue::Type::List:
// Not implemented
assert(false);
}
}
protected:
Stream &encoder;

View File

@ -3,7 +3,6 @@
#include "communication/bolt/v1/serialization/bolt_serializer.hpp"
#include "communication/bolt/v1/transport/chunked_buffer.hpp"
#include "communication/bolt/v1/transport/chunked_encoder.hpp"
#include "communication/bolt/v1/transport/socket_stream.hpp"
#include "logging/default.hpp"
@ -145,13 +144,12 @@ class RecordStream {
Logger logger;
private:
using socket_t = SocketStream<Socket>;
using buffer_t = ChunkedBuffer<socket_t>;
using buffer_t = ChunkedBuffer<Socket>;
using chunked_encoder_t = ChunkedEncoder<buffer_t>;
using bolt_encoder_t = BoltEncoder<chunked_encoder_t>;
using bolt_serializer_t = BoltSerializer<bolt_encoder_t>;
socket_t socket;
Socket &socket;
buffer_t chunked_buffer{socket};
chunked_encoder_t chunked_encoder{chunked_buffer};
bolt_encoder_t bolt_encoder{chunked_encoder};

View File

@ -1,62 +0,0 @@
#pragma once
#include <atomic>
#include <memory>
#include <thread>
#include <vector>
#include "communication/bolt/v1/bolt.hpp"
#include "io/network/server.hpp"
#include "logging/default.hpp"
#include "utils/assert.hpp"
namespace bolt {
template <class Worker>
class Server : public io::Server<Server<Worker>> {
public:
Server(io::Socket&& socket)
: io::Server<Server<Worker>>(std::forward<io::Socket>(socket)),
logger(logging::log->logger("bolt::Server")) {}
void start(size_t n) {
workers.reserve(n);
for (size_t i = 0; i < n; ++i) {
workers.push_back(std::make_shared<Worker>(bolt));
workers.back()->start(alive);
}
while (alive) {
this->wait_and_process_events();
}
}
void shutdown() {
alive.store(false);
for (auto& worker : workers) worker->thread.join();
}
void on_connect() {
debug_assert(idx < workers.size(), "Invalid worker id.");
logger.trace("on connect");
if (UNLIKELY(!workers[idx]->accept(this->socket))) return;
idx = idx == workers.size() - 1 ? 0 : idx + 1;
}
void on_wait_timeout() {}
private:
Bolt bolt;
std::vector<typename Worker::sptr> workers;
std::atomic<bool> alive{true};
int idx{0};
Logger logger;
};
}

View File

@ -1,103 +0,0 @@
#pragma once
#include <atomic>
#include <cstdio>
#include <iomanip>
#include <memory>
#include <sstream>
#include <thread>
#include "communication/bolt/v1/bolt.hpp"
#include "communication/bolt/v1/session.hpp"
#include "io/network/stream_reader.hpp"
#include "logging/default.hpp"
namespace bolt {
template <class Worker>
class Server;
class Worker : public io::StreamReader<Worker, Session> {
friend class bolt::Server<Worker>;
public:
using sptr = std::shared_ptr<Worker>;
Worker(Bolt &bolt) : bolt(bolt) {
logger = logging::log->logger("bolt::Worker");
}
Session &on_connect(io::Socket &&socket) {
logger.trace("Accepting connection on socket {}", socket.id());
return *bolt.get().create_session(std::forward<io::Socket>(socket));
}
void on_error(Session &) {
logger.trace("[on_error] errno = {}", errno);
#ifndef NDEBUG
auto err = io::NetworkError("");
logger.debug("{}", err.what());
#endif
logger.error("Error occured in this session");
}
void on_wait_timeout() {}
Buffer on_alloc(Session &) {
/* logger.trace("[on_alloc] Allocating {}B", sizeof buf); */
return Buffer{buf, sizeof buf};
}
void on_read(Session &session, Buffer &buf) {
logger.trace("[on_read] Received {}B", buf.len);
#ifndef NDEBUG
std::stringstream stream;
for (size_t i = 0; i < buf.len; ++i)
stream << fmt::format("{:02X} ", static_cast<byte>(buf.ptr[i]));
logger.trace("[on_read] {}", stream.str());
#endif
try {
session.execute(reinterpret_cast<const byte *>(buf.ptr), buf.len);
} catch (const std::exception &e) {
logger.error("Error occured while executing statement.");
logger.error("{}", e.what());
// TODO: report to client
}
}
void on_close(Session &session) {
logger.trace("[on_close] Client closed the connection");
session.close();
}
template <class... Args>
void on_exception(Session &session, Args &&... args) {
logger.error("Error occured in this session");
logger.error(args...);
// TODO: Do something about it
}
char buf[65536];
protected:
std::reference_wrapper<Bolt> bolt;
Logger logger;
std::thread thread;
void start(std::atomic<bool> &alive) {
thread = std::thread([&, this]() {
while (alive) wait_and_process_events();
});
}
};
}

View File

@ -1,43 +0,0 @@
#include "communication/bolt/v1/session.hpp"
namespace bolt {
Session::Session(io::Socket &&socket, Bolt &bolt)
: Stream(std::forward<io::Socket>(socket)), bolt(bolt) {
logger = logging::log->logger("Session");
// start with a handshake state
state = bolt.states.handshake.get();
}
bool Session::alive() const { return state != nullptr; }
void Session::execute(const byte *data, size_t len) {
// mark the end of the message
auto end = data + len;
while (true) {
auto size = end - data;
if (LIKELY(connected)) {
logger.debug("Decoding chunk of size {}", size);
auto finished = decoder.decode(data, size);
if (!finished) return;
} else {
logger.debug("Decoding handshake of size {}", size);
decoder.handshake(data, size);
}
state = state->run(*this);
decoder.reset();
}
}
void Session::close() {
logger.debug("Closing session");
bolt.close(this);
}
GraphDbAccessor Session::active_db() { return bolt.dbms.active(); }
}

View File

@ -1,12 +1,18 @@
#pragma once
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
#include "io/network/tcp/stream.hpp"
#include "communication/bolt/communication.hpp"
#include "communication/bolt/v1/bolt.hpp"
#include "dbms/dbms.hpp"
#include "query/engine.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/states/init.hpp"
#include "communication/bolt/v1/states/executor.hpp"
#include "communication/bolt/v1/states/error.hpp"
#include "communication/bolt/v1/serialization/record_stream.hpp"
#include "communication/bolt/v1/states/state.hpp"
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "communication/bolt/v1/transport/bolt_encoder.hpp"
@ -14,28 +20,85 @@
namespace bolt {
class Session : public io::tcp::Stream<io::Socket> {
template<typename Socket>
class Session {
public:
using Decoder = BoltDecoder;
using OutputStream = communication::OutputStream;
using OutputStream = RecordStream<Socket>;
Session(io::Socket &&socket, Bolt &bolt);
Session(Socket &&socket, Dbms &dbms, QueryEngine<OutputStream> &query_engine)
: socket(std::move(socket)),
dbms(dbms), query_engine(query_engine),
logger(logging::log->logger("Session")) {
event.data.ptr = this;
// start with a handshake state
state = HANDSHAKE;
}
bool alive() const;
bool alive() const { return state != NULLSTATE; }
void execute(const byte *data, size_t len);
int id() const { return socket.id(); }
void close();
void execute(const byte *data, size_t len) {
// mark the end of the message
auto end = data + len;
Bolt &bolt;
while (true) {
auto size = end - data;
GraphDbAccessor active_db();
if (LIKELY(connected)) {
logger.debug("Decoding chunk of size {}", size);
auto finished = decoder.decode(data, size);
if (!finished) return;
} else {
logger.debug("Decoding handshake of size {}", size);
decoder.handshake(data, size);
}
switch(state) {
case HANDSHAKE:
logger.debug("Current state: DEBUG");
state = state_handshake_run<Socket>(decoder, this->socket, &connected);
break;
case INIT:
logger.debug("Current state: INIT");
state = state_init_run<Socket>(output_stream, decoder);
break;
case EXECUTOR:
logger.debug("Current state: EXECUTOR");
state = state_executor_run<Socket>(output_stream, decoder, dbms, query_engine);
break;
case ERROR:
logger.debug("Current state: ERROR");
state = state_error_run<Socket>(output_stream, decoder);
break;
case NULLSTATE:
break;
}
decoder.reset();
}
}
void close() {
logger.debug("Closing session");
this->socket.Close();
}
Socket socket;
io::network::Epoll::Event event;
Dbms &dbms;
QueryEngine<OutputStream> &query_engine;
GraphDbAccessor active_db() { return dbms.active(); }
Decoder decoder;
OutputStream output_stream{socket};
bool connected{false};
State *state;
State state;
protected:
Logger logger;

View File

@ -0,0 +1,13 @@
#pragma once
namespace bolt {
enum State {
HANDSHAKE,
INIT,
EXECUTOR,
ERROR,
NULLSTATE
};
}

View File

@ -1,16 +0,0 @@
#include "communication/bolt/v1/states.hpp"
#include "communication/bolt/v1/states/error.hpp"
#include "communication/bolt/v1/states/executor.hpp"
#include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/states/init.hpp"
namespace bolt {
States::States() {
handshake = std::make_unique<Handshake>();
init = std::make_unique<Init>();
executor = std::make_unique<Executor>();
error = std::make_unique<Error>();
}
}

View File

@ -1,17 +0,0 @@
#pragma once
#include "communication/bolt/v1/states/state.hpp"
#include "logging/log.hpp"
namespace bolt {
class States {
public:
States();
State::uptr handshake;
State::uptr init;
State::uptr executor;
State::uptr error;
};
}

View File

@ -1,47 +0,0 @@
#include "communication/bolt/v1/states/error.hpp"
namespace bolt {
Error::Error() : State(logging::log->logger("Error State")) {}
State* Error::run(Session& session) {
logger.trace("Run");
session.decoder.read_byte();
auto message_type = session.decoder.read_byte();
logger.trace("Message type byte is: {:02X}", message_type);
if (message_type == MessageCode::PullAll) {
session.output_stream.write_ignored();
session.output_stream.chunk();
session.output_stream.send();
return this;
} else if (message_type == MessageCode::AckFailure) {
// TODO reset current statement? is it even necessary?
logger.trace("AckFailure received");
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
return session.bolt.states.executor.get();
} else if (message_type == MessageCode::Reset) {
// TODO rollback current transaction
// discard all records waiting to be sent
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
return session.bolt.states.executor.get();
}
// TODO: write this as single call
session.output_stream.write_ignored();
session.output_stream.chunk();
session.output_stream.send();
return this;
}
}

View File

@ -1,14 +1,53 @@
#pragma once
#include "communication/bolt/v1/session.hpp"
#include "communication/bolt/v1/states/state.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "communication/bolt/v1/serialization/record_stream.hpp"
#include "logging/default.hpp"
namespace bolt {
class Error : public State {
public:
Error();
template<typename Socket>
State state_error_run(RecordStream<Socket> &output_stream, BoltDecoder &decoder) {
Logger logger = logging::log->logger("State ERROR");
logger.trace("Run");
State *run(Session &session) override;
};
decoder.read_byte();
auto message_type = decoder.read_byte();
logger.trace("Message type byte is: {:02X}", message_type);
if (message_type == MessageCode::PullAll) {
output_stream.write_ignored();
output_stream.chunk();
output_stream.send();
return ERROR;
} else if (message_type == MessageCode::AckFailure) {
// TODO reset current statement? is it even necessary?
logger.trace("AckFailure received");
output_stream.write_success_empty();
output_stream.chunk();
output_stream.send();
return EXECUTOR;
} else if (message_type == MessageCode::Reset) {
// TODO rollback current transaction
// discard all records waiting to be sent
output_stream.write_success_empty();
output_stream.chunk();
output_stream.send();
return EXECUTOR;
}
// TODO: write this as single call
output_stream.write_ignored();
output_stream.chunk();
output_stream.send();
return ERROR;
}
}

View File

@ -1,117 +0,0 @@
#include "communication/bolt/v1/states/executor.hpp"
#include "communication/bolt/v1/messaging/codes.hpp"
#include "database/graph_db_accessor.hpp"
#include "query/frontend/opencypher/parser.hpp"
#ifdef BARRIER
#include "barrier/barrier.cpp"
#endif
namespace bolt {
Executor::Executor() : State(logging::log->logger("Executor")) {}
State *Executor::run(Session &session) {
// just read one byte that represents the struct type, we can skip the
// information contained in this byte
session.decoder.read_byte();
logger.debug("Run");
auto message_type = session.decoder.read_byte();
if (message_type == MessageCode::Run) {
Query q;
q.statement = session.decoder.read_string();
// TODO: refactor bolt exception handling (Ferencevic)
try {
return this->run(session, q);
} catch (const frontend::opencypher::SyntaxException &e) {
session.output_stream.write_failure(
{{"code", "Memgraph.SyntaxException"}, {"message", "Syntax error"}});
session.output_stream.send();
return session.bolt.states.error.get();
} catch (const backend::cpp::GeneratorException &e) {
session.output_stream.write_failure(
{{"code", "Memgraph.GeneratorException"},
{"message", "Unsupported query"}});
session.output_stream.send();
return session.bolt.states.error.get();
} catch (const QueryEngineException &e) {
session.output_stream.write_failure(
{{"code", "Memgraph.QueryEngineException"},
{"message", "Query engine was unable to execute the query"}});
session.output_stream.send();
return session.bolt.states.error.get();
} catch (const StacktraceException &e) {
session.output_stream.write_failure(
{{"code", "Memgraph.StacktraceException"},
{"message", "Unknow exception"}});
session.output_stream.send();
return session.bolt.states.error.get();
} catch (std::exception &e) {
session.output_stream.write_failure(
{{"code", "Memgraph.Exception"}, {"message", "unknow exception"}});
session.output_stream.send();
return session.bolt.states.error.get();
}
} else if (message_type == MessageCode::PullAll) {
pull_all(session);
} else if (message_type == MessageCode::DiscardAll) {
discard_all(session);
} else if (message_type == MessageCode::Reset) {
// TODO: rollback current transaction
// discard all records waiting to be sent
return this;
} else {
logger.error("Unrecognized message recieved");
logger.debug("Invalid message type 0x{:02X}", message_type);
return session.bolt.states.error.get();
}
return this;
}
State *Executor::run(Session &session, Query &query) {
logger.trace("[Run] '{}'", query.statement);
auto db_accessor = session.active_db();
logger.debug("[ActiveDB] '{}'", db_accessor.name());
auto is_successfully_executed =
query_engine.Run(query.statement, db_accessor, session.output_stream);
if (!is_successfully_executed) {
session.output_stream.write_failure(
{{"code", "Memgraph.QueryExecutionFail"},
{"message",
"Query execution has failed (probably there is no "
"element or there are some problems with concurrent "
"access -> client has to resolve problems with "
"concurrent access)"}});
session.output_stream.send();
return session.bolt.states.error.get();
}
return this;
}
void Executor::pull_all(Session &session) {
logger.trace("[PullAll]");
session.output_stream.send();
}
void Executor::discard_all(Session &session) {
logger.trace("[DiscardAll]");
// TODO: discard state
session.output_stream.write_success();
session.output_stream.chunk();
session.output_stream.send();
}
}

View File

@ -1,38 +1,113 @@
#pragma once
#include "communication/bolt/v1/session.hpp"
#include "communication/bolt/v1/states/state.hpp"
#include <string>
#include "dbms/dbms.hpp"
#include "query/engine.hpp"
#include "communication/bolt/v1/states/executor.hpp"
#include "communication/bolt/v1/messaging/codes.hpp"
#include "communication/bolt/v1/state.hpp"
#include "logging/default.hpp"
namespace bolt {
class Executor : public State {
struct Query {
std::string statement;
};
public:
Executor();
State* run(Session& session) override final;
protected:
/* Execute an incoming query
*
*/
State* run(Session& session, Query& query);
/* Send all remaining results to the client
*
*/
void pull_all(Session& session);
/* Discard all remaining results
*
*/
void discard_all(Session& session);
private:
QueryEngine<communication::OutputStream> query_engine;
struct Query {
std::string statement;
};
template<typename Socket>
State state_executor_run(RecordStream<Socket> &output_stream, BoltDecoder &decoder, Dbms &dmbs, QueryEngine<RecordStream<Socket>> &query_engine){
Logger logger = logging::log->logger("State EXECUTOR");
// just read one byte that represents the struct type, we can skip the
// information contained in this byte
decoder.read_byte();
logger.debug("Run");
auto message_type = decoder.read_byte();
if (message_type == MessageCode::Run) {
Query query;
query.statement = decoder.read_string();
// TODO (mferencevic): refactor bolt exception handling
try {
logger.trace("[Run] '{}'", query.statement);
auto db_accessor = dmbs.active();
logger.debug("[ActiveDB] '{}'", db_accessor.name());
auto is_successfully_executed =
query_engine.Run(query.statement, db_accessor, output_stream);
if (!is_successfully_executed) {
output_stream.write_failure(
{{"code", "Memgraph.QueryExecutionFail"},
{"message",
"Query execution has failed (probably there is no "
"element or there are some problems with concurrent "
"access -> client has to resolve problems with "
"concurrent access)"}});
output_stream.send();
return ERROR;
}
return EXECUTOR;
// TODO: RETURN success MAYBE
} catch (const frontend::opencypher::SyntaxException &e) {
output_stream.write_failure(
{{"code", "Memgraph.SyntaxException"}, {"message", "Syntax error"}});
output_stream.send();
return ERROR;
} catch (const backend::cpp::GeneratorException &e) {
output_stream.write_failure(
{{"code", "Memgraph.GeneratorException"},
{"message", "Unsupported query"}});
output_stream.send();
return ERROR;
} catch (const QueryEngineException &e) {
output_stream.write_failure(
{{"code", "Memgraph.QueryEngineException"},
{"message", "Query engine was unable to execute the query"}});
output_stream.send();
return ERROR;
} catch (const StacktraceException &e) {
output_stream.write_failure(
{{"code", "Memgraph.StacktraceException"},
{"message", "Unknow exception"}});
output_stream.send();
return ERROR;
} catch (std::exception &e) {
output_stream.write_failure(
{{"code", "Memgraph.Exception"}, {"message", "unknow exception"}});
output_stream.send();
return ERROR;
}
} else if (message_type == MessageCode::PullAll) {
logger.trace("[PullAll]");
output_stream.send();
} else if (message_type == MessageCode::DiscardAll) {
logger.trace("[DiscardAll]");
// TODO: discard state
output_stream.write_success();
output_stream.chunk();
output_stream.send();
} else if (message_type == MessageCode::Reset) {
// TODO: rollback current transaction
// discard all records waiting to be sent
return EXECUTOR;
} else {
logger.error("Unrecognized message recieved");
logger.debug("Invalid message type 0x{:02X}", message_type);
return ERROR;
}
return EXECUTOR;
}
}

View File

@ -1,27 +0,0 @@
#include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/session.hpp"
namespace bolt {
static constexpr uint32_t preamble = 0x6060B017;
static constexpr byte protocol[4] = {0x00, 0x00, 0x00, 0x01};
Handshake::Handshake() : State(logging::log->logger("Handshake")) {}
State* Handshake::run(Session& session) {
logger.debug("run");
if (UNLIKELY(session.decoder.read_uint32() != preamble)) return nullptr;
// TODO so far we only support version 1 of the protocol so it doesn't
// make sense to check which version the client prefers
// this will change in the future
session.connected = true;
session.socket.write(protocol, sizeof protocol);
return session.bolt.states.init.get();
}
}

View File

@ -1,12 +1,32 @@
#pragma once
#include "communication/bolt/v1/states/state.hpp"
#include "dbms/dbms.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "logging/default.hpp"
namespace bolt {
class Handshake : public State {
public:
Handshake();
State* run(Session& session) override;
};
static constexpr uint32_t preamble = 0x6060B017;
static constexpr byte protocol[4] = {0x00, 0x00, 0x00, 0x01};
template<typename Socket>
State state_handshake_run(BoltDecoder &decoder, Socket &socket_, bool *connected) {
Logger logger = logging::log->logger("State HANDSHAKE");
logger.debug("run");
if (UNLIKELY(decoder.read_uint32() != preamble)) return NULLSTATE;
// TODO so far we only support version 1 of the protocol so it doesn't
// make sense to check which version the client prefers
// this will change in the future
*connected = true;
// TODO: check for success
socket_.Write(protocol, sizeof protocol);
return INIT;
}
}

View File

@ -1,54 +0,0 @@
#include "communication/bolt/v1/states/init.hpp"
#include "communication/bolt/v1/messaging/codes.hpp"
#include "communication/bolt/v1/session.hpp"
#include "utils/likely.hpp"
namespace bolt {
Init::Init() : MessageParser<Init>(logging::log->logger("Init")) {}
State *Init::parse(Session &session, Message &message) {
logger.debug("bolt::Init.parse()");
auto struct_type = session.decoder.read_byte();
if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) {
logger.debug("{}", struct_type);
logger.debug(
"Expected struct marker of max size 0x{:02} instead of 0x{:02X}",
(unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type);
return nullptr;
}
auto message_type = session.decoder.read_byte();
if (UNLIKELY(message_type != MessageCode::Init)) {
logger.debug("Expected Init (0x01) instead of (0x{:02X})",
(unsigned)message_type);
return nullptr;
}
message.client_name = session.decoder.read_string();
if (struct_type == pack::Code::StructTwo) {
// TODO process authentication tokens
}
return this;
}
State *Init::execute(Session &session, Message &message) {
logger.debug("Client connected '{}'", message.client_name);
session.output_stream.write_success_empty();
session.output_stream.chunk();
session.output_stream.send();
return session.bolt.states.executor.get();
}
}

View File

@ -1,18 +1,55 @@
#pragma once
#include "communication/bolt/v1/states/message_parser.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
#include "communication/bolt/v1/serialization/record_stream.hpp"
#include "communication/bolt/v1/messaging/codes.hpp"
#include "logging/default.hpp"
#include "utils/likely.hpp"
namespace bolt {
class Init : public MessageParser<Init> {
public:
struct Message {
std::string client_name;
};
template<typename Socket>
State state_init_run(RecordStream<Socket> &output_stream, BoltDecoder &decoder) {
Logger logger = logging::log->logger("State INIT");
logger.debug("Parsing message");
Init();
auto struct_type = decoder.read_byte();
if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) {
logger.debug("{}", struct_type);
logger.debug(
"Expected struct marker of max size 0x{:02} instead of 0x{:02X}",
(unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type);
return NULLSTATE;
}
auto message_type = decoder.read_byte();
if (UNLIKELY(message_type != MessageCode::Init)) {
logger.debug("Expected Init (0x01) instead of (0x{:02X})",
(unsigned)message_type);
return NULLSTATE;
}
auto client_name = decoder.read_string();
if (struct_type == pack::Code::StructTwo) {
// TODO process authentication tokens
}
logger.debug("Executing state");
logger.debug("Client connected '{}'", client_name);
output_stream.write_success_empty();
output_stream.chunk();
output_stream.send();
return EXECUTOR;
}
State* parse(Session& session, Message& message);
State* execute(Session& session, Message& message);
};
}

View File

@ -1,30 +0,0 @@
#pragma once
#include "communication/bolt/v1/session.hpp"
#include "communication/bolt/v1/states/state.hpp"
#include "utils/crtp.hpp"
namespace bolt {
template <class Derived>
class MessageParser : public State, public Crtp<Derived> {
public:
MessageParser(Logger &&logger) : logger(std::forward<Logger>(logger)) {}
State *run(Session &session) override final {
typename Derived::Message message;
logger.debug("Parsing message");
auto next = this->derived().parse(session, message);
// return next state if parsing was unsuccessful (i.e. error state)
if (next != &this->derived()) return next;
logger.debug("Executing state");
return this->derived().execute(session, message);
}
protected:
Logger logger;
};
}

View File

@ -1,27 +0,0 @@
#pragma once
#include <cstdint>
#include <cstdlib>
#include <memory>
#include "logging/default.hpp"
namespace bolt {
class Session;
class State {
public:
using uptr = std::unique_ptr<State>;
State() = default;
State(Logger logger) : logger(logger) {}
virtual ~State() = default;
virtual State* run(Session& session) = 0;
protected:
Logger logger;
};
}

View File

@ -33,7 +33,8 @@ class ChunkedBuffer {
}
void flush() {
stream.get().write(&buffer.front(), size);
// TODO: check for success
stream.get().Write(&buffer.front(), size);
logger.trace("Flushed {} bytes", size);

View File

@ -1,33 +0,0 @@
#pragma once
#include <cstdint>
#include <cstdio>
#include <vector>
#include "communication/bolt/v1/transport/stream_error.hpp"
#include "io/network/socket.hpp"
namespace bolt {
template <typename Stream>
class SocketStream {
public:
using byte = uint8_t;
SocketStream(Stream& socket) : socket(socket) {}
void write(const byte* data, size_t n) {
while (n > 0) {
auto written = socket.get().write(data, n);
if (UNLIKELY(written == -1)) throw StreamError("Can't write to stream");
n -= written;
data += written;
}
}
private:
std::reference_wrapper<Stream> socket;
};
}

View File

@ -0,0 +1,113 @@
#pragma once
#include <atomic>
#include <memory>
#include <thread>
#include <vector>
#include "dbms/dbms.hpp"
#include "query/engine.hpp"
#include "communication/worker.hpp"
#include "io/network/event_listener.hpp"
#include "logging/default.hpp"
#include "utils/assert.hpp"
namespace communication {
/**
* Communication server.
* Listens for incomming connections on the server port and assings them in a
* round-robin manner to it's workers.
*
* Current Server achitecture:
* incomming connection -> server -> worker -> session
*
* @tparam Session the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
* @tparam OutputStream the server has to get the output stream as a template
parameter because the output stream is templated
* @tparam Socket the input/output socket that should be used
*/
template <typename Session, typename OutputStream, typename Socket>
class Server
: public io::network::EventListener<Server<Session, OutputStream, Socket>> {
using Event = io::network::Epoll::Event;
public:
Server(Socket &&socket, Dbms &dbms, QueryEngine<OutputStream> &query_engine)
: socket_(std::forward<Socket>(socket)),
dbms_(dbms),
query_engine_(query_engine),
logger_(logging::log->logger("communication::Server")) {
event_.data.fd = socket_;
// TODO: EPOLLET is hard to use -> figure out how should EPOLLET be used
// event.events = EPOLLIN | EPOLLET;
event_.events = EPOLLIN;
this->listener_.Add(socket_, &event_);
}
void Start(size_t n) {
workers_.reserve(n);
for (size_t i = 0; i < n; ++i) {
workers_.push_back(
std::make_shared<Worker<Session, OutputStream, Socket>>(
dbms_, query_engine_));
workers_.back()->Start(alive_);
}
while (alive_) {
this->WaitAndProcessEvents();
}
}
void Shutdown() {
alive_.store(false);
for (auto &worker : workers_) worker->thread_.join();
}
void OnConnect() {
debug_assert(idx_ < workers_.size(), "Invalid worker id.");
logger_.trace("on connect");
if (UNLIKELY(!workers_[idx_]->Accept(socket_))) return;
idx_ = idx_ == (int)workers_.size() - 1 ? 0 : idx_ + 1;
}
void OnWaitTimeout() {}
void OnDataEvent(Event &event) {
if (UNLIKELY(socket_ != event.data.fd)) return;
this->derived().OnConnect();
}
template <class... Args>
void OnExceptionEvent(Event &event, Args &&... args) {
// TODO: Do something about it
logger_.warn("epoll exception");
}
void OnCloseEvent(Event &event) { close(event.data.fd); }
void OnErrorEvent(Event &event) { close(event.data.fd); }
private:
std::vector<typename Worker<Session, OutputStream, Socket>::sptr> workers_;
std::atomic<bool> alive_{true};
int idx_{0};
Dbms &dbms_;
QueryEngine<OutputStream> &query_engine_;
Event event_;
Socket socket_;
Logger logger_;
};
}

View File

@ -0,0 +1,109 @@
#pragma once
#include <atomic>
#include <cstdio>
#include <iomanip>
#include <memory>
#include <sstream>
#include <thread>
#include "dbms/dbms.hpp"
#include "query/engine.hpp"
#include "communication/bolt/v1/session.hpp"
#include "io/network/network_error.hpp"
#include "io/network/stream_reader.hpp"
#include "logging/default.hpp"
namespace communication {
/**
* Communication worker.
* Listens for incomming data on connections and accepts new connections.
* Also, executes sessions on incomming data.
*
* @tparam Session the worker can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
* @tparam OutputStream the worker has to get the output stream as a template
parameter because the output stream is templated
* @tparam Socket the input/output socket that should be used
*/
template <typename Session, typename OutputStream, typename Socket>
class Worker
: public io::network::StreamReader<Worker<Session, OutputStream, Socket>,
Session> {
using StreamBuffer = io::network::StreamBuffer;
public:
using sptr = std::shared_ptr<Worker<Session, OutputStream, Socket>>;
Worker(Dbms &dbms, QueryEngine<OutputStream> &query_engine)
: dbms_(dbms),
query_engine_(query_engine),
logger_(logging::log->logger("communication::Worker")) {}
Session &OnConnect(Socket &&socket) {
logger_.trace("Accepting connection on socket {}", socket.id());
// TODO fix session lifecycle handling
// dangling pointers are not cool :)
// TODO attach currently active Db
return *(new Session(std::forward<Socket>(socket), dbms_, query_engine_));
}
void OnError(Session &session) {
logger_.error("Error occured in this session");
OnClose(session);
}
void OnWaitTimeout() {}
StreamBuffer OnAlloc(Session &) {
/* logger.trace("[on_alloc] Allocating {}B", sizeof buf); */
return StreamBuffer{buf_, sizeof buf_};
}
void OnRead(Session &session, StreamBuffer &buf) {
logger_.trace("[on_read] Received {}B", buf.len);
try {
session.execute(reinterpret_cast<const byte *>(buf.ptr), buf.len);
} catch (const std::exception &e) {
logger_.error("Error occured while executing statement.");
logger_.error("{}", e.what());
// TODO: report to client
}
}
void OnClose(Session &session) {
logger_.trace("Client closed the connection");
// TODO: remove socket from epoll object
session.close();
delete &session;
}
template <class... Args>
void OnException(Session &session, Args &&... args) {
logger_.error("Error occured in this session");
logger_.error(args...);
// TODO: Do something about it
}
char buf_[65536];
std::thread thread_;
void Start(std::atomic<bool> &alive) {
thread_ = std::thread([&, this]() {
while (alive) this->WaitAndProcessEvents();
});
}
private:
Dbms &dbms_;
QueryEngine<OutputStream> &query_engine_;
Logger logger_;
};
}

View File

@ -0,0 +1,32 @@
#include <netdb.h>
#include <cstring>
#include "io/network/addrinfo.hpp"
#include "io/network/network_error.hpp"
#include "utils/underlying_cast.hpp"
namespace io::network {
AddrInfo::AddrInfo(struct addrinfo* info) : info(info) {}
AddrInfo::~AddrInfo() { freeaddrinfo(info); }
AddrInfo AddrInfo::Get(const char* addr, const char* port) {
struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_UNSPEC; // IPv4 and IPv6
hints.ai_socktype = SOCK_STREAM; // TCP socket
hints.ai_flags = AI_PASSIVE;
struct addrinfo* result;
auto status = getaddrinfo(addr, port, &hints, &result);
if (status != 0) throw NetworkError(gai_strerror(status));
return AddrInfo(result);
}
AddrInfo::operator struct addrinfo*() { return info; }
}

View File

@ -1,36 +1,20 @@
#pragma once
#include <netdb.h>
#include <cstring>
#include "io/network/network_error.hpp"
#include "utils/underlying_cast.hpp"
namespace io {
namespace io::network {
/**
* Wrapper class for getaddrinfo.
* see: man 3 getaddrinfo
*/
class AddrInfo {
AddrInfo(struct addrinfo* info) : info(info) {}
AddrInfo(struct addrinfo* info);
public:
~AddrInfo() { freeaddrinfo(info); }
~AddrInfo();
static AddrInfo get(const char* addr, const char* port) {
struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo));
static AddrInfo Get(const char* addr, const char* port);
hints.ai_family = AF_UNSPEC; // IPv4 and IPv6
hints.ai_socktype = SOCK_STREAM; // TCP socket
hints.ai_flags = AI_PASSIVE;
struct addrinfo* result;
auto status = getaddrinfo(addr, port, &hints, &result);
if (status != 0) throw NetworkError(gai_strerror(status));
return AddrInfo(result);
}
operator struct addrinfo*() { return info; }
operator struct addrinfo*();
private:
struct addrinfo* info;

View File

@ -7,40 +7,45 @@
#include "logging/default.hpp"
#include "utils/likely.hpp"
namespace io {
namespace io::network {
class EpollError : StacktraceException {
public:
using StacktraceException::StacktraceException;
};
/**
* Wrapper class for epoll.
* Creates an object that listens on file descriptor status changes.
* see: man 4 epoll
*/
class Epoll {
public:
using Event = struct epoll_event;
Epoll(int flags) : logger(logging::log->logger("io::Epoll")) {
epoll_fd = epoll_create1(flags);
Epoll(int flags) : logger_(logging::log->logger("io::Epoll")) {
epoll_fd_ = epoll_create1(flags);
if (UNLIKELY(epoll_fd == -1))
if (UNLIKELY(epoll_fd_ == -1))
throw EpollError("Can't create epoll file descriptor");
}
template <class Stream>
void add(Stream& stream, Event* event) {
auto status = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, stream, event);
void Add(Stream& stream, Event* event) {
auto status = epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, stream, event);
if (UNLIKELY(status))
throw EpollError("Can't add an event to epoll listener.");
}
int wait(Event* events, int max_events, int timeout) {
return epoll_wait(epoll_fd, events, max_events, timeout);
int Wait(Event* events, int max_events, int timeout) {
return epoll_wait(epoll_fd_, events, max_events, timeout);
}
int id() const { return epoll_fd; }
int id() const { return epoll_fd_; }
private:
int epoll_fd;
Logger logger;
int epoll_fd_;
Logger logger_;
};
}

View File

@ -4,17 +4,21 @@
#include "logging/default.hpp"
#include "utils/crtp.hpp"
namespace io {
namespace io::network {
/**
* This class listens to events on an epoll object and calls
* callback functions to process them.
*/
template <class Derived, size_t max_events = 64, int wait_timeout = -1>
class EventListener : public Crtp<Derived> {
public:
using Crtp<Derived>::derived;
EventListener(uint32_t flags = 0)
: listener(flags), logger(logging::log->logger("io::EventListener")) {}
: listener_(flags), logger_(logging::log->logger("io::EventListener")) {}
void wait_and_process_events() {
void WaitAndProcessEvents() {
// TODO hardcoded a wait timeout because of thread joining
// when you shutdown the server. This should be wait_timeout of the
// template parameter and should almost never change from that.
@ -25,36 +29,36 @@ class EventListener : public Crtp<Derived> {
// max_events and stores them in the events array. it waits for
// wait_timeout milliseconds. if wait_timeout is achieved, returns 0
auto n = listener.wait(events, max_events, 200);
auto n = listener_.Wait(events_, max_events, 200);
#ifndef NDEBUG
#ifndef LOG_NO_TRACE
if (n > 0) logger.trace("number of events: {}", n);
if (n > 0) logger_.trace("number of events: {}", n);
#endif
#endif
// go through all events and process them in order
for (int i = 0; i < n; ++i) {
auto &event = events[i];
auto &event = events_[i];
try {
// hangup event
if (UNLIKELY(event.events & EPOLLRDHUP)) {
this->derived().on_close_event(event);
this->derived().OnCloseEvent(event);
continue;
}
// there was an error on the server side
if (UNLIKELY(!(event.events & EPOLLIN) ||
event.events & (EPOLLHUP | EPOLLERR))) {
this->derived().on_error_event(event);
this->derived().OnErrorEvent(event);
continue;
}
// we have some data waiting to be read
this->derived().on_data_event(event);
this->derived().OnDataEvent(event);
} catch (const std::exception &e) {
this->derived().on_exception_event(
this->derived().OnExceptionEvent(
event, "Error occured while processing event \n{}", e.what());
}
}
@ -69,14 +73,14 @@ class EventListener : public Crtp<Derived> {
// is -1 there will never be any timeouts so client should provide
// an empty function. in that case the conditional above and the
// function call will be optimized out by the compiler
this->derived().on_wait_timeout();
this->derived().OnWaitTimeout();
}
protected:
Epoll listener;
Epoll::Event events[max_events];
Epoll listener_;
Epoll::Event events_[max_events];
private:
Logger logger;
Logger logger_;
};
}

View File

@ -1,29 +0,0 @@
#pragma once
#include "io/network/stream_reader.hpp"
namespace io {
template <class Derived, class Stream>
class Client : public StreamReader<Derived, Stream> {
public:
bool connect(const std::string& host, const std::string& port) {
return connect(host.c_str(), port.c_str());
}
bool connect(const char* host, const char* port) {
auto socket = io::Socket::connect(host, port);
if (!socket.is_open()) return false;
socket.set_non_blocking();
auto& stream = this->derived().on_connect(std::move(socket));
stream.event.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
this->add(stream);
return true;
}
};
}

View File

@ -0,0 +1,61 @@
#include "io/network/network_endpoint.hpp"
#include "io/network/network_error.hpp"
#include <arpa/inet.h>
#include <netdb.h>
namespace io::network {
NetworkEndpoint::NetworkEndpoint() : port_(0), family_(0) {
memset(address_, 0, sizeof address_);
memset(port_str_, 0, sizeof port_str_);
}
NetworkEndpoint::NetworkEndpoint(const char* addr, const char* port) {
if (addr == nullptr) throw NetworkEndpointException("Address can't be null!");
if (port == nullptr) throw NetworkEndpointException("Port can't be null!");
// strncpy isn't used because it does not guarantee an ending null terminator
snprintf(address_, sizeof address_, "%s", addr);
snprintf(port_str_, sizeof port_str_, "%s", port);
is_address_valid();
int ret = sscanf(port, "%hu", &port_);
if (ret != 1) throw NetworkEndpointException("Port isn't valid!");
}
NetworkEndpoint::NetworkEndpoint(const std::string& addr,
const std::string& port)
: NetworkEndpoint(addr.c_str(), port.c_str()) {}
NetworkEndpoint::NetworkEndpoint(const char* addr, unsigned short port) {
if (addr == nullptr) throw NetworkEndpointException("Address can't be null!");
snprintf(address_, sizeof address_, "%s", addr);
snprintf(port_str_, sizeof port_str_, "%hu", port);
port_ = port;
is_address_valid();
}
void NetworkEndpoint::is_address_valid() {
in_addr addr4;
in6_addr addr6;
int ret = inet_pton(AF_INET, address_, &addr4);
if (ret != 1) {
ret = inet_pton(AF_INET6, address_, &addr6);
if (ret != 1)
throw NetworkEndpointException(
"Address isn't a valid IPv4 or IPv6 address!");
else
family_ = 6;
} else
family_ = 4;
}
const char* NetworkEndpoint::address() { return address_; }
const char* NetworkEndpoint::port_str() { return port_str_; }
unsigned short NetworkEndpoint::port() { return port_; }
unsigned char NetworkEndpoint::family() { return family_; }
}

View File

@ -0,0 +1,40 @@
#pragma once
#include "utils/exceptions/basic_exception.hpp"
#include <netinet/in.h>
#include <string>
namespace io::network {
class NetworkEndpointException : public BasicException {
public:
using BasicException::BasicException;
};
/**
* This class represents a network endpoint that is used in Socket.
* It is used when connecting to an address and to get the current
* connection address.
*/
class NetworkEndpoint {
public:
NetworkEndpoint();
NetworkEndpoint(const char* addr, const char* port);
NetworkEndpoint(const char* addr, unsigned short port);
NetworkEndpoint(const std::string& addr, const std::string& port);
const char* address();
const char* port_str();
unsigned short port();
unsigned char family();
private:
void is_address_valid();
char address_[INET6_ADDRSTRLEN];
char port_str_[6];
unsigned short port_;
unsigned char family_;
};
}

View File

@ -4,7 +4,7 @@
#include "utils/exceptions/stacktrace_exception.hpp"
namespace io {
namespace io::network {
class NetworkError : public StacktraceException {
public:

View File

@ -1,63 +0,0 @@
#pragma once
#include "io/network/socket.hpp"
#include "tls.hpp"
#include "tls_error.hpp"
#include "utils/types/byte.hpp"
#include <iostream>
namespace io {
class SecureSocket {
public:
SecureSocket(Socket&& socket, const Tls::Context& tls)
: socket(std::forward<Socket>(socket)) {
ssl = SSL_new(tls);
SSL_set_fd(ssl, this->socket);
SSL_set_accept_state(ssl);
if (SSL_accept(ssl) <= 0) ERR_print_errors_fp(stderr);
}
SecureSocket(SecureSocket&& other) {
*this = std::forward<SecureSocket>(other);
}
SecureSocket& operator=(SecureSocket&& other) {
socket = std::move(other.socket);
ssl = other.ssl;
other.ssl = nullptr;
return *this;
}
~SecureSocket() {
if (ssl == nullptr) return;
std::cout << "DELETING SSL" << std::endl;
SSL_free(ssl);
}
int error(int status) { return SSL_get_error(ssl, status); }
int write(const std::string& str) { return write(str.c_str(), str.size()); }
int write(const byte* data, size_t len) { return SSL_write(ssl, data, len); }
int write(const char* data, size_t len) { return SSL_write(ssl, data, len); }
int read(char* buffer, size_t len) { return SSL_read(ssl, buffer, len); }
operator int() { return socket; }
operator Socket&() { return socket; }
private:
Socket socket;
SSL* ssl{nullptr};
};
}

View File

@ -1,51 +0,0 @@
#pragma once
#include <openssl/ssl.h>
#include "io/network/stream_reader.hpp"
#include "logging/default.hpp"
namespace io {
using namespace memory::literals;
template <class Derived, class Stream>
class SecureStreamReader : public StreamReader<Derived, Stream> {
public:
struct Buffer {
char* ptr;
size_t len;
};
SecureStreamReader(uint32_t flags = 0)
: StreamReader<Derived, Stream>(flags) {}
void on_data(Stream& stream) {
while (true) {
// allocate the buffer to fill the data
auto buf = this->derived().on_alloc(stream);
// read from the buffer at most buf.len bytes
auto len = stream.socket.read(buf.ptr, buf.len);
if (LIKELY(len > 0)) {
buf.len = len;
return this->derived().on_read(stream, buf);
}
auto err = stream.socket.error(len);
// the socket is not ready for reading yet
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE ||
err == SSL_ERROR_WANT_X509_LOOKUP) {
return;
}
// the socket notified a close event
if (err == SSL_ERROR_ZERO_RETURN) return stream.close();
// some other error occurred, check errno
return this->derived().on_error(stream);
}
}
};
}

View File

@ -1,43 +0,0 @@
#pragma once
#include "io/network/stream_reader.hpp"
namespace io {
template <class Derived>
class Server : public EventListener<Derived> {
public:
Server(Socket &&socket)
: socket(std::forward<Socket>(socket)),
logger(logging::log->logger("io::Server")) {
event.data.fd = this->socket;
// TODO: EPOLLET is hard to use -> figure out how should EPOLLET be used
// event.events = EPOLLIN | EPOLLET;
event.events = EPOLLIN;
this->listener.add(this->socket, &event);
}
void on_close_event(Epoll::Event &event) { ::close(event.data.fd); }
void on_error_event(Epoll::Event &event) { ::close(event.data.fd); }
void on_data_event(Epoll::Event &event) {
if (UNLIKELY(socket != event.data.fd)) return;
this->derived().on_connect();
}
template <class... Args>
void on_exception_event(Epoll::Event &event, Args &&... args) {
// TODO: Do something about it
logger.warn("epoll exception");
}
protected:
Epoll::Event event;
Socket socket;
Logger logger;
};
}

194
src/io/network/socket.cpp Normal file
View File

@ -0,0 +1,194 @@
#include "io/network/socket.hpp"
#include <cassert>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <stdexcept>
#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/types.h>
#include <unistd.h>
#include "io/network/addrinfo.hpp"
#include "utils/likely.hpp"
namespace io::network {
Socket::Socket() : socket_(-1) {}
Socket::Socket(int sock, NetworkEndpoint& endpoint)
: socket_(sock), endpoint_(endpoint) {}
Socket::Socket(const Socket& s) : socket_(s.id()) {}
Socket::Socket(Socket&& other) { *this = std::forward<Socket>(other); }
Socket& Socket::operator=(Socket&& other) {
socket_ = other.socket_;
endpoint_ = other.endpoint_;
other.socket_ = -1;
return *this;
}
Socket::~Socket() {
if (socket_ == -1) return;
close(socket_);
}
void Socket::Close() {
if (socket_ == -1) return;
close(socket_);
socket_ = -1;
}
bool Socket::IsOpen() { return socket_ != -1; }
bool Socket::Connect(NetworkEndpoint& endpoint) {
if (UNLIKELY(socket_ != -1)) return false;
auto info = AddrInfo::Get(endpoint.address(), endpoint.port_str());
for (struct addrinfo* it = info; it != nullptr; it = it->ai_next) {
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
if (sfd == -1) continue;
if (connect(sfd, it->ai_addr, it->ai_addrlen) == 0) {
socket_ = sfd;
endpoint_ = endpoint;
break;
}
}
if (socket_ == -1) return false;
return true;
}
bool Socket::Bind(NetworkEndpoint& endpoint) {
if (UNLIKELY(socket_ != -1)) return false;
auto info = AddrInfo::Get(endpoint.address(), endpoint.port_str());
for (struct addrinfo* it = info; it != nullptr; it = it->ai_next) {
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
if (sfd == -1) continue;
int on = 1;
if (setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) != 0)
continue;
if (bind(sfd, it->ai_addr, it->ai_addrlen) == 0) {
socket_ = sfd;
endpoint_ = endpoint;
break;
}
}
if (socket_ == -1) return false;
return true;
}
bool Socket::SetNonBlocking() {
int flags = fcntl(socket_, F_GETFL, 0);
if (UNLIKELY(flags == -1)) return false;
flags |= O_NONBLOCK;
int ret = fcntl(socket_, F_SETFL, flags);
if (UNLIKELY(ret == -1)) return false;
return true;
}
bool Socket::SetKeepAlive() {
int optval = 1;
socklen_t optlen = sizeof(optval);
if (setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen) < 0)
return false;
optval = 20; // wait 120s before seding keep-alive packets
if (setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void*)&optval, optlen) < 0)
return false;
optval = 4; // 4 keep-alive packets must fail to close
if (setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void*)&optval, optlen) < 0)
return false;
optval = 15; // send keep-alive packets every 15s
if (setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void*)&optval, optlen) < 0)
return false;
return true;
}
bool Socket::Listen(int backlog) { return listen(socket_, backlog) == 0; }
bool Socket::Accept(Socket* s) {
sockaddr_storage addr;
socklen_t addr_size = sizeof addr;
char addr_decoded[INET6_ADDRSTRLEN];
void* addr_src;
unsigned short port;
unsigned char family;
int sfd = accept(socket_, (struct sockaddr*)&addr, &addr_size);
if (UNLIKELY(sfd == -1)) return false;
if (addr.ss_family == AF_INET) {
addr_src = (void*)&(((sockaddr_in*)&addr)->sin_addr);
port = ntohs(((sockaddr_in*)&addr)->sin_port);
family = 4;
} else {
addr_src = (void*)&(((sockaddr_in6*)&addr)->sin6_addr);
port = ntohs(((sockaddr_in6*)&addr)->sin6_port);
family = 6;
}
inet_ntop(addr.ss_family, addr_src, addr_decoded, INET6_ADDRSTRLEN);
NetworkEndpoint endpoint;
try {
endpoint = NetworkEndpoint(addr_decoded, port);
} catch (NetworkEndpointException& e) {
return false;
}
*s = Socket(sfd, endpoint);
return true;
}
Socket::operator int() { return socket_; }
int Socket::id() const { return socket_; }
NetworkEndpoint& Socket::endpoint() { return endpoint_; }
bool Socket::Write(const std::string& str) {
return Write(str.c_str(), str.size());
}
bool Socket::Write(const char* data, size_t len) {
return Write(reinterpret_cast<const uint8_t*>(data), len);
}
bool Socket::Write(const uint8_t* data, size_t len) {
while (len > 0) {
auto written = send(socket_, data, len, 0);
if (UNLIKELY(written == -1)) return false;
len -= written;
data += written;
}
return true;
}
int Socket::Read(void* buffer, size_t len) {
return read(socket_, buffer, len);
}
}

View File

@ -1,157 +1,150 @@
#pragma once
#include <cstdio>
#include <cstring>
#include <stdexcept>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include "io/network/addrinfo.hpp"
#include "utils/likely.hpp"
#include "logging/default.hpp"
#include "io/network/network_endpoint.hpp"
#include <iostream>
namespace io {
namespace io::network {
/**
* This class creates a network socket.
* It is used to connect/bind/listen on a NetworkEndpoint (address + port).
* It has wrappers for setting network socket flags and wrappers for
* reading/writing data from/to the socket.
*/
class Socket {
protected:
Socket(int family, int socket_type, int protocol) {
socket = ::socket(family, socket_type, protocol);
}
public:
using byte = uint8_t;
Socket();
Socket(const Socket& s);
Socket(Socket&& other);
Socket& operator=(Socket&& other);
~Socket();
Socket(int socket = -1) : socket(socket) {}
/**
* Closes the socket if it is open.
*/
void Close();
Socket(const Socket&) = delete;
/**
* Checks whether the socket is open.
*
* @return socket open status:
* true if the socket is open
* false if the socket is closed
*/
bool IsOpen();
Socket(Socket&& other) { *this = std::forward<Socket>(other); }
/**
* Connects the socket to the specified endpoint.
*
* @param endpoint NetworkEndpoint to which to connect to
*
* @return connection success status:
* true if the connect succeeded
* false if the connect failed
*/
bool Connect(NetworkEndpoint& endpoint);
~Socket() {
if (socket == -1) return;
/**
* Binds the socket to the specified endpoint.
*
* @param endpoint NetworkEndpoint to which to bind to
*
* @return bind success status:
* true if the bind succeeded
* false if the bind failed
*/
bool Bind(NetworkEndpoint& endpoint);
#ifndef NDEBUG
logging::debug("DELETING SOCKET");
#endif
/**
* Start listening on the bound socket.
*
* @param backlog maximum number of pending connections in the connection queue
*
* @return listen success status:
* true if the listen succeeded
* false if the listen failed
*/
bool Listen(int backlog);
::close(socket);
}
/**
* Accepts a new connection.
* This function accepts a new connection on a listening socket.
*
* @param s Socket object that will be instantiated with the new connection
*
* @return accept success status:
* true if a new connection was accepted and the socket 's' was instantiated
* false if a new connection accept failed
*/
bool Accept(Socket* s);
void close() {
::close(socket);
socket = -1;
}
/**
* Sets the socket to non-blocking.
*
* @return set non-blocking success status:
* true if the socket was successfully set to non-blocking
* false if the socket was not set to non-blocking
*/
bool SetNonBlocking();
Socket& operator=(Socket&& other) {
this->socket = other.socket;
other.socket = -1;
return *this;
}
/**
* Enables TCP keep-alive on the socket.
*
* @return enable keep-alive success status:
* true if keep-alive was successfully enabled on the socket
* false if keep-alive was not enabled
*/
bool SetKeepAlive();
bool is_open() { return socket != -1; }
// TODO: this will be removed
operator int();
static Socket connect(const std::string& addr, const std::string& port) {
return connect(addr.c_str(), port.c_str());
}
/**
* Returns the socket ID.
* The socket ID is its unix file descriptor number.
*/
int id() const;
static Socket connect(const char* addr, const char* port) {
auto info = AddrInfo::get(addr, port);
/**
* Returns the currently active endpoint of the socket.
*/
NetworkEndpoint& endpoint();
for (struct addrinfo* it = info; it != nullptr; it = it->ai_next) {
auto s = Socket(it->ai_family, it->ai_socktype, it->ai_protocol);
/**
* Write data to the socket.
* Theese functions guarantee that all data will be written.
*
* @param str std::string to write to the socket
* @param data char* or uint8_t* to data that should be written
* @param len length of char* or uint8_t* data
*
* @return write success status:
* true if write succeeded
* false if write failed
*/
bool Write(const std::string& str);
bool Write(const char* data, size_t len);
bool Write(const uint8_t* data, size_t len);
if (!s.is_open()) continue;
/**
* Read data from the socket.
* This function is a direct wrapper for the read function.
*
* @param buffer pointer to the read buffer
* @param len length of the read buffer
*
* @return read success status:
* > 0 if data was read, means number of read bytes
* == 0 if the client closed the connection
* < 0 if an error has occurred
*/
int Read(void* buffer, size_t len);
if (::connect(s, it->ai_addr, it->ai_addrlen) == 0) return s;
}
private:
Socket(int sock, NetworkEndpoint& endpoint);
throw NetworkError("Unable to connect to socket");
}
static Socket bind(const std::string& addr, const std::string& port) {
return bind(addr.c_str(), port.c_str());
}
static Socket bind(const char* addr, const char* port) {
auto info = AddrInfo::get(addr, port);
for (struct addrinfo* it = info; it != nullptr; it = it->ai_next) {
auto s = Socket(it->ai_family, it->ai_socktype, it->ai_protocol);
if (!s.is_open()) continue;
int on = 1;
if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) continue;
if (::bind(s, it->ai_addr, it->ai_addrlen) == 0) return s;
}
throw NetworkError("Unable to bind to socket");
}
void set_non_blocking() {
auto flags = fcntl(socket, F_GETFL, 0);
if (UNLIKELY(flags == -1))
throw NetworkError("Cannot read flags from socket");
flags |= O_NONBLOCK;
auto status = fcntl(socket, F_SETFL, flags);
if (UNLIKELY(status == -1))
throw NetworkError("Cannot set NON_BLOCK flag to socket");
}
void listen(int backlog) {
auto status = ::listen(socket, backlog);
if (UNLIKELY(status == -1)) throw NetworkError("Cannot listen on socket");
}
Socket accept(struct sockaddr* addr, socklen_t* len) {
return Socket(::accept(socket, addr, len));
}
operator int() { return socket; }
int id() const { return socket; }
int write(const std::string& str) { return write(str.c_str(), str.size()); }
int write(const char* data, size_t len) {
return write(reinterpret_cast<const byte*>(data), len);
}
int write(const byte* data, size_t len) {
// TODO: use logger
#ifndef NDEBUG
std::stringstream stream;
for (size_t i = 0; i < len; ++i)
stream << fmt::format("{:02X} ", static_cast<byte>(data[i]));
auto str = stream.str();
logging::debug("[Write {}B] {}", len, str);
#endif
return ::write(socket, data, len);
}
int read(void* buffer, size_t len) { return ::read(socket, buffer, len); }
protected:
Logger logger;
int socket;
int socket_;
NetworkEndpoint endpoint_;
};
}

View File

@ -2,7 +2,7 @@
#include "io/network/event_listener.hpp"
namespace io {
namespace io::network {
template <class Derived, class Stream, size_t max_events = 64,
int wait_timeout = -1>
@ -10,26 +10,26 @@ class StreamListener : public EventListener<Derived, max_events, wait_timeout> {
public:
using EventListener<Derived, max_events, wait_timeout>::EventListener;
void add(Stream &stream) {
void Add(Stream &stream) {
// add the stream to the event listener
this->listener.add(stream.socket, &stream.event);
this->listener_.Add(stream.socket, &stream.event);
}
void on_close_event(Epoll::Event &event) {
this->derived().on_close(to_stream(event));
void OnCloseEvent(Epoll::Event &event) {
this->derived().OnClose(to_stream(event));
}
void on_error_event(Epoll::Event &event) {
this->derived().on_error(to_stream(event));
void OnErrorEvent(Epoll::Event &event) {
this->derived().OnError(to_stream(event));
}
void on_data_event(Epoll::Event &event) {
this->derived().on_data(to_stream(event));
void OnDataEvent(Epoll::Event &event) {
this->derived().OnData(to_stream(event));
}
template <class... Args>
void on_exception_event(Epoll::Event &event, Args &&... args) {
this->derived().on_exception(to_stream(event), args...);
void OnExceptionEvent(Epoll::Event &event, Args &&... args) {
this->derived().OnException(to_stream(event), args...);
}
private:

View File

@ -3,82 +3,91 @@
#include "io/network/stream_listener.hpp"
#include "memory/literals.hpp"
namespace io {
namespace io::network {
using namespace memory::literals;
struct StreamBuffer {
char* ptr;
size_t len;
};
/**
* This class is used to get data from a socket that has been notified
* with a data available event.
*/
template <class Derived, class Stream>
class StreamReader : public StreamListener<Derived, Stream> {
public:
struct Buffer {
char* ptr;
size_t len;
};
StreamReader(uint32_t flags = 0)
: StreamListener<Derived, Stream>(flags),
logger(logging::log->logger("io::StreamReader")) {}
logger_(logging::log->logger("io::StreamReader")) {}
bool accept(Socket& socket) {
logger.trace("accept");
bool Accept(Socket& socket) {
logger_.trace("Accept");
// accept a connection from a socket
auto s = socket.accept(nullptr, nullptr);
Socket s;
if (!socket.Accept(&s)) return false;
if (!s.is_open()) return false;
logger_.trace(
"Accepted a connection: scoket {}, address '{}', family {}, port {}",
s.id(), s.endpoint().address(), s.endpoint().family(),
s.endpoint().port());
// make the recieved socket non blocking
s.set_non_blocking();
if (!s.SetKeepAlive()) return false;
auto& stream = this->derived().on_connect(std::move(s));
auto& stream = this->derived().OnConnect(std::move(s));
// we want to listen to an incoming event which is edge triggered and
// we also want to listen on the hangup event
stream.event.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
stream.event.events = EPOLLIN | EPOLLRDHUP;
// add the connection to the event listener
this->add(stream);
this->Add(stream);
return true;
}
void on_data(Stream& stream) {
logger.trace("on data");
void OnData(Stream& stream) {
logger_.trace("On data");
while (true) {
if (UNLIKELY(!stream.alive())) {
stream.close();
logger_.trace("Calling OnClose because the stream isn't alive!");
this->derived().OnClose(stream);
break;
}
// allocate the buffer to fill the data
auto buf = this->derived().on_alloc(stream);
auto buf = this->derived().OnAlloc(stream);
// read from the buffer at most buf.len bytes
buf.len = stream.socket.read(buf.ptr, buf.len);
buf.len = stream.socket.Read(buf.ptr, buf.len);
// check for read errors
if (buf.len == -1) {
// this means we have read all available data
if (LIKELY(errno == EAGAIN)) {
if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) {
break;
}
// some other error occurred, check errno
this->derived().on_error(stream);
this->derived().OnError(stream);
break;
}
// end of file, the client has closed the connection
if (UNLIKELY(buf.len == 0)) {
stream.close();
logger_.trace("Calling OnClose because the socket is closed!");
this->derived().OnClose(stream);
break;
}
this->derived().on_read(stream, buf);
this->derived().OnRead(stream, buf);
}
}
private:
Logger logger;
Logger logger_;
};
}

View File

@ -1,29 +0,0 @@
#pragma once
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
namespace io {
namespace tcp {
template <class Socket>
class Stream {
public:
Stream(Socket&& socket) : socket(std::move(socket)) {
// save this to epoll event data baton to access later
event.data.ptr = this;
}
Stream(Stream&& stream) {
socket = std::move(stream.socket);
event = stream.event;
event.data.ptr = this;
}
int id() const { return socket.id(); }
Socket socket;
Epoll::Event event;
};
}
}

View File

@ -1,42 +0,0 @@
#include "io/network/tls.hpp"
#include "io/network/tls_error.hpp"
namespace io {
Tls::Context::Context() {
auto method = SSLv23_server_method();
ctx = SSL_CTX_new(method);
if (!ctx) {
ERR_print_errors_fp(stderr);
throw io::TlsError("Unable to create TLS context");
}
SSL_CTX_set_ecdh_auto(ctx, 1);
}
Tls::Context::~Context() { SSL_CTX_free(ctx); }
Tls::Context& Tls::Context::cert(const std::string& path) {
if (SSL_CTX_use_certificate_file(ctx, path.c_str(), SSL_FILETYPE_PEM) >= 0)
return *this;
ERR_print_errors_fp(stderr);
throw TlsError("Error Loading cert '{}'", path);
}
Tls::Context& Tls::Context::key(const std::string& path) {
if (SSL_CTX_use_PrivateKey_file(ctx, path.c_str(), SSL_FILETYPE_PEM) >= 0)
return *this;
ERR_print_errors_fp(stderr);
throw TlsError("Error Loading private key '{}'", path);
}
void Tls::initialize() {
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
}
void Tls::cleanup() { EVP_cleanup(); }
}

View File

@ -1,29 +0,0 @@
#pragma once
#include <string>
#include <openssl/err.h>
#include <openssl/ssl.h>
namespace io {
class Tls {
public:
class Context {
public:
Context();
~Context();
Context& cert(const std::string& path);
Context& key(const std::string& path);
operator SSL_CTX*() const { return ctx; }
private:
SSL_CTX* ctx;
};
static void initialize();
static void cleanup();
};
}

View File

@ -1,11 +0,0 @@
#pragma once
#include "utils/exceptions/stacktrace_exception.hpp"
namespace io {
class TlsError : public StacktraceException {
public:
using StacktraceException::StacktraceException;
};
}

View File

@ -1,9 +1,14 @@
#include <signal.h>
#include <iostream>
#include "communication/bolt/v1/server/server.hpp"
#include "communication/bolt/v1/server/worker.hpp"
#include "dbms/dbms.hpp"
#include "query/engine.hpp"
#include "communication/bolt/v1/session.hpp"
#include "communication/server.hpp"
#include "io/network/network_endpoint.hpp"
#include "io/network/network_error.hpp"
#include "io/network/socket.hpp"
#include "logging/default.hpp"
@ -14,7 +19,13 @@
#include "utils/stacktrace/log.hpp"
#include "utils/terminate_handler.hpp"
static bolt::Server<bolt::Worker> *serverptr;
using endpoint_t = io::network::NetworkEndpoint;
using socket_t = io::network::Socket;
using bolt_server_t =
communication::Server<bolt::Session<socket_t>, bolt::RecordStream<socket_t>,
socket_t>;
static bolt_server_t *serverptr;
Logger logger;
@ -60,28 +71,44 @@ int main(int argc, char **argv) {
// register args
CONFIG_REGISTER_ARGS(argc, argv);
// initialize socket
io::Socket socket;
// initialize endpoint
endpoint_t endpoint;
try {
socket = io::Socket::bind(interface, port);
} catch (io::NetworkError e) {
logger.error("Cannot bind to socket on {} at {}", interface, port);
endpoint = endpoint_t(interface, port);
} catch (io::network::NetworkEndpointException &e) {
logger.error("{}", e.what());
std::exit(EXIT_FAILURE);
}
socket.set_non_blocking();
socket.listen(1024);
// initialize socket
socket_t socket;
if (!socket.Bind(endpoint)) {
logger.error("Cannot bind to socket on {} at {}", interface, port);
std::exit(EXIT_FAILURE);
}
if (!socket.SetNonBlocking()) {
logger.error("Cannot set socket to non blocking!");
std::exit(EXIT_FAILURE);
}
if (!socket.Listen(1024)) {
logger.error("Cannot listen on socket!");
std::exit(EXIT_FAILURE);
}
logger.info("Listening on {} at {}", interface, port);
Dbms dbms;
QueryEngine<bolt::RecordStream<socket_t>> query_engine;
// initialize server
bolt::Server<bolt::Worker> server(std::move(socket));
bolt_server_t server(std::move(socket), dbms, query_engine);
serverptr = &server;
// server start with N threads
// TODO: N should be configurable
auto N = std::thread::hardware_concurrency();
logger.info("Starting {} workers", N);
server.start(N);
server.Start(N);
logger.info("Shutting down...");
return EXIT_SUCCESS;

View File

@ -1,6 +1,5 @@
#pragma once
#include "communication/bolt/communication.hpp"
#include "database/graph_db_accessor.hpp"
#include "query/stripped.hpp"

34
tests/client-stress.sh Executable file
View File

@ -0,0 +1,34 @@
#!/bin/bash
NUM_PARALLEL=4
if [[ "$1" != "" ]]; then
NUM_PARALLEL=$1
fi
for i in $( seq 1 $NUM_PARALLEL ); do
echo "CREATE (n {name: 29383}) RETURN n;" | neo4j-client --insecure --username= --password= neo4j://localhost:7687 >/dev/null 2>/dev/null &
done
running="yes"
count=0
while [[ "$running" != "" ]]; do
running=$( pidof neo4j-client )
num=$( echo "$running" | wc -w )
echo "Running clients: $num"
count=$(( count + 1 ))
if [[ $count -gt 5 ]]; then break; fi
sleep 1
done
if [[ "$running" != "" ]]; then
echo "Something went wrong!"
echo "Running PIDs: $running"
echo "Killing leftover clients..."
kill -9 $running >/dev/null 2>/dev/null
wait $running 2>/dev/null
else
echo "All ok!"
fi

View File

@ -0,0 +1,117 @@
#pragma ONCE
#include <array>
#include <cassert>
#include <cstring>
#include <iostream>
#include <vector>
#include "gtest/gtest.h"
#include "logging/default.hpp"
#include "logging/streams/stdout.hpp"
#include "communication/server.hpp"
#include "dbms/dbms.hpp"
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
#include "query/engine.hpp"
static constexpr const int SIZE = 60000;
static constexpr const int REPLY = 10;
using endpoint_t = io::network::NetworkEndpoint;
using socket_t = io::network::Socket;
class TestOutputStream {};
class TestSession {
public:
TestSession(socket_t &&socket, Dbms &dbms,
QueryEngine<TestOutputStream> &query_engine)
: socket(std::move(socket)), logger_(logging::log->logger("TestSession")) {
event.data.ptr = this;
}
bool alive() { return socket.IsOpen(); }
int id() const { return socket.id(); }
void execute(const byte *data, size_t len) {
if (size_ == 0) {
size_ = data[0];
size_ <<= 8;
size_ += data[1];
data += 2;
len -= 2;
}
memcpy(buffer_ + have_, data, len);
have_ += len;
if (have_ < size_) return;
for (int i = 0; i < REPLY; ++i)
ASSERT_TRUE(this->socket.Write(buffer_, size_));
have_ = 0;
size_ = 0;
}
void close() {
logger_.trace("Close session!");
this->socket.Close();
}
char buffer_[SIZE * 2];
uint32_t have_, size_;
Logger logger_;
socket_t socket;
io::network::Epoll::Event event;
};
using test_server_t =
communication::Server<TestSession, TestOutputStream, socket_t>;
void server_start(void* serverptr, int num) {
((test_server_t*)serverptr)->Start(num);
}
void client_run(int num, const char* interface, const char* port, const unsigned char* data, int lo, int hi) {
std::stringstream name;
name << "Client " << num;
Logger logger = logging::log->logger(name.str());
unsigned char buffer[SIZE * REPLY], head[2];
int have, read;
endpoint_t endpoint(interface, port);
socket_t socket;
ASSERT_TRUE(socket.Connect(endpoint));
logger.trace("Socket create: {}", socket.id());
for (int len = lo; len <= hi; len += 100) {
have = 0;
head[0] = (len >> 8) & 0xff;
head[1] = len & 0xff;
ASSERT_TRUE(socket.Write(head, 2));
ASSERT_TRUE(socket.Write(data, len));
logger.trace("Socket write: {}", socket.id());
while (have < len * REPLY) {
read = socket.Read(buffer + have, SIZE);
logger.trace("Socket read: {}", socket.id());
if (read == -1) break;
have += read;
}
for (int i = 0; i < REPLY; ++i)
for (int j = 0; j < len; ++j) ASSERT_EQ(buffer[i * len + j], data[j]);
}
logger.trace("Socket done: {}", socket.id());
socket.Close();
}
void initialize_data(unsigned char* data, int size) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(0, 255);
for (int i = 0; i < size; ++i) {
data[i] = dis(gen);
}
}

View File

@ -0,0 +1,51 @@
#define NDEBUG
#include "network_common.hpp"
static constexpr const char interface[] = "127.0.0.1";
static constexpr const char port[] = "31337";
unsigned char data[SIZE];
test_server_t *serverptr;
TEST(Network, Server) {
// initialize test data
initialize_data(data, SIZE);
// initialize listen socket
endpoint_t endpoint(interface, port);
socket_t socket;
ASSERT_TRUE(socket.Bind(endpoint));
ASSERT_TRUE(socket.SetNonBlocking());
ASSERT_TRUE(socket.Listen(1024));
// initialize server
Dbms dbms;
QueryEngine<TestOutputStream> query_engine;
test_server_t server(std::move(socket), dbms, query_engine);
serverptr = &server;
// start server
int N = std::thread::hardware_concurrency() / 2;
std::thread server_thread(server_start, serverptr, N);
// start clients
std::vector<std::thread> clients;
for (int i = 0; i < N; ++i)
clients.push_back(std::thread(client_run, i, interface, port, data, 30000, SIZE));
// cleanup clients
for (int i = 0; i < N; ++i) clients[i].join();
// stop server
server.Shutdown();
server_thread.join();
}
int main(int argc, char **argv) {
logging::init_async();
logging::log->pipe(std::make_unique<Stdout>());
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,62 @@
#define NDEBUG
#include <chrono>
#include "network_common.hpp"
static constexpr const char interface[] = "127.0.0.1";
static constexpr const char port[] = "31337";
unsigned char data[SIZE];
test_server_t *serverptr;
using namespace std::chrono_literals;
TEST(Network, SessionLeak) {
// initialize test data
initialize_data(data, SIZE);
// initialize listen socket
endpoint_t endpoint(interface, port);
socket_t socket;
ASSERT_TRUE(socket.Bind(endpoint));
ASSERT_TRUE(socket.SetNonBlocking());
ASSERT_TRUE(socket.Listen(1024));
// initialize server
Dbms dbms;
QueryEngine<TestOutputStream> query_engine;
test_server_t server(std::move(socket), dbms, query_engine);
serverptr = &server;
// start server
std::thread server_thread(server_start, serverptr, 2);
// start clients
int N = 50;
std::vector<std::thread> clients;
int testlen = 3000;
for (int i = 0; i < N; ++i) {
clients.push_back(std::thread(client_run, i, interface, port, data, testlen, testlen));
std::this_thread::sleep_for(10ms);
}
// cleanup clients
for (int i = 0; i < N; ++i) clients[i].join();
std::this_thread::sleep_for(2s);
// stop server
server.Shutdown();
server_thread.join();
}
// run with "valgrind --leak-check=full ./network_session_leak" to check for memory leaks
int main(int argc, char **argv) {
logging::init_sync();
logging::log->pipe(std::make_unique<Stdout>());
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -3,8 +3,9 @@
// the flag is only used in hardcoded queries compilation
// see usage in plan_compiler.hpp
#ifndef HARDCODED_OUTPUT_STREAM
#include "communication/bolt/communication.hpp"
using Stream = communication::OutputStream;
#include "communication/bolt/v1/serialization/record_stream.hpp"
#include "io/network/socket.hpp"
using Stream = bolt::RecordStream<io::network::Socket>;
#else
#include "../stream/print_record_stream.hpp"
using Stream = PrintRecordStream;

117
tests/unit/bolt_session.cpp Normal file
View File

@ -0,0 +1,117 @@
#include <array>
#include <cassert>
#include <cstring>
#include <iostream>
#include <vector>
#include "gtest/gtest.h"
#include "logging/streams/stdout.hpp"
#include "communication/bolt/v1/serialization/record_stream.hpp"
#include "communication/bolt/v1/session.hpp"
#include "dbms/dbms.hpp"
#include "query/engine.hpp"
class TestSocket {
public:
TestSocket(int socket) : socket(socket) {}
TestSocket(const TestSocket& s) : socket(s.id()){};
TestSocket(TestSocket&& other) { *this = std::forward<TestSocket>(other); }
TestSocket& operator=(TestSocket&& other) {
this->socket = other.socket;
other.socket = -1;
return *this;
}
void Close() { socket = -1; }
bool IsOpen() { return socket != -1; }
int id() const { return socket; }
int Write(const std::string& str) { return Write(str.c_str(), str.size()); }
int Write(const char* data, size_t len) {
return Write(reinterpret_cast<const uint8_t*>(data), len);
}
int Write(const uint8_t* data, size_t len) {
for (int i = 0; i < len; ++i) output.push_back(data[i]);
return len;
}
std::vector<uint8_t> output;
protected:
int socket;
};
const uint8_t handshake_req[] =
"\x60\x60\xb0\x17\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
"\x00\x00";
const uint8_t handshake_resp[] = "\x00\x00\x00\x01";
const uint8_t init_req[] =
"\x00\x3f\xb2\x01\xd0\x15\x6c\x69\x62\x6e\x65\x6f\x34\x6a\x2d\x63\x6c\x69"
"\x65\x6e\x74\x2f\x31\x2e\x32\x2e\x31\xa3\x86\x73\x63\x68\x65\x6d\x65\x85"
"\x62\x61\x73\x69\x63\x89\x70\x72\x69\x6e\x63\x69\x70\x61\x6c\x80\x8b\x63"
"\x72\x65\x64\x65\x6e\x74\x69\x61\x6c\x73\x80\x00\x00";
const uint8_t init_resp[] = "\x00\x03\xb1\x70\xa0\x00\x00";
const uint8_t run_req[] =
"\x00\x26\xb2\x10\xd0\x21\x43\x52\x45\x41\x54\x45\x20\x28\x6e\x20\x7b\x6e"
"\x61\x6d\x65\x3a\x20\x32\x39\x33\x38\x33\x7d\x29\x20\x52\x45\x54\x55\x52"
"\x4e\x20\x6e\xa0\x00\x00";
void print_output(std::vector<uint8_t>& output) {
fprintf(stderr, "output: ");
for (int i = 0; i < output.size(); ++i) {
fprintf(stderr, "%02X ", output[i]);
}
fprintf(stderr, "\n");
}
void check_output(std::vector<uint8_t>& output, const uint8_t* data,
uint64_t len) {
EXPECT_EQ(len, output.size());
for (int i = 0; i < len; ++i) {
EXPECT_EQ(output[i], data[i]);
}
output.clear();
}
TEST(Bolt, Session) {
Dbms dbms;
TestSocket socket(10);
QueryEngine<bolt::RecordStream<TestSocket>> query_engine;
bolt::Session<TestSocket> session(std::move(socket), dbms, query_engine);
std::vector<uint8_t>& output = session.socket.output;
// execute handshake
session.execute(handshake_req, 20);
ASSERT_EQ(session.state, bolt::INIT);
print_output(output);
check_output(output, handshake_resp, 4);
// execute init
session.execute(init_req, 67);
ASSERT_EQ(session.state, bolt::EXECUTOR);
print_output(output);
check_output(output, init_resp, 7);
// execute run
session.execute(run_req, 42);
// TODO: query engine doesn't currently work,
// we should test the query output here and the next state
// ASSERT_EQ(session.state, bolt::EXECUTOR);
// print_output(output);
// check_output(output, run_resp, len);
// TODO: add more tests
session.close();
}
int main(int argc, char** argv) {
logging::init_sync();
logging::log->pipe(std::make_unique<Stdout>());
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,84 @@
#include <iostream>
#include "gtest/gtest.h"
#include "io/network/network_endpoint.hpp"
#include "io/network/network_error.hpp"
using endpoint_t = io::network::NetworkEndpoint;
using exception_t = io::network::NetworkEndpointException;
TEST(NetworkEndpoint, IPv4) {
endpoint_t endpoint;
// test first constructor
endpoint = endpoint_t("127.0.0.1", "12345");
EXPECT_STREQ(endpoint.address(), "127.0.0.1");
EXPECT_STREQ(endpoint.port_str(), "12345");
EXPECT_EQ(endpoint.port(), 12345);
EXPECT_EQ(endpoint.family(), 4);
// test second constructor
std::string addr("127.0.0.2"), port("12346");
endpoint = endpoint_t(addr, port);
EXPECT_STREQ(endpoint.address(), "127.0.0.2");
EXPECT_STREQ(endpoint.port_str(), "12346");
EXPECT_EQ(endpoint.port(), 12346);
EXPECT_EQ(endpoint.family(), 4);
// test third constructor
endpoint = endpoint_t("127.0.0.1", 12347);
EXPECT_STREQ(endpoint.address(), "127.0.0.1");
EXPECT_STREQ(endpoint.port_str(), "12347");
EXPECT_EQ(endpoint.port(), 12347);
EXPECT_EQ(endpoint.family(), 4);
// test address null
EXPECT_THROW(endpoint_t(nullptr, nullptr), exception_t);
// test address invalid
EXPECT_THROW(endpoint_t("invalid", "12345"), exception_t);
// test port invalid
EXPECT_THROW(endpoint_t("127.0.0.1", "invalid"), exception_t);
}
TEST(NetworkEndpoint, IPv6) {
endpoint_t endpoint;
// test first constructor
endpoint = endpoint_t("ab:cd:ef::1", "12345");
EXPECT_STREQ(endpoint.address(), "ab:cd:ef::1");
EXPECT_STREQ(endpoint.port_str(), "12345");
EXPECT_EQ(endpoint.port(), 12345);
EXPECT_EQ(endpoint.family(), 6);
// test second constructor
std::string addr("ab:cd:ef::2"), port("12346");
endpoint = endpoint_t(addr, port);
EXPECT_STREQ(endpoint.address(), "ab:cd:ef::2");
EXPECT_STREQ(endpoint.port_str(), "12346");
EXPECT_EQ(endpoint.port(), 12346);
EXPECT_EQ(endpoint.family(), 6);
// test third constructor
endpoint = endpoint_t("ab:cd:ef::3", 12347);
EXPECT_STREQ(endpoint.address(), "ab:cd:ef::3");
EXPECT_STREQ(endpoint.port_str(), "12347");
EXPECT_EQ(endpoint.port(), 12347);
EXPECT_EQ(endpoint.family(), 6);
// test address null
EXPECT_THROW(endpoint_t(nullptr, nullptr), exception_t);
// test address invalid
EXPECT_THROW(endpoint_t("::g", "12345"), exception_t);
// test port invalid
EXPECT_THROW(endpoint_t("::1", "invalid"), exception_t);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}