Implement Bolt over WebSocket with asio

* Replace server implementation with asio

* Add support for bolt over WebSocket
This commit is contained in:
Jure Bajic 2022-04-27 10:13:16 +02:00 committed by GitHub
parent 4abaf27765
commit 7b78665cd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 872 additions and 25 deletions

View File

@ -184,7 +184,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall \
-Werror=switch -Werror=switch-bool -Werror=return-type \
-Werror=return-stack-address \
-Wno-c99-designator")
-Wno-c99-designator \
-DBOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT")
# Don't omit frame pointer in RelWithDebInfo, for additional callchain debug.
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO

View File

@ -78,14 +78,18 @@ bool ClientContext::use_ssl() { return use_ssl_; }
ServerContext::ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file,
bool verify_peer) {
ctx_.emplace(boost::asio::ssl::context::tls_server);
namespace ssl = boost::asio::ssl;
ctx_.emplace(ssl::context::tls_server);
// NOLINTNEXTLINE(hicpp-signed-bitwise)
ctx_->set_options(ssl::context::default_workarounds | ssl::context::no_sslv2 | ssl::context::no_sslv3 |
ssl::context::single_dh_use);
ctx_->set_default_verify_paths();
// TODO: add support for encrypted private keys
// TODO: add certificate revocation list (CRL)
boost::system::error_code ec;
ctx_->use_certificate_chain_file(cert_file, ec);
MG_ASSERT(!ec, "Couldn't load server certificate from file: {}", cert_file);
ctx_->use_private_key_file(key_file, boost::asio::ssl::context::pem, ec);
ctx_->use_private_key_file(key_file, ssl::context::pem, ec);
MG_ASSERT(!ec, "Couldn't load server private key from file: {}", key_file);
ctx_->set_options(SSL_OP_NO_SSLv3, ec);
@ -100,7 +104,7 @@ ServerContext::ServerContext(const std::string &key_file, const std::string &cer
if (verify_peer) {
// Enable verification of the client certificate.
// NOLINTNEXTLINE(hicpp-signed-bitwise)
ctx_->set_verify_mode(boost::asio::ssl::verify_peer | boost::asio::ssl::verify_fail_if_no_peer_cert, ec);
ctx_->set_verify_mode(ssl::verify_peer | ssl::verify_fail_if_no_peer_cert, ec);
MG_ASSERT(!ec, "Setting SSL verification mode failed!");
}
}

View File

@ -0,0 +1,135 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <atomic>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string_view>
#include <thread>
#include <vector>
#include <spdlog/spdlog.h>
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/strand.hpp>
#include <boost/beast/core.hpp>
#include <boost/system/detail/error_code.hpp>
#include "communication/context.hpp"
#include "communication/v2/pool.hpp"
#include "communication/v2/session.hpp"
#include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
namespace memgraph::communication::v2 {
template <class TSession, class TSessionData>
class Listener final : public std::enable_shared_from_this<Listener<TSession, TSessionData>> {
using tcp = boost::asio::ip::tcp;
using SessionHandler = Session<TSession, TSessionData>;
using std::enable_shared_from_this<Listener<TSession, TSessionData>>::shared_from_this;
public:
Listener(const Listener &) = delete;
Listener(Listener &&) = delete;
Listener &operator=(const Listener &) = delete;
Listener &operator=(Listener &&) = delete;
~Listener() {}
template <typename... Args>
static std::shared_ptr<Listener> Create(Args &&...args) {
return std::shared_ptr<Listener>{new Listener(std::forward<Args>(args)...)};
}
void Start() { DoAccept(); }
bool IsRunning() const noexcept { return alive_.load(std::memory_order_relaxed); }
private:
Listener(boost::asio::io_context &io_context, TSessionData *data, ServerContext *server_context,
tcp::endpoint &endpoint, const std::string_view service_name, const uint64_t inactivity_timeout_sec)
: io_context_(io_context),
data_(data),
server_context_(server_context),
acceptor_(io_context_),
endpoint_{endpoint},
service_name_{service_name},
inactivity_timeout_{inactivity_timeout_sec} {
boost::system::error_code ec;
// Open the acceptor
acceptor_.open(endpoint.protocol(), ec);
if (ec) {
OnError(ec, "open");
return;
}
// Allow address reuse
acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec);
if (ec) {
OnError(ec, "set_option");
return;
}
// Bind to the server address
acceptor_.bind(endpoint, ec);
if (ec) {
spdlog::error(
utils::MessageWithLink("Cannot bind to socket on endpoint {}.", endpoint, "https://memgr.ph/socket"));
OnError(ec, "bind");
return;
}
acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec);
if (ec) {
OnError(ec, "listen");
return;
}
}
void DoAccept() {
acceptor_.async_accept(io_context_,
[shared_this = shared_from_this()](auto ec, boost::asio::ip::tcp::socket &&socket) {
shared_this->OnAccept(ec, std::move(socket));
});
}
void OnAccept(boost::system::error_code ec, tcp::socket socket) {
if (ec) {
return OnError(ec, "accept");
}
auto session = SessionHandler::Create(std::move(socket), data_, *server_context_, endpoint_, inactivity_timeout_,
service_name_);
session->Start();
DoAccept();
}
void OnError(const boost::system::error_code &ec, const std::string_view what) {
spdlog::error("Listener failed on {}: {}", what, ec.message());
alive_.store(false, std::memory_order_relaxed);
}
boost::asio::io_context &io_context_;
TSessionData *data_;
ServerContext *server_context_;
tcp::acceptor acceptor_;
tcp::endpoint endpoint_;
std::string_view service_name_;
std::chrono::seconds inactivity_timeout_;
std::atomic<bool> alive_;
};
} // namespace memgraph::communication::v2

