Begin refactor of network layer
Reviewers: mferencevic Reviewed By: mferencevic Subscribers: mferencevic, pullbot Differential Revision: https://phabricator.memgraph.io/D855
This commit is contained in:
parent
0f5c2bb6c3
commit
f10380a861
@ -72,7 +72,7 @@ class Session {
|
||||
/**
|
||||
* @return the socket id
|
||||
*/
|
||||
int Id() const { return socket_.id(); }
|
||||
int Id() const { return socket_.fd(); }
|
||||
|
||||
/**
|
||||
* Executes the session after data has been read into the buffer.
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <experimental/optional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
@ -10,7 +11,7 @@
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "communication/worker.hpp"
|
||||
#include "io/network/event_listener.hpp"
|
||||
#include "io/network/socket_event_dispatcher.hpp"
|
||||
#include "utils/assert.hpp"
|
||||
|
||||
namespace communication {
|
||||
@ -34,42 +35,42 @@ namespace communication {
|
||||
* @tparam SessionData the class with objects that will be forwarded to the
|
||||
* session
|
||||
*/
|
||||
// TODO: Remove Socket templatisation. Socket requirements are very specific.
|
||||
// It needs to be in non blocking mode, etc.
|
||||
template <typename Session, typename Socket, typename SessionData>
|
||||
class Server
|
||||
: public io::network::EventListener<Server<Session, Socket, SessionData>> {
|
||||
using Event = io::network::Epoll::Event;
|
||||
|
||||
class Server {
|
||||
public:
|
||||
using worker_t = Worker<Session, Socket, SessionData>;
|
||||
|
||||
Server(Socket &&socket, SessionData &session_data)
|
||||
: socket_(std::forward<Socket>(socket)), session_data_(session_data) {
|
||||
event_.data.fd = socket_;
|
||||
|
||||
// TODO: EPOLLET is hard to use -> figure out how should EPOLLET be used
|
||||
// event.events = EPOLLIN | EPOLLET;
|
||||
event_.events = EPOLLIN;
|
||||
|
||||
this->listener_.Add(socket_, &event_);
|
||||
}
|
||||
: socket_(std::move(socket)), session_data_(session_data) {}
|
||||
|
||||
void Start(size_t n) {
|
||||
std::cout << fmt::format("Starting {} workers", n) << std::endl;
|
||||
workers_.reserve(n);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
workers_.push_back(std::make_unique<Worker<Session, Socket, SessionData>>(
|
||||
session_data_));
|
||||
workers_.back()->Start(alive_);
|
||||
workers_.push_back(std::make_unique<worker_t>(session_data_));
|
||||
worker_threads_.emplace_back(
|
||||
[this](worker_t &worker) -> void { worker.Start(alive_); },
|
||||
std::ref(*workers_.back()));
|
||||
}
|
||||
std::cout << "Server is fully armed and operational" << std::endl;
|
||||
std::cout << fmt::format("Listening on {} at {}",
|
||||
socket_.endpoint().address(),
|
||||
socket_.endpoint().port())
|
||||
<< std::endl;
|
||||
|
||||
io::network::SocketEventDispatcher<ConnectionAcceptor> dispatcher;
|
||||
ConnectionAcceptor acceptor(socket_, *this);
|
||||
dispatcher.AddListener(socket_.fd(), acceptor, EPOLLIN);
|
||||
while (alive_) {
|
||||
this->WaitAndProcessEvents();
|
||||
dispatcher.WaitAndProcessEvents();
|
||||
}
|
||||
|
||||
std::cout << "Shutting down..." << std::endl;
|
||||
for (auto &worker : workers_) worker->thread_.join();
|
||||
for (auto &worker_thread : worker_threads_) {
|
||||
worker_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
void Shutdown() {
|
||||
@ -78,41 +79,55 @@ class Server
|
||||
alive_.store(false);
|
||||
}
|
||||
|
||||
void OnConnect() {
|
||||
debug_assert(idx_ < workers_.size(), "Invalid worker id.");
|
||||
|
||||
DLOG(INFO) << "on connect";
|
||||
|
||||
if (UNLIKELY(!workers_[idx_]->Accept(socket_))) return;
|
||||
|
||||
idx_ = idx_ == static_cast<int>(workers_.size()) - 1 ? 0 : idx_ + 1;
|
||||
}
|
||||
|
||||
void OnWaitTimeout() {}
|
||||
|
||||
void OnDataEvent(Event &event) {
|
||||
if (UNLIKELY(socket_ != event.data.fd)) return;
|
||||
|
||||
this->derived().OnConnect();
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void OnExceptionEvent(Event &, Args &&...) {
|
||||
// TODO: Do something about it
|
||||
DLOG(WARNING) << "epoll exception";
|
||||
}
|
||||
|
||||
void OnCloseEvent(Event &event) { close(event.data.fd); }
|
||||
|
||||
void OnErrorEvent(Event &event) { close(event.data.fd); }
|
||||
|
||||
private:
|
||||
std::vector<typename Worker<Session, Socket, SessionData>::uptr> workers_;
|
||||
class ConnectionAcceptor : public io::network::BaseListener {
|
||||
public:
|
||||
ConnectionAcceptor(Socket &socket,
|
||||
Server<Session, Socket, SessionData> &server)
|
||||
: io::network::BaseListener(socket), server_(server) {}
|
||||
|
||||
void OnData() {
|
||||
debug_assert(server_.idx_ < server_.workers_.size(),
|
||||
"Invalid worker id.");
|
||||
DLOG(INFO) << "On connect";
|
||||
auto connection = AcceptConnection();
|
||||
if (UNLIKELY(!connection)) {
|
||||
// Connection is not available anymore or configuration failed.
|
||||
return;
|
||||
}
|
||||
server_.workers_[server_.idx_]->AddConnection(std::move(*connection));
|
||||
server_.idx_ = (server_.idx_ + 1) % server_.workers_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
// Accepts connection on socket_ and configures new connections. If done
|
||||
// successfuly new socket (connection) is returner, nullopt otherwise.
|
||||
std::experimental::optional<Socket> AcceptConnection() {
|
||||
DLOG(INFO) << "Accept new connection on socket: " << socket_.fd();
|
||||
|
||||
// Accept a connection from a socket.
|
||||
auto s = socket_.Accept();
|
||||
if (!s) return std::experimental::nullopt;
|
||||
|
||||
DLOG(INFO) << fmt::format(
|
||||
"Accepted a connection: socket {}, address '{}', family {}, port {}",
|
||||
s->fd(), s->endpoint().address(), s->endpoint().family(),
|
||||
s->endpoint().port());
|
||||
|
||||
if (!s->SetKeepAlive()) return std::experimental::nullopt;
|
||||
if (!s->SetNoDelay()) return std::experimental::nullopt;
|
||||
return s;
|
||||
}
|
||||
|
||||
Server<Session, Socket, SessionData> &server_;
|
||||
};
|
||||
|
||||
std::vector<std::unique_ptr<worker_t>> workers_;
|
||||
std::vector<std::thread> worker_threads_;
|
||||
std::atomic<bool> alive_{true};
|
||||
int idx_{0};
|
||||
|
||||
Socket socket_;
|
||||
Event event_;
|
||||
SessionData &session_data_;
|
||||
};
|
||||
|
||||
|
@ -1,16 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "io/network/network_error.hpp"
|
||||
#include "io/network/stream_reader.hpp"
|
||||
#include "io/network/socket_event_dispatcher.hpp"
|
||||
#include "io/network/stream_buffer.hpp"
|
||||
#include "threading/sync/spinlock.hpp"
|
||||
|
||||
namespace communication {
|
||||
|
||||
@ -27,74 +31,125 @@ namespace communication {
|
||||
* represents a different protocol so the same network infrastructure
|
||||
* can be used for handling different protocols
|
||||
* @tparam Socket the input/output socket that should be used
|
||||
* @tparam SessionData the class with objects that will be forwarded to the session
|
||||
* @tparam SessionData the class with objects that will be forwarded to the
|
||||
* session
|
||||
*/
|
||||
template <typename Session, typename Socket, typename SessionData>
|
||||
class Worker
|
||||
|
||||
: public io::network::StreamReader<Worker<Session, Socket, SessionData>,
|
||||
Session> {
|
||||
class Worker {
|
||||
using StreamBuffer = io::network::StreamBuffer;
|
||||
|
||||
public:
|
||||
using uptr = std::unique_ptr<Worker<Session, Socket, SessionData>>;
|
||||
void AddConnection(Socket &&connection) {
|
||||
std::unique_lock<SpinLock> gurad(lock_);
|
||||
// Remember fd before moving connection into SessionListener.
|
||||
int fd = connection.fd();
|
||||
session_listeners_.push_back(
|
||||
std::make_unique<SessionSocketListener>(std::move(connection), *this));
|
||||
// We want to listen to an incoming event which is edge triggered and
|
||||
// we also want to listen on the hangup event.
|
||||
dispatcher_.AddListener(fd, *session_listeners_.back(),
|
||||
EPOLLIN | EPOLLRDHUP);
|
||||
}
|
||||
|
||||
Worker(SessionData &session_data) : session_data_(session_data) {}
|
||||
|
||||
Session &OnConnect(Socket &&socket) {
|
||||
DLOG(INFO) << "Accepting connection on socket " << socket.id();
|
||||
|
||||
// TODO fix session lifecycle handling
|
||||
// dangling pointers are not cool :)
|
||||
// TODO attach currently active Db
|
||||
return *(new Session(std::forward<Socket>(socket), session_data_));
|
||||
}
|
||||
|
||||
void OnError(Session &session) {
|
||||
LOG(ERROR) << "Error occured in this session";
|
||||
OnClose(session);
|
||||
}
|
||||
|
||||
void OnWaitTimeout() {}
|
||||
|
||||
void OnRead(Session &session) {
|
||||
DLOG(INFO) << "OnRead";
|
||||
|
||||
try {
|
||||
session.Execute();
|
||||
} catch (const std::exception &e) {
|
||||
LOG(ERROR) << "Error occured while executing statement. " << std::endl
|
||||
<< e.what();
|
||||
// TODO: report to client
|
||||
void Start(std::atomic<bool> &alive) {
|
||||
while (alive) {
|
||||
dispatcher_.WaitAndProcessEvents();
|
||||
}
|
||||
}
|
||||
|
||||
void OnClose(Session &session) {
|
||||
LOG(INFO) << fmt::format("Client {}:{} closed the connection.",
|
||||
session.socket_.endpoint().address(),
|
||||
session.socket_.endpoint().port())
|
||||
<< std::endl;
|
||||
// TODO: remove socket from epoll object
|
||||
session.Close();
|
||||
delete &session;
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void OnException(Session &, Args &&...) {
|
||||
LOG(ERROR) << "Error occured in this session";
|
||||
|
||||
// TODO: Do something about it
|
||||
}
|
||||
|
||||
std::thread thread_;
|
||||
|
||||
void Start(std::atomic<bool> &alive) {
|
||||
thread_ = std::thread([&, this]() {
|
||||
while (alive) this->WaitAndProcessEvents();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// TODO: Think about ownership. Who should own socket session,
|
||||
// SessionSocketListener or Worker?
|
||||
class SessionSocketListener : public io::network::BaseListener {
|
||||
public:
|
||||
SessionSocketListener(Socket &&socket,
|
||||
Worker<Session, Socket, SessionData> &worker)
|
||||
: BaseListener(session_.socket_),
|
||||
session_(std::move(socket), worker.session_data_),
|
||||
worker_(worker) {}
|
||||
|
||||
void OnError() {
|
||||
LOG(ERROR) << "Error occured in this session";
|
||||
OnClose();
|
||||
}
|
||||
|
||||
void OnData() {
|
||||
DLOG(INFO) << "On data";
|
||||
|
||||
if (UNLIKELY(!session_.Alive())) {
|
||||
DLOG(WARNING) << "Calling OnClose because the stream isn't alive!";
|
||||
OnClose();
|
||||
return;
|
||||
}
|
||||
|
||||
// allocate the buffer to fill the data
|
||||
auto buf = session_.Allocate();
|
||||
|
||||
// read from the buffer at most buf.len bytes
|
||||
int len = session_.socket_.Read(buf.data, buf.len);
|
||||
|
||||
// check for read errors
|
||||
if (len == -1) {
|
||||
// this means we have read all available data
|
||||
if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// some other error occurred, check errno
|
||||
OnError();
|
||||
return;
|
||||
}
|
||||
|
||||
// end of file, the client has closed the connection
|
||||
if (UNLIKELY(len == 0)) {
|
||||
DLOG(WARNING) << "Calling OnClose because the socket is closed!";
|
||||
OnClose();
|
||||
return;
|
||||
}
|
||||
|
||||
// notify the stream that it has new data
|
||||
session_.Written(len);
|
||||
|
||||
DLOG(INFO) << "OnRead";
|
||||
|
||||
try {
|
||||
session_.Execute();
|
||||
} catch (const std::exception &e) {
|
||||
LOG(ERROR) << "Error occured while executing statement. " << std::endl
|
||||
<< e.what();
|
||||
// TODO: report to client
|
||||
}
|
||||
// TODO: Should we even continue with this session if error occurs while
|
||||
// reading.
|
||||
}
|
||||
|
||||
void OnClose() {
|
||||
LOG(INFO) << fmt::format("Client {}:{} closed the connection.",
|
||||
session_.socket_.endpoint().address(),
|
||||
session_.socket_.endpoint().port())
|
||||
<< std::endl;
|
||||
session_.Close();
|
||||
std::unique_lock<SpinLock> gurad(worker_.lock_);
|
||||
auto it = std::find_if(
|
||||
worker_.session_listeners_.begin(), worker_.session_listeners_.end(),
|
||||
[&](const auto &l) { return l->session_.Id() == session_.Id(); });
|
||||
CHECK(it != worker_.session_listeners_.end())
|
||||
<< "Trying to remove session that is not found in worker's sessions";
|
||||
int i = it - worker_.session_listeners_.begin();
|
||||
swap(worker_.session_listeners_[i], worker_.session_listeners_.back());
|
||||
worker_.session_listeners_.pop_back();
|
||||
}
|
||||
|
||||
private:
|
||||
Session session_;
|
||||
Worker &worker_;
|
||||
};
|
||||
|
||||
SpinLock lock_;
|
||||
SessionData &session_data_;
|
||||
io::network::SocketEventDispatcher<SessionSocketListener> dispatcher_;
|
||||
std::vector<std::unique_ptr<SessionSocketListener>> session_listeners_;
|
||||
};
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <errno.h>
|
||||
#include <fmt/format.h>
|
||||
#include <glog/logging.h>
|
||||
#include <malloc.h>
|
||||
#include <sys/epoll.h>
|
||||
|
||||
@ -9,11 +12,6 @@
|
||||
|
||||
namespace io::network {
|
||||
|
||||
class EpollError : utils::StacktraceException {
|
||||
public:
|
||||
using utils::StacktraceException::StacktraceException;
|
||||
};
|
||||
|
||||
/**
|
||||
* Wrapper class for epoll.
|
||||
* Creates an object that listens on file descriptor status changes.
|
||||
@ -25,25 +23,35 @@ class Epoll {
|
||||
|
||||
Epoll(int flags) {
|
||||
epoll_fd_ = epoll_create1(flags);
|
||||
|
||||
if (UNLIKELY(epoll_fd_ == -1))
|
||||
throw EpollError("Can't create epoll file descriptor");
|
||||
// epoll_create1 returns an error if there is a logical error in our code
|
||||
// (for example invalid flags) or if there is irrecoverable error. In both
|
||||
// cases it is best to terminate.
|
||||
CHECK(epoll_fd_ != -1) << "Error on epoll_create1, errno: " << errno
|
||||
<< ", message: " << strerror(errno);
|
||||
}
|
||||
|
||||
template <class Stream>
|
||||
void Add(Stream &stream, Event *event) {
|
||||
auto status = epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, stream, event);
|
||||
|
||||
if (UNLIKELY(status))
|
||||
throw EpollError("Can't add an event to epoll listener.");
|
||||
void Add(int fd, Event *event) {
|
||||
auto status = epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, event);
|
||||
// epoll_ctl can return an error on our logical error or on irrecoverable
|
||||
// error. There is a third possibility that some system limit is reached. In
|
||||
// that case we could return an erorr and close connection. Chances of
|
||||
// reaching system limit in normally working memgraph is extremely unlikely,
|
||||
// so it is correct to terminate even in that case.
|
||||
CHECK(!status) << "Error on epoll_ctl, errno: " << errno
|
||||
<< ", message: " << strerror(errno);
|
||||
}
|
||||
|
||||
int Wait(Event *events, int max_events, int timeout) {
|
||||
return epoll_wait(epoll_fd_, events, max_events, timeout);
|
||||
auto num_events = epoll_wait(epoll_fd_, events, max_events, timeout);
|
||||
// If this check fails there was logical error in our code.
|
||||
CHECK(num_events != -1 || errno == EINTR)
|
||||
<< "Error on epoll_wait, errno: " << errno
|
||||
<< ", message: " << strerror(errno);
|
||||
// num_events can be -1 if errno was EINTR (epoll_wait interrupted by signal
|
||||
// handler). We treat that as no events, so we return 0.
|
||||
return num_events == -1 ? 0 : num_events;
|
||||
}
|
||||
|
||||
int id() const { return epoll_fd_; }
|
||||
|
||||
private:
|
||||
int epoll_fd_;
|
||||
};
|
||||
|
@ -1,83 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "io/network/epoll.hpp"
|
||||
#include "utils/crtp.hpp"
|
||||
|
||||
namespace io::network {
|
||||
|
||||
/**
|
||||
* This class listens to events on an epoll object and calls
|
||||
* callback functions to process them.
|
||||
*/
|
||||
template <class Derived, size_t max_events = 64, int wait_timeout = -1>
|
||||
class EventListener : public Crtp<Derived> {
|
||||
public:
|
||||
using Crtp<Derived>::derived;
|
||||
|
||||
EventListener(uint32_t flags = 0) : listener_(flags) {}
|
||||
|
||||
void WaitAndProcessEvents() {
|
||||
// TODO hardcoded a wait timeout because of thread joining
|
||||
// when you shutdown the server. This should be wait_timeout of the
|
||||
// template parameter and should almost never change from that.
|
||||
// thread joining should be resolved using a signal that interrupts
|
||||
// the system call.
|
||||
|
||||
// waits for an event/multiple events and returns a maximum of
|
||||
// max_events and stores them in the events array. it waits for
|
||||
// wait_timeout milliseconds. if wait_timeout is achieved, returns 0
|
||||
|
||||
auto n = listener_.Wait(events_, max_events, 200);
|
||||
|
||||
#ifndef NDEBUG
|
||||
#ifndef LOG_NO_TRACE
|
||||
DLOG_IF(INFO, n > 0) << "number of events: " << n;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// go through all events and process them in order
|
||||
for (int i = 0; i < n; ++i) {
|
||||
auto &event = events_[i];
|
||||
|
||||
try {
|
||||
// hangup event
|
||||
if (UNLIKELY(event.events & EPOLLRDHUP)) {
|
||||
this->derived().OnCloseEvent(event);
|
||||
continue;
|
||||
}
|
||||
|
||||
// there was an error on the server side
|
||||
if (UNLIKELY(!(event.events & EPOLLIN) ||
|
||||
event.events & (EPOLLHUP | EPOLLERR))) {
|
||||
this->derived().OnErrorEvent(event);
|
||||
continue;
|
||||
}
|
||||
|
||||
// we have some data waiting to be read
|
||||
this->derived().OnDataEvent(event);
|
||||
} catch (const std::exception &e) {
|
||||
this->derived().OnExceptionEvent(
|
||||
event, "Error occured while processing event \n{}", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
// this will be optimized out :D
|
||||
if (wait_timeout < 0) return;
|
||||
|
||||
// if there was events, continue to wait on new events
|
||||
if (n != 0) return;
|
||||
|
||||
// wait timeout occurred and there were no events. if wait_timeout
|
||||
// is -1 there will never be any timeouts so client should provide
|
||||
// an empty function. in that case the conditional above and the
|
||||
// function call will be optimized out by the compiler
|
||||
this->derived().OnWaitTimeout();
|
||||
}
|
||||
|
||||
protected:
|
||||
Epoll listener_;
|
||||
Epoll::Event events_[max_events];
|
||||
};
|
||||
}
|
@ -22,19 +22,18 @@
|
||||
|
||||
namespace io::network {
|
||||
|
||||
Socket::Socket() : socket_(-1) {}
|
||||
|
||||
Socket::Socket(int sock, const NetworkEndpoint &endpoint)
|
||||
: socket_(sock), endpoint_(endpoint) {}
|
||||
|
||||
Socket::Socket(const Socket &s) : socket_(s.id()) {}
|
||||
|
||||
Socket::Socket(Socket &&other) { *this = std::forward<Socket>(other); }
|
||||
Socket::Socket(Socket &&other) {
|
||||
socket_ = other.socket_;
|
||||
endpoint_ = std::move(other.endpoint_);
|
||||
other.socket_ = -1;
|
||||
}
|
||||
|
||||
Socket &Socket::operator=(Socket &&other) {
|
||||
socket_ = other.socket_;
|
||||
endpoint_ = other.endpoint_;
|
||||
other.socket_ = -1;
|
||||
if (this != &other) {
|
||||
socket_ = other.socket_;
|
||||
endpoint_ = std::move(other.endpoint_);
|
||||
other.socket_ = -1;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -52,7 +51,7 @@ void Socket::Close() {
|
||||
bool Socket::IsOpen() { return socket_ != -1; }
|
||||
|
||||
bool Socket::Connect(const NetworkEndpoint &endpoint) {
|
||||
if (UNLIKELY(socket_ != -1)) return false;
|
||||
if (socket_ != -1) return false;
|
||||
|
||||
auto info = AddrInfo::Get(endpoint.address(), endpoint.port_str());
|
||||
|
||||
@ -71,7 +70,7 @@ bool Socket::Connect(const NetworkEndpoint &endpoint) {
|
||||
}
|
||||
|
||||
bool Socket::Bind(const NetworkEndpoint &endpoint) {
|
||||
if (UNLIKELY(socket_ != -1)) return false;
|
||||
if (socket_ != -1) return false;
|
||||
|
||||
auto info = AddrInfo::Get(endpoint.address(), endpoint.port_str());
|
||||
|
||||
@ -163,7 +162,7 @@ bool Socket::SetTimeout(long sec, long usec) {
|
||||
|
||||
bool Socket::Listen(int backlog) { return listen(socket_, backlog) == 0; }
|
||||
|
||||
bool Socket::Accept(Socket *s) {
|
||||
std::experimental::optional<Socket> Socket::Accept() {
|
||||
sockaddr_storage addr;
|
||||
socklen_t addr_size = sizeof addr;
|
||||
char addr_decoded[INET6_ADDRSTRLEN];
|
||||
@ -172,7 +171,7 @@ bool Socket::Accept(Socket *s) {
|
||||
unsigned char family;
|
||||
|
||||
int sfd = accept(socket_, (struct sockaddr *)&addr, &addr_size);
|
||||
if (UNLIKELY(sfd == -1)) return false;
|
||||
if (UNLIKELY(sfd == -1)) return std::experimental::nullopt;
|
||||
|
||||
if (addr.ss_family == AF_INET) {
|
||||
addr_src = (void *)&(((sockaddr_in *)&addr)->sin_addr);
|
||||
@ -190,17 +189,12 @@ bool Socket::Accept(Socket *s) {
|
||||
try {
|
||||
endpoint = NetworkEndpoint(addr_decoded, port);
|
||||
} catch (NetworkEndpointException &e) {
|
||||
return false;
|
||||
return std::experimental::nullopt;
|
||||
}
|
||||
|
||||
*s = Socket(sfd, endpoint);
|
||||
|
||||
return true;
|
||||
return Socket(sfd, endpoint);
|
||||
}
|
||||
|
||||
Socket::operator int() { return socket_; }
|
||||
|
||||
int Socket::id() const { return socket_; }
|
||||
const NetworkEndpoint &Socket::endpoint() const { return endpoint_; }
|
||||
|
||||
bool Socket::Write(const std::string &str) {
|
||||
|
@ -1,9 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "io/network/network_endpoint.hpp"
|
||||
|
||||
#include <experimental/optional>
|
||||
#include <iostream>
|
||||
|
||||
#include "io/network/network_endpoint.hpp"
|
||||
|
||||
namespace io::network {
|
||||
|
||||
/**
|
||||
@ -14,10 +15,11 @@ namespace io::network {
|
||||
*/
|
||||
class Socket {
|
||||
public:
|
||||
Socket();
|
||||
Socket(const Socket &s);
|
||||
Socket(Socket &&other);
|
||||
Socket &operator=(Socket &&other);
|
||||
Socket() = default;
|
||||
Socket(const Socket &) = delete;
|
||||
Socket &operator=(const Socket &) = delete;
|
||||
Socket(Socket &&);
|
||||
Socket &operator=(Socket &&);
|
||||
~Socket();
|
||||
|
||||
/**
|
||||
@ -72,14 +74,9 @@ class Socket {
|
||||
* Accepts a new connection.
|
||||
* This function accepts a new connection on a listening socket.
|
||||
*
|
||||
* @param s Socket object that will be instantiated with the new connection
|
||||
*
|
||||
* @return accept success status:
|
||||
* true if a new connection was accepted and the socket 's' was
|
||||
* instantiated
|
||||
* false if a new connection accept failed
|
||||
* @return socket if accepted, nullopt otherwise.
|
||||
*/
|
||||
bool Accept(Socket *s);
|
||||
std::experimental::optional<Socket> Accept();
|
||||
|
||||
/**
|
||||
* Sets the socket to non-blocking.
|
||||
@ -122,14 +119,10 @@ class Socket {
|
||||
*/
|
||||
bool SetTimeout(long sec, long usec);
|
||||
|
||||
// TODO: this will be removed
|
||||
operator int();
|
||||
|
||||
/**
|
||||
* Returns the socket ID.
|
||||
* The socket ID is its unix file descriptor number.
|
||||
* Returns the socket file descriptor.
|
||||
*/
|
||||
int id() const;
|
||||
int fd() const { return socket_; }
|
||||
|
||||
/**
|
||||
* Returns the currently active endpoint of the socket.
|
||||
@ -167,9 +160,10 @@ class Socket {
|
||||
int Read(void *buffer, size_t len);
|
||||
|
||||
private:
|
||||
Socket(int sock, const NetworkEndpoint &endpoint);
|
||||
Socket(int fd, const NetworkEndpoint &endpoint)
|
||||
: socket_(fd), endpoint_(endpoint) {}
|
||||
|
||||
int socket_;
|
||||
int socket_ = -1;
|
||||
NetworkEndpoint endpoint_;
|
||||
};
|
||||
}
|
||||
|
106
src/io/network/socket_event_dispatcher.hpp
Normal file
106
src/io/network/socket_event_dispatcher.hpp
Normal file
@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "io/network/epoll.hpp"
|
||||
#include "utils/crtp.hpp"
|
||||
|
||||
namespace io::network {
|
||||
|
||||
/**
|
||||
* This class listens to events on an epoll object and calls
|
||||
* callback functions to process them.
|
||||
*/
|
||||
|
||||
template <class Listener>
|
||||
class SocketEventDispatcher {
|
||||
public:
|
||||
SocketEventDispatcher(uint32_t flags = 0) : epoll_(flags) {}
|
||||
|
||||
void AddListener(int fd, Listener &listener, uint32_t events) {
|
||||
// Add the listener associated to fd file descriptor to epoll.
|
||||
epoll_event event;
|
||||
event.events = events;
|
||||
event.data.ptr = &listener;
|
||||
epoll_.Add(fd, &event);
|
||||
}
|
||||
|
||||
// Returns true if there was event before timeout.
|
||||
bool WaitAndProcessEvents() {
|
||||
// Waits for an event/multiple events and returns a maximum of max_events
|
||||
// and stores them in the events array. It waits for wait_timeout
|
||||
// milliseconds. If wait_timeout is achieved, returns 0.
|
||||
const auto n = epoll_.Wait(events_, kMaxEvents, 200);
|
||||
DLOG_IF(INFO, n > 0) << "number of events: " << n;
|
||||
|
||||
// Go through all events and process them in order.
|
||||
for (int i = 0; i < n; ++i) {
|
||||
auto &event = events_[i];
|
||||
|
||||
Listener &listener = *reinterpret_cast<Listener *>(event.data.ptr);
|
||||
|
||||
// TODO: revise this. Reported events will be combined so continue is not
|
||||
// probably what we want to do.
|
||||
try {
|
||||
// Hangup event.
|
||||
if (UNLIKELY(event.events & EPOLLRDHUP)) {
|
||||
listener.OnClose();
|
||||
continue;
|
||||
}
|
||||
|
||||
// There was an error on the server side.
|
||||
if (UNLIKELY(!(event.events & EPOLLIN) ||
|
||||
event.events & (EPOLLHUP | EPOLLERR))) {
|
||||
listener.OnError();
|
||||
continue;
|
||||
}
|
||||
|
||||
// We have some data waiting to be read.
|
||||
listener.OnData();
|
||||
} catch (const std::exception &e) {
|
||||
listener.OnException(e);
|
||||
}
|
||||
}
|
||||
|
||||
return n > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static const int kMaxEvents = 64;
|
||||
// TODO: epoll is really ridiculous here. We don't plan to handle thousands of
|
||||
// connections so ppoll would actually be better or (even plain nonblocking
|
||||
// socket).
|
||||
Epoll epoll_;
|
||||
Epoll::Event events_[kMaxEvents];
|
||||
};
|
||||
|
||||
/**
|
||||
* Implements Listener concept, suitable for inheritance.
|
||||
*/
|
||||
class BaseListener {
|
||||
public:
|
||||
BaseListener(Socket &socket) : socket_(socket) {}
|
||||
|
||||
void OnClose() { socket_.Close(); }
|
||||
|
||||
// If server is listening on socket and there is incoming connection OnData
|
||||
// event will be triggered.
|
||||
void OnData() {}
|
||||
|
||||
void OnException(const std::exception &e) {
|
||||
// TODO: this actually sounds quite bad, maybe we should close socket here
|
||||
// because we don'y know in which state Listener class is.
|
||||
LOG(ERROR) << "Exception was thrown while processing event on socket "
|
||||
<< socket_.fd() << " with message: " << e.what();
|
||||
}
|
||||
|
||||
void OnError() {
|
||||
LOG(ERROR) << "Error on server side occured in epoll";
|
||||
socket_.Close();
|
||||
}
|
||||
|
||||
protected:
|
||||
Socket &socket_;
|
||||
};
|
||||
|
||||
} // namespace io::network
|
@ -1,40 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "io/network/event_listener.hpp"
|
||||
|
||||
namespace io::network {
|
||||
|
||||
template <class Derived, class Stream, size_t max_events = 64,
|
||||
int wait_timeout = -1>
|
||||
class StreamListener : public EventListener<Derived, max_events, wait_timeout> {
|
||||
public:
|
||||
using EventListener<Derived, max_events, wait_timeout>::EventListener;
|
||||
|
||||
void Add(Stream &stream) {
|
||||
// add the stream to the event listener
|
||||
this->listener_.Add(stream.socket_, &stream.event_);
|
||||
}
|
||||
|
||||
void OnCloseEvent(Epoll::Event &event) {
|
||||
this->derived().OnClose(to_stream(event));
|
||||
}
|
||||
|
||||
void OnErrorEvent(Epoll::Event &event) {
|
||||
this->derived().OnError(to_stream(event));
|
||||
}
|
||||
|
||||
void OnDataEvent(Epoll::Event &event) {
|
||||
this->derived().OnData(to_stream(event));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void OnExceptionEvent(Epoll::Event &event, Args &&... args) {
|
||||
this->derived().OnException(to_stream(event), args...);
|
||||
}
|
||||
|
||||
private:
|
||||
Stream &to_stream(Epoll::Event &event) {
|
||||
return *reinterpret_cast<Stream *>(event.data.ptr);
|
||||
}
|
||||
};
|
||||
}
|
@ -1,85 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "io/network/stream_buffer.hpp"
|
||||
#include "io/network/stream_listener.hpp"
|
||||
#include "utils/memory_literals.hpp"
|
||||
|
||||
namespace io::network {
|
||||
|
||||
/**
|
||||
* This class is used to get data from a socket that has been notified
|
||||
* with a data available event.
|
||||
*/
|
||||
template <class Derived, class Stream>
|
||||
class StreamReader : public StreamListener<Derived, Stream> {
|
||||
public:
|
||||
StreamReader(uint32_t flags = 0) : StreamListener<Derived, Stream>(flags) {}
|
||||
|
||||
bool Accept(Socket &socket) {
|
||||
DLOG(INFO) << "Accept";
|
||||
|
||||
// accept a connection from a socket
|
||||
Socket s;
|
||||
if (!socket.Accept(&s)) return false;
|
||||
|
||||
DLOG(INFO) << fmt::format(
|
||||
"Accepted a connection: socket {}, address '{}', family {}, port {}",
|
||||
s.id(), s.endpoint().address(), s.endpoint().family(),
|
||||
s.endpoint().port());
|
||||
|
||||
if (!s.SetKeepAlive()) return false;
|
||||
if (!s.SetNoDelay()) return false;
|
||||
|
||||
auto &stream = this->derived().OnConnect(std::move(s));
|
||||
|
||||
// we want to listen to an incoming event which is edge triggered and
|
||||
// we also want to listen on the hangup event
|
||||
stream.event_.events = EPOLLIN | EPOLLRDHUP;
|
||||
|
||||
// add the connection to the event listener
|
||||
this->Add(stream);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OnData(Stream &stream) {
|
||||
DLOG(INFO) << "On data";
|
||||
|
||||
if (UNLIKELY(!stream.Alive())) {
|
||||
DLOG(WARNING) << "Calling OnClose because the stream isn't alive!";
|
||||
this->derived().OnClose(stream);
|
||||
return;
|
||||
}
|
||||
|
||||
// allocate the buffer to fill the data
|
||||
auto buf = stream.Allocate();
|
||||
|
||||
// read from the buffer at most buf.len bytes
|
||||
int len = stream.socket_.Read(buf.data, buf.len);
|
||||
|
||||
// check for read errors
|
||||
if (len == -1) {
|
||||
// this means we have read all available data
|
||||
if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// some other error occurred, check errno
|
||||
this->derived().OnError(stream);
|
||||
return;
|
||||
}
|
||||
|
||||
// end of file, the client has closed the connection
|
||||
if (UNLIKELY(len == 0)) {
|
||||
DLOG(WARNING) << "Calling OnClose because the socket is closed!";
|
||||
this->derived().OnClose(stream);
|
||||
return;
|
||||
}
|
||||
|
||||
// notify the stream that it has new data
|
||||
stream.Written(len);
|
||||
|
||||
this->derived().OnRead(stream);
|
||||
}
|
||||
};
|
||||
}
|
@ -176,5 +176,5 @@ int main(int argc, char **argv) {
|
||||
// Start worker threads.
|
||||
server.Start(FLAGS_num_workers);
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
return 0;
|
||||
}
|
||||
|
@ -25,14 +25,13 @@ class TestData {};
|
||||
|
||||
class TestSession {
|
||||
public:
|
||||
TestSession(socket_t &&socket, TestData &data)
|
||||
: socket_(std::move(socket)) {
|
||||
TestSession(socket_t &&socket, TestData &) : socket_(std::move(socket)) {
|
||||
event_.data.ptr = this;
|
||||
}
|
||||
|
||||
bool Alive() { return socket_.IsOpen(); }
|
||||
|
||||
int Id() const { return socket_.id(); }
|
||||
int Id() const { return socket_.fd(); }
|
||||
|
||||
void Execute() {
|
||||
if (buffer_.size() < 2) return;
|
||||
@ -78,24 +77,24 @@ void client_run(int num, const char *interface, const char *port,
|
||||
socket_t socket;
|
||||
ASSERT_TRUE(socket.Connect(endpoint));
|
||||
ASSERT_TRUE(socket.SetTimeout(2, 0));
|
||||
DLOG(INFO) << "Socket create: " << socket.id();
|
||||
DLOG(INFO) << "Socket create: " << socket.fd();
|
||||
for (int len = lo; len <= hi; len += 100) {
|
||||
have = 0;
|
||||
head[0] = (len >> 8) & 0xff;
|
||||
head[1] = len & 0xff;
|
||||
ASSERT_TRUE(socket.Write(head, 2));
|
||||
ASSERT_TRUE(socket.Write(data, len));
|
||||
DLOG(INFO) << "Socket write: " << socket.id();
|
||||
DLOG(INFO) << "Socket write: " << socket.fd();
|
||||
while (have < len * REPLY) {
|
||||
read = socket.Read(buffer + have, SIZE);
|
||||
DLOG(INFO) << "Socket read: " << socket.id();
|
||||
DLOG(INFO) << "Socket read: " << socket.fd();
|
||||
if (read == -1) break;
|
||||
have += read;
|
||||
}
|
||||
for (int i = 0; i < REPLY; ++i)
|
||||
for (int j = 0; j < len; ++j) ASSERT_EQ(buffer[i * len + j], data[j]);
|
||||
}
|
||||
DLOG(INFO) << "Socket done: " << socket.id();
|
||||
DLOG(INFO) << "Socket done: " << socket.fd();
|
||||
socket.Close();
|
||||
}
|
||||
|
||||
|
@ -27,14 +27,13 @@ class TestData {};
|
||||
|
||||
class TestSession {
|
||||
public:
|
||||
TestSession(socket_t &&socket, TestData &data)
|
||||
: socket_(std::move(socket)) {
|
||||
TestSession(socket_t &&socket, TestData &) : socket_(std::move(socket)) {
|
||||
event_.data.ptr = this;
|
||||
}
|
||||
|
||||
bool Alive() { return socket_.IsOpen(); }
|
||||
|
||||
int Id() const { return socket_.id(); }
|
||||
int Id() const { return socket_.fd(); }
|
||||
|
||||
void Execute() { this->socket_.Write(buffer_.data(), buffer_.size()); }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user