From 7b78665cd896a063bdc81af1b168d6b69d6d55d0 Mon Sep 17 00:00:00 2001 From: Jure Bajic Date: Wed, 27 Apr 2022 10:13:16 +0200 Subject: [PATCH] Implement Bolt over WebSocket with asio * Replace server implementation with asio * Add support for bolt over WebSocket --- CMakeLists.txt | 3 +- src/communication/context.cpp | 10 +- src/communication/v2/listener.hpp | 135 ++++++ src/communication/v2/pool.hpp | 68 +++ src/communication/v2/server.hpp | 128 ++++++ src/communication/v2/session.hpp | 513 +++++++++++++++++++++++ src/communication/websocket/listener.hpp | 2 - src/communication/websocket/server.hpp | 2 - src/communication/websocket/session.hpp | 2 - src/memgraph.cpp | 34 +- 10 files changed, 872 insertions(+), 25 deletions(-) create mode 100644 src/communication/v2/listener.hpp create mode 100644 src/communication/v2/pool.hpp create mode 100644 src/communication/v2/server.hpp create mode 100644 src/communication/v2/session.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index dc5fc204f..916389424 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -184,7 +184,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall \ -Werror=switch -Werror=switch-bool -Werror=return-type \ -Werror=return-stack-address \ - -Wno-c99-designator") + -Wno-c99-designator \ + -DBOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT") # Don't omit frame pointer in RelWithDebInfo, for additional callchain debug. set(CMAKE_CXX_FLAGS_RELWITHDEBINFO diff --git a/src/communication/context.cpp b/src/communication/context.cpp index b3e58824b..53cb4586b 100644 --- a/src/communication/context.cpp +++ b/src/communication/context.cpp @@ -78,14 +78,18 @@ bool ClientContext::use_ssl() { return use_ssl_; } ServerContext::ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file, bool verify_peer) { - ctx_.emplace(boost::asio::ssl::context::tls_server); + namespace ssl = boost::asio::ssl; + ctx_.emplace(ssl::context::tls_server); + // NOLINTNEXTLINE(hicpp-signed-bitwise) + ctx_->set_options(ssl::context::default_workarounds | ssl::context::no_sslv2 | ssl::context::no_sslv3 | + ssl::context::single_dh_use); ctx_->set_default_verify_paths(); // TODO: add support for encrypted private keys // TODO: add certificate revocation list (CRL) boost::system::error_code ec; ctx_->use_certificate_chain_file(cert_file, ec); MG_ASSERT(!ec, "Couldn't load server certificate from file: {}", cert_file); - ctx_->use_private_key_file(key_file, boost::asio::ssl::context::pem, ec); + ctx_->use_private_key_file(key_file, ssl::context::pem, ec); MG_ASSERT(!ec, "Couldn't load server private key from file: {}", key_file); ctx_->set_options(SSL_OP_NO_SSLv3, ec); @@ -100,7 +104,7 @@ ServerContext::ServerContext(const std::string &key_file, const std::string &cer if (verify_peer) { // Enable verification of the client certificate. // NOLINTNEXTLINE(hicpp-signed-bitwise) - ctx_->set_verify_mode(boost::asio::ssl::verify_peer | boost::asio::ssl::verify_fail_if_no_peer_cert, ec); + ctx_->set_verify_mode(ssl::verify_peer | ssl::verify_fail_if_no_peer_cert, ec); MG_ASSERT(!ec, "Setting SSL verification mode failed!"); } } diff --git a/src/communication/v2/listener.hpp b/src/communication/v2/listener.hpp new file mode 100644 index 000000000..bd26bea12 --- /dev/null +++ b/src/communication/v2/listener.hpp @@ -0,0 +1,135 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "communication/context.hpp" +#include "communication/v2/pool.hpp" +#include "communication/v2/session.hpp" +#include "utils/spin_lock.hpp" +#include "utils/synchronized.hpp" + +namespace memgraph::communication::v2 { + +template +class Listener final : public std::enable_shared_from_this> { + using tcp = boost::asio::ip::tcp; + using SessionHandler = Session; + using std::enable_shared_from_this>::shared_from_this; + + public: + Listener(const Listener &) = delete; + Listener(Listener &&) = delete; + Listener &operator=(const Listener &) = delete; + Listener &operator=(Listener &&) = delete; + ~Listener() {} + + template + static std::shared_ptr Create(Args &&...args) { + return std::shared_ptr{new Listener(std::forward(args)...)}; + } + + void Start() { DoAccept(); } + + bool IsRunning() const noexcept { return alive_.load(std::memory_order_relaxed); } + + private: + Listener(boost::asio::io_context &io_context, TSessionData *data, ServerContext *server_context, + tcp::endpoint &endpoint, const std::string_view service_name, const uint64_t inactivity_timeout_sec) + : io_context_(io_context), + data_(data), + server_context_(server_context), + acceptor_(io_context_), + endpoint_{endpoint}, + service_name_{service_name}, + inactivity_timeout_{inactivity_timeout_sec} { + boost::system::error_code ec; + // Open the acceptor + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + OnError(ec, "open"); + return; + } + + // Allow address reuse + acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec); + if (ec) { + OnError(ec, "set_option"); + return; + } + + // Bind to the server address + acceptor_.bind(endpoint, ec); + if (ec) { + spdlog::error( + utils::MessageWithLink("Cannot bind to socket on endpoint {}.", endpoint, "https://memgr.ph/socket")); + OnError(ec, "bind"); + return; + } + + acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec); + if (ec) { + OnError(ec, "listen"); + return; + } + } + + void DoAccept() { + acceptor_.async_accept(io_context_, + [shared_this = shared_from_this()](auto ec, boost::asio::ip::tcp::socket &&socket) { + shared_this->OnAccept(ec, std::move(socket)); + }); + } + + void OnAccept(boost::system::error_code ec, tcp::socket socket) { + if (ec) { + return OnError(ec, "accept"); + } + + auto session = SessionHandler::Create(std::move(socket), data_, *server_context_, endpoint_, inactivity_timeout_, + service_name_); + session->Start(); + DoAccept(); + } + + void OnError(const boost::system::error_code &ec, const std::string_view what) { + spdlog::error("Listener failed on {}: {}", what, ec.message()); + alive_.store(false, std::memory_order_relaxed); + } + + boost::asio::io_context &io_context_; + TSessionData *data_; + ServerContext *server_context_; + tcp::acceptor acceptor_; + + tcp::endpoint endpoint_; + std::string_view service_name_; + std::chrono::seconds inactivity_timeout_; + + std::atomic alive_; +}; +} // namespace memgraph::communication::v2 diff --git a/src/communication/v2/pool.hpp b/src/communication/v2/pool.hpp new file mode 100644 index 000000000..f29675aef --- /dev/null +++ b/src/communication/v2/pool.hpp @@ -0,0 +1,68 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include + +#include +#include + +#include "utils/logging.hpp" + +namespace memgraph::communication::v2 { + +class IOContextThreadPool final { + private: + using IOContext = boost::asio::io_context; + using IOContextGuard = boost::asio::executor_work_guard; + + public: + explicit IOContextThreadPool(size_t pool_size) : guard_{io_context_.get_executor()}, pool_size_{pool_size} { + MG_ASSERT(pool_size != 0, "Pool size must be greater then 0!"); + } + + IOContextThreadPool(const IOContextThreadPool &) = delete; + IOContextThreadPool &operator=(const IOContextThreadPool &) = delete; + IOContextThreadPool(IOContextThreadPool &&) = delete; + IOContextThreadPool &operator=(IOContextThreadPool &&) = delete; + ~IOContextThreadPool() = default; + + void Run() { + background_threads_.reserve(pool_size_); + for (size_t i = 0; i < pool_size_; ++i) { + background_threads_.emplace_back([this]() { io_context_.run(); }); + } + running_ = true; + } + + void Shutdown() { + io_context_.stop(); + running_ = false; + } + + void AwaitShutdown() { background_threads_.clear(); } + + bool IsRunning() const noexcept { return running_; } + + IOContext &GetIOContext() noexcept { return io_context_; } + + private: + /// The pool of io_context. + IOContext io_context_; + IOContextGuard guard_; + size_t pool_size_; + std::vector background_threads_; + bool running_{false}; +}; +} // namespace memgraph::communication::v2 diff --git a/src/communication/v2/server.hpp b/src/communication/v2/server.hpp new file mode 100644 index 000000000..0165035b4 --- /dev/null +++ b/src/communication/v2/server.hpp @@ -0,0 +1,128 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "communication/context.hpp" +#include "communication/init.hpp" +#include "communication/v2/listener.hpp" +#include "communication/v2/pool.hpp" +#include "utils/logging.hpp" +#include "utils/message.hpp" +#include "utils/thread.hpp" + +namespace memgraph::communication::v2 { + +using Socket = boost::asio::ip::tcp::socket; +using ServerEndpoint = boost::asio::ip::tcp::endpoint; +/** + * Communication server. + * + * Listens for incoming connections on the server port and assigns them to the + * connection listener. The listener and session are implemented using asio + * async model. Currently the implemented model is thread per core model + * opposed to io_context per core. The reasoning for opting for the former model + * is the robustness to the multiple resource demanding queries that can be split + * across multiple threads, and then a single thread would not block io_context, + * unlike in the latter model where it is possible that thread that accepts + * request is being blocked by demanding query. + * All logic is contained within handlers that are being dispatched + * on a single strand per session. The only exception is write which is + * synchronous since the nature of the clients conenction is synchronous as + * well. + * + * Current Server architecture: + * incoming connection -> server -> listener -> session + + * + * @tparam TSession the server can handle different Sessions, each session + * represents a different protocol so the same network infrastructure + * can be used for handling different protocols + * @tparam TSessionData the class with objects that will be forwarded to the + * session + */ +template +class Server final { + using ServerHandler = Server; + + public: + /** + * Constructs and binds server to endpoint, operates on session data and + * invokes workers_count workers + */ + Server(ServerEndpoint &endpoint, TSessionData *session_data, ServerContext *server_context, + const int inactivity_timeout_sec, const std::string_view service_name, + size_t workers_count = std::thread::hardware_concurrency()) + : endpoint_{endpoint}, + service_name_{service_name}, + context_thread_pool_{workers_count}, + listener_{Listener::Create(context_thread_pool_.GetIOContext(), session_data, + server_context, endpoint_, service_name_, + inactivity_timeout_sec)} {} + + ~Server() { MG_ASSERT(!IsRunning(), "Server wasn't shutdown properly"); } + + Server(const Server &) = delete; + Server(Server &&) = delete; + Server &operator=(const Server &) = delete; + Server &operator=(Server &&) = delete; + + const auto &Endpoint() const { + MG_ASSERT(IsRunning(), "You can't get the server endpoint when it's not running!"); + return endpoint_; + } + + bool Start() { + if (IsRunning()) { + spdlog::error("The server is already running"); + return false; + } + listener_->Start(); + + spdlog::info("{} server is fully armed and operational", service_name_); + spdlog::info("{} listening on {}", service_name_, endpoint_.address()); + context_thread_pool_.Run(); + + return true; + } + + void Shutdown() { + context_thread_pool_.Shutdown(); + spdlog::info("{} shutting down...", service_name_); + } + + void AwaitShutdown() { context_thread_pool_.AwaitShutdown(); } + + bool IsRunning() const noexcept { return context_thread_pool_.IsRunning() && listener_->IsRunning(); } + + private: + ServerEndpoint endpoint_; + std::string service_name_; + + IOContextThreadPool context_thread_pool_; + std::shared_ptr> listener_; +}; + +} // namespace memgraph::communication::v2 diff --git a/src/communication/v2/session.hpp b/src/communication/v2/session.hpp new file mode 100644 index 000000000..0add8244d --- /dev/null +++ b/src/communication/v2/session.hpp @@ -0,0 +1,513 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "communication/context.hpp" +#include "communication/exceptions.hpp" +#include "utils/logging.hpp" +#include "utils/variant_helpers.hpp" + +namespace memgraph::communication::v2 { + +/** + * This is used to provide input to user Sessions. All Sessions used with the + * network stack should use this class as their input stream. + */ +using InputStream = communication::Buffer::ReadEnd; +using tcp = boost::asio::ip::tcp; + +/** + * This is used to provide output from user Sessions. All Sessions used with the + * network stack should use this class for their output stream. + */ +class OutputStream final { + public: + explicit OutputStream(std::function write_function) + : write_function_(write_function) {} + + OutputStream(const OutputStream &) = delete; + OutputStream(OutputStream &&) = delete; + OutputStream &operator=(const OutputStream &) = delete; + OutputStream &operator=(OutputStream &&) = delete; + ~OutputStream() = default; + + bool Write(const uint8_t *data, size_t len, bool have_more = false) { return write_function_(data, len, have_more); } + + bool Write(const std::string &str, bool have_more = false) { + return Write(reinterpret_cast(str.data()), str.size(), have_more); + } + + private: + std::function write_function_; +}; + +/** + * This class is used internally in the communication stack to handle all user + * Websocket Sessions. It handles socket ownership, inactivity timeout and protocol + * wrapping. + */ +template +class WebsocketSession : public std::enable_shared_from_this> { + using WebSocket = boost::beast::websocket::stream; + using std::enable_shared_from_this>::shared_from_this; + + public: + template + static std::shared_ptr Create(Args &&...args) { + return std::shared_ptr(new WebsocketSession(std::forward(args)...)); + } + + // Start the asynchronous accept operation + template + void DoAccept(boost::beast::http::request> req) { + execution_active_ = true; + // Set suggested timeout settings for the websocket + ws_.set_option(boost::beast::websocket::stream_base::timeout::suggested(boost::beast::role_type::server)); + boost::asio::socket_base::keep_alive option(true); + + // Set a decorator to change the Server of the handshake + ws_.set_option(boost::beast::websocket::stream_base::decorator([](boost::beast::websocket::response_type &res) { + res.set(boost::beast::http::field::server, std::string("Memgraph Bolt WS")); + res.set(boost::beast::http::field::sec_websocket_protocol, "binary"); + })); + ws_.binary(true); + + // Accept the websocket handshake + ws_.async_accept( + req, boost::asio::bind_executor(strand_, std::bind_front(&WebsocketSession::OnAccept, shared_from_this()))); + } + + bool Write(const uint8_t *data, size_t len) { + if (!IsConnected()) { + return false; + } + + boost::system::error_code ec; + ws_.write(boost::asio::buffer(data, len), ec); + if (ec) { + OnError(ec, "write"); + return false; + } + return true; + } + + private: + // Take ownership of the socket + explicit WebsocketSession(tcp::socket &&socket, TSessionData *data, tcp::endpoint endpoint, + std::string_view service_name) + : ws_(std::move(socket)), + strand_{boost::asio::make_strand(ws_.get_executor())}, + output_stream_([this](const uint8_t *data, size_t len, bool /*have_more*/) { return Write(data, len); }), + session_(data, endpoint, input_buffer_.read_end(), &output_stream_), + endpoint_{endpoint}, + remote_endpoint_{ws_.next_layer().socket().remote_endpoint()}, + service_name_{service_name} {} + + void OnAccept(boost::beast::error_code ec) { + if (ec) { + return OnError(ec, "accept"); + } + + // Read a message + DoRead(); + } + + void DoRead() { + if (!IsConnected()) { + return; + } + // Read a message into our buffer + auto buffer = input_buffer_.write_end()->Allocate(); + ws_.async_read_some( + boost::asio::buffer(buffer.data, buffer.len), + boost::asio::bind_executor(strand_, std::bind_front(&WebsocketSession::OnRead, shared_from_this()))); + } + + void OnRead(const boost::system::error_code &ec, [[maybe_unused]] const size_t bytes_transferred) { + // This indicates that the WebsocketSession was closed + if (ec == boost::beast::websocket::error::closed) { + return; + } + if (ec) { + OnError(ec, "read"); + } + input_buffer_.write_end()->Written(bytes_transferred); + + try { + session_.Execute(); + DoRead(); + } catch (const SessionClosedException &e) { + spdlog::info("{} client {}:{} closed the connection.", service_name_, remote_endpoint_.address(), + remote_endpoint_.port()); + DoClose(); + } catch (const std::exception &e) { + spdlog::error( + "Exception was thrown while processing event in {} session " + "associated with {}:{}", + service_name_, remote_endpoint_.address(), remote_endpoint_.port()); + spdlog::debug("Exception message: {}", e.what()); + DoClose(); + } + } + + void OnError(const boost::system::error_code &ec, const std::string_view action) { + spdlog::error("Websocket Bolt session error: {} on {}", ec.message(), action); + + DoClose(); + } + + void DoClose() { + ws_.async_close( + boost::beast::websocket::close_code::normal, + boost::asio::bind_executor( + strand_, [shared_this = shared_from_this()](boost::beast::error_code ec) { shared_this->OnClose(ec); })); + } + + void OnClose(const boost::system::error_code &ec) { + if (!IsConnected()) { + return; + } + if (ec) { + return OnError(ec, "close"); + } + } + + bool IsConnected() const { return ws_.is_open() && execution_active_; } + + WebSocket ws_; + boost::asio::strand strand_; + + communication::Buffer input_buffer_; + OutputStream output_stream_; + TSession session_; + tcp::endpoint endpoint_; + tcp::endpoint remote_endpoint_; + std::string_view service_name_; + bool execution_active_{false}; +}; + +/** + * This class is used internally in the communication stack to handle all user + * Sessions. It handles socket ownership, inactivity timeout and protocol + * wrapping. + */ +template +class Session final : public std::enable_shared_from_this> { + using TCPSocket = tcp::socket; + using SSLSocket = boost::asio::ssl::stream; + using std::enable_shared_from_this>::shared_from_this; + + public: + template + static std::shared_ptr Create(Args &&...args) { + return std::shared_ptr(new Session(std::forward(args)...)); + } + + Session(const Session &) = delete; + Session(Session &&) = delete; + Session &operator=(const Session &) = delete; + Session &operator=(Session &&) = delete; + ~Session() { + if (IsConnected()) { + spdlog::error("Session: Destructor called while execution is active"); + } + } + + bool Start() { + if (execution_active_) { + return false; + } + execution_active_ = true; + timeout_timer_.async_wait(boost::asio::bind_executor(strand_, std::bind(&Session::OnTimeout, shared_from_this()))); + + if (std::holds_alternative(socket_)) { + boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoHandshake(); }); + } else { + boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoRead(); }); + } + return true; + } + + bool Write(const uint8_t *data, size_t len, bool have_more = false) { + if (!IsConnected()) { + return false; + } + return std::visit( + utils::Overloaded{[shared_this = shared_from_this(), data, len, have_more](TCPSocket &socket) mutable { + boost::system::error_code ec; + while (len > 0) { + const auto sent = socket.send(boost::asio::buffer(data, len), + MSG_NOSIGNAL | (have_more ? MSG_MORE : 0), ec); + if (ec) { + shared_this->OnError(ec); + return false; + } + data += sent; + len -= sent; + } + return true; + }, + [shared_this = shared_from_this(), data, len](SSLSocket &socket) mutable { + boost::system::error_code ec; + while (len > 0) { + const auto sent = socket.write_some(boost::asio::buffer(data, len), ec); + if (ec) { + shared_this->OnError(ec); + return false; + } + data += sent; + len -= sent; + } + return true; + }}, + socket_); + } + + bool IsConnected() const { + return std::visit([this](const auto &socket) { return execution_active_ && socket.lowest_layer().is_open(); }, + socket_); + } + + private: + explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &server_context, tcp::endpoint endpoint, + const std::chrono::seconds inactivity_timeout_sec, std::string_view service_name) + : socket_(CreateSocket(std::move(socket), server_context)), + strand_{boost::asio::make_strand(GetExecutor())}, + output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }), + session_(data, endpoint, input_buffer_.read_end(), &output_stream_), + data_{data}, + endpoint_{endpoint}, + remote_endpoint_{GetRemoteEndpoint()}, + service_name_{service_name}, + timeout_seconds_(inactivity_timeout_sec), + timeout_timer_(GetExecutor()) { + ExecuteForSocket([](auto &&socket) { + socket.lowest_layer().set_option(tcp::no_delay(true)); // enable PSH + socket.lowest_layer().set_option(boost::asio::socket_base::keep_alive(true)); // enable SO_KEEPALIVE + socket.lowest_layer().non_blocking(false); + }); + timeout_timer_.expires_at(boost::asio::steady_timer::time_point::max()); + spdlog::info("Accepted a connection from {}:", service_name_, remote_endpoint_.address(), remote_endpoint_.port()); + } + + void DoRead() { + if (!IsConnected()) { + return; + } + timeout_timer_.expires_after(timeout_seconds_); + ExecuteForSocket([this](auto &&socket) { + auto buffer = input_buffer_.write_end()->Allocate(); + socket.async_read_some( + boost::asio::buffer(buffer.data, buffer.len), + boost::asio::bind_executor(strand_, std::bind_front(&Session::OnRead, shared_from_this()))); + }); + } + + bool IsWebsocketUpgrade(boost::beast::http::request_parser &parser) { + boost::system::error_code error_code_parsing; + parser.put(boost::asio::buffer(input_buffer_.read_end()->data(), input_buffer_.read_end()->size()), + error_code_parsing); + if (error_code_parsing) { + return false; + } + + return boost::beast::websocket::is_upgrade(parser.get()); + } + + void OnRead(const boost::system::error_code &ec, const size_t bytes_transferred) { + if (ec) { + return OnError(ec); + } + input_buffer_.write_end()->Written(bytes_transferred); + + // Can be a websocket connection only on the first read, since it is not + // expected from clients to upgrade from tcp to websocket + if (!has_received_msg_) { + has_received_msg_ = true; + boost::beast::http::request_parser parser; + + if (IsWebsocketUpgrade(parser)) { + spdlog::info("Switching {} to websocket connection", remote_endpoint_); + if (std::holds_alternative(socket_)) { + auto sock = std::get(std::move(socket_)); + WebsocketSession::Create(std::move(sock), data_, endpoint_, service_name_) + ->DoAccept(parser.release()); + execution_active_ = false; + return; + } + spdlog::error("Error while upgrading connection to websocket"); + DoShutdown(); + } + } + + try { + session_.Execute(); + DoRead(); + } catch (const SessionClosedException &e) { + spdlog::info("{} client {}:{} closed the connection.", service_name_, remote_endpoint_.address(), + remote_endpoint_.port()); + DoShutdown(); + } catch (const std::exception &e) { + spdlog::error( + "Exception was thrown while processing event in {} session " + "associated with {}:{}", + service_name_, remote_endpoint_.address(), remote_endpoint_.port()); + spdlog::debug("Exception message: {}", e.what()); + DoShutdown(); + } + } + + void OnError(const boost::system::error_code &ec) { + if (ec == boost::asio::error::operation_aborted) { + return; + } + execution_active_ = false; + + if (ec == boost::asio::error::eof) { + spdlog::info("Session closed by peer"); + } else { + spdlog::error("Session error: {}", ec.message()); + } + + DoShutdown(); + } + + void DoShutdown() { + if (!IsConnected()) { + return; + } + execution_active_ = false; + timeout_timer_.cancel(); + ExecuteForSocket([](auto &socket) { + boost::system::error_code ec; + auto &lowest_layer = socket.lowest_layer(); + lowest_layer.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); + if (ec) { + spdlog::error("Session shutdown failed: {}", ec.what()); + } + lowest_layer.close(); + }); + } + + void DoHandshake() { + if (!IsConnected()) { + return; + } + if (auto *socket = std::get_if(&socket_); socket) { + socket->async_handshake( + boost::asio::ssl::stream_base::server, + boost::asio::bind_executor(strand_, std::bind_front(&Session::OnHandshake, shared_from_this()))); + } + } + + void OnHandshake(const boost::system::error_code &ec) { + if (ec) { + return OnError(ec); + } + DoRead(); + } + + void OnClose(const boost::system::error_code &ec) { + if (ec) { + return OnError(ec); + } + } + + void OnTimeout() { + if (!IsConnected()) { + return; + } + // Check whether the deadline has passed. We compare the deadline against + // the current time since a new asynchronous operation may have moved the + // deadline before this actor had a chance to run. + if (timeout_timer_.expiry() <= boost::asio::steady_timer::clock_type::now()) { + // The deadline has passed. Stop the session. The other actors will + // terminate as soon as possible. + spdlog::info("Shutting down session after {} of inactivity", timeout_seconds_); + DoShutdown(); + } else { + // Put the actor back to sleep. + timeout_timer_.async_wait( + boost::asio::bind_executor(strand_, std::bind(&Session::OnTimeout, shared_from_this()))); + } + } + + std::variant CreateSocket(tcp::socket &&socket, ServerContext &context) { + if (context.use_ssl()) { + ssl_context_.emplace(context.context_clone()); + return SSLSocket{std::move(socket), *ssl_context_}; + } + + return TCPSocket{std::move(socket)}; + } + + auto GetExecutor() { + return std::visit(utils::Overloaded{[](auto &&socket) { return socket.get_executor(); }}, socket_); + } + + auto GetRemoteEndpoint() const { + return std::visit(utils::Overloaded{[](const auto &socket) { return socket.lowest_layer().remote_endpoint(); }}, + socket_); + } + + template + decltype(auto) ExecuteForSocket(F &&fun) { + return std::visit(utils::Overloaded{std::forward(fun)}, socket_); + } + + std::variant socket_; + std::optional> ssl_context_; + boost::asio::strand strand_; + + communication::Buffer input_buffer_; + OutputStream output_stream_; + TSession session_; + TSessionData *data_; + tcp::endpoint endpoint_; + tcp::endpoint remote_endpoint_; + std::string_view service_name_; + std::chrono::seconds timeout_seconds_; + boost::asio::steady_timer timeout_timer_; + bool execution_active_{false}; + bool has_received_msg_{false}; +}; +} // namespace memgraph::communication::v2 diff --git a/src/communication/websocket/listener.hpp b/src/communication/websocket/listener.hpp index 9a0ba0dc0..0ad19f960 100644 --- a/src/communication/websocket/listener.hpp +++ b/src/communication/websocket/listener.hpp @@ -11,8 +11,6 @@ #pragma once -#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT - #include #include diff --git a/src/communication/websocket/server.hpp b/src/communication/websocket/server.hpp index 18c377a7d..0853d3ebc 100644 --- a/src/communication/websocket/server.hpp +++ b/src/communication/websocket/server.hpp @@ -11,8 +11,6 @@ #pragma once -#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT - #include #include diff --git a/src/communication/websocket/session.hpp b/src/communication/websocket/session.hpp index 232ef5e15..0e5c92aa3 100644 --- a/src/communication/websocket/session.hpp +++ b/src/communication/websocket/session.hpp @@ -11,8 +11,6 @@ #pragma once -#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT - #include #include #include diff --git a/src/memgraph.cpp b/src/memgraph.cpp index bad6069ff..0554eca8b 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -81,8 +81,8 @@ #include "communication/bolt/v1/exceptions.hpp" #include "communication/bolt/v1/session.hpp" #include "communication/init.hpp" -#include "communication/server.hpp" -#include "communication/session.hpp" +#include "communication/v2/server.hpp" +#include "communication/v2/session.hpp" #include "glue/communication.hpp" #include "auth/auth.hpp" @@ -842,13 +842,14 @@ class AuthChecker final : public memgraph::query::AuthChecker { memgraph::utils::Synchronized *auth_; }; -class BoltSession final : public memgraph::communication::bolt::Session { +class BoltSession final : public memgraph::communication::bolt::Session { public: - BoltSession(SessionData *data, const memgraph::io::network::Endpoint &endpoint, - memgraph::communication::InputStream *input_stream, memgraph::communication::OutputStream *output_stream) - : memgraph::communication::bolt::Session(input_stream, output_stream), + BoltSession(SessionData *data, const memgraph::communication::v2::ServerEndpoint &endpoint, + memgraph::communication::v2::InputStream *input_stream, + memgraph::communication::v2::OutputStream *output_stream) + : memgraph::communication::bolt::Session(input_stream, output_stream), db_(data->db), interpreter_(data->interpreter_context), auth_(data->auth), @@ -858,8 +859,8 @@ class BoltSession final : public memgraph::communication::bolt::Session::TEncoder; + using memgraph::communication::bolt::Session::TEncoder; void BeginTransaction() override { interpreter_.BeginTransaction(); } @@ -877,7 +878,8 @@ class BoltSession final : public memgraph::communication::bolt::SessionRecord(endpoint_.address, user_ ? *username : "", query, memgraph::storage::PropertyValue(params_pv)); + audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query, + memgraph::storage::PropertyValue(params_pv)); } #endif try { @@ -996,10 +998,10 @@ class BoltSession final : public memgraph::communication::bolt::Session; +using ServerT = memgraph::communication::v2::Server; using memgraph::communication::ServerContext; // Needed to correctly handle memgraph destruction from a signal handler. @@ -1241,8 +1243,10 @@ int main(int argc, char **argv) { memgraph::utils::MessageWithLink("Using non-secure Bolt connection (without SSL).", "https://memgr.ph/ssl")); } - ServerT server({FLAGS_bolt_address, static_cast(FLAGS_bolt_port)}, &session_data, &context, - FLAGS_bolt_session_inactivity_timeout, service_name, FLAGS_bolt_num_workers); + auto server_endpoint = memgraph::communication::v2::ServerEndpoint{ + boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast(FLAGS_bolt_port)}; + ServerT server(server_endpoint, &session_data, &context, FLAGS_bolt_session_inactivity_timeout, service_name, + FLAGS_bolt_num_workers); // Setup telemetry std::optional telemetry;