View File

@ -0,0 +1,68 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstddef>
#include <thread>
#include <vector>
#include <boost/asio/executor_work_guard.hpp>
#include <boost/asio/io_context.hpp>
#include "utils/logging.hpp"
namespace memgraph::communication::v2 {
class IOContextThreadPool final {
private:
using IOContext = boost::asio::io_context;
using IOContextGuard = boost::asio::executor_work_guard<boost::asio::io_context::executor_type>;
public:
explicit IOContextThreadPool(size_t pool_size) : guard_{io_context_.get_executor()}, pool_size_{pool_size} {
MG_ASSERT(pool_size != 0, "Pool size must be greater then 0!");
}
IOContextThreadPool(const IOContextThreadPool &) = delete;
IOContextThreadPool &operator=(const IOContextThreadPool &) = delete;
IOContextThreadPool(IOContextThreadPool &&) = delete;
IOContextThreadPool &operator=(IOContextThreadPool &&) = delete;
~IOContextThreadPool() = default;
void Run() {
background_threads_.reserve(pool_size_);
for (size_t i = 0; i < pool_size_; ++i) {
background_threads_.emplace_back([this]() { io_context_.run(); });
}
running_ = true;
}
void Shutdown() {
io_context_.stop();
running_ = false;
}
void AwaitShutdown() { background_threads_.clear(); }
bool IsRunning() const noexcept { return running_; }
IOContext &GetIOContext() noexcept { return io_context_; }
private:
/// The pool of io_context.
IOContext io_context_;
IOContextGuard guard_;
size_t pool_size_;
std::vector<std::jthread> background_threads_;
bool running_{false};
};
} // namespace memgraph::communication::v2

View File

