Started network refactorization.
Summary: Moved server and worker from bolt to communication. Started templatization. Started removal of Bolt class. Removed unnecessary files from network. Converted states to template functions. Bolt::Session is now a template. Merge remote-tracking branch 'origin/dev' into mg_refactor_network Merged bolt_serializer.cpp into hpp. Removed obsolete include. Initial version of bolt session unit test. Uncommented leftover log commands. Reimplemented io::Socket. Added client-stress.sh script. Reviewers: dgleich, buda Reviewed By: dgleich, buda Subscribers: pullbot, mferencevic, buda Differential Revision: https://phabricator.memgraph.io/D64
This commit is contained in:
parent
38c3c513fa
commit
813a3b9eed
@ -298,16 +298,11 @@ set(memgraph_src_files
|
||||
${src_dir}/utils/string/join.cpp
|
||||
${src_dir}/utils/string/file.cpp
|
||||
${src_dir}/utils/numerics/saturate.cpp
|
||||
${src_dir}/communication/bolt/v1/bolt.cpp
|
||||
${src_dir}/communication/bolt/v1/states.cpp
|
||||
${src_dir}/communication/bolt/v1/session.cpp
|
||||
${src_dir}/communication/bolt/v1/states/error.cpp
|
||||
${src_dir}/communication/bolt/v1/states/executor.cpp
|
||||
${src_dir}/communication/bolt/v1/states/init.cpp
|
||||
${src_dir}/communication/bolt/v1/states/handshake.cpp
|
||||
${src_dir}/communication/bolt/v1/transport/bolt_decoder.cpp
|
||||
${src_dir}/communication/bolt/v1/transport/buffer.cpp
|
||||
${src_dir}/communication/bolt/v1/serialization/bolt_serializer.cpp
|
||||
${src_dir}/io/network/addrinfo.cpp
|
||||
${src_dir}/io/network/network_endpoint.cpp
|
||||
${src_dir}/io/network/socket.cpp
|
||||
${src_dir}/threading/thread.cpp
|
||||
${src_dir}/mvcc/id.cpp
|
||||
# ${src_dir}/snapshot/snapshot_engine.cpp
|
||||
|
@ -1,8 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/v1/serialization/record_stream.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
namespace communication {
|
||||
using OutputStream = bolt::RecordStream<io::Socket>;
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
#include "communication/bolt/v1/bolt.hpp"
|
||||
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
Bolt::Bolt() {}
|
||||
|
||||
Session* Bolt::create_session(io::Socket&& socket) {
|
||||
// TODO fix session lifecycle handling
|
||||
// dangling pointers are not cool :)
|
||||
|
||||
// TODO attach currently active Db
|
||||
|
||||
return new Session(std::forward<io::Socket>(socket), *this);
|
||||
}
|
||||
|
||||
void Bolt::close(Session* session) { session->socket.close(); }
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/v1/states.hpp"
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Session;
|
||||
|
||||
class Bolt {
|
||||
friend class Session;
|
||||
|
||||
public:
|
||||
Bolt();
|
||||
|
||||
Session *create_session(io::Socket &&socket);
|
||||
void close(Session *session);
|
||||
|
||||
States states;
|
||||
Dbms dbms;
|
||||
};
|
||||
}
|
@ -1,104 +0,0 @@
|
||||
#include "communication/bolt/v1/serialization/bolt_serializer.hpp"
|
||||
|
||||
#include "communication/bolt/v1/transport/chunked_buffer.hpp"
|
||||
#include "communication/bolt/v1/transport/chunked_encoder.hpp"
|
||||
#include "communication/bolt/v1/transport/socket_stream.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
#include "database/graph_db.hpp"
|
||||
#include "database/graph_db_accessor.hpp"
|
||||
#include "storage/property_value_store.hpp"
|
||||
#include <cassert>
|
||||
|
||||
template <class Stream>
|
||||
void bolt::BoltSerializer<Stream>::write(const VertexAccessor &vertex) {
|
||||
// write signatures for the node struct and node data type
|
||||
encoder.write_struct_header(3);
|
||||
encoder.write(underlying_cast(pack::Code::Node));
|
||||
|
||||
// IMPORTANT: here we write a hardcoded 0 because we don't
|
||||
// use internal IDs, but need to give something to Bolt
|
||||
// note that OpenCypher has no id(x) function, so the client
|
||||
// should not be able to do anything with this value anyway
|
||||
encoder.write_integer(0); // uID
|
||||
|
||||
// write the list of labels
|
||||
auto labels = vertex.labels();
|
||||
encoder.write_list_header(labels.size());
|
||||
for (auto label : labels)
|
||||
encoder.write_string(vertex.db_accessor().label_name(label));
|
||||
|
||||
// write the properties
|
||||
const PropertyValueStore<GraphDb::Property> &props = vertex.Properties();
|
||||
encoder.write_map_header(props.size());
|
||||
props.Accept([this, &vertex](const GraphDb::Property prop,
|
||||
const PropertyValue &value) {
|
||||
this->encoder.write(vertex.db_accessor().property_name(prop));
|
||||
this->write(value);
|
||||
});
|
||||
}
|
||||
|
||||
template <class Stream>
|
||||
void bolt::BoltSerializer<Stream>::write(const EdgeAccessor &edge) {
|
||||
// write signatures for the edge struct and edge data type
|
||||
encoder.write_struct_header(5);
|
||||
encoder.write(underlying_cast(pack::Code::Relationship));
|
||||
|
||||
// IMPORTANT: here we write a hardcoded 0 because we don't
|
||||
// use internal IDs, but need to give something to Bolt
|
||||
// note that OpenCypher has no id(x) function, so the client
|
||||
// should not be able to do anything with this value anyway
|
||||
encoder.write_integer(0);
|
||||
encoder.write_integer(0);
|
||||
encoder.write_integer(0);
|
||||
|
||||
// write the type of the edge
|
||||
encoder.write(edge.db_accessor().edge_type_name(edge.edge_type()));
|
||||
|
||||
// write the property map
|
||||
const PropertyValueStore<GraphDb::Property> &props = edge.Properties();
|
||||
encoder.write_map_header(props.size());
|
||||
props.Accept(
|
||||
[this, &edge](GraphDb::Property prop, const PropertyValue &value) {
|
||||
this->encoder.write(edge.db_accessor().property_name(prop));
|
||||
this->write(value);
|
||||
});
|
||||
}
|
||||
|
||||
template <class Stream>
|
||||
void bolt::BoltSerializer<Stream>::write(const PropertyValue &value) {
|
||||
switch (value.type()) {
|
||||
case PropertyValue::Type::Null:
|
||||
encoder.write_null();
|
||||
return;
|
||||
case PropertyValue::Type::Bool:
|
||||
encoder.write_bool(value.Value<bool>());
|
||||
return;
|
||||
case PropertyValue::Type::String:
|
||||
encoder.write_string(value.Value<std::string>());
|
||||
return;
|
||||
case PropertyValue::Type::Int:
|
||||
encoder.write_integer(value.Value<int>());
|
||||
return;
|
||||
case PropertyValue::Type::Double:
|
||||
encoder.write_double(value.Value<double>());
|
||||
return;
|
||||
case PropertyValue::Type::List:
|
||||
// Not implemented
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Stream>
|
||||
void bolt::BoltSerializer<Stream>::write_failure(
|
||||
const std::map<std::string, std::string> &data) {
|
||||
encoder.message_failure();
|
||||
encoder.write_map_header(data.size());
|
||||
for (auto const &kv : data) {
|
||||
write(kv.first);
|
||||
write(kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
template class bolt::BoltSerializer<bolt::BoltEncoder<
|
||||
bolt::ChunkedEncoder<bolt::ChunkedBuffer<bolt::SocketStream<io::Socket>>>>>;
|
@ -2,11 +2,12 @@
|
||||
|
||||
#include "communication/bolt/v1/packing/codes.hpp"
|
||||
#include "communication/bolt/v1/transport/bolt_encoder.hpp"
|
||||
#include "communication/bolt/v1/transport/chunked_buffer.hpp"
|
||||
#include "communication/bolt/v1/transport/chunked_encoder.hpp"
|
||||
|
||||
#include "storage/edge_accessor.hpp"
|
||||
#include "storage/vertex_accessor.hpp"
|
||||
|
||||
#include "storage/property_value.hpp"
|
||||
#include "database/graph_db.hpp"
|
||||
#include "database/graph_db_accessor.hpp"
|
||||
#include "storage/property_value_store.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
@ -24,7 +25,32 @@ class BoltSerializer {
|
||||
* }
|
||||
*
|
||||
*/
|
||||
void write(const VertexAccessor &vertex);
|
||||
void write(const VertexAccessor &vertex) {
|
||||
// write signatures for the node struct and node data type
|
||||
encoder.write_struct_header(3);
|
||||
encoder.write(underlying_cast(pack::Code::Node));
|
||||
|
||||
// IMPORTANT: here we write a hardcoded 0 because we don't
|
||||
// use internal IDs, but need to give something to Bolt
|
||||
// note that OpenCypher has no id(x) function, so the client
|
||||
// should not be able to do anything with this value anyway
|
||||
encoder.write_integer(0); // uID
|
||||
|
||||
// write the list of labels
|
||||
auto labels = vertex.labels();
|
||||
encoder.write_list_header(labels.size());
|
||||
for (auto label : labels)
|
||||
encoder.write_string(vertex.db_accessor().label_name(label));
|
||||
|
||||
// write the properties
|
||||
const PropertyValueStore<GraphDb::Property> &props = vertex.Properties();
|
||||
encoder.write_map_header(props.size());
|
||||
props.Accept([this, &vertex](const GraphDb::Property prop,
|
||||
const PropertyValue &value) {
|
||||
this->encoder.write(vertex.db_accessor().property_name(prop));
|
||||
this->write(value);
|
||||
});
|
||||
}
|
||||
|
||||
/** Serializes the edge accessor into the packstream format
|
||||
*
|
||||
@ -40,17 +66,69 @@ class BoltSerializer {
|
||||
* }
|
||||
*
|
||||
*/
|
||||
void write(const EdgeAccessor &edge);
|
||||
void write(const EdgeAccessor &edge) {
|
||||
// write signatures for the edge struct and edge data type
|
||||
encoder.write_struct_header(5);
|
||||
encoder.write(underlying_cast(pack::Code::Relationship));
|
||||
|
||||
// IMPORTANT: here we write a hardcoded 0 because we don't
|
||||
// use internal IDs, but need to give something to Bolt
|
||||
// note that OpenCypher has no id(x) function, so the client
|
||||
// should not be able to do anything with this value anyway
|
||||
encoder.write_integer(0);
|
||||
encoder.write_integer(0);
|
||||
encoder.write_integer(0);
|
||||
|
||||
// write the type of the edge
|
||||
encoder.write(edge.db_accessor().edge_type_name(edge.edge_type()));
|
||||
|
||||
// write the property map
|
||||
const PropertyValueStore<GraphDb::Property> &props = edge.Properties();
|
||||
encoder.write_map_header(props.size());
|
||||
props.Accept(
|
||||
[this, &edge](GraphDb::Property prop, const PropertyValue &value) {
|
||||
this->encoder.write(edge.db_accessor().property_name(prop));
|
||||
this->write(value);
|
||||
});
|
||||
}
|
||||
|
||||
// TODO document
|
||||
void write_failure(const std::map<std::string, std::string> &data);
|
||||
void write_failure(const std::map<std::string, std::string> &data) {
|
||||
encoder.message_failure();
|
||||
encoder.write_map_header(data.size());
|
||||
for (auto const &kv : data) {
|
||||
write(kv.first);
|
||||
write(kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes a PropertyValue (typically a property value in the edge or vertex).
|
||||
*
|
||||
* @param value The value to write.
|
||||
*/
|
||||
void write(const PropertyValue &value);
|
||||
void write(const PropertyValue &value) {
|
||||
switch (value.type()) {
|
||||
case PropertyValue::Type::Null:
|
||||
encoder.write_null();
|
||||
return;
|
||||
case PropertyValue::Type::Bool:
|
||||
encoder.write_bool(value.Value<bool>());
|
||||
return;
|
||||
case PropertyValue::Type::String:
|
||||
encoder.write_string(value.Value<std::string>());
|
||||
return;
|
||||
case PropertyValue::Type::Int:
|
||||
encoder.write_integer(value.Value<int>());
|
||||
return;
|
||||
case PropertyValue::Type::Double:
|
||||
encoder.write_double(value.Value<double>());
|
||||
return;
|
||||
case PropertyValue::Type::List:
|
||||
// Not implemented
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
Stream &encoder;
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include "communication/bolt/v1/serialization/bolt_serializer.hpp"
|
||||
#include "communication/bolt/v1/transport/chunked_buffer.hpp"
|
||||
#include "communication/bolt/v1/transport/chunked_encoder.hpp"
|
||||
#include "communication/bolt/v1/transport/socket_stream.hpp"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
|
||||
@ -145,13 +144,12 @@ class RecordStream {
|
||||
Logger logger;
|
||||
|
||||
private:
|
||||
using socket_t = SocketStream<Socket>;
|
||||
using buffer_t = ChunkedBuffer<socket_t>;
|
||||
using buffer_t = ChunkedBuffer<Socket>;
|
||||
using chunked_encoder_t = ChunkedEncoder<buffer_t>;
|
||||
using bolt_encoder_t = BoltEncoder<chunked_encoder_t>;
|
||||
using bolt_serializer_t = BoltSerializer<bolt_encoder_t>;
|
||||
|
||||
socket_t socket;
|
||||
Socket &socket;
|
||||
buffer_t chunked_buffer{socket};
|
||||
chunked_encoder_t chunked_encoder{chunked_buffer};
|
||||
bolt_encoder_t bolt_encoder{chunked_encoder};
|
||||
|
@ -1,62 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "communication/bolt/v1/bolt.hpp"
|
||||
#include "io/network/server.hpp"
|
||||
#include "logging/default.hpp"
|
||||
#include "utils/assert.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
template <class Worker>
|
||||
class Server : public io::Server<Server<Worker>> {
|
||||
public:
|
||||
Server(io::Socket&& socket)
|
||||
: io::Server<Server<Worker>>(std::forward<io::Socket>(socket)),
|
||||
logger(logging::log->logger("bolt::Server")) {}
|
||||
|
||||
void start(size_t n) {
|
||||
workers.reserve(n);
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
workers.push_back(std::make_shared<Worker>(bolt));
|
||||
workers.back()->start(alive);
|
||||
}
|
||||
|
||||
while (alive) {
|
||||
this->wait_and_process_events();
|
||||
}
|
||||
}
|
||||
|
||||
void shutdown() {
|
||||
alive.store(false);
|
||||
|
||||
for (auto& worker : workers) worker->thread.join();
|
||||
}
|
||||
|
||||
void on_connect() {
|
||||
debug_assert(idx < workers.size(), "Invalid worker id.");
|
||||
|
||||
logger.trace("on connect");
|
||||
|
||||
if (UNLIKELY(!workers[idx]->accept(this->socket))) return;
|
||||
|
||||
idx = idx == workers.size() - 1 ? 0 : idx + 1;
|
||||
}
|
||||
|
||||
void on_wait_timeout() {}
|
||||
|
||||
private:
|
||||
Bolt bolt;
|
||||
|
||||
std::vector<typename Worker::sptr> workers;
|
||||
std::atomic<bool> alive{true};
|
||||
|
||||
int idx{0};
|
||||
Logger logger;
|
||||
};
|
||||
}
|
@ -1,103 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "communication/bolt/v1/bolt.hpp"
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "io/network/stream_reader.hpp"
|
||||
#include "logging/default.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
template <class Worker>
|
||||
class Server;
|
||||
|
||||
class Worker : public io::StreamReader<Worker, Session> {
|
||||
friend class bolt::Server<Worker>;
|
||||
|
||||
public:
|
||||
using sptr = std::shared_ptr<Worker>;
|
||||
|
||||
Worker(Bolt &bolt) : bolt(bolt) {
|
||||
logger = logging::log->logger("bolt::Worker");
|
||||
}
|
||||
|
||||
Session &on_connect(io::Socket &&socket) {
|
||||
logger.trace("Accepting connection on socket {}", socket.id());
|
||||
|
||||
return *bolt.get().create_session(std::forward<io::Socket>(socket));
|
||||
}
|
||||
|
||||
void on_error(Session &) {
|
||||
logger.trace("[on_error] errno = {}", errno);
|
||||
|
||||
#ifndef NDEBUG
|
||||
auto err = io::NetworkError("");
|
||||
logger.debug("{}", err.what());
|
||||
#endif
|
||||
|
||||
logger.error("Error occured in this session");
|
||||
}
|
||||
|
||||
void on_wait_timeout() {}
|
||||
|
||||
Buffer on_alloc(Session &) {
|
||||
/* logger.trace("[on_alloc] Allocating {}B", sizeof buf); */
|
||||
|
||||
return Buffer{buf, sizeof buf};
|
||||
}
|
||||
|
||||
void on_read(Session &session, Buffer &buf) {
|
||||
logger.trace("[on_read] Received {}B", buf.len);
|
||||
|
||||
#ifndef NDEBUG
|
||||
std::stringstream stream;
|
||||
|
||||
for (size_t i = 0; i < buf.len; ++i)
|
||||
stream << fmt::format("{:02X} ", static_cast<byte>(buf.ptr[i]));
|
||||
|
||||
logger.trace("[on_read] {}", stream.str());
|
||||
#endif
|
||||
|
||||
try {
|
||||
session.execute(reinterpret_cast<const byte *>(buf.ptr), buf.len);
|
||||
} catch (const std::exception &e) {
|
||||
logger.error("Error occured while executing statement.");
|
||||
logger.error("{}", e.what());
|
||||
// TODO: report to client
|
||||
}
|
||||
}
|
||||
|
||||
void on_close(Session &session) {
|
||||
logger.trace("[on_close] Client closed the connection");
|
||||
session.close();
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void on_exception(Session &session, Args &&... args) {
|
||||
logger.error("Error occured in this session");
|
||||
logger.error(args...);
|
||||
|
||||
// TODO: Do something about it
|
||||
}
|
||||
|
||||
char buf[65536];
|
||||
|
||||
protected:
|
||||
std::reference_wrapper<Bolt> bolt;
|
||||
|
||||
Logger logger;
|
||||
std::thread thread;
|
||||
|
||||
void start(std::atomic<bool> &alive) {
|
||||
thread = std::thread([&, this]() {
|
||||
while (alive) wait_and_process_events();
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
Session::Session(io::Socket &&socket, Bolt &bolt)
|
||||
: Stream(std::forward<io::Socket>(socket)), bolt(bolt) {
|
||||
logger = logging::log->logger("Session");
|
||||
|
||||
// start with a handshake state
|
||||
state = bolt.states.handshake.get();
|
||||
}
|
||||
|
||||
bool Session::alive() const { return state != nullptr; }
|
||||
|
||||
void Session::execute(const byte *data, size_t len) {
|
||||
// mark the end of the message
|
||||
auto end = data + len;
|
||||
|
||||
while (true) {
|
||||
auto size = end - data;
|
||||
|
||||
if (LIKELY(connected)) {
|
||||
logger.debug("Decoding chunk of size {}", size);
|
||||
auto finished = decoder.decode(data, size);
|
||||
|
||||
if (!finished) return;
|
||||
} else {
|
||||
logger.debug("Decoding handshake of size {}", size);
|
||||
decoder.handshake(data, size);
|
||||
}
|
||||
|
||||
state = state->run(*this);
|
||||
decoder.reset();
|
||||
}
|
||||
}
|
||||
|
||||
void Session::close() {
|
||||
logger.debug("Closing session");
|
||||
bolt.close(this);
|
||||
}
|
||||
|
||||
GraphDbAccessor Session::active_db() { return bolt.dbms.active(); }
|
||||
}
|
@ -1,12 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "io/network/epoll.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
#include "io/network/tcp/stream.hpp"
|
||||
|
||||
#include "communication/bolt/communication.hpp"
|
||||
#include "communication/bolt/v1/bolt.hpp"
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "query/engine.hpp"
|
||||
|
||||
#include "communication/bolt/v1/state.hpp"
|
||||
#include "communication/bolt/v1/states/handshake.hpp"
|
||||
#include "communication/bolt/v1/states/init.hpp"
|
||||
#include "communication/bolt/v1/states/executor.hpp"
|
||||
#include "communication/bolt/v1/states/error.hpp"
|
||||
|
||||
#include "communication/bolt/v1/serialization/record_stream.hpp"
|
||||
#include "communication/bolt/v1/states/state.hpp"
|
||||
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
|
||||
#include "communication/bolt/v1/transport/bolt_encoder.hpp"
|
||||
|
||||
@ -14,28 +20,85 @@
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Session : public io::tcp::Stream<io::Socket> {
|
||||
template<typename Socket>
|
||||
class Session {
|
||||
public:
|
||||
using Decoder = BoltDecoder;
|
||||
using OutputStream = communication::OutputStream;
|
||||
using OutputStream = RecordStream<Socket>;
|
||||
|
||||
Session(io::Socket &&socket, Bolt &bolt);
|
||||
Session(Socket &&socket, Dbms &dbms, QueryEngine<OutputStream> &query_engine)
|
||||
: socket(std::move(socket)),
|
||||
dbms(dbms), query_engine(query_engine),
|
||||
logger(logging::log->logger("Session")) {
|
||||
event.data.ptr = this;
|
||||
// start with a handshake state
|
||||
state = HANDSHAKE;
|
||||
}
|
||||
|
||||
bool alive() const;
|
||||
bool alive() const { return state != NULLSTATE; }
|
||||
|
||||
void execute(const byte *data, size_t len);
|
||||
int id() const { return socket.id(); }
|
||||
|
||||
void close();
|
||||
void execute(const byte *data, size_t len) {
|
||||
// mark the end of the message
|
||||
auto end = data + len;
|
||||
|
||||
Bolt &bolt;
|
||||
while (true) {
|
||||
auto size = end - data;
|
||||
|
||||
GraphDbAccessor active_db();
|
||||
if (LIKELY(connected)) {
|
||||
logger.debug("Decoding chunk of size {}", size);
|
||||
auto finished = decoder.decode(data, size);
|
||||
|
||||
if (!finished) return;
|
||||
} else {
|
||||
logger.debug("Decoding handshake of size {}", size);
|
||||
decoder.handshake(data, size);
|
||||
}
|
||||
|
||||
switch(state) {
|
||||
case HANDSHAKE:
|
||||
logger.debug("Current state: DEBUG");
|
||||
state = state_handshake_run<Socket>(decoder, this->socket, &connected);
|
||||
break;
|
||||
case INIT:
|
||||
logger.debug("Current state: INIT");
|
||||
state = state_init_run<Socket>(output_stream, decoder);
|
||||
break;
|
||||
case EXECUTOR:
|
||||
logger.debug("Current state: EXECUTOR");
|
||||
state = state_executor_run<Socket>(output_stream, decoder, dbms, query_engine);
|
||||
break;
|
||||
case ERROR:
|
||||
logger.debug("Current state: ERROR");
|
||||
state = state_error_run<Socket>(output_stream, decoder);
|
||||
break;
|
||||
case NULLSTATE:
|
||||
break;
|
||||
}
|
||||
|
||||
decoder.reset();
|
||||
}
|
||||
}
|
||||
|
||||
void close() {
|
||||
logger.debug("Closing session");
|
||||
this->socket.Close();
|
||||
}
|
||||
|
||||
Socket socket;
|
||||
io::network::Epoll::Event event;
|
||||
|
||||
Dbms &dbms;
|
||||
QueryEngine<OutputStream> &query_engine;
|
||||
|
||||
GraphDbAccessor active_db() { return dbms.active(); }
|
||||
|
||||
Decoder decoder;
|
||||
OutputStream output_stream{socket};
|
||||
|
||||
bool connected{false};
|
||||
State *state;
|
||||
State state;
|
||||
|
||||
protected:
|
||||
Logger logger;
|
||||
|
13
src/communication/bolt/v1/state.hpp
Normal file
13
src/communication/bolt/v1/state.hpp
Normal file
@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
namespace bolt {
|
||||
|
||||
enum State {
|
||||
HANDSHAKE,
|
||||
INIT,
|
||||
EXECUTOR,
|
||||
ERROR,
|
||||
NULLSTATE
|
||||
};
|
||||
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
#include "communication/bolt/v1/states.hpp"
|
||||
|
||||
#include "communication/bolt/v1/states/error.hpp"
|
||||
#include "communication/bolt/v1/states/executor.hpp"
|
||||
#include "communication/bolt/v1/states/handshake.hpp"
|
||||
#include "communication/bolt/v1/states/init.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
States::States() {
|
||||
handshake = std::make_unique<Handshake>();
|
||||
init = std::make_unique<Init>();
|
||||
executor = std::make_unique<Executor>();
|
||||
error = std::make_unique<Error>();
|
||||
}
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/v1/states/state.hpp"
|
||||
#include "logging/log.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class States {
|
||||
public:
|
||||
States();
|
||||
|
||||
State::uptr handshake;
|
||||
State::uptr init;
|
||||
State::uptr executor;
|
||||
State::uptr error;
|
||||
};
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
#include "communication/bolt/v1/states/error.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
Error::Error() : State(logging::log->logger("Error State")) {}
|
||||
|
||||
State* Error::run(Session& session) {
|
||||
logger.trace("Run");
|
||||
|
||||
session.decoder.read_byte();
|
||||
auto message_type = session.decoder.read_byte();
|
||||
|
||||
logger.trace("Message type byte is: {:02X}", message_type);
|
||||
|
||||
if (message_type == MessageCode::PullAll) {
|
||||
session.output_stream.write_ignored();
|
||||
session.output_stream.chunk();
|
||||
session.output_stream.send();
|
||||
return this;
|
||||
} else if (message_type == MessageCode::AckFailure) {
|
||||
// TODO reset current statement? is it even necessary?
|
||||
logger.trace("AckFailure received");
|
||||
|
||||
session.output_stream.write_success_empty();
|
||||
session.output_stream.chunk();
|
||||
session.output_stream.send();
|
||||
|
||||
return session.bolt.states.executor.get();
|
||||
} else if (message_type == MessageCode::Reset) {
|
||||
// TODO rollback current transaction
|
||||
// discard all records waiting to be sent
|
||||
|
||||
session.output_stream.write_success_empty();
|
||||
session.output_stream.chunk();
|
||||
session.output_stream.send();
|
||||
|
||||
return session.bolt.states.executor.get();
|
||||
}
|
||||
|
||||
// TODO: write this as single call
|
||||
session.output_stream.write_ignored();
|
||||
session.output_stream.chunk();
|
||||
session.output_stream.send();
|
||||
|
||||
return this;
|
||||
}
|
||||
}
|
@ -1,14 +1,53 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "communication/bolt/v1/states/state.hpp"
|
||||
#include "communication/bolt/v1/state.hpp"
|
||||
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
|
||||
#include "communication/bolt/v1/serialization/record_stream.hpp"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Error : public State {
|
||||
public:
|
||||
Error();
|
||||
template<typename Socket>
|
||||
State state_error_run(RecordStream<Socket> &output_stream, BoltDecoder &decoder) {
|
||||
Logger logger = logging::log->logger("State ERROR");
|
||||
logger.trace("Run");
|
||||
|
||||
State *run(Session &session) override;
|
||||
};
|
||||
decoder.read_byte();
|
||||
auto message_type = decoder.read_byte();
|
||||
|
||||
logger.trace("Message type byte is: {:02X}", message_type);
|
||||
|
||||
if (message_type == MessageCode::PullAll) {
|
||||
output_stream.write_ignored();
|
||||
output_stream.chunk();
|
||||
output_stream.send();
|
||||
return ERROR;
|
||||
} else if (message_type == MessageCode::AckFailure) {
|
||||
// TODO reset current statement? is it even necessary?
|
||||
logger.trace("AckFailure received");
|
||||
|
||||
output_stream.write_success_empty();
|
||||
output_stream.chunk();
|
||||
output_stream.send();
|
||||
|
||||
return EXECUTOR;
|
||||
} else if (message_type == MessageCode::Reset) {
|
||||
// TODO rollback current transaction
|
||||
// discard all records waiting to be sent
|
||||
|
||||
output_stream.write_success_empty();
|
||||
output_stream.chunk();
|
||||
output_stream.send();
|
||||
|
||||
return EXECUTOR;
|
||||
}
|
||||
|
||||
// TODO: write this as single call
|
||||
output_stream.write_ignored();
|
||||
output_stream.chunk();
|
||||
output_stream.send();
|
||||
|
||||
return ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -1,117 +0,0 @@
|
||||
#include "communication/bolt/v1/states/executor.hpp"
|
||||
#include "communication/bolt/v1/messaging/codes.hpp"
|
||||
#include "database/graph_db_accessor.hpp"
|
||||
#include "query/frontend/opencypher/parser.hpp"
|
||||
|
||||
#ifdef BARRIER
|
||||
#include "barrier/barrier.cpp"
|
||||
#endif
|
||||
|
||||
namespace bolt {
|
||||
|
||||
Executor::Executor() : State(logging::log->logger("Executor")) {}
|
||||
|
||||
State *Executor::run(Session &session) {
|
||||
// just read one byte that represents the struct type, we can skip the
|
||||
// information contained in this byte
|
||||
session.decoder.read_byte();
|
||||
|
||||
logger.debug("Run");
|
||||
|
||||
auto message_type = session.decoder.read_byte();
|
||||
|
||||
if (message_type == MessageCode::Run) {
|
||||
Query q;
|
||||
|
||||
q.statement = session.decoder.read_string();
|
||||
|
||||
// TODO: refactor bolt exception handling (Ferencevic)
|
||||
try {
|
||||
return this->run(session, q);
|
||||
} catch (const frontend::opencypher::SyntaxException &e) {
|
||||
session.output_stream.write_failure(
|
||||
{{"code", "Memgraph.SyntaxException"}, {"message", "Syntax error"}});
|
||||
session.output_stream.send();
|
||||
return session.bolt.states.error.get();
|
||||
} catch (const backend::cpp::GeneratorException &e) {
|
||||
session.output_stream.write_failure(
|
||||
{{"code", "Memgraph.GeneratorException"},
|
||||
{"message", "Unsupported query"}});
|
||||
session.output_stream.send();
|
||||
return session.bolt.states.error.get();
|
||||
} catch (const QueryEngineException &e) {
|
||||
session.output_stream.write_failure(
|
||||
{{"code", "Memgraph.QueryEngineException"},
|
||||
{"message", "Query engine was unable to execute the query"}});
|
||||
session.output_stream.send();
|
||||
return session.bolt.states.error.get();
|
||||
} catch (const StacktraceException &e) {
|
||||
session.output_stream.write_failure(
|
||||
{{"code", "Memgraph.StacktraceException"},
|
||||
{"message", "Unknow exception"}});
|
||||
session.output_stream.send();
|
||||
return session.bolt.states.error.get();
|
||||
} catch (std::exception &e) {
|
||||
session.output_stream.write_failure(
|
||||
{{"code", "Memgraph.Exception"}, {"message", "unknow exception"}});
|
||||
session.output_stream.send();
|
||||
return session.bolt.states.error.get();
|
||||
}
|
||||
} else if (message_type == MessageCode::PullAll) {
|
||||
pull_all(session);
|
||||
} else if (message_type == MessageCode::DiscardAll) {
|
||||
discard_all(session);
|
||||
} else if (message_type == MessageCode::Reset) {
|
||||
// TODO: rollback current transaction
|
||||
// discard all records waiting to be sent
|
||||
return this;
|
||||
} else {
|
||||
logger.error("Unrecognized message recieved");
|
||||
logger.debug("Invalid message type 0x{:02X}", message_type);
|
||||
|
||||
return session.bolt.states.error.get();
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
State *Executor::run(Session &session, Query &query) {
|
||||
logger.trace("[Run] '{}'", query.statement);
|
||||
|
||||
auto db_accessor = session.active_db();
|
||||
logger.debug("[ActiveDB] '{}'", db_accessor.name());
|
||||
|
||||
auto is_successfully_executed =
|
||||
query_engine.Run(query.statement, db_accessor, session.output_stream);
|
||||
|
||||
if (!is_successfully_executed) {
|
||||
session.output_stream.write_failure(
|
||||
{{"code", "Memgraph.QueryExecutionFail"},
|
||||
{"message",
|
||||
"Query execution has failed (probably there is no "
|
||||
"element or there are some problems with concurrent "
|
||||
"access -> client has to resolve problems with "
|
||||
"concurrent access)"}});
|
||||
session.output_stream.send();
|
||||
return session.bolt.states.error.get();
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
void Executor::pull_all(Session &session) {
|
||||
logger.trace("[PullAll]");
|
||||
|
||||
session.output_stream.send();
|
||||
}
|
||||
|
||||
void Executor::discard_all(Session &session) {
|
||||
logger.trace("[DiscardAll]");
|
||||
|
||||
// TODO: discard state
|
||||
|
||||
session.output_stream.write_success();
|
||||
session.output_stream.chunk();
|
||||
session.output_stream.send();
|
||||
}
|
||||
}
|
@ -1,38 +1,113 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "communication/bolt/v1/states/state.hpp"
|
||||
#include <string>
|
||||
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "query/engine.hpp"
|
||||
#include "communication/bolt/v1/states/executor.hpp"
|
||||
#include "communication/bolt/v1/messaging/codes.hpp"
|
||||
#include "communication/bolt/v1/state.hpp"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Executor : public State {
|
||||
struct Query {
|
||||
std::string statement;
|
||||
};
|
||||
|
||||
public:
|
||||
Executor();
|
||||
|
||||
State* run(Session& session) override final;
|
||||
|
||||
protected:
|
||||
/* Execute an incoming query
|
||||
*
|
||||
*/
|
||||
State* run(Session& session, Query& query);
|
||||
|
||||
/* Send all remaining results to the client
|
||||
*
|
||||
*/
|
||||
void pull_all(Session& session);
|
||||
|
||||
/* Discard all remaining results
|
||||
*
|
||||
*/
|
||||
void discard_all(Session& session);
|
||||
|
||||
private:
|
||||
QueryEngine<communication::OutputStream> query_engine;
|
||||
struct Query {
|
||||
std::string statement;
|
||||
};
|
||||
|
||||
template<typename Socket>
|
||||
State state_executor_run(RecordStream<Socket> &output_stream, BoltDecoder &decoder, Dbms &dmbs, QueryEngine<RecordStream<Socket>> &query_engine){
|
||||
Logger logger = logging::log->logger("State EXECUTOR");
|
||||
// just read one byte that represents the struct type, we can skip the
|
||||
// information contained in this byte
|
||||
decoder.read_byte();
|
||||
|
||||
logger.debug("Run");
|
||||
|
||||
auto message_type = decoder.read_byte();
|
||||
|
||||
if (message_type == MessageCode::Run) {
|
||||
Query query;
|
||||
|
||||
query.statement = decoder.read_string();
|
||||
|
||||
// TODO (mferencevic): refactor bolt exception handling
|
||||
try {
|
||||
logger.trace("[Run] '{}'", query.statement);
|
||||
|
||||
auto db_accessor = dmbs.active();
|
||||
logger.debug("[ActiveDB] '{}'", db_accessor.name());
|
||||
|
||||
auto is_successfully_executed =
|
||||
query_engine.Run(query.statement, db_accessor, output_stream);
|
||||
|
||||
if (!is_successfully_executed) {
|
||||
output_stream.write_failure(
|
||||
{{"code", "Memgraph.QueryExecutionFail"},
|
||||
{"message",
|
||||
"Query execution has failed (probably there is no "
|
||||
"element or there are some problems with concurrent "
|
||||
"access -> client has to resolve problems with "
|
||||
"concurrent access)"}});
|
||||
output_stream.send();
|
||||
return ERROR;
|
||||
}
|
||||
|
||||
return EXECUTOR;
|
||||
// TODO: RETURN success MAYBE
|
||||
} catch (const frontend::opencypher::SyntaxException &e) {
|
||||
output_stream.write_failure(
|
||||
{{"code", "Memgraph.SyntaxException"}, {"message", "Syntax error"}});
|
||||
output_stream.send();
|
||||
return ERROR;
|
||||
} catch (const backend::cpp::GeneratorException &e) {
|
||||
output_stream.write_failure(
|
||||
{{"code", "Memgraph.GeneratorException"},
|
||||
{"message", "Unsupported query"}});
|
||||
output_stream.send();
|
||||
return ERROR;
|
||||
} catch (const QueryEngineException &e) {
|
||||
output_stream.write_failure(
|
||||
{{"code", "Memgraph.QueryEngineException"},
|
||||
{"message", "Query engine was unable to execute the query"}});
|
||||
output_stream.send();
|
||||
return ERROR;
|
||||
} catch (const StacktraceException &e) {
|
||||
output_stream.write_failure(
|
||||
{{"code", "Memgraph.StacktraceException"},
|
||||
{"message", "Unknow exception"}});
|
||||
output_stream.send();
|
||||
return ERROR;
|
||||
} catch (std::exception &e) {
|
||||
output_stream.write_failure(
|
||||
{{"code", "Memgraph.Exception"}, {"message", "unknow exception"}});
|
||||
output_stream.send();
|
||||
return ERROR;
|
||||
}
|
||||
} else if (message_type == MessageCode::PullAll) {
|
||||
logger.trace("[PullAll]");
|
||||
output_stream.send();
|
||||
} else if (message_type == MessageCode::DiscardAll) {
|
||||
logger.trace("[DiscardAll]");
|
||||
|
||||
// TODO: discard state
|
||||
|
||||
output_stream.write_success();
|
||||
output_stream.chunk();
|
||||
output_stream.send();
|
||||
} else if (message_type == MessageCode::Reset) {
|
||||
// TODO: rollback current transaction
|
||||
// discard all records waiting to be sent
|
||||
return EXECUTOR;
|
||||
} else {
|
||||
logger.error("Unrecognized message recieved");
|
||||
logger.debug("Invalid message type 0x{:02X}", message_type);
|
||||
|
||||
return ERROR;
|
||||
}
|
||||
|
||||
return EXECUTOR;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,27 +0,0 @@
|
||||
#include "communication/bolt/v1/states/handshake.hpp"
|
||||
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
static constexpr uint32_t preamble = 0x6060B017;
|
||||
|
||||
static constexpr byte protocol[4] = {0x00, 0x00, 0x00, 0x01};
|
||||
|
||||
Handshake::Handshake() : State(logging::log->logger("Handshake")) {}
|
||||
|
||||
State* Handshake::run(Session& session) {
|
||||
logger.debug("run");
|
||||
|
||||
if (UNLIKELY(session.decoder.read_uint32() != preamble)) return nullptr;
|
||||
|
||||
// TODO so far we only support version 1 of the protocol so it doesn't
|
||||
// make sense to check which version the client prefers
|
||||
// this will change in the future
|
||||
|
||||
session.connected = true;
|
||||
session.socket.write(protocol, sizeof protocol);
|
||||
|
||||
return session.bolt.states.init.get();
|
||||
}
|
||||
}
|
@ -1,12 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/v1/states/state.hpp"
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "communication/bolt/v1/state.hpp"
|
||||
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Handshake : public State {
|
||||
public:
|
||||
Handshake();
|
||||
State* run(Session& session) override;
|
||||
};
|
||||
static constexpr uint32_t preamble = 0x6060B017;
|
||||
|
||||
static constexpr byte protocol[4] = {0x00, 0x00, 0x00, 0x01};
|
||||
|
||||
template<typename Socket>
|
||||
State state_handshake_run(BoltDecoder &decoder, Socket &socket_, bool *connected) {
|
||||
Logger logger = logging::log->logger("State HANDSHAKE");
|
||||
logger.debug("run");
|
||||
|
||||
if (UNLIKELY(decoder.read_uint32() != preamble)) return NULLSTATE;
|
||||
|
||||
// TODO so far we only support version 1 of the protocol so it doesn't
|
||||
// make sense to check which version the client prefers
|
||||
// this will change in the future
|
||||
|
||||
*connected = true;
|
||||
// TODO: check for success
|
||||
socket_.Write(protocol, sizeof protocol);
|
||||
|
||||
return INIT;
|
||||
}
|
||||
}
|
||||
|
@ -1,54 +0,0 @@
|
||||
#include "communication/bolt/v1/states/init.hpp"
|
||||
|
||||
#include "communication/bolt/v1/messaging/codes.hpp"
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
|
||||
#include "utils/likely.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
Init::Init() : MessageParser<Init>(logging::log->logger("Init")) {}
|
||||
|
||||
State *Init::parse(Session &session, Message &message) {
|
||||
logger.debug("bolt::Init.parse()");
|
||||
|
||||
auto struct_type = session.decoder.read_byte();
|
||||
|
||||
if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) {
|
||||
logger.debug("{}", struct_type);
|
||||
|
||||
logger.debug(
|
||||
"Expected struct marker of max size 0x{:02} instead of 0x{:02X}",
|
||||
(unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto message_type = session.decoder.read_byte();
|
||||
|
||||
if (UNLIKELY(message_type != MessageCode::Init)) {
|
||||
logger.debug("Expected Init (0x01) instead of (0x{:02X})",
|
||||
(unsigned)message_type);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
message.client_name = session.decoder.read_string();
|
||||
|
||||
if (struct_type == pack::Code::StructTwo) {
|
||||
// TODO process authentication tokens
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
State *Init::execute(Session &session, Message &message) {
|
||||
logger.debug("Client connected '{}'", message.client_name);
|
||||
|
||||
session.output_stream.write_success_empty();
|
||||
session.output_stream.chunk();
|
||||
session.output_stream.send();
|
||||
|
||||
return session.bolt.states.executor.get();
|
||||
}
|
||||
}
|
@ -1,18 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/v1/states/message_parser.hpp"
|
||||
#include "communication/bolt/v1/state.hpp"
|
||||
#include "communication/bolt/v1/transport/bolt_decoder.hpp"
|
||||
#include "communication/bolt/v1/serialization/record_stream.hpp"
|
||||
#include "communication/bolt/v1/messaging/codes.hpp"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
#include "utils/likely.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Init : public MessageParser<Init> {
|
||||
public:
|
||||
struct Message {
|
||||
std::string client_name;
|
||||
};
|
||||
template<typename Socket>
|
||||
State state_init_run(RecordStream<Socket> &output_stream, BoltDecoder &decoder) {
|
||||
Logger logger = logging::log->logger("State INIT");
|
||||
logger.debug("Parsing message");
|
||||
|
||||
Init();
|
||||
auto struct_type = decoder.read_byte();
|
||||
|
||||
if (UNLIKELY((struct_type & 0x0F) > pack::Rule::MaxInitStructSize)) {
|
||||
logger.debug("{}", struct_type);
|
||||
|
||||
logger.debug(
|
||||
"Expected struct marker of max size 0x{:02} instead of 0x{:02X}",
|
||||
(unsigned)pack::Rule::MaxInitStructSize, (unsigned)struct_type);
|
||||
|
||||
return NULLSTATE;
|
||||
}
|
||||
|
||||
auto message_type = decoder.read_byte();
|
||||
|
||||
if (UNLIKELY(message_type != MessageCode::Init)) {
|
||||
logger.debug("Expected Init (0x01) instead of (0x{:02X})",
|
||||
(unsigned)message_type);
|
||||
|
||||
return NULLSTATE;
|
||||
}
|
||||
|
||||
auto client_name = decoder.read_string();
|
||||
|
||||
if (struct_type == pack::Code::StructTwo) {
|
||||
// TODO process authentication tokens
|
||||
}
|
||||
|
||||
logger.debug("Executing state");
|
||||
logger.debug("Client connected '{}'", client_name);
|
||||
|
||||
output_stream.write_success_empty();
|
||||
output_stream.chunk();
|
||||
output_stream.send();
|
||||
|
||||
return EXECUTOR;
|
||||
}
|
||||
|
||||
State* parse(Session& session, Message& message);
|
||||
State* execute(Session& session, Message& message);
|
||||
};
|
||||
}
|
||||
|
@ -1,30 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "communication/bolt/v1/states/state.hpp"
|
||||
#include "utils/crtp.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
template <class Derived>
|
||||
class MessageParser : public State, public Crtp<Derived> {
|
||||
public:
|
||||
MessageParser(Logger &&logger) : logger(std::forward<Logger>(logger)) {}
|
||||
|
||||
State *run(Session &session) override final {
|
||||
typename Derived::Message message;
|
||||
|
||||
logger.debug("Parsing message");
|
||||
auto next = this->derived().parse(session, message);
|
||||
|
||||
// return next state if parsing was unsuccessful (i.e. error state)
|
||||
if (next != &this->derived()) return next;
|
||||
|
||||
logger.debug("Executing state");
|
||||
return this->derived().execute(session, message);
|
||||
}
|
||||
|
||||
protected:
|
||||
Logger logger;
|
||||
};
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
|
||||
#include "logging/default.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Session;
|
||||
|
||||
class State {
|
||||
public:
|
||||
using uptr = std::unique_ptr<State>;
|
||||
|
||||
State() = default;
|
||||
State(Logger logger) : logger(logger) {}
|
||||
|
||||
virtual ~State() = default;
|
||||
|
||||
virtual State* run(Session& session) = 0;
|
||||
|
||||
protected:
|
||||
Logger logger;
|
||||
};
|
||||
}
|
@ -33,7 +33,8 @@ class ChunkedBuffer {
|
||||
}
|
||||
|
||||
void flush() {
|
||||
stream.get().write(&buffer.front(), size);
|
||||
// TODO: check for success
|
||||
stream.get().Write(&buffer.front(), size);
|
||||
|
||||
logger.trace("Flushed {} bytes", size);
|
||||
|
||||
|
@ -1,33 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
|
||||
#include "communication/bolt/v1/transport/stream_error.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
template <typename Stream>
|
||||
class SocketStream {
|
||||
public:
|
||||
using byte = uint8_t;
|
||||
|
||||
SocketStream(Stream& socket) : socket(socket) {}
|
||||
|
||||
void write(const byte* data, size_t n) {
|
||||
while (n > 0) {
|
||||
auto written = socket.get().write(data, n);
|
||||
|
||||
if (UNLIKELY(written == -1)) throw StreamError("Can't write to stream");
|
||||
|
||||
n -= written;
|
||||
data += written;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::reference_wrapper<Stream> socket;
|
||||
};
|
||||
}
|
113
src/communication/server.hpp
Normal file
113
src/communication/server.hpp
Normal file
@ -0,0 +1,113 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "query/engine.hpp"
|
||||
|
||||
#include "communication/worker.hpp"
|
||||
#include "io/network/event_listener.hpp"
|
||||
#include "logging/default.hpp"
|
||||
#include "utils/assert.hpp"
|
||||
|
||||
namespace communication {
|
||||
|
||||
/**
|
||||
* Communication server.
|
||||
* Listens for incomming connections on the server port and assings them in a
|
||||
* round-robin manner to it's workers.
|
||||
*
|
||||
* Current Server achitecture:
|
||||
* incomming connection -> server -> worker -> session
|
||||
*
|
||||
* @tparam Session the server can handle different Sessions, each session
|
||||
* represents a different protocol so the same network infrastructure
|
||||
* can be used for handling different protocols
|
||||
* @tparam OutputStream the server has to get the output stream as a template
|
||||
parameter because the output stream is templated
|
||||
* @tparam Socket the input/output socket that should be used
|
||||
*/
|
||||
template <typename Session, typename OutputStream, typename Socket>
|
||||
class Server
|
||||
: public io::network::EventListener<Server<Session, OutputStream, Socket>> {
|
||||
using Event = io::network::Epoll::Event;
|
||||
|
||||
public:
|
||||
Server(Socket &&socket, Dbms &dbms, QueryEngine<OutputStream> &query_engine)
|
||||
: socket_(std::forward<Socket>(socket)),
|
||||
dbms_(dbms),
|
||||
query_engine_(query_engine),
|
||||
logger_(logging::log->logger("communication::Server")) {
|
||||
event_.data.fd = socket_;
|
||||
|
||||
// TODO: EPOLLET is hard to use -> figure out how should EPOLLET be used
|
||||
// event.events = EPOLLIN | EPOLLET;
|
||||
event_.events = EPOLLIN;
|
||||
|
||||
this->listener_.Add(socket_, &event_);
|
||||
}
|
||||
|
||||
void Start(size_t n) {
|
||||
workers_.reserve(n);
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
workers_.push_back(
|
||||
std::make_shared<Worker<Session, OutputStream, Socket>>(
|
||||
dbms_, query_engine_));
|
||||
workers_.back()->Start(alive_);
|
||||
}
|
||||
|
||||
while (alive_) {
|
||||
this->WaitAndProcessEvents();
|
||||
}
|
||||
}
|
||||
|
||||
void Shutdown() {
|
||||
alive_.store(false);
|
||||
|
||||
for (auto &worker : workers_) worker->thread_.join();
|
||||
}
|
||||
|
||||
void OnConnect() {
|
||||
debug_assert(idx_ < workers_.size(), "Invalid worker id.");
|
||||
|
||||
logger_.trace("on connect");
|
||||
|
||||
if (UNLIKELY(!workers_[idx_]->Accept(socket_))) return;
|
||||
|
||||
idx_ = idx_ == (int)workers_.size() - 1 ? 0 : idx_ + 1;
|
||||
}
|
||||
|
||||
void OnWaitTimeout() {}
|
||||
|
||||
void OnDataEvent(Event &event) {
|
||||
if (UNLIKELY(socket_ != event.data.fd)) return;
|
||||
|
||||
this->derived().OnConnect();
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void OnExceptionEvent(Event &event, Args &&... args) {
|
||||
// TODO: Do something about it
|
||||
logger_.warn("epoll exception");
|
||||
}
|
||||
|
||||
void OnCloseEvent(Event &event) { close(event.data.fd); }
|
||||
|
||||
void OnErrorEvent(Event &event) { close(event.data.fd); }
|
||||
|
||||
private:
|
||||
std::vector<typename Worker<Session, OutputStream, Socket>::sptr> workers_;
|
||||
std::atomic<bool> alive_{true};
|
||||
int idx_{0};
|
||||
|
||||
Dbms &dbms_;
|
||||
QueryEngine<OutputStream> &query_engine_;
|
||||
Event event_;
|
||||
Socket socket_;
|
||||
Logger logger_;
|
||||
};
|
||||
}
|
109
src/communication/worker.hpp
Normal file
109
src/communication/worker.hpp
Normal file
@ -0,0 +1,109 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "query/engine.hpp"
|
||||
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "io/network/network_error.hpp"
|
||||
#include "io/network/stream_reader.hpp"
|
||||
#include "logging/default.hpp"
|
||||
|
||||
namespace communication {
|
||||
|
||||
/**
|
||||
* Communication worker.
|
||||
* Listens for incomming data on connections and accepts new connections.
|
||||
* Also, executes sessions on incomming data.
|
||||
*
|
||||
* @tparam Session the worker can handle different Sessions, each session
|
||||
* represents a different protocol so the same network infrastructure
|
||||
* can be used for handling different protocols
|
||||
* @tparam OutputStream the worker has to get the output stream as a template
|
||||
parameter because the output stream is templated
|
||||
* @tparam Socket the input/output socket that should be used
|
||||
*/
|
||||
template <typename Session, typename OutputStream, typename Socket>
|
||||
class Worker
|
||||
: public io::network::StreamReader<Worker<Session, OutputStream, Socket>,
|
||||
Session> {
|
||||
using StreamBuffer = io::network::StreamBuffer;
|
||||
|
||||
public:
|
||||
using sptr = std::shared_ptr<Worker<Session, OutputStream, Socket>>;
|
||||
|
||||
Worker(Dbms &dbms, QueryEngine<OutputStream> &query_engine)
|
||||
: dbms_(dbms),
|
||||
query_engine_(query_engine),
|
||||
logger_(logging::log->logger("communication::Worker")) {}
|
||||
|
||||
Session &OnConnect(Socket &&socket) {
|
||||
logger_.trace("Accepting connection on socket {}", socket.id());
|
||||
|
||||
// TODO fix session lifecycle handling
|
||||
// dangling pointers are not cool :)
|
||||
// TODO attach currently active Db
|
||||
return *(new Session(std::forward<Socket>(socket), dbms_, query_engine_));
|
||||
}
|
||||
|
||||
void OnError(Session &session) {
|
||||
logger_.error("Error occured in this session");
|
||||
OnClose(session);
|
||||
}
|
||||
|
||||
void OnWaitTimeout() {}
|
||||
|
||||
StreamBuffer OnAlloc(Session &) {
|
||||
/* logger.trace("[on_alloc] Allocating {}B", sizeof buf); */
|
||||
|
||||
return StreamBuffer{buf_, sizeof buf_};
|
||||
}
|
||||
|
||||
void OnRead(Session &session, StreamBuffer &buf) {
|
||||
logger_.trace("[on_read] Received {}B", buf.len);
|
||||
|
||||
try {
|
||||
session.execute(reinterpret_cast<const byte *>(buf.ptr), buf.len);
|
||||
} catch (const std::exception &e) {
|
||||
logger_.error("Error occured while executing statement.");
|
||||
logger_.error("{}", e.what());
|
||||
// TODO: report to client
|
||||
}
|
||||
}
|
||||
|
||||
void OnClose(Session &session) {
|
||||
logger_.trace("Client closed the connection");
|
||||
// TODO: remove socket from epoll object
|
||||
session.close();
|
||||
delete &session;
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void OnException(Session &session, Args &&... args) {
|
||||
logger_.error("Error occured in this session");
|
||||
logger_.error(args...);
|
||||
|
||||
// TODO: Do something about it
|
||||
}
|
||||
|
||||
char buf_[65536];
|
||||
std::thread thread_;
|
||||
|
||||
void Start(std::atomic<bool> &alive) {
|
||||
thread_ = std::thread([&, this]() {
|
||||
while (alive) this->WaitAndProcessEvents();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
Dbms &dbms_;
|
||||
QueryEngine<OutputStream> &query_engine_;
|
||||
Logger logger_;
|
||||
};
|
||||
}
|
32
src/io/network/addrinfo.cpp
Normal file
32
src/io/network/addrinfo.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
#include <netdb.h>
|
||||
#include <cstring>
|
||||
|
||||
#include "io/network/addrinfo.hpp"
|
||||
|
||||
#include "io/network/network_error.hpp"
|
||||
#include "utils/underlying_cast.hpp"
|
||||
|
||||
namespace io::network {
|
||||
|
||||
AddrInfo::AddrInfo(struct addrinfo* info) : info(info) {}
|
||||
|
||||
AddrInfo::~AddrInfo() { freeaddrinfo(info); }
|
||||
|
||||
AddrInfo AddrInfo::Get(const char* addr, const char* port) {
|
||||
struct addrinfo hints;
|
||||
memset(&hints, 0, sizeof(struct addrinfo));
|
||||
|
||||
hints.ai_family = AF_UNSPEC; // IPv4 and IPv6
|
||||
hints.ai_socktype = SOCK_STREAM; // TCP socket
|
||||
hints.ai_flags = AI_PASSIVE;
|
||||
|
||||
struct addrinfo* result;
|
||||
auto status = getaddrinfo(addr, port, &hints, &result);
|
||||
|
||||
if (status != 0) throw NetworkError(gai_strerror(status));
|
||||
|
||||
return AddrInfo(result);
|
||||
}
|
||||
|
||||
AddrInfo::operator struct addrinfo*() { return info; }
|
||||
}
|
@ -1,36 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <netdb.h>
|
||||
#include <cstring>
|
||||
|
||||
#include "io/network/network_error.hpp"
|
||||
#include "utils/underlying_cast.hpp"
|
||||
|
||||
namespace io {
|
||||
namespace io::network {
|
||||
|
||||
/**
|
||||
* Wrapper class for getaddrinfo.
|
||||
* see: man 3 getaddrinfo
|
||||
*/
|
||||
class AddrInfo {
|
||||
AddrInfo(struct addrinfo* info) : info(info) {}
|
||||
AddrInfo(struct addrinfo* info);
|
||||
|
||||
public:
|
||||
~AddrInfo() { freeaddrinfo(info); }
|
||||
~AddrInfo();
|
||||
|
||||
static AddrInfo get(const char* addr, const char* port) {
|
||||
struct addrinfo hints;
|
||||
memset(&hints, 0, sizeof(struct addrinfo));
|
||||
static AddrInfo Get(const char* addr, const char* port);
|
||||
|
||||
hints.ai_family = AF_UNSPEC; // IPv4 and IPv6
|
||||
hints.ai_socktype = SOCK_STREAM; // TCP socket
|
||||
hints.ai_flags = AI_PASSIVE;
|
||||
|
||||
struct addrinfo* result;
|
||||
auto status = getaddrinfo(addr, port, &hints, &result);
|
||||
|
||||
if (status != 0) throw NetworkError(gai_strerror(status));
|
||||
|
||||
return AddrInfo(result);
|
||||
}
|
||||
|
||||
operator struct addrinfo*() { return info; }
|
||||
operator struct addrinfo*();
|
||||
|
||||
private:
|
||||
struct addrinfo* info;
|
||||
|
@ -7,40 +7,45 @@
|
||||
#include "logging/default.hpp"
|
||||
#include "utils/likely.hpp"
|
||||
|
||||
namespace io {
|
||||
namespace io::network {
|
||||
|
||||
class EpollError : StacktraceException {
|
||||
public:
|
||||
using StacktraceException::StacktraceException;
|
||||
};
|
||||
|
||||
/**
|
||||
* Wrapper class for epoll.
|
||||
* Creates an object that listens on file descriptor status changes.
|
||||
* see: man 4 epoll
|
||||
*/
|
||||
class Epoll {
|
||||
public:
|
||||
using Event = struct epoll_event;
|
||||
|
||||
Epoll(int flags) : logger(logging::log->logger("io::Epoll")) {
|
||||
epoll_fd = epoll_create1(flags);
|
||||
Epoll(int flags) : logger_(logging::log->logger("io::Epoll")) {
|
||||
epoll_fd_ = epoll_create1(flags);
|
||||
|
||||
if (UNLIKELY(epoll_fd == -1))
|
||||
if (UNLIKELY(epoll_fd_ == -1))
|
||||
throw EpollError("Can't create epoll file descriptor");
|
||||
}
|
||||
|
||||
template <class Stream>
|
||||
void add(Stream& stream, Event* event) {
|
||||
auto status = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, stream, event);
|
||||
void Add(Stream& stream, Event* event) {
|
||||
auto status = epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, stream, event);
|
||||
|
||||
if (UNLIKELY(status))
|
||||
throw EpollError("Can't add an event to epoll listener.");
|
||||
}
|
||||
|
||||
int wait(Event* events, int max_events, int timeout) {
|
||||
return epoll_wait(epoll_fd, events, max_events, timeout);
|
||||
int Wait(Event* events, int max_events, int timeout) {
|
||||
return epoll_wait(epoll_fd_, events, max_events, timeout);
|
||||
}
|
||||
|
||||
int id() const { return epoll_fd; }
|
||||
int id() const { return epoll_fd_; }
|
||||
|
||||
private:
|
||||
int epoll_fd;
|
||||
Logger logger;
|
||||
int epoll_fd_;
|
||||
Logger logger_;
|
||||
};
|
||||
}
|
||||
|
@ -4,17 +4,21 @@
|
||||
#include "logging/default.hpp"
|
||||
#include "utils/crtp.hpp"
|
||||
|
||||
namespace io {
|
||||
namespace io::network {
|
||||
|
||||
/**
|
||||
* This class listens to events on an epoll object and calls
|
||||
* callback functions to process them.
|
||||
*/
|
||||
template <class Derived, size_t max_events = 64, int wait_timeout = -1>
|
||||
class EventListener : public Crtp<Derived> {
|
||||
public:
|
||||
using Crtp<Derived>::derived;
|
||||
|
||||
EventListener(uint32_t flags = 0)
|
||||
: listener(flags), logger(logging::log->logger("io::EventListener")) {}
|
||||
: listener_(flags), logger_(logging::log->logger("io::EventListener")) {}
|
||||
|
||||
void wait_and_process_events() {
|
||||
void WaitAndProcessEvents() {
|
||||
// TODO hardcoded a wait timeout because of thread joining
|
||||
// when you shutdown the server. This should be wait_timeout of the
|
||||
// template parameter and should almost never change from that.
|
||||
@ -25,36 +29,36 @@ class EventListener : public Crtp<Derived> {
|
||||
// max_events and stores them in the events array. it waits for
|
||||
// wait_timeout milliseconds. if wait_timeout is achieved, returns 0
|
||||
|
||||
auto n = listener.wait(events, max_events, 200);
|
||||
auto n = listener_.Wait(events_, max_events, 200);
|
||||
|
||||
#ifndef NDEBUG
|
||||
#ifndef LOG_NO_TRACE
|
||||
if (n > 0) logger.trace("number of events: {}", n);
|
||||
if (n > 0) logger_.trace("number of events: {}", n);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// go through all events and process them in order
|
||||
for (int i = 0; i < n; ++i) {
|
||||
auto &event = events[i];
|
||||
auto &event = events_[i];
|
||||
|
||||
try {
|
||||
// hangup event
|
||||
if (UNLIKELY(event.events & EPOLLRDHUP)) {
|
||||
this->derived().on_close_event(event);
|
||||
this->derived().OnCloseEvent(event);
|
||||
continue;
|
||||
}
|
||||
|
||||
// there was an error on the server side
|
||||
if (UNLIKELY(!(event.events & EPOLLIN) ||
|
||||
event.events & (EPOLLHUP | EPOLLERR))) {
|
||||
this->derived().on_error_event(event);
|
||||
this->derived().OnErrorEvent(event);
|
||||
continue;
|
||||
}
|
||||
|
||||
// we have some data waiting to be read
|
||||
this->derived().on_data_event(event);
|
||||
this->derived().OnDataEvent(event);
|
||||
} catch (const std::exception &e) {
|
||||
this->derived().on_exception_event(
|
||||
this->derived().OnExceptionEvent(
|
||||
event, "Error occured while processing event \n{}", e.what());
|
||||
}
|
||||
}
|
||||
@ -69,14 +73,14 @@ class EventListener : public Crtp<Derived> {
|
||||
// is -1 there will never be any timeouts so client should provide
|
||||
// an empty function. in that case the conditional above and the
|
||||
// function call will be optimized out by the compiler
|
||||
this->derived().on_wait_timeout();
|
||||
this->derived().OnWaitTimeout();
|
||||
}
|
||||
|
||||
protected:
|
||||
Epoll listener;
|
||||
Epoll::Event events[max_events];
|
||||
Epoll listener_;
|
||||
Epoll::Event events_[max_events];
|
||||
|
||||
private:
|
||||
Logger logger;
|
||||
Logger logger_;
|
||||
};
|
||||
}
|
||||
|
@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "io/network/stream_reader.hpp"
|
||||
|
||||
namespace io {
|
||||
|
||||
template <class Derived, class Stream>
|
||||
class Client : public StreamReader<Derived, Stream> {
|
||||
public:
|
||||
bool connect(const std::string& host, const std::string& port) {
|
||||
return connect(host.c_str(), port.c_str());
|
||||
}
|
||||
|
||||
bool connect(const char* host, const char* port) {
|
||||
auto socket = io::Socket::connect(host, port);
|
||||
|
||||
if (!socket.is_open()) return false;
|
||||
|
||||
socket.set_non_blocking();
|
||||
|
||||
auto& stream = this->derived().on_connect(std::move(socket));
|
||||
|
||||
stream.event.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
|
||||
this->add(stream);
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
61
src/io/network/network_endpoint.cpp
Normal file
61
src/io/network/network_endpoint.cpp
Normal file
@ -0,0 +1,61 @@
|
||||
#include "io/network/network_endpoint.hpp"
|
||||
#include "io/network/network_error.hpp"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <netdb.h>
|
||||
|
||||
namespace io::network {
|
||||
|
||||
NetworkEndpoint::NetworkEndpoint() : port_(0), family_(0) {
|
||||
memset(address_, 0, sizeof address_);
|
||||
memset(port_str_, 0, sizeof port_str_);
|
||||
}
|
||||
|
||||
NetworkEndpoint::NetworkEndpoint(const char* addr, const char* port) {
|
||||
if (addr == nullptr) throw NetworkEndpointException("Address can't be null!");
|
||||
if (port == nullptr) throw NetworkEndpointException("Port can't be null!");
|
||||
|
||||
// strncpy isn't used because it does not guarantee an ending null terminator
|
||||
snprintf(address_, sizeof address_, "%s", addr);
|
||||
snprintf(port_str_, sizeof port_str_, "%s", port);
|
||||
|
||||
is_address_valid();
|
||||
|
||||
int ret = sscanf(port, "%hu", &port_);
|
||||
if (ret != 1) throw NetworkEndpointException("Port isn't valid!");
|
||||
}
|
||||
|
||||
NetworkEndpoint::NetworkEndpoint(const std::string& addr,
|
||||
const std::string& port)
|
||||
: NetworkEndpoint(addr.c_str(), port.c_str()) {}
|
||||
|
||||
NetworkEndpoint::NetworkEndpoint(const char* addr, unsigned short port) {
|
||||
if (addr == nullptr) throw NetworkEndpointException("Address can't be null!");
|
||||
|
||||
snprintf(address_, sizeof address_, "%s", addr);
|
||||
snprintf(port_str_, sizeof port_str_, "%hu", port);
|
||||
port_ = port;
|
||||
|
||||
is_address_valid();
|
||||
}
|
||||
|
||||
void NetworkEndpoint::is_address_valid() {
|
||||
in_addr addr4;
|
||||
in6_addr addr6;
|
||||
int ret = inet_pton(AF_INET, address_, &addr4);
|
||||
if (ret != 1) {
|
||||
ret = inet_pton(AF_INET6, address_, &addr6);
|
||||
if (ret != 1)
|
||||
throw NetworkEndpointException(
|
||||
"Address isn't a valid IPv4 or IPv6 address!");
|
||||
else
|
||||
family_ = 6;
|
||||
} else
|
||||
family_ = 4;
|
||||
}
|
||||
|
||||
const char* NetworkEndpoint::address() { return address_; }
|
||||
const char* NetworkEndpoint::port_str() { return port_str_; }
|
||||
unsigned short NetworkEndpoint::port() { return port_; }
|
||||
unsigned char NetworkEndpoint::family() { return family_; }
|
||||
}
|
40
src/io/network/network_endpoint.hpp
Normal file
40
src/io/network/network_endpoint.hpp
Normal file
@ -0,0 +1,40 @@
|
||||
#pragma once
|
||||
|
||||
#include "utils/exceptions/basic_exception.hpp"
|
||||
|
||||
#include <netinet/in.h>
|
||||
#include <string>
|
||||
|
||||
namespace io::network {
|
||||
|
||||
class NetworkEndpointException : public BasicException {
|
||||
public:
|
||||
using BasicException::BasicException;
|
||||
};
|
||||
|
||||
/**
|
||||
* This class represents a network endpoint that is used in Socket.
|
||||
* It is used when connecting to an address and to get the current
|
||||
* connection address.
|
||||
*/
|
||||
class NetworkEndpoint {
|
||||
public:
|
||||
NetworkEndpoint();
|
||||
NetworkEndpoint(const char* addr, const char* port);
|
||||
NetworkEndpoint(const char* addr, unsigned short port);
|
||||
NetworkEndpoint(const std::string& addr, const std::string& port);
|
||||
|
||||
const char* address();
|
||||
const char* port_str();
|
||||
unsigned short port();
|
||||
unsigned char family();
|
||||
|
||||
private:
|
||||
void is_address_valid();
|
||||
|
||||
char address_[INET6_ADDRSTRLEN];
|
||||
char port_str_[6];
|
||||
unsigned short port_;
|
||||
unsigned char family_;
|
||||
};
|
||||
}
|
@ -4,7 +4,7 @@
|
||||
|
||||
#include "utils/exceptions/stacktrace_exception.hpp"
|
||||
|
||||
namespace io {
|
||||
namespace io::network {
|
||||
|
||||
class NetworkError : public StacktraceException {
|
||||
public:
|
||||
|
@ -1,63 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "io/network/socket.hpp"
|
||||
#include "tls.hpp"
|
||||
#include "tls_error.hpp"
|
||||
#include "utils/types/byte.hpp"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace io {
|
||||
|
||||
class SecureSocket {
|
||||
public:
|
||||
SecureSocket(Socket&& socket, const Tls::Context& tls)
|
||||
: socket(std::forward<Socket>(socket)) {
|
||||
ssl = SSL_new(tls);
|
||||
SSL_set_fd(ssl, this->socket);
|
||||
|
||||
SSL_set_accept_state(ssl);
|
||||
|
||||
if (SSL_accept(ssl) <= 0) ERR_print_errors_fp(stderr);
|
||||
}
|
||||
|
||||
SecureSocket(SecureSocket&& other) {
|
||||
*this = std::forward<SecureSocket>(other);
|
||||
}
|
||||
|
||||
SecureSocket& operator=(SecureSocket&& other) {
|
||||
socket = std::move(other.socket);
|
||||
|
||||
ssl = other.ssl;
|
||||
other.ssl = nullptr;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
~SecureSocket() {
|
||||
if (ssl == nullptr) return;
|
||||
|
||||
std::cout << "DELETING SSL" << std::endl;
|
||||
|
||||
SSL_free(ssl);
|
||||
}
|
||||
|
||||
int error(int status) { return SSL_get_error(ssl, status); }
|
||||
|
||||
int write(const std::string& str) { return write(str.c_str(), str.size()); }
|
||||
|
||||
int write(const byte* data, size_t len) { return SSL_write(ssl, data, len); }
|
||||
|
||||
int write(const char* data, size_t len) { return SSL_write(ssl, data, len); }
|
||||
|
||||
int read(char* buffer, size_t len) { return SSL_read(ssl, buffer, len); }
|
||||
|
||||
operator int() { return socket; }
|
||||
|
||||
operator Socket&() { return socket; }
|
||||
|
||||
private:
|
||||
Socket socket;
|
||||
SSL* ssl{nullptr};
|
||||
};
|
||||
}
|
@ -1,51 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "io/network/stream_reader.hpp"
|
||||
#include "logging/default.hpp"
|
||||
|
||||
namespace io {
|
||||
using namespace memory::literals;
|
||||
|
||||
template <class Derived, class Stream>
|
||||
class SecureStreamReader : public StreamReader<Derived, Stream> {
|
||||
public:
|
||||
struct Buffer {
|
||||
char* ptr;
|
||||
size_t len;
|
||||
};
|
||||
|
||||
SecureStreamReader(uint32_t flags = 0)
|
||||
: StreamReader<Derived, Stream>(flags) {}
|
||||
|
||||
void on_data(Stream& stream) {
|
||||
while (true) {
|
||||
// allocate the buffer to fill the data
|
||||
auto buf = this->derived().on_alloc(stream);
|
||||
|
||||
// read from the buffer at most buf.len bytes
|
||||
auto len = stream.socket.read(buf.ptr, buf.len);
|
||||
|
||||
if (LIKELY(len > 0)) {
|
||||
buf.len = len;
|
||||
return this->derived().on_read(stream, buf);
|
||||
}
|
||||
|
||||
auto err = stream.socket.error(len);
|
||||
|
||||
// the socket is not ready for reading yet
|
||||
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE ||
|
||||
err == SSL_ERROR_WANT_X509_LOOKUP) {
|
||||
return;
|
||||
}
|
||||
|
||||
// the socket notified a close event
|
||||
if (err == SSL_ERROR_ZERO_RETURN) return stream.close();
|
||||
|
||||
// some other error occurred, check errno
|
||||
return this->derived().on_error(stream);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "io/network/stream_reader.hpp"
|
||||
|
||||
namespace io {
|
||||
|
||||
template <class Derived>
|
||||
class Server : public EventListener<Derived> {
|
||||
public:
|
||||
Server(Socket &&socket)
|
||||
: socket(std::forward<Socket>(socket)),
|
||||
logger(logging::log->logger("io::Server")) {
|
||||
event.data.fd = this->socket;
|
||||
|
||||
// TODO: EPOLLET is hard to use -> figure out how should EPOLLET be used
|
||||
// event.events = EPOLLIN | EPOLLET;
|
||||
event.events = EPOLLIN;
|
||||
|
||||
this->listener.add(this->socket, &event);
|
||||
}
|
||||
|
||||
void on_close_event(Epoll::Event &event) { ::close(event.data.fd); }
|
||||
|
||||
void on_error_event(Epoll::Event &event) { ::close(event.data.fd); }
|
||||
|
||||
void on_data_event(Epoll::Event &event) {
|
||||
if (UNLIKELY(socket != event.data.fd)) return;
|
||||
|
||||
this->derived().on_connect();
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void on_exception_event(Epoll::Event &event, Args &&... args) {
|
||||
// TODO: Do something about it
|
||||
logger.warn("epoll exception");
|
||||
}
|
||||
|
||||
protected:
|
||||
Epoll::Event event;
|
||||
Socket socket;
|
||||
Logger logger;
|
||||
};
|
||||
}
|
194
src/io/network/socket.cpp
Normal file
194
src/io/network/socket.cpp
Normal file
@ -0,0 +1,194 @@
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <sys/epoll.h>
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "io/network/addrinfo.hpp"
|
||||
#include "utils/likely.hpp"
|
||||
|
||||
namespace io::network {
|
||||
|
||||
Socket::Socket() : socket_(-1) {}
|
||||
|
||||
Socket::Socket(int sock, NetworkEndpoint& endpoint)
|
||||
: socket_(sock), endpoint_(endpoint) {}
|
||||
|
||||
Socket::Socket(const Socket& s) : socket_(s.id()) {}
|
||||
|
||||
Socket::Socket(Socket&& other) { *this = std::forward<Socket>(other); }
|
||||
|
||||
Socket& Socket::operator=(Socket&& other) {
|
||||
socket_ = other.socket_;
|
||||
endpoint_ = other.endpoint_;
|
||||
other.socket_ = -1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Socket::~Socket() {
|
||||
if (socket_ == -1) return;
|
||||
close(socket_);
|
||||
}
|
||||
|
||||
void Socket::Close() {
|
||||
if (socket_ == -1) return;
|
||||
close(socket_);
|
||||
socket_ = -1;
|
||||
}
|
||||
|
||||
bool Socket::IsOpen() { return socket_ != -1; }
|
||||
|
||||
bool Socket::Connect(NetworkEndpoint& endpoint) {
|
||||
if (UNLIKELY(socket_ != -1)) return false;
|
||||
|
||||
auto info = AddrInfo::Get(endpoint.address(), endpoint.port_str());
|
||||
|
||||
for (struct addrinfo* it = info; it != nullptr; it = it->ai_next) {
|
||||
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
|
||||
if (sfd == -1) continue;
|
||||
if (connect(sfd, it->ai_addr, it->ai_addrlen) == 0) {
|
||||
socket_ = sfd;
|
||||
endpoint_ = endpoint;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (socket_ == -1) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Socket::Bind(NetworkEndpoint& endpoint) {
|
||||
if (UNLIKELY(socket_ != -1)) return false;
|
||||
|
||||
auto info = AddrInfo::Get(endpoint.address(), endpoint.port_str());
|
||||
|
||||
for (struct addrinfo* it = info; it != nullptr; it = it->ai_next) {
|
||||
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
|
||||
if (sfd == -1) continue;
|
||||
|
||||
int on = 1;
|
||||
if (setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) != 0)
|
||||
continue;
|
||||
|
||||
if (bind(sfd, it->ai_addr, it->ai_addrlen) == 0) {
|
||||
socket_ = sfd;
|
||||
endpoint_ = endpoint;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (socket_ == -1) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Socket::SetNonBlocking() {
|
||||
int flags = fcntl(socket_, F_GETFL, 0);
|
||||
|
||||
if (UNLIKELY(flags == -1)) return false;
|
||||
|
||||
flags |= O_NONBLOCK;
|
||||
int ret = fcntl(socket_, F_SETFL, flags);
|
||||
|
||||
if (UNLIKELY(ret == -1)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Socket::SetKeepAlive() {
|
||||
int optval = 1;
|
||||
socklen_t optlen = sizeof(optval);
|
||||
|
||||
if (setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen) < 0)
|
||||
return false;
|
||||
|
||||
optval = 20; // wait 120s before seding keep-alive packets
|
||||
if (setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void*)&optval, optlen) < 0)
|
||||
return false;
|
||||
|
||||
optval = 4; // 4 keep-alive packets must fail to close
|
||||
if (setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void*)&optval, optlen) < 0)
|
||||
return false;
|
||||
|
||||
optval = 15; // send keep-alive packets every 15s
|
||||
if (setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void*)&optval, optlen) < 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Socket::Listen(int backlog) { return listen(socket_, backlog) == 0; }
|
||||
|
||||
bool Socket::Accept(Socket* s) {
|
||||
sockaddr_storage addr;
|
||||
socklen_t addr_size = sizeof addr;
|
||||
char addr_decoded[INET6_ADDRSTRLEN];
|
||||
void* addr_src;
|
||||
unsigned short port;
|
||||
unsigned char family;
|
||||
|
||||
int sfd = accept(socket_, (struct sockaddr*)&addr, &addr_size);
|
||||
if (UNLIKELY(sfd == -1)) return false;
|
||||
|
||||
if (addr.ss_family == AF_INET) {
|
||||
addr_src = (void*)&(((sockaddr_in*)&addr)->sin_addr);
|
||||
port = ntohs(((sockaddr_in*)&addr)->sin_port);
|
||||
family = 4;
|
||||
} else {
|
||||
addr_src = (void*)&(((sockaddr_in6*)&addr)->sin6_addr);
|
||||
port = ntohs(((sockaddr_in6*)&addr)->sin6_port);
|
||||
family = 6;
|
||||
}
|
||||
|
||||
inet_ntop(addr.ss_family, addr_src, addr_decoded, INET6_ADDRSTRLEN);
|
||||
|
||||
NetworkEndpoint endpoint;
|
||||
try {
|
||||
endpoint = NetworkEndpoint(addr_decoded, port);
|
||||
} catch (NetworkEndpointException& e) {
|
||||
return false;
|
||||
}
|
||||
|
||||
*s = Socket(sfd, endpoint);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Socket::operator int() { return socket_; }
|
||||
|
||||
int Socket::id() const { return socket_; }
|
||||
NetworkEndpoint& Socket::endpoint() { return endpoint_; }
|
||||
|
||||
bool Socket::Write(const std::string& str) {
|
||||
return Write(str.c_str(), str.size());
|
||||
}
|
||||
|
||||
bool Socket::Write(const char* data, size_t len) {
|
||||
return Write(reinterpret_cast<const uint8_t*>(data), len);
|
||||
}
|
||||
|
||||
bool Socket::Write(const uint8_t* data, size_t len) {
|
||||
while (len > 0) {
|
||||
auto written = send(socket_, data, len, 0);
|
||||
if (UNLIKELY(written == -1)) return false;
|
||||
len -= written;
|
||||
data += written;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int Socket::Read(void* buffer, size_t len) {
|
||||
return read(socket_, buffer, len);
|
||||
}
|
||||
}
|
@ -1,157 +1,150 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <sys/epoll.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "io/network/addrinfo.hpp"
|
||||
#include "utils/likely.hpp"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
#include "io/network/network_endpoint.hpp"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace io {
|
||||
namespace io::network {
|
||||
|
||||
/**
|
||||
* This class creates a network socket.
|
||||
* It is used to connect/bind/listen on a NetworkEndpoint (address + port).
|
||||
* It has wrappers for setting network socket flags and wrappers for
|
||||
* reading/writing data from/to the socket.
|
||||
*/
|
||||
class Socket {
|
||||
protected:
|
||||
Socket(int family, int socket_type, int protocol) {
|
||||
socket = ::socket(family, socket_type, protocol);
|
||||
}
|
||||
|
||||
public:
|
||||
using byte = uint8_t;
|
||||
Socket();
|
||||
Socket(const Socket& s);
|
||||
Socket(Socket&& other);
|
||||
Socket& operator=(Socket&& other);
|
||||
~Socket();
|
||||
|
||||
Socket(int socket = -1) : socket(socket) {}
|
||||
/**
|
||||
* Closes the socket if it is open.
|
||||
*/
|
||||
void Close();
|
||||
|
||||
Socket(const Socket&) = delete;
|
||||
/**
|
||||
* Checks whether the socket is open.
|
||||
*
|
||||
* @return socket open status:
|
||||
* true if the socket is open
|
||||
* false if the socket is closed
|
||||
*/
|
||||
bool IsOpen();
|
||||
|
||||
Socket(Socket&& other) { *this = std::forward<Socket>(other); }
|
||||
/**
|
||||
* Connects the socket to the specified endpoint.
|
||||
*
|
||||
* @param endpoint NetworkEndpoint to which to connect to
|
||||
*
|
||||
* @return connection success status:
|
||||
* true if the connect succeeded
|
||||
* false if the connect failed
|
||||
*/
|
||||
bool Connect(NetworkEndpoint& endpoint);
|
||||
|
||||
~Socket() {
|
||||
if (socket == -1) return;
|
||||
/**
|
||||
* Binds the socket to the specified endpoint.
|
||||
*
|
||||
* @param endpoint NetworkEndpoint to which to bind to
|
||||
*
|
||||
* @return bind success status:
|
||||
* true if the bind succeeded
|
||||
* false if the bind failed
|
||||
*/
|
||||
bool Bind(NetworkEndpoint& endpoint);
|
||||
|
||||
#ifndef NDEBUG
|
||||
logging::debug("DELETING SOCKET");
|
||||
#endif
|
||||
/**
|
||||
* Start listening on the bound socket.
|
||||
*
|
||||
* @param backlog maximum number of pending connections in the connection queue
|
||||
*
|
||||
* @return listen success status:
|
||||
* true if the listen succeeded
|
||||
* false if the listen failed
|
||||
*/
|
||||
bool Listen(int backlog);
|
||||
|
||||
::close(socket);
|
||||
}
|
||||
/**
|
||||
* Accepts a new connection.
|
||||
* This function accepts a new connection on a listening socket.
|
||||
*
|
||||
* @param s Socket object that will be instantiated with the new connection
|
||||
*
|
||||
* @return accept success status:
|
||||
* true if a new connection was accepted and the socket 's' was instantiated
|
||||
* false if a new connection accept failed
|
||||
*/
|
||||
bool Accept(Socket* s);
|
||||
|
||||
void close() {
|
||||
::close(socket);
|
||||
socket = -1;
|
||||
}
|
||||
/**
|
||||
* Sets the socket to non-blocking.
|
||||
*
|
||||
* @return set non-blocking success status:
|
||||
* true if the socket was successfully set to non-blocking
|
||||
* false if the socket was not set to non-blocking
|
||||
*/
|
||||
bool SetNonBlocking();
|
||||
|
||||
Socket& operator=(Socket&& other) {
|
||||
this->socket = other.socket;
|
||||
other.socket = -1;
|
||||
return *this;
|
||||
}
|
||||
/**
|
||||
* Enables TCP keep-alive on the socket.
|
||||
*
|
||||
* @return enable keep-alive success status:
|
||||
* true if keep-alive was successfully enabled on the socket
|
||||
* false if keep-alive was not enabled
|
||||
*/
|
||||
bool SetKeepAlive();
|
||||
|
||||
bool is_open() { return socket != -1; }
|
||||
// TODO: this will be removed
|
||||
operator int();
|
||||
|
||||
static Socket connect(const std::string& addr, const std::string& port) {
|
||||
return connect(addr.c_str(), port.c_str());
|
||||
}
|
||||
/**
|
||||
* Returns the socket ID.
|
||||
* The socket ID is its unix file descriptor number.
|
||||
*/
|
||||
int id() const;
|
||||
|
||||
static Socket connect(const char* addr, const char* port) {
|
||||
auto info = AddrInfo::get(addr, port);
|
||||
/**
|
||||
* Returns the currently active endpoint of the socket.
|
||||
*/
|
||||
NetworkEndpoint& endpoint();
|
||||
|
||||
for (struct addrinfo* it = info; it != nullptr; it = it->ai_next) {
|
||||
auto s = Socket(it->ai_family, it->ai_socktype, it->ai_protocol);
|
||||
/**
|
||||
* Write data to the socket.
|
||||
* Theese functions guarantee that all data will be written.
|
||||
*
|
||||
* @param str std::string to write to the socket
|
||||
* @param data char* or uint8_t* to data that should be written
|
||||
* @param len length of char* or uint8_t* data
|
||||
*
|
||||
* @return write success status:
|
||||
* true if write succeeded
|
||||
* false if write failed
|
||||
*/
|
||||
bool Write(const std::string& str);
|
||||
bool Write(const char* data, size_t len);
|
||||
bool Write(const uint8_t* data, size_t len);
|
||||
|
||||
if (!s.is_open()) continue;
|
||||
/**
|
||||
* Read data from the socket.
|
||||
* This function is a direct wrapper for the read function.
|
||||
*
|
||||
* @param buffer pointer to the read buffer
|
||||
* @param len length of the read buffer
|
||||
*
|
||||
* @return read success status:
|
||||
* > 0 if data was read, means number of read bytes
|
||||
* == 0 if the client closed the connection
|
||||
* < 0 if an error has occurred
|
||||
*/
|
||||
int Read(void* buffer, size_t len);
|
||||
|
||||
if (::connect(s, it->ai_addr, it->ai_addrlen) == 0) return s;
|
||||
}
|
||||
private:
|
||||
Socket(int sock, NetworkEndpoint& endpoint);
|
||||
|
||||
throw NetworkError("Unable to connect to socket");
|
||||
}
|
||||
|
||||
static Socket bind(const std::string& addr, const std::string& port) {
|
||||
return bind(addr.c_str(), port.c_str());
|
||||
}
|
||||
|
||||
static Socket bind(const char* addr, const char* port) {
|
||||
auto info = AddrInfo::get(addr, port);
|
||||
|
||||
for (struct addrinfo* it = info; it != nullptr; it = it->ai_next) {
|
||||
auto s = Socket(it->ai_family, it->ai_socktype, it->ai_protocol);
|
||||
|
||||
if (!s.is_open()) continue;
|
||||
|
||||
int on = 1;
|
||||
if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) continue;
|
||||
|
||||
if (::bind(s, it->ai_addr, it->ai_addrlen) == 0) return s;
|
||||
}
|
||||
|
||||
throw NetworkError("Unable to bind to socket");
|
||||
}
|
||||
|
||||
void set_non_blocking() {
|
||||
auto flags = fcntl(socket, F_GETFL, 0);
|
||||
|
||||
if (UNLIKELY(flags == -1))
|
||||
throw NetworkError("Cannot read flags from socket");
|
||||
|
||||
flags |= O_NONBLOCK;
|
||||
|
||||
auto status = fcntl(socket, F_SETFL, flags);
|
||||
|
||||
if (UNLIKELY(status == -1))
|
||||
throw NetworkError("Cannot set NON_BLOCK flag to socket");
|
||||
}
|
||||
|
||||
void listen(int backlog) {
|
||||
auto status = ::listen(socket, backlog);
|
||||
|
||||
if (UNLIKELY(status == -1)) throw NetworkError("Cannot listen on socket");
|
||||
}
|
||||
|
||||
Socket accept(struct sockaddr* addr, socklen_t* len) {
|
||||
return Socket(::accept(socket, addr, len));
|
||||
}
|
||||
|
||||
operator int() { return socket; }
|
||||
|
||||
int id() const { return socket; }
|
||||
|
||||
int write(const std::string& str) { return write(str.c_str(), str.size()); }
|
||||
|
||||
int write(const char* data, size_t len) {
|
||||
return write(reinterpret_cast<const byte*>(data), len);
|
||||
}
|
||||
|
||||
int write(const byte* data, size_t len) {
|
||||
// TODO: use logger
|
||||
#ifndef NDEBUG
|
||||
std::stringstream stream;
|
||||
|
||||
for (size_t i = 0; i < len; ++i)
|
||||
stream << fmt::format("{:02X} ", static_cast<byte>(data[i]));
|
||||
|
||||
auto str = stream.str();
|
||||
|
||||
logging::debug("[Write {}B] {}", len, str);
|
||||
#endif
|
||||
|
||||
return ::write(socket, data, len);
|
||||
}
|
||||
|
||||
int read(void* buffer, size_t len) { return ::read(socket, buffer, len); }
|
||||
|
||||
protected:
|
||||
Logger logger;
|
||||
int socket;
|
||||
int socket_;
|
||||
NetworkEndpoint endpoint_;
|
||||
};
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include "io/network/event_listener.hpp"
|
||||
|
||||
namespace io {
|
||||
namespace io::network {
|
||||
|
||||
template <class Derived, class Stream, size_t max_events = 64,
|
||||
int wait_timeout = -1>
|
||||
@ -10,26 +10,26 @@ class StreamListener : public EventListener<Derived, max_events, wait_timeout> {
|
||||
public:
|
||||
using EventListener<Derived, max_events, wait_timeout>::EventListener;
|
||||
|
||||
void add(Stream &stream) {
|
||||
void Add(Stream &stream) {
|
||||
// add the stream to the event listener
|
||||
this->listener.add(stream.socket, &stream.event);
|
||||
this->listener_.Add(stream.socket, &stream.event);
|
||||
}
|
||||
|
||||
void on_close_event(Epoll::Event &event) {
|
||||
this->derived().on_close(to_stream(event));
|
||||
void OnCloseEvent(Epoll::Event &event) {
|
||||
this->derived().OnClose(to_stream(event));
|
||||
}
|
||||
|
||||
void on_error_event(Epoll::Event &event) {
|
||||
this->derived().on_error(to_stream(event));
|
||||
void OnErrorEvent(Epoll::Event &event) {
|
||||
this->derived().OnError(to_stream(event));
|
||||
}
|
||||
|
||||
void on_data_event(Epoll::Event &event) {
|
||||
this->derived().on_data(to_stream(event));
|
||||
void OnDataEvent(Epoll::Event &event) {
|
||||
this->derived().OnData(to_stream(event));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void on_exception_event(Epoll::Event &event, Args &&... args) {
|
||||
this->derived().on_exception(to_stream(event), args...);
|
||||
void OnExceptionEvent(Epoll::Event &event, Args &&... args) {
|
||||
this->derived().OnException(to_stream(event), args...);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -3,82 +3,91 @@
|
||||
#include "io/network/stream_listener.hpp"
|
||||
#include "memory/literals.hpp"
|
||||
|
||||
namespace io {
|
||||
namespace io::network {
|
||||
using namespace memory::literals;
|
||||
|
||||
struct StreamBuffer {
|
||||
char* ptr;
|
||||
size_t len;
|
||||
};
|
||||
|
||||
/**
|
||||
* This class is used to get data from a socket that has been notified
|
||||
* with a data available event.
|
||||
*/
|
||||
template <class Derived, class Stream>
|
||||
class StreamReader : public StreamListener<Derived, Stream> {
|
||||
public:
|
||||
struct Buffer {
|
||||
char* ptr;
|
||||
size_t len;
|
||||
};
|
||||
|
||||
StreamReader(uint32_t flags = 0)
|
||||
: StreamListener<Derived, Stream>(flags),
|
||||
logger(logging::log->logger("io::StreamReader")) {}
|
||||
logger_(logging::log->logger("io::StreamReader")) {}
|
||||
|
||||
bool accept(Socket& socket) {
|
||||
logger.trace("accept");
|
||||
bool Accept(Socket& socket) {
|
||||
logger_.trace("Accept");
|
||||
|
||||
// accept a connection from a socket
|
||||
auto s = socket.accept(nullptr, nullptr);
|
||||
Socket s;
|
||||
if (!socket.Accept(&s)) return false;
|
||||
|
||||
if (!s.is_open()) return false;
|
||||
logger_.trace(
|
||||
"Accepted a connection: scoket {}, address '{}', family {}, port {}",
|
||||
s.id(), s.endpoint().address(), s.endpoint().family(),
|
||||
s.endpoint().port());
|
||||
|
||||
// make the recieved socket non blocking
|
||||
s.set_non_blocking();
|
||||
if (!s.SetKeepAlive()) return false;
|
||||
|
||||
auto& stream = this->derived().on_connect(std::move(s));
|
||||
auto& stream = this->derived().OnConnect(std::move(s));
|
||||
|
||||
// we want to listen to an incoming event which is edge triggered and
|
||||
// we also want to listen on the hangup event
|
||||
stream.event.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
|
||||
stream.event.events = EPOLLIN | EPOLLRDHUP;
|
||||
|
||||
// add the connection to the event listener
|
||||
this->add(stream);
|
||||
this->Add(stream);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void on_data(Stream& stream) {
|
||||
logger.trace("on data");
|
||||
void OnData(Stream& stream) {
|
||||
logger_.trace("On data");
|
||||
|
||||
while (true) {
|
||||
if (UNLIKELY(!stream.alive())) {
|
||||
stream.close();
|
||||
logger_.trace("Calling OnClose because the stream isn't alive!");
|
||||
this->derived().OnClose(stream);
|
||||
break;
|
||||
}
|
||||
|
||||
// allocate the buffer to fill the data
|
||||
auto buf = this->derived().on_alloc(stream);
|
||||
auto buf = this->derived().OnAlloc(stream);
|
||||
|
||||
// read from the buffer at most buf.len bytes
|
||||
buf.len = stream.socket.read(buf.ptr, buf.len);
|
||||
buf.len = stream.socket.Read(buf.ptr, buf.len);
|
||||
|
||||
// check for read errors
|
||||
if (buf.len == -1) {
|
||||
// this means we have read all available data
|
||||
if (LIKELY(errno == EAGAIN)) {
|
||||
if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) {
|
||||
break;
|
||||
}
|
||||
|
||||
// some other error occurred, check errno
|
||||
this->derived().on_error(stream);
|
||||
this->derived().OnError(stream);
|
||||
break;
|
||||
}
|
||||
|
||||
// end of file, the client has closed the connection
|
||||
if (UNLIKELY(buf.len == 0)) {
|
||||
stream.close();
|
||||
logger_.trace("Calling OnClose because the socket is closed!");
|
||||
this->derived().OnClose(stream);
|
||||
break;
|
||||
}
|
||||
|
||||
this->derived().on_read(stream, buf);
|
||||
this->derived().OnRead(stream, buf);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Logger logger;
|
||||
Logger logger_;
|
||||
};
|
||||
}
|
||||
|
@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "io/network/epoll.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
namespace io {
|
||||
namespace tcp {
|
||||
|
||||
template <class Socket>
|
||||
class Stream {
|
||||
public:
|
||||
Stream(Socket&& socket) : socket(std::move(socket)) {
|
||||
// save this to epoll event data baton to access later
|
||||
event.data.ptr = this;
|
||||
}
|
||||
|
||||
Stream(Stream&& stream) {
|
||||
socket = std::move(stream.socket);
|
||||
event = stream.event;
|
||||
event.data.ptr = this;
|
||||
}
|
||||
|
||||
int id() const { return socket.id(); }
|
||||
|
||||
Socket socket;
|
||||
Epoll::Event event;
|
||||
};
|
||||
}
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
#include "io/network/tls.hpp"
|
||||
#include "io/network/tls_error.hpp"
|
||||
|
||||
namespace io {
|
||||
|
||||
Tls::Context::Context() {
|
||||
auto method = SSLv23_server_method();
|
||||
ctx = SSL_CTX_new(method);
|
||||
|
||||
if (!ctx) {
|
||||
ERR_print_errors_fp(stderr);
|
||||
throw io::TlsError("Unable to create TLS context");
|
||||
}
|
||||
|
||||
SSL_CTX_set_ecdh_auto(ctx, 1);
|
||||
}
|
||||
|
||||
Tls::Context::~Context() { SSL_CTX_free(ctx); }
|
||||
|
||||
Tls::Context& Tls::Context::cert(const std::string& path) {
|
||||
if (SSL_CTX_use_certificate_file(ctx, path.c_str(), SSL_FILETYPE_PEM) >= 0)
|
||||
return *this;
|
||||
|
||||
ERR_print_errors_fp(stderr);
|
||||
throw TlsError("Error Loading cert '{}'", path);
|
||||
}
|
||||
|
||||
Tls::Context& Tls::Context::key(const std::string& path) {
|
||||
if (SSL_CTX_use_PrivateKey_file(ctx, path.c_str(), SSL_FILETYPE_PEM) >= 0)
|
||||
return *this;
|
||||
|
||||
ERR_print_errors_fp(stderr);
|
||||
throw TlsError("Error Loading private key '{}'", path);
|
||||
}
|
||||
|
||||
void Tls::initialize() {
|
||||
SSL_load_error_strings();
|
||||
OpenSSL_add_ssl_algorithms();
|
||||
}
|
||||
|
||||
void Tls::cleanup() { EVP_cleanup(); }
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
namespace io {
|
||||
|
||||
class Tls {
|
||||
public:
|
||||
class Context {
|
||||
public:
|
||||
Context();
|
||||
~Context();
|
||||
|
||||
Context& cert(const std::string& path);
|
||||
Context& key(const std::string& path);
|
||||
|
||||
operator SSL_CTX*() const { return ctx; }
|
||||
|
||||
private:
|
||||
SSL_CTX* ctx;
|
||||
};
|
||||
|
||||
static void initialize();
|
||||
static void cleanup();
|
||||
};
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "utils/exceptions/stacktrace_exception.hpp"
|
||||
|
||||
namespace io {
|
||||
|
||||
class TlsError : public StacktraceException {
|
||||
public:
|
||||
using StacktraceException::StacktraceException;
|
||||
};
|
||||
}
|
@ -1,9 +1,14 @@
|
||||
#include <signal.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "communication/bolt/v1/server/server.hpp"
|
||||
#include "communication/bolt/v1/server/worker.hpp"
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "query/engine.hpp"
|
||||
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "communication/server.hpp"
|
||||
|
||||
#include "io/network/network_endpoint.hpp"
|
||||
#include "io/network/network_error.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
@ -14,7 +19,13 @@
|
||||
#include "utils/stacktrace/log.hpp"
|
||||
#include "utils/terminate_handler.hpp"
|
||||
|
||||
static bolt::Server<bolt::Worker> *serverptr;
|
||||
using endpoint_t = io::network::NetworkEndpoint;
|
||||
using socket_t = io::network::Socket;
|
||||
using bolt_server_t =
|
||||
communication::Server<bolt::Session<socket_t>, bolt::RecordStream<socket_t>,
|
||||
socket_t>;
|
||||
|
||||
static bolt_server_t *serverptr;
|
||||
|
||||
Logger logger;
|
||||
|
||||
@ -60,28 +71,44 @@ int main(int argc, char **argv) {
|
||||
// register args
|
||||
CONFIG_REGISTER_ARGS(argc, argv);
|
||||
|
||||
// initialize socket
|
||||
io::Socket socket;
|
||||
// initialize endpoint
|
||||
endpoint_t endpoint;
|
||||
try {
|
||||
socket = io::Socket::bind(interface, port);
|
||||
} catch (io::NetworkError e) {
|
||||
logger.error("Cannot bind to socket on {} at {}", interface, port);
|
||||
endpoint = endpoint_t(interface, port);
|
||||
} catch (io::network::NetworkEndpointException &e) {
|
||||
logger.error("{}", e.what());
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
socket.set_non_blocking();
|
||||
socket.listen(1024);
|
||||
|
||||
// initialize socket
|
||||
socket_t socket;
|
||||
if (!socket.Bind(endpoint)) {
|
||||
logger.error("Cannot bind to socket on {} at {}", interface, port);
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
if (!socket.SetNonBlocking()) {
|
||||
logger.error("Cannot set socket to non blocking!");
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
if (!socket.Listen(1024)) {
|
||||
logger.error("Cannot listen on socket!");
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
logger.info("Listening on {} at {}", interface, port);
|
||||
|
||||
Dbms dbms;
|
||||
QueryEngine<bolt::RecordStream<socket_t>> query_engine;
|
||||
|
||||
// initialize server
|
||||
bolt::Server<bolt::Worker> server(std::move(socket));
|
||||
bolt_server_t server(std::move(socket), dbms, query_engine);
|
||||
serverptr = &server;
|
||||
|
||||
// server start with N threads
|
||||
// TODO: N should be configurable
|
||||
auto N = std::thread::hardware_concurrency();
|
||||
logger.info("Starting {} workers", N);
|
||||
server.start(N);
|
||||
server.Start(N);
|
||||
|
||||
logger.info("Shutting down...");
|
||||
return EXIT_SUCCESS;
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "communication/bolt/communication.hpp"
|
||||
#include "database/graph_db_accessor.hpp"
|
||||
#include "query/stripped.hpp"
|
||||
|
||||
|
34
tests/client-stress.sh
Executable file
34
tests/client-stress.sh
Executable file
@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
NUM_PARALLEL=4
|
||||
|
||||
if [[ "$1" != "" ]]; then
|
||||
NUM_PARALLEL=$1
|
||||
fi
|
||||
|
||||
for i in $( seq 1 $NUM_PARALLEL ); do
|
||||
echo "CREATE (n {name: 29383}) RETURN n;" | neo4j-client --insecure --username= --password= neo4j://localhost:7687 >/dev/null 2>/dev/null &
|
||||
done
|
||||
|
||||
running="yes"
|
||||
count=0
|
||||
|
||||
while [[ "$running" != "" ]]; do
|
||||
running=$( pidof neo4j-client )
|
||||
num=$( echo "$running" | wc -w )
|
||||
echo "Running clients: $num"
|
||||
count=$(( count + 1 ))
|
||||
if [[ $count -gt 5 ]]; then break; fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
if [[ "$running" != "" ]]; then
|
||||
echo "Something went wrong!"
|
||||
echo "Running PIDs: $running"
|
||||
echo "Killing leftover clients..."
|
||||
kill -9 $running >/dev/null 2>/dev/null
|
||||
wait $running 2>/dev/null
|
||||
else
|
||||
echo "All ok!"
|
||||
fi
|
||||
|
117
tests/concurrent/network_common.hpp
Normal file
117
tests/concurrent/network_common.hpp
Normal file
@ -0,0 +1,117 @@
|
||||
#pragma ONCE
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
#include "logging/streams/stdout.hpp"
|
||||
|
||||
#include "communication/server.hpp"
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "io/network/epoll.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
#include "query/engine.hpp"
|
||||
|
||||
|
||||
static constexpr const int SIZE = 60000;
|
||||
static constexpr const int REPLY = 10;
|
||||
|
||||
using endpoint_t = io::network::NetworkEndpoint;
|
||||
using socket_t = io::network::Socket;
|
||||
|
||||
class TestOutputStream {};
|
||||
|
||||
class TestSession {
|
||||
public:
|
||||
TestSession(socket_t &&socket, Dbms &dbms,
|
||||
QueryEngine<TestOutputStream> &query_engine)
|
||||
: socket(std::move(socket)), logger_(logging::log->logger("TestSession")) {
|
||||
event.data.ptr = this;
|
||||
}
|
||||
|
||||
bool alive() { return socket.IsOpen(); }
|
||||
|
||||
int id() const { return socket.id(); }
|
||||
|
||||
void execute(const byte *data, size_t len) {
|
||||
if (size_ == 0) {
|
||||
size_ = data[0];
|
||||
size_ <<= 8;
|
||||
size_ += data[1];
|
||||
data += 2;
|
||||
len -= 2;
|
||||
}
|
||||
memcpy(buffer_ + have_, data, len);
|
||||
have_ += len;
|
||||
if (have_ < size_) return;
|
||||
|
||||
for (int i = 0; i < REPLY; ++i)
|
||||
ASSERT_TRUE(this->socket.Write(buffer_, size_));
|
||||
|
||||
have_ = 0;
|
||||
size_ = 0;
|
||||
}
|
||||
|
||||
void close() {
|
||||
logger_.trace("Close session!");
|
||||
this->socket.Close();
|
||||
}
|
||||
|
||||
char buffer_[SIZE * 2];
|
||||
uint32_t have_, size_;
|
||||
|
||||
Logger logger_;
|
||||
socket_t socket;
|
||||
io::network::Epoll::Event event;
|
||||
};
|
||||
|
||||
using test_server_t =
|
||||
communication::Server<TestSession, TestOutputStream, socket_t>;
|
||||
|
||||
void server_start(void* serverptr, int num) {
|
||||
((test_server_t*)serverptr)->Start(num);
|
||||
}
|
||||
|
||||
void client_run(int num, const char* interface, const char* port, const unsigned char* data, int lo, int hi) {
|
||||
std::stringstream name;
|
||||
name << "Client " << num;
|
||||
Logger logger = logging::log->logger(name.str());
|
||||
unsigned char buffer[SIZE * REPLY], head[2];
|
||||
int have, read;
|
||||
endpoint_t endpoint(interface, port);
|
||||
socket_t socket;
|
||||
ASSERT_TRUE(socket.Connect(endpoint));
|
||||
logger.trace("Socket create: {}", socket.id());
|
||||
for (int len = lo; len <= hi; len += 100) {
|
||||
have = 0;
|
||||
head[0] = (len >> 8) & 0xff;
|
||||
head[1] = len & 0xff;
|
||||
ASSERT_TRUE(socket.Write(head, 2));
|
||||
ASSERT_TRUE(socket.Write(data, len));
|
||||
logger.trace("Socket write: {}", socket.id());
|
||||
while (have < len * REPLY) {
|
||||
read = socket.Read(buffer + have, SIZE);
|
||||
logger.trace("Socket read: {}", socket.id());
|
||||
if (read == -1) break;
|
||||
have += read;
|
||||
}
|
||||
for (int i = 0; i < REPLY; ++i)
|
||||
for (int j = 0; j < len; ++j) ASSERT_EQ(buffer[i * len + j], data[j]);
|
||||
}
|
||||
logger.trace("Socket done: {}", socket.id());
|
||||
socket.Close();
|
||||
}
|
||||
|
||||
void initialize_data(unsigned char* data, int size) {
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(0, 255);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data[i] = dis(gen);
|
||||
}
|
||||
}
|
51
tests/concurrent/network_server.cpp
Normal file
51
tests/concurrent/network_server.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#define NDEBUG
|
||||
|
||||
#include "network_common.hpp"
|
||||
|
||||
static constexpr const char interface[] = "127.0.0.1";
|
||||
static constexpr const char port[] = "31337";
|
||||
|
||||
unsigned char data[SIZE];
|
||||
|
||||
test_server_t *serverptr;
|
||||
|
||||
TEST(Network, Server) {
|
||||
// initialize test data
|
||||
initialize_data(data, SIZE);
|
||||
|
||||
// initialize listen socket
|
||||
endpoint_t endpoint(interface, port);
|
||||
socket_t socket;
|
||||
ASSERT_TRUE(socket.Bind(endpoint));
|
||||
ASSERT_TRUE(socket.SetNonBlocking());
|
||||
ASSERT_TRUE(socket.Listen(1024));
|
||||
|
||||
// initialize server
|
||||
Dbms dbms;
|
||||
QueryEngine<TestOutputStream> query_engine;
|
||||
test_server_t server(std::move(socket), dbms, query_engine);
|
||||
serverptr = &server;
|
||||
|
||||
// start server
|
||||
int N = std::thread::hardware_concurrency() / 2;
|
||||
std::thread server_thread(server_start, serverptr, N);
|
||||
|
||||
// start clients
|
||||
std::vector<std::thread> clients;
|
||||
for (int i = 0; i < N; ++i)
|
||||
clients.push_back(std::thread(client_run, i, interface, port, data, 30000, SIZE));
|
||||
|
||||
// cleanup clients
|
||||
for (int i = 0; i < N; ++i) clients[i].join();
|
||||
|
||||
// stop server
|
||||
server.Shutdown();
|
||||
server_thread.join();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
logging::init_async();
|
||||
logging::log->pipe(std::make_unique<Stdout>());
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
62
tests/concurrent/network_session_leak.cpp
Normal file
62
tests/concurrent/network_session_leak.cpp
Normal file
@ -0,0 +1,62 @@
|
||||
#define NDEBUG
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "network_common.hpp"
|
||||
|
||||
static constexpr const char interface[] = "127.0.0.1";
|
||||
static constexpr const char port[] = "31337";
|
||||
|
||||
unsigned char data[SIZE];
|
||||
|
||||
test_server_t *serverptr;
|
||||
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
TEST(Network, SessionLeak) {
|
||||
// initialize test data
|
||||
initialize_data(data, SIZE);
|
||||
|
||||
// initialize listen socket
|
||||
endpoint_t endpoint(interface, port);
|
||||
socket_t socket;
|
||||
ASSERT_TRUE(socket.Bind(endpoint));
|
||||
ASSERT_TRUE(socket.SetNonBlocking());
|
||||
ASSERT_TRUE(socket.Listen(1024));
|
||||
|
||||
// initialize server
|
||||
Dbms dbms;
|
||||
QueryEngine<TestOutputStream> query_engine;
|
||||
test_server_t server(std::move(socket), dbms, query_engine);
|
||||
serverptr = &server;
|
||||
|
||||
// start server
|
||||
std::thread server_thread(server_start, serverptr, 2);
|
||||
|
||||
// start clients
|
||||
int N = 50;
|
||||
std::vector<std::thread> clients;
|
||||
|
||||
int testlen = 3000;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
clients.push_back(std::thread(client_run, i, interface, port, data, testlen, testlen));
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
|
||||
// cleanup clients
|
||||
for (int i = 0; i < N; ++i) clients[i].join();
|
||||
|
||||
std::this_thread::sleep_for(2s);
|
||||
|
||||
// stop server
|
||||
server.Shutdown();
|
||||
server_thread.join();
|
||||
}
|
||||
|
||||
// run with "valgrind --leak-check=full ./network_session_leak" to check for memory leaks
|
||||
int main(int argc, char **argv) {
|
||||
logging::init_sync();
|
||||
logging::log->pipe(std::make_unique<Stdout>());
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -3,8 +3,9 @@
|
||||
// the flag is only used in hardcoded queries compilation
|
||||
// see usage in plan_compiler.hpp
|
||||
#ifndef HARDCODED_OUTPUT_STREAM
|
||||
#include "communication/bolt/communication.hpp"
|
||||
using Stream = communication::OutputStream;
|
||||
#include "communication/bolt/v1/serialization/record_stream.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
using Stream = bolt::RecordStream<io::network::Socket>;
|
||||
#else
|
||||
#include "../stream/print_record_stream.hpp"
|
||||
using Stream = PrintRecordStream;
|
||||
|
117
tests/unit/bolt_session.cpp
Normal file
117
tests/unit/bolt_session.cpp
Normal file
@ -0,0 +1,117 @@
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "logging/streams/stdout.hpp"
|
||||
|
||||
#include "communication/bolt/v1/serialization/record_stream.hpp"
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "query/engine.hpp"
|
||||
|
||||
class TestSocket {
|
||||
public:
|
||||
TestSocket(int socket) : socket(socket) {}
|
||||
TestSocket(const TestSocket& s) : socket(s.id()){};
|
||||
TestSocket(TestSocket&& other) { *this = std::forward<TestSocket>(other); }
|
||||
|
||||
TestSocket& operator=(TestSocket&& other) {
|
||||
this->socket = other.socket;
|
||||
other.socket = -1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Close() { socket = -1; }
|
||||
bool IsOpen() { return socket != -1; }
|
||||
|
||||
int id() const { return socket; }
|
||||
|
||||
int Write(const std::string& str) { return Write(str.c_str(), str.size()); }
|
||||
int Write(const char* data, size_t len) {
|
||||
return Write(reinterpret_cast<const uint8_t*>(data), len);
|
||||
}
|
||||
int Write(const uint8_t* data, size_t len) {
|
||||
for (int i = 0; i < len; ++i) output.push_back(data[i]);
|
||||
return len;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> output;
|
||||
|
||||
protected:
|
||||
int socket;
|
||||
};
|
||||
|
||||
const uint8_t handshake_req[] =
|
||||
"\x60\x60\xb0\x17\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
"\x00\x00";
|
||||
const uint8_t handshake_resp[] = "\x00\x00\x00\x01";
|
||||
const uint8_t init_req[] =
|
||||
"\x00\x3f\xb2\x01\xd0\x15\x6c\x69\x62\x6e\x65\x6f\x34\x6a\x2d\x63\x6c\x69"
|
||||
"\x65\x6e\x74\x2f\x31\x2e\x32\x2e\x31\xa3\x86\x73\x63\x68\x65\x6d\x65\x85"
|
||||
"\x62\x61\x73\x69\x63\x89\x70\x72\x69\x6e\x63\x69\x70\x61\x6c\x80\x8b\x63"
|
||||
"\x72\x65\x64\x65\x6e\x74\x69\x61\x6c\x73\x80\x00\x00";
|
||||
const uint8_t init_resp[] = "\x00\x03\xb1\x70\xa0\x00\x00";
|
||||
const uint8_t run_req[] =
|
||||
"\x00\x26\xb2\x10\xd0\x21\x43\x52\x45\x41\x54\x45\x20\x28\x6e\x20\x7b\x6e"
|
||||
"\x61\x6d\x65\x3a\x20\x32\x39\x33\x38\x33\x7d\x29\x20\x52\x45\x54\x55\x52"
|
||||
"\x4e\x20\x6e\xa0\x00\x00";
|
||||
|
||||
void print_output(std::vector<uint8_t>& output) {
|
||||
fprintf(stderr, "output: ");
|
||||
for (int i = 0; i < output.size(); ++i) {
|
||||
fprintf(stderr, "%02X ", output[i]);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void check_output(std::vector<uint8_t>& output, const uint8_t* data,
|
||||
uint64_t len) {
|
||||
EXPECT_EQ(len, output.size());
|
||||
for (int i = 0; i < len; ++i) {
|
||||
EXPECT_EQ(output[i], data[i]);
|
||||
}
|
||||
output.clear();
|
||||
}
|
||||
|
||||
TEST(Bolt, Session) {
|
||||
Dbms dbms;
|
||||
TestSocket socket(10);
|
||||
QueryEngine<bolt::RecordStream<TestSocket>> query_engine;
|
||||
bolt::Session<TestSocket> session(std::move(socket), dbms, query_engine);
|
||||
std::vector<uint8_t>& output = session.socket.output;
|
||||
|
||||
// execute handshake
|
||||
session.execute(handshake_req, 20);
|
||||
ASSERT_EQ(session.state, bolt::INIT);
|
||||
print_output(output);
|
||||
check_output(output, handshake_resp, 4);
|
||||
|
||||
// execute init
|
||||
session.execute(init_req, 67);
|
||||
ASSERT_EQ(session.state, bolt::EXECUTOR);
|
||||
print_output(output);
|
||||
check_output(output, init_resp, 7);
|
||||
|
||||
// execute run
|
||||
session.execute(run_req, 42);
|
||||
// TODO: query engine doesn't currently work,
|
||||
// we should test the query output here and the next state
|
||||
// ASSERT_EQ(session.state, bolt::EXECUTOR);
|
||||
// print_output(output);
|
||||
// check_output(output, run_resp, len);
|
||||
|
||||
// TODO: add more tests
|
||||
|
||||
session.close();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
logging::init_sync();
|
||||
logging::log->pipe(std::make_unique<Stdout>());
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
84
tests/unit/network_endpoint.cpp
Normal file
84
tests/unit/network_endpoint.cpp
Normal file
@ -0,0 +1,84 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "io/network/network_endpoint.hpp"
|
||||
#include "io/network/network_error.hpp"
|
||||
|
||||
using endpoint_t = io::network::NetworkEndpoint;
|
||||
using exception_t = io::network::NetworkEndpointException;
|
||||
|
||||
TEST(NetworkEndpoint, IPv4) {
|
||||
endpoint_t endpoint;
|
||||
|
||||
// test first constructor
|
||||
endpoint = endpoint_t("127.0.0.1", "12345");
|
||||
EXPECT_STREQ(endpoint.address(), "127.0.0.1");
|
||||
EXPECT_STREQ(endpoint.port_str(), "12345");
|
||||
EXPECT_EQ(endpoint.port(), 12345);
|
||||
EXPECT_EQ(endpoint.family(), 4);
|
||||
|
||||
// test second constructor
|
||||
std::string addr("127.0.0.2"), port("12346");
|
||||
endpoint = endpoint_t(addr, port);
|
||||
EXPECT_STREQ(endpoint.address(), "127.0.0.2");
|
||||
EXPECT_STREQ(endpoint.port_str(), "12346");
|
||||
EXPECT_EQ(endpoint.port(), 12346);
|
||||
EXPECT_EQ(endpoint.family(), 4);
|
||||
|
||||
// test third constructor
|
||||
endpoint = endpoint_t("127.0.0.1", 12347);
|
||||
EXPECT_STREQ(endpoint.address(), "127.0.0.1");
|
||||
EXPECT_STREQ(endpoint.port_str(), "12347");
|
||||
EXPECT_EQ(endpoint.port(), 12347);
|
||||
EXPECT_EQ(endpoint.family(), 4);
|
||||
|
||||
// test address null
|
||||
EXPECT_THROW(endpoint_t(nullptr, nullptr), exception_t);
|
||||
|
||||
// test address invalid
|
||||
EXPECT_THROW(endpoint_t("invalid", "12345"), exception_t);
|
||||
|
||||
// test port invalid
|
||||
EXPECT_THROW(endpoint_t("127.0.0.1", "invalid"), exception_t);
|
||||
}
|
||||
|
||||
TEST(NetworkEndpoint, IPv6) {
|
||||
endpoint_t endpoint;
|
||||
|
||||
// test first constructor
|
||||
endpoint = endpoint_t("ab:cd:ef::1", "12345");
|
||||
EXPECT_STREQ(endpoint.address(), "ab:cd:ef::1");
|
||||
EXPECT_STREQ(endpoint.port_str(), "12345");
|
||||
EXPECT_EQ(endpoint.port(), 12345);
|
||||
EXPECT_EQ(endpoint.family(), 6);
|
||||
|
||||
// test second constructor
|
||||
std::string addr("ab:cd:ef::2"), port("12346");
|
||||
endpoint = endpoint_t(addr, port);
|
||||
EXPECT_STREQ(endpoint.address(), "ab:cd:ef::2");
|
||||
EXPECT_STREQ(endpoint.port_str(), "12346");
|
||||
EXPECT_EQ(endpoint.port(), 12346);
|
||||
EXPECT_EQ(endpoint.family(), 6);
|
||||
|
||||
// test third constructor
|
||||
endpoint = endpoint_t("ab:cd:ef::3", 12347);
|
||||
EXPECT_STREQ(endpoint.address(), "ab:cd:ef::3");
|
||||
EXPECT_STREQ(endpoint.port_str(), "12347");
|
||||
EXPECT_EQ(endpoint.port(), 12347);
|
||||
EXPECT_EQ(endpoint.family(), 6);
|
||||
|
||||
// test address null
|
||||
EXPECT_THROW(endpoint_t(nullptr, nullptr), exception_t);
|
||||
|
||||
// test address invalid
|
||||
EXPECT_THROW(endpoint_t("::g", "12345"), exception_t);
|
||||
|
||||
// test port invalid
|
||||
EXPECT_THROW(endpoint_t("::1", "invalid"), exception_t);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
Loading…
Reference in New Issue
Block a user