@ -0,0 +1,128 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <memory>
#include <optional>
#include <string>
#include <thread>
#include <vector>
#include <fmt/format.h>
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/address.hpp>
#include <boost/asio/ip/tcp.hpp>
#include "communication/context.hpp"
#include "communication/init.hpp"
#include "communication/v2/listener.hpp"
#include "communication/v2/pool.hpp"
#include "utils/logging.hpp"
#include "utils/message.hpp"
#include "utils/thread.hpp"
namespace memgraph::communication::v2 {
using Socket = boost::asio::ip::tcp::socket;
using ServerEndpoint = boost::asio::ip::tcp::endpoint;
/**
* Communication server.
*
* Listens for incoming connections on the server port and assigns them to the
* connection listener. The listener and session are implemented using asio
* async model. Currently the implemented model is thread per core model
* opposed to io_context per core. The reasoning for opting for the former model
* is the robustness to the multiple resource demanding queries that can be split
* across multiple threads, and then a single thread would not block io_context,
* unlike in the latter model where it is possible that thread that accepts
* request is being blocked by demanding query.
* All logic is contained within handlers that are being dispatched
* on a single strand per session. The only exception is write which is
* synchronous since the nature of the clients conenction is synchronous as
* well.
*
* Current Server architecture:
* incoming connection -> server -> listener -> session
*
* @tparam TSession the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
* @tparam TSessionData the class with objects that will be forwarded to the
* session
*/
template <typename TSession, typename TSessionData>
class Server final {
using ServerHandler = Server<TSession, TSessionData>;
public:
/**
* Constructs and binds server to endpoint, operates on session data and
* invokes workers_count workers
*/
Server(ServerEndpoint &endpoint, TSessionData *session_data, ServerContext *server_context,
const int inactivity_timeout_sec, const std::string_view service_name,
size_t workers_count = std::thread::hardware_concurrency())
: endpoint_{endpoint},
service_name_{service_name},
context_thread_pool_{workers_count},
listener_{Listener<TSession, TSessionData>::Create(context_thread_pool_.GetIOContext(), session_data,
server_context, endpoint_, service_name_,
inactivity_timeout_sec)} {}
~Server() { MG_ASSERT(!IsRunning(), "Server wasn't shutdown properly"); }
Server(const Server &) = delete;
Server(Server &&) = delete;
Server &operator=(const Server &) = delete;
Server &operator=(Server &&) = delete;
const auto &Endpoint() const {
MG_ASSERT(IsRunning(), "You can't get the server endpoint when it's not running!");
return endpoint_;
}
bool Start() {
if (IsRunning()) {
spdlog::error("The server is already running");
return false;
}
listener_->Start();
spdlog::info("{} server is fully armed and operational", service_name_);
spdlog::info("{} listening on {}", service_name_, endpoint_.address());
context_thread_pool_.Run();
return true;
}
void Shutdown() {
context_thread_pool_.Shutdown();
spdlog::info("{} shutting down...", service_name_);
}
void AwaitShutdown() { context_thread_pool_.AwaitShutdown(); }
bool IsRunning() const noexcept { return context_thread_pool_.IsRunning() && listener_->IsRunning(); }
private:
ServerEndpoint endpoint_;
std::string service_name_;
IOContextThreadPool context_thread_pool_;
std::shared_ptr<Listener<TSession, TSessionData>> listener_;
};
} // namespace memgraph::communication::v2

View File

@ -0,0 +1,513 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <deque>
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <spdlog/spdlog.h>
#include <boost/asio/bind_executor.hpp>
#include <boost/asio/buffer.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/read.hpp>
#include <boost/asio/socket_base.hpp>
#include <boost/asio/ssl/stream.hpp>
#include <boost/asio/ssl/stream_base.hpp>
#include <boost/asio/steady_timer.hpp>
#include <boost/asio/strand.hpp>
#include <boost/asio/system_context.hpp>
#include <boost/asio/write.hpp>
#include <boost/beast/core/tcp_stream.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/websocket.hpp>
#include <boost/beast/websocket/rfc6455.hpp>
#include <boost/system/detail/error_code.hpp>
#include "communication/context.hpp"
#include "communication/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/variant_helpers.hpp"
namespace memgraph::communication::v2 {
/**
* This is used to provide input to user Sessions. All Sessions used with the
* network stack should use this class as their input stream.
*/
using InputStream = communication::Buffer::ReadEnd;
using tcp = boost::asio::ip::tcp;
/**
* This is used to provide output from user Sessions. All Sessions used with the
* network stack should use this class for their output stream.
*/
class OutputStream final {
public:
explicit OutputStream(std::function<bool(const uint8_t *, size_t, bool)> write_function)
: write_function_(write_function) {}
OutputStream(const OutputStream &) = delete;
OutputStream(OutputStream &&) = delete;
OutputStream &operator=(const OutputStream &) = delete;
OutputStream &operator=(OutputStream &&) = delete;
~OutputStream() = default;
bool Write(const uint8_t *data, size_t len, bool have_more = false) { return write_function_(data, len, have_more); }
bool Write(const std::string &str, bool have_more = false) {
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(), have_more);
}
private:
std::function<bool(const uint8_t *, size_t, bool)> write_function_;
};
/**
* This class is used internally in the communication stack to handle all user
* Websocket Sessions. It handles socket ownership, inactivity timeout and protocol
* wrapping.
*/
template <typename TSession, typename TSessionData>
class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TSession, TSessionData>> {
using WebSocket = boost::beast::websocket::stream<boost::beast::tcp_stream>;
using std::enable_shared_from_this<WebsocketSession<TSession, TSessionData>>::shared_from_this;
public:
template <typename... Args>
static std::shared_ptr<WebsocketSession> Create(Args &&...args) {
return std::shared_ptr<WebsocketSession>(new WebsocketSession(std::forward<Args>(args)...));
}
// Start the asynchronous accept operation
template <class Body, class Allocator>
void DoAccept(boost::beast::http::request<Body, boost::beast::http::basic_fields<Allocator>> req) {
execution_active_ = true;
// Set suggested timeout settings for the websocket
ws_.set_option(boost::beast::websocket::stream_base::timeout::suggested(boost::beast::role_type::server));
boost::asio::socket_base::keep_alive option(true);
// Set a decorator to change the Server of the handshake
ws_.set_option(boost::beast::websocket::stream_base::decorator([](boost::beast::websocket::response_type &res) {
res.set(boost::beast::http::field::server, std::string("Memgraph Bolt WS"));
res.set(boost::beast::http::field::sec_websocket_protocol, "binary");
}));
ws_.binary(true);
// Accept the websocket handshake
ws_.async_accept(
req, boost::asio::bind_executor(strand_, std::bind_front(&WebsocketSession::OnAccept, shared_from_this())));
}
bool Write(const uint8_t *data, size_t len) {
if (!IsConnected()) {
return false;
}
boost::system::error_code ec;
ws_.write(boost::asio::buffer(data, len), ec);
if (ec) {
OnError(ec, "write");
return false;
}
return true;
}
private:
// Take ownership of the socket
explicit WebsocketSession(tcp::socket &&socket, TSessionData *data, tcp::endpoint endpoint,
std::string_view service_name)
: ws_(std::move(socket)),
strand_{boost::asio::make_strand(ws_.get_executor())},
output_stream_([this](const uint8_t *data, size_t len, bool /*have_more*/) { return Write(data, len); }),
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
endpoint_{endpoint},
remote_endpoint_{ws_.next_layer().socket().remote_endpoint()},
service_name_{service_name} {}
void OnAccept(boost::beast::error_code ec) {
if (ec) {
return OnError(ec, "accept");
}
// Read a message
DoRead();
}
void DoRead() {
if (!IsConnected()) {
return;
}
// Read a message into our buffer
auto buffer = input_buffer_.write_end()->Allocate();
ws_.async_read_some(
boost::asio::buffer(buffer.data, buffer.len),
boost::asio::bind_executor(strand_, std::bind_front(&WebsocketSession::OnRead, shared_from_this())));
}
void OnRead(const boost::system::error_code &ec, [[maybe_unused]] const size_t bytes_transferred) {
// This indicates that the WebsocketSession was closed
if (ec == boost::beast::websocket::error::closed) {
return;
}
if (ec) {
OnError(ec, "read");
}
input_buffer_.write_end()->Written(bytes_transferred);
try {
session_.Execute();
DoRead();
} catch (const SessionClosedException &e) {
spdlog::info("{} client {}:{} closed the connection.", service_name_, remote_endpoint_.address(),
remote_endpoint_.port());
DoClose();
} catch (const std::exception &e) {
spdlog::error(
"Exception was thrown while processing event in {} session "
"associated with {}:{}",
service_name_, remote_endpoint_.address(), remote_endpoint_.port());
spdlog::debug("Exception message: {}", e.what());
DoClose();
}
}
void OnError(const boost::system::error_code &ec, const std::string_view action) {
spdlog::error("Websocket Bolt session error: {} on {}", ec.message(), action);
DoClose();
}
void DoClose() {
ws_.async_close(
boost::beast::websocket::close_code::normal,
boost::asio::bind_executor(
strand_, [shared_this = shared_from_this()](boost::beast::error_code ec) { shared_this->OnClose(ec); }));
}
void OnClose(const boost::system::error_code &ec) {
if (!IsConnected()) {
return;
}
if (ec) {
return OnError(ec, "close");
}
}
bool IsConnected() const { return ws_.is_open() && execution_active_; }
WebSocket ws_;
boost::asio::strand<WebSocket::executor_type> strand_;
communication::Buffer input_buffer_;
OutputStream output_stream_;
TSession session_;
tcp::endpoint endpoint_;
tcp::endpoint remote_endpoint_;
std::string_view service_name_;
bool execution_active_{false};
};
/**
* This class is used internally in the communication stack to handle all user
* Sessions. It handles socket ownership, inactivity timeout and protocol
* wrapping.
*/
template <typename TSession, typename TSessionData>
class Session final : public std::enable_shared_from_this<Session<TSession, TSessionData>> {
using TCPSocket = tcp::socket;
using SSLSocket = boost::asio::ssl::stream<TCPSocket>;
using std::enable_shared_from_this<Session<TSession, TSessionData>>::shared_from_this;
public:
template <typename... Args>
static std::shared_ptr<Session> Create(Args &&...args) {
return std::shared_ptr<Session>(new Session(std::forward<Args>(args)...));
}
Session(const Session &) = delete;
Session(Session &&) = delete;
Session &operator=(const Session &) = delete;
Session &operator=(Session &&) = delete;
~Session() {
if (IsConnected()) {
spdlog::error("Session: Destructor called while execution is active");
}
}
bool Start() {
if (execution_active_) {
return false;
}
execution_active_ = true;
timeout_timer_.async_wait(boost::asio::bind_executor(strand_, std::bind(&Session::OnTimeout, shared_from_this())));
if (std::holds_alternative<SSLSocket>(socket_)) {
boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoHandshake(); });
} else {
boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoRead(); });
}
return true;
}
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
if (!IsConnected()) {
return false;
}
return std::visit(
utils::Overloaded{[shared_this = shared_from_this(), data, len, have_more](TCPSocket &socket) mutable {
boost::system::error_code ec;
while (len > 0) {
const auto sent = socket.send(boost::asio::buffer(data, len),
MSG_NOSIGNAL | (have_more ? MSG_MORE : 0), ec);
if (ec) {
shared_this->OnError(ec);
return false;
}
data += sent;
len -= sent;
}
return true;
},
[shared_this = shared_from_this(), data, len](SSLSocket &socket) mutable {
boost::system::error_code ec;
while (len > 0) {
const auto sent = socket.write_some(boost::asio::buffer(data, len), ec);
if (ec) {
shared_this->OnError(ec);
return false;
}
data += sent;
len -= sent;
}
return true;
}},
socket_);
}
bool IsConnected() const {
return std::visit([this](const auto &socket) { return execution_active_ && socket.lowest_layer().is_open(); },
socket_);
}
private:
explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &server_context, tcp::endpoint endpoint,
const std::chrono::seconds inactivity_timeout_sec, std::string_view service_name)
: socket_(CreateSocket(std::move(socket), server_context)),
strand_{boost::asio::make_strand(GetExecutor())},
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
data_{data},
endpoint_{endpoint},
remote_endpoint_{GetRemoteEndpoint()},
service_name_{service_name},
timeout_seconds_(inactivity_timeout_sec),
timeout_timer_(GetExecutor()) {
ExecuteForSocket([](auto &&socket) {
socket.lowest_layer().set_option(tcp::no_delay(true)); // enable PSH
socket.lowest_layer().set_option(boost::asio::socket_base::keep_alive(true)); // enable SO_KEEPALIVE
socket.lowest_layer().non_blocking(false);
});
timeout_timer_.expires_at(boost::asio::steady_timer::time_point::max());
spdlog::info("Accepted a connection from {}:", service_name_, remote_endpoint_.address(), remote_endpoint_.port());
}
void DoRead() {
if (!IsConnected()) {
return;
}
timeout_timer_.expires_after(timeout_seconds_);
ExecuteForSocket([this](auto &&socket) {
auto buffer = input_buffer_.write_end()->Allocate();
socket.async_read_some(
boost::asio::buffer(buffer.data, buffer.len),
boost::asio::bind_executor(strand_, std::bind_front(&Session::OnRead, shared_from_this())));
});
}
bool IsWebsocketUpgrade(boost::beast::http::request_parser<boost::beast::http::string_body> &parser) {
boost::system::error_code error_code_parsing;
parser.put(boost::asio::buffer(input_buffer_.read_end()->data(), input_buffer_.read_end()->size()),
error_code_parsing);
if (error_code_parsing) {
return false;
}
return boost::beast::websocket::is_upgrade(parser.get());
}
void OnRead(const boost::system::error_code &ec, const size_t bytes_transferred) {
if (ec) {
return OnError(ec);
}
input_buffer_.write_end()->Written(bytes_transferred);
// Can be a websocket connection only on the first read, since it is not
// expected from clients to upgrade from tcp to websocket
if (!has_received_msg_) {
has_received_msg_ = true;
boost::beast::http::request_parser<boost::beast::http::string_body> parser;
if (IsWebsocketUpgrade(parser)) {
spdlog::info("Switching {} to websocket connection", remote_endpoint_);
if (std::holds_alternative<TCPSocket>(socket_)) {
auto sock = std::get<TCPSocket>(std::move(socket_));
WebsocketSession<TSession, TSessionData>::Create(std::move(sock), data_, endpoint_, service_name_)
->DoAccept(parser.release());
execution_active_ = false;
return;
}
spdlog::error("Error while upgrading connection to websocket");
DoShutdown();
}
}
try {
session_.Execute();
DoRead();
} catch (const SessionClosedException &e) {
spdlog::info("{} client {}:{} closed the connection.", service_name_, remote_endpoint_.address(),
remote_endpoint_.port());
DoShutdown();
} catch (const std::exception &e) {
spdlog::error(
"Exception was thrown while processing event in {} session "
"associated with {}:{}",
service_name_, remote_endpoint_.address(), remote_endpoint_.port());
spdlog::debug("Exception message: {}", e.what());
DoShutdown();
}
}
void OnError(const boost::system::error_code &ec) {
if (ec == boost::asio::error::operation_aborted) {
return;
}
execution_active_ = false;
if (ec == boost::asio::error::eof) {
spdlog::info("Session closed by peer");
} else {
spdlog::error("Session error: {}", ec.message());
}
DoShutdown();
}
void DoShutdown() {
if (!IsConnected()) {
return;
}
execution_active_ = false;
timeout_timer_.cancel();
ExecuteForSocket([](auto &socket) {
boost::system::error_code ec;
auto &lowest_layer = socket.lowest_layer();
lowest_layer.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
if (ec) {
spdlog::error("Session shutdown failed: {}", ec.what());
}
lowest_layer.close();
});
}
void DoHandshake() {
if (!IsConnected()) {
return;
}
if (auto *socket = std::get_if<SSLSocket>(&socket_); socket) {
socket->async_handshake(
boost::asio::ssl::stream_base::server,
boost::asio::bind_executor(strand_, std::bind_front(&Session::OnHandshake, shared_from_this())));
}
}
void OnHandshake(const boost::system::error_code &ec) {
if (ec) {
return OnError(ec);
}
DoRead();
}
void OnClose(const boost::system::error_code &ec) {
if (ec) {
return OnError(ec);
}
}
void OnTimeout() {
if (!IsConnected()) {
return;
}
// Check whether the deadline has passed. We compare the deadline against
// the current time since a new asynchronous operation may have moved the
// deadline before this actor had a chance to run.
if (timeout_timer_.expiry() <= boost::asio::steady_timer::clock_type::now()) {
// The deadline has passed. Stop the session. The other actors will
// terminate as soon as possible.
spdlog::info("Shutting down session after {} of inactivity", timeout_seconds_);
DoShutdown();
} else {
// Put the actor back to sleep.
timeout_timer_.async_wait(
boost::asio::bind_executor(strand_, std::bind(&Session::OnTimeout, shared_from_this())));
}
}
std::variant<TCPSocket, SSLSocket> CreateSocket(tcp::socket &&socket, ServerContext &context) {
if (context.use_ssl()) {
ssl_context_.emplace(context.context_clone());
return SSLSocket{std::move(socket), *ssl_context_};
}
return TCPSocket{std::move(socket)};
}
auto GetExecutor() {
return std::visit(utils::Overloaded{[](auto &&socket) { return socket.get_executor(); }}, socket_);
}
auto GetRemoteEndpoint() const {
return std::visit(utils::Overloaded{[](const auto &socket) { return socket.lowest_layer().remote_endpoint(); }},
socket_);
}
template <typename F>
decltype(auto) ExecuteForSocket(F &&fun) {
return std::visit(utils::Overloaded{std::forward<F>(fun)}, socket_);
}
std::variant<TCPSocket, SSLSocket> socket_;
std::optional<std::reference_wrapper<boost::asio::ssl::context>> ssl_context_;
boost::asio::strand<tcp::socket::executor_type> strand_;
communication::Buffer input_buffer_;
OutputStream output_stream_;
TSession session_;
TSessionData *data_;
tcp::endpoint endpoint_;
tcp::endpoint remote_endpoint_;
std::string_view service_name_;
std::chrono::seconds timeout_seconds_;
boost::asio::steady_timer timeout_timer_;
bool execution_active_{false};
bool has_received_msg_{false};
};
} // namespace memgraph::communication::v2

View File

@ -11,8 +11,6 @@
#pragma once
#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT
#include <list>
#include <memory>

View File

@ -11,8 +11,6 @@
#pragma once
#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT
#include <thread>
#include <spdlog/sinks/base_sink.h>

View File

@ -11,8 +11,6 @@
#pragma once
#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT
#include <deque>
#include <memory>
#include <optional>

View File

@ -81,8 +81,8 @@
#include "communication/bolt/v1/exceptions.hpp"
#include "communication/bolt/v1/session.hpp"
#include "communication/init.hpp"
#include "communication/server.hpp"
#include "communication/session.hpp"
#include "communication/v2/server.hpp"
#include "communication/v2/session.hpp"
#include "glue/communication.hpp"
#include "auth/auth.hpp"
@ -842,13 +842,14 @@ class AuthChecker final : public memgraph::query::AuthChecker {
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
};
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::InputStream,
memgraph::communication::OutputStream> {
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream> {
public:
BoltSession(SessionData *data, const memgraph::io::network::Endpoint &endpoint,
memgraph::communication::InputStream *input_stream, memgraph::communication::OutputStream *output_stream)
: memgraph::communication::bolt::Session<memgraph::communication::InputStream,
memgraph::communication::OutputStream>(input_stream, output_stream),
BoltSession(SessionData *data, const memgraph::communication::v2::ServerEndpoint &endpoint,
memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream)
: memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream>(input_stream, output_stream),
db_(data->db),
interpreter_(data->interpreter_context),
auth_(data->auth),
@ -858,8 +859,8 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
endpoint_(endpoint) {
}
using memgraph::communication::bolt::Session<memgraph::communication::InputStream,
memgraph::communication::OutputStream>::TEncoder;
using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream>::TEncoder;
void BeginTransaction() override { interpreter_.BeginTransaction(); }
@ -877,7 +878,8 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
}
#ifdef MG_ENTERPRISE
if (memgraph::utils::license::global_license_checker.IsValidLicenseFast()) {
audit_log_->Record(endpoint_.address, user_ ? *username : "", query, memgraph::storage::PropertyValue(params_pv));
audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query,
memgraph::storage::PropertyValue(params_pv));
}
#endif
try {
@ -996,10 +998,10 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log_;
#endif
memgraph::io::network::Endpoint endpoint_;
memgraph::communication::v2::ServerEndpoint endpoint_;
};
using ServerT = memgraph::communication::Server<BoltSession, SessionData>;
using ServerT = memgraph::communication::v2::Server<BoltSession, SessionData>;
using memgraph::communication::ServerContext;
// Needed to correctly handle memgraph destruction from a signal handler.
@ -1241,8 +1243,10 @@ int main(int argc, char **argv) {
memgraph::utils::MessageWithLink("Using non-secure Bolt connection (without SSL).", "https://memgr.ph/ssl"));
}
ServerT server({FLAGS_bolt_address, static_cast<uint16_t>(FLAGS_bolt_port)}, &session_data, &context,
FLAGS_bolt_session_inactivity_timeout, service_name, FLAGS_bolt_num_workers);
auto server_endpoint = memgraph::communication::v2::ServerEndpoint{
boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast<uint16_t>(FLAGS_bolt_port)};
ServerT server(server_endpoint, &session_data, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
FLAGS_bolt_num_workers);
// Setup telemetry
std::optional<memgraph::telemetry::Telemetry> telemetry;