Implement Bolt over WebSocket with asio
* Replace server implementation with asio * Add support for bolt over WebSocket
This commit is contained in:
parent
4abaf27765
commit
7b78665cd8
@ -184,7 +184,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall \
|
||||
-Werror=switch -Werror=switch-bool -Werror=return-type \
|
||||
-Werror=return-stack-address \
|
||||
-Wno-c99-designator")
|
||||
-Wno-c99-designator \
|
||||
-DBOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT")
|
||||
|
||||
# Don't omit frame pointer in RelWithDebInfo, for additional callchain debug.
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO
|
||||
|
@ -78,14 +78,18 @@ bool ClientContext::use_ssl() { return use_ssl_; }
|
||||
|
||||
ServerContext::ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file,
|
||||
bool verify_peer) {
|
||||
ctx_.emplace(boost::asio::ssl::context::tls_server);
|
||||
namespace ssl = boost::asio::ssl;
|
||||
ctx_.emplace(ssl::context::tls_server);
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
ctx_->set_options(ssl::context::default_workarounds | ssl::context::no_sslv2 | ssl::context::no_sslv3 |
|
||||
ssl::context::single_dh_use);
|
||||
ctx_->set_default_verify_paths();
|
||||
// TODO: add support for encrypted private keys
|
||||
// TODO: add certificate revocation list (CRL)
|
||||
boost::system::error_code ec;
|
||||
ctx_->use_certificate_chain_file(cert_file, ec);
|
||||
MG_ASSERT(!ec, "Couldn't load server certificate from file: {}", cert_file);
|
||||
ctx_->use_private_key_file(key_file, boost::asio::ssl::context::pem, ec);
|
||||
ctx_->use_private_key_file(key_file, ssl::context::pem, ec);
|
||||
MG_ASSERT(!ec, "Couldn't load server private key from file: {}", key_file);
|
||||
|
||||
ctx_->set_options(SSL_OP_NO_SSLv3, ec);
|
||||
@ -100,7 +104,7 @@ ServerContext::ServerContext(const std::string &key_file, const std::string &cer
|
||||
if (verify_peer) {
|
||||
// Enable verification of the client certificate.
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
ctx_->set_verify_mode(boost::asio::ssl::verify_peer | boost::asio::ssl::verify_fail_if_no_peer_cert, ec);
|
||||
ctx_->set_verify_mode(ssl::verify_peer | ssl::verify_fail_if_no_peer_cert, ec);
|
||||
MG_ASSERT(!ec, "Setting SSL verification mode failed!");
|
||||
}
|
||||
}
|
||||
|
135
src/communication/v2/listener.hpp
Normal file
135
src/communication/v2/listener.hpp
Normal file
@ -0,0 +1,135 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string_view>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <boost/asio/io_context.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/strand.hpp>
|
||||
#include <boost/beast/core.hpp>
|
||||
#include <boost/system/detail/error_code.hpp>
|
||||
|
||||
#include "communication/context.hpp"
|
||||
#include "communication/v2/pool.hpp"
|
||||
#include "communication/v2/session.hpp"
|
||||
#include "utils/spin_lock.hpp"
|
||||
#include "utils/synchronized.hpp"
|
||||
|
||||
namespace memgraph::communication::v2 {
|
||||
|
||||
template <class TSession, class TSessionData>
|
||||
class Listener final : public std::enable_shared_from_this<Listener<TSession, TSessionData>> {
|
||||
using tcp = boost::asio::ip::tcp;
|
||||
using SessionHandler = Session<TSession, TSessionData>;
|
||||
using std::enable_shared_from_this<Listener<TSession, TSessionData>>::shared_from_this;
|
||||
|
||||
public:
|
||||
Listener(const Listener &) = delete;
|
||||
Listener(Listener &&) = delete;
|
||||
Listener &operator=(const Listener &) = delete;
|
||||
Listener &operator=(Listener &&) = delete;
|
||||
~Listener() {}
|
||||
|
||||
template <typename... Args>
|
||||
static std::shared_ptr<Listener> Create(Args &&...args) {
|
||||
return std::shared_ptr<Listener>{new Listener(std::forward<Args>(args)...)};
|
||||
}
|
||||
|
||||
void Start() { DoAccept(); }
|
||||
|
||||
bool IsRunning() const noexcept { return alive_.load(std::memory_order_relaxed); }
|
||||
|
||||
private:
|
||||
Listener(boost::asio::io_context &io_context, TSessionData *data, ServerContext *server_context,
|
||||
tcp::endpoint &endpoint, const std::string_view service_name, const uint64_t inactivity_timeout_sec)
|
||||
: io_context_(io_context),
|
||||
data_(data),
|
||||
server_context_(server_context),
|
||||
acceptor_(io_context_),
|
||||
endpoint_{endpoint},
|
||||
service_name_{service_name},
|
||||
inactivity_timeout_{inactivity_timeout_sec} {
|
||||
boost::system::error_code ec;
|
||||
// Open the acceptor
|
||||
acceptor_.open(endpoint.protocol(), ec);
|
||||
if (ec) {
|
||||
OnError(ec, "open");
|
||||
return;
|
||||
}
|
||||
|
||||
// Allow address reuse
|
||||
acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec);
|
||||
if (ec) {
|
||||
OnError(ec, "set_option");
|
||||
return;
|
||||
}
|
||||
|
||||
// Bind to the server address
|
||||
acceptor_.bind(endpoint, ec);
|
||||
if (ec) {
|
||||
spdlog::error(
|
||||
utils::MessageWithLink("Cannot bind to socket on endpoint {}.", endpoint, "https://memgr.ph/socket"));
|
||||
OnError(ec, "bind");
|
||||
return;
|
||||
}
|
||||
|
||||
acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec);
|
||||
if (ec) {
|
||||
OnError(ec, "listen");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void DoAccept() {
|
||||
acceptor_.async_accept(io_context_,
|
||||
[shared_this = shared_from_this()](auto ec, boost::asio::ip::tcp::socket &&socket) {
|
||||
shared_this->OnAccept(ec, std::move(socket));
|
||||
});
|
||||
}
|
||||
|
||||
void OnAccept(boost::system::error_code ec, tcp::socket socket) {
|
||||
if (ec) {
|
||||
return OnError(ec, "accept");
|
||||
}
|
||||
|
||||
auto session = SessionHandler::Create(std::move(socket), data_, *server_context_, endpoint_, inactivity_timeout_,
|
||||
service_name_);
|
||||
session->Start();
|
||||
DoAccept();
|
||||
}
|
||||
|
||||
void OnError(const boost::system::error_code &ec, const std::string_view what) {
|
||||
spdlog::error("Listener failed on {}: {}", what, ec.message());
|
||||
alive_.store(false, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
boost::asio::io_context &io_context_;
|
||||
TSessionData *data_;
|
||||
ServerContext *server_context_;
|
||||
tcp::acceptor acceptor_;
|
||||
|
||||
tcp::endpoint endpoint_;
|
||||
std::string_view service_name_;
|
||||
std::chrono::seconds inactivity_timeout_;
|
||||
|
||||
std::atomic<bool> alive_;
|
||||
};
|
||||
} // namespace memgraph::communication::v2
|
68
src/communication/v2/pool.hpp
Normal file
68
src/communication/v2/pool.hpp
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include <boost/asio/executor_work_guard.hpp>
|
||||
#include <boost/asio/io_context.hpp>
|
||||
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
namespace memgraph::communication::v2 {
|
||||
|
||||
class IOContextThreadPool final {
|
||||
private:
|
||||
using IOContext = boost::asio::io_context;
|
||||
using IOContextGuard = boost::asio::executor_work_guard<boost::asio::io_context::executor_type>;
|
||||
|
||||
public:
|
||||
explicit IOContextThreadPool(size_t pool_size) : guard_{io_context_.get_executor()}, pool_size_{pool_size} {
|
||||
MG_ASSERT(pool_size != 0, "Pool size must be greater then 0!");
|
||||
}
|
||||
|
||||
IOContextThreadPool(const IOContextThreadPool &) = delete;
|
||||
IOContextThreadPool &operator=(const IOContextThreadPool &) = delete;
|
||||
IOContextThreadPool(IOContextThreadPool &&) = delete;
|
||||
IOContextThreadPool &operator=(IOContextThreadPool &&) = delete;
|
||||
~IOContextThreadPool() = default;
|
||||
|
||||
void Run() {
|
||||
background_threads_.reserve(pool_size_);
|
||||
for (size_t i = 0; i < pool_size_; ++i) {
|
||||
background_threads_.emplace_back([this]() { io_context_.run(); });
|
||||
}
|
||||
running_ = true;
|
||||
}
|
||||
|
||||
void Shutdown() {
|
||||
io_context_.stop();
|
||||
running_ = false;
|
||||
}
|
||||
|
||||
void AwaitShutdown() { background_threads_.clear(); }
|
||||
|
||||
bool IsRunning() const noexcept { return running_; }
|
||||
|
||||
IOContext &GetIOContext() noexcept { return io_context_; }
|
||||
|
||||
private:
|
||||
/// The pool of io_context.
|
||||
IOContext io_context_;
|
||||
IOContextGuard guard_;
|
||||
size_t pool_size_;
|
||||
std::vector<std::jthread> background_threads_;
|
||||
bool running_{false};
|
||||
};
|
||||
} // namespace memgraph::communication::v2
|
128
src/communication/v2/server.hpp
Normal file
128
src/communication/v2/server.hpp
Normal file
@ -0,0 +1,128 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <boost/asio/io_context.hpp>
|
||||
#include <boost/asio/ip/address.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
|
||||
#include "communication/context.hpp"
|
||||
#include "communication/init.hpp"
|
||||
#include "communication/v2/listener.hpp"
|
||||
#include "communication/v2/pool.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/message.hpp"
|
||||
#include "utils/thread.hpp"
|
||||
|
||||
namespace memgraph::communication::v2 {
|
||||
|
||||
using Socket = boost::asio::ip::tcp::socket;
|
||||
using ServerEndpoint = boost::asio::ip::tcp::endpoint;
|
||||
/**
|
||||
* Communication server.
|
||||
*
|
||||
* Listens for incoming connections on the server port and assigns them to the
|
||||
* connection listener. The listener and session are implemented using asio
|
||||
* async model. Currently the implemented model is thread per core model
|
||||
* opposed to io_context per core. The reasoning for opting for the former model
|
||||
* is the robustness to the multiple resource demanding queries that can be split
|
||||
* across multiple threads, and then a single thread would not block io_context,
|
||||
* unlike in the latter model where it is possible that thread that accepts
|
||||
* request is being blocked by demanding query.
|
||||
* All logic is contained within handlers that are being dispatched
|
||||
* on a single strand per session. The only exception is write which is
|
||||
* synchronous since the nature of the clients conenction is synchronous as
|
||||
* well.
|
||||
*
|
||||
* Current Server architecture:
|
||||
* incoming connection -> server -> listener -> session
|
||||
|
||||
*
|
||||
* @tparam TSession the server can handle different Sessions, each session
|
||||
* represents a different protocol so the same network infrastructure
|
||||
* can be used for handling different protocols
|
||||
* @tparam TSessionData the class with objects that will be forwarded to the
|
||||
* session
|
||||
*/
|
||||
template <typename TSession, typename TSessionData>
|
||||
class Server final {
|
||||
using ServerHandler = Server<TSession, TSessionData>;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Constructs and binds server to endpoint, operates on session data and
|
||||
* invokes workers_count workers
|
||||
*/
|
||||
Server(ServerEndpoint &endpoint, TSessionData *session_data, ServerContext *server_context,
|
||||
const int inactivity_timeout_sec, const std::string_view service_name,
|
||||
size_t workers_count = std::thread::hardware_concurrency())
|
||||
: endpoint_{endpoint},
|
||||
service_name_{service_name},
|
||||
context_thread_pool_{workers_count},
|
||||
listener_{Listener<TSession, TSessionData>::Create(context_thread_pool_.GetIOContext(), session_data,
|
||||
server_context, endpoint_, service_name_,
|
||||
inactivity_timeout_sec)} {}
|
||||
|
||||
~Server() { MG_ASSERT(!IsRunning(), "Server wasn't shutdown properly"); }
|
||||
|
||||
Server(const Server &) = delete;
|
||||
Server(Server &&) = delete;
|
||||
Server &operator=(const Server &) = delete;
|
||||
Server &operator=(Server &&) = delete;
|
||||
|
||||
const auto &Endpoint() const {
|
||||
MG_ASSERT(IsRunning(), "You can't get the server endpoint when it's not running!");
|
||||
return endpoint_;
|
||||
}
|
||||
|
||||
bool Start() {
|
||||
if (IsRunning()) {
|
||||
spdlog::error("The server is already running");
|
||||
return false;
|
||||
}
|
||||
listener_->Start();
|
||||
|
||||
spdlog::info("{} server is fully armed and operational", service_name_);
|
||||
spdlog::info("{} listening on {}", service_name_, endpoint_.address());
|
||||
context_thread_pool_.Run();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Shutdown() {
|
||||
context_thread_pool_.Shutdown();
|
||||
spdlog::info("{} shutting down...", service_name_);
|
||||
}
|
||||
|
||||
void AwaitShutdown() { context_thread_pool_.AwaitShutdown(); }
|
||||
|
||||
bool IsRunning() const noexcept { return context_thread_pool_.IsRunning() && listener_->IsRunning(); }
|
||||
|
||||
private:
|
||||
ServerEndpoint endpoint_;
|
||||
std::string service_name_;
|
||||
|
||||
IOContextThreadPool context_thread_pool_;
|
||||
std::shared_ptr<Listener<TSession, TSessionData>> listener_;
|
||||
};
|
||||
|
||||
} // namespace memgraph::communication::v2
|
513
src/communication/v2/session.hpp
Normal file
513
src/communication/v2/session.hpp
Normal file
@ -0,0 +1,513 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <boost/asio/bind_executor.hpp>
|
||||
#include <boost/asio/buffer.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/read.hpp>
|
||||
#include <boost/asio/socket_base.hpp>
|
||||
#include <boost/asio/ssl/stream.hpp>
|
||||
#include <boost/asio/ssl/stream_base.hpp>
|
||||
#include <boost/asio/steady_timer.hpp>
|
||||
#include <boost/asio/strand.hpp>
|
||||
#include <boost/asio/system_context.hpp>
|
||||
#include <boost/asio/write.hpp>
|
||||
#include <boost/beast/core/tcp_stream.hpp>
|
||||
#include <boost/beast/http.hpp>
|
||||
#include <boost/beast/websocket.hpp>
|
||||
#include <boost/beast/websocket/rfc6455.hpp>
|
||||
#include <boost/system/detail/error_code.hpp>
|
||||
|
||||
#include "communication/context.hpp"
|
||||
#include "communication/exceptions.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/variant_helpers.hpp"
|
||||
|
||||
namespace memgraph::communication::v2 {
|
||||
|
||||
/**
|
||||
* This is used to provide input to user Sessions. All Sessions used with the
|
||||
* network stack should use this class as their input stream.
|
||||
*/
|
||||
using InputStream = communication::Buffer::ReadEnd;
|
||||
using tcp = boost::asio::ip::tcp;
|
||||
|
||||
/**
|
||||
* This is used to provide output from user Sessions. All Sessions used with the
|
||||
* network stack should use this class for their output stream.
|
||||
*/
|
||||
class OutputStream final {
|
||||
public:
|
||||
explicit OutputStream(std::function<bool(const uint8_t *, size_t, bool)> write_function)
|
||||
: write_function_(write_function) {}
|
||||
|
||||
OutputStream(const OutputStream &) = delete;
|
||||
OutputStream(OutputStream &&) = delete;
|
||||
OutputStream &operator=(const OutputStream &) = delete;
|
||||
OutputStream &operator=(OutputStream &&) = delete;
|
||||
~OutputStream() = default;
|
||||
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) { return write_function_(data, len, have_more); }
|
||||
|
||||
bool Write(const std::string &str, bool have_more = false) {
|
||||
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(), have_more);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<bool(const uint8_t *, size_t, bool)> write_function_;
|
||||
};
|
||||
|
||||
/**
|
||||
* This class is used internally in the communication stack to handle all user
|
||||
* Websocket Sessions. It handles socket ownership, inactivity timeout and protocol
|
||||
* wrapping.
|
||||
*/
|
||||
template <typename TSession, typename TSessionData>
|
||||
class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TSession, TSessionData>> {
|
||||
using WebSocket = boost::beast::websocket::stream<boost::beast::tcp_stream>;
|
||||
using std::enable_shared_from_this<WebsocketSession<TSession, TSessionData>>::shared_from_this;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
static std::shared_ptr<WebsocketSession> Create(Args &&...args) {
|
||||
return std::shared_ptr<WebsocketSession>(new WebsocketSession(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
// Start the asynchronous accept operation
|
||||
template <class Body, class Allocator>
|
||||
void DoAccept(boost::beast::http::request<Body, boost::beast::http::basic_fields<Allocator>> req) {
|
||||
execution_active_ = true;
|
||||
// Set suggested timeout settings for the websocket
|
||||
ws_.set_option(boost::beast::websocket::stream_base::timeout::suggested(boost::beast::role_type::server));
|
||||
boost::asio::socket_base::keep_alive option(true);
|
||||
|
||||
// Set a decorator to change the Server of the handshake
|
||||
ws_.set_option(boost::beast::websocket::stream_base::decorator([](boost::beast::websocket::response_type &res) {
|
||||
res.set(boost::beast::http::field::server, std::string("Memgraph Bolt WS"));
|
||||
res.set(boost::beast::http::field::sec_websocket_protocol, "binary");
|
||||
}));
|
||||
ws_.binary(true);
|
||||
|
||||
// Accept the websocket handshake
|
||||
ws_.async_accept(
|
||||
req, boost::asio::bind_executor(strand_, std::bind_front(&WebsocketSession::OnAccept, shared_from_this())));
|
||||
}
|
||||
|
||||
bool Write(const uint8_t *data, size_t len) {
|
||||
if (!IsConnected()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
boost::system::error_code ec;
|
||||
ws_.write(boost::asio::buffer(data, len), ec);
|
||||
if (ec) {
|
||||
OnError(ec, "write");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
// Take ownership of the socket
|
||||
explicit WebsocketSession(tcp::socket &&socket, TSessionData *data, tcp::endpoint endpoint,
|
||||
std::string_view service_name)
|
||||
: ws_(std::move(socket)),
|
||||
strand_{boost::asio::make_strand(ws_.get_executor())},
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool /*have_more*/) { return Write(data, len); }),
|
||||
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
|
||||
endpoint_{endpoint},
|
||||
remote_endpoint_{ws_.next_layer().socket().remote_endpoint()},
|
||||
service_name_{service_name} {}
|
||||
|
||||
void OnAccept(boost::beast::error_code ec) {
|
||||
if (ec) {
|
||||
return OnError(ec, "accept");
|
||||
}
|
||||
|
||||
// Read a message
|
||||
DoRead();
|
||||
}
|
||||
|
||||
void DoRead() {
|
||||
if (!IsConnected()) {
|
||||
return;
|
||||
}
|
||||
// Read a message into our buffer
|
||||
auto buffer = input_buffer_.write_end()->Allocate();
|
||||
ws_.async_read_some(
|
||||
boost::asio::buffer(buffer.data, buffer.len),
|
||||
boost::asio::bind_executor(strand_, std::bind_front(&WebsocketSession::OnRead, shared_from_this())));
|
||||
}
|
||||
|
||||
void OnRead(const boost::system::error_code &ec, [[maybe_unused]] const size_t bytes_transferred) {
|
||||
// This indicates that the WebsocketSession was closed
|
||||
if (ec == boost::beast::websocket::error::closed) {
|
||||
return;
|
||||
}
|
||||
if (ec) {
|
||||
OnError(ec, "read");
|
||||
}
|
||||
input_buffer_.write_end()->Written(bytes_transferred);
|
||||
|
||||
try {
|
||||
session_.Execute();
|
||||
DoRead();
|
||||
} catch (const SessionClosedException &e) {
|
||||
spdlog::info("{} client {}:{} closed the connection.", service_name_, remote_endpoint_.address(),
|
||||
remote_endpoint_.port());
|
||||
DoClose();
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::error(
|
||||
"Exception was thrown while processing event in {} session "
|
||||
"associated with {}:{}",
|
||||
service_name_, remote_endpoint_.address(), remote_endpoint_.port());
|
||||
spdlog::debug("Exception message: {}", e.what());
|
||||
DoClose();
|
||||
}
|
||||
}
|
||||
|
||||
void OnError(const boost::system::error_code &ec, const std::string_view action) {
|
||||
spdlog::error("Websocket Bolt session error: {} on {}", ec.message(), action);
|
||||
|
||||
DoClose();
|
||||
}
|
||||
|
||||
void DoClose() {
|
||||
ws_.async_close(
|
||||
boost::beast::websocket::close_code::normal,
|
||||
boost::asio::bind_executor(
|
||||
strand_, [shared_this = shared_from_this()](boost::beast::error_code ec) { shared_this->OnClose(ec); }));
|
||||
}
|
||||
|
||||
void OnClose(const boost::system::error_code &ec) {
|
||||
if (!IsConnected()) {
|
||||
return;
|
||||
}
|
||||
if (ec) {
|
||||
return OnError(ec, "close");
|
||||
}
|
||||
}
|
||||
|
||||
bool IsConnected() const { return ws_.is_open() && execution_active_; }
|
||||
|
||||
WebSocket ws_;
|
||||
boost::asio::strand<WebSocket::executor_type> strand_;
|
||||
|
||||
communication::Buffer input_buffer_;
|
||||
OutputStream output_stream_;
|
||||
TSession session_;
|
||||
tcp::endpoint endpoint_;
|
||||
tcp::endpoint remote_endpoint_;
|
||||
std::string_view service_name_;
|
||||
bool execution_active_{false};
|
||||
};
|
||||
|
||||
/**
|
||||
* This class is used internally in the communication stack to handle all user
|
||||
* Sessions. It handles socket ownership, inactivity timeout and protocol
|
||||
* wrapping.
|
||||
*/
|
||||
template <typename TSession, typename TSessionData>
|
||||
class Session final : public std::enable_shared_from_this<Session<TSession, TSessionData>> {
|
||||
using TCPSocket = tcp::socket;
|
||||
using SSLSocket = boost::asio::ssl::stream<TCPSocket>;
|
||||
using std::enable_shared_from_this<Session<TSession, TSessionData>>::shared_from_this;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
static std::shared_ptr<Session> Create(Args &&...args) {
|
||||
return std::shared_ptr<Session>(new Session(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
Session(const Session &) = delete;
|
||||
Session(Session &&) = delete;
|
||||
Session &operator=(const Session &) = delete;
|
||||
Session &operator=(Session &&) = delete;
|
||||
~Session() {
|
||||
if (IsConnected()) {
|
||||
spdlog::error("Session: Destructor called while execution is active");
|
||||
}
|
||||
}
|
||||
|
||||
bool Start() {
|
||||
if (execution_active_) {
|
||||
return false;
|
||||
}
|
||||
execution_active_ = true;
|
||||
timeout_timer_.async_wait(boost::asio::bind_executor(strand_, std::bind(&Session::OnTimeout, shared_from_this())));
|
||||
|
||||
if (std::holds_alternative<SSLSocket>(socket_)) {
|
||||
boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoHandshake(); });
|
||||
} else {
|
||||
boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoRead(); });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
||||
if (!IsConnected()) {
|
||||
return false;
|
||||
}
|
||||
return std::visit(
|
||||
utils::Overloaded{[shared_this = shared_from_this(), data, len, have_more](TCPSocket &socket) mutable {
|
||||
boost::system::error_code ec;
|
||||
while (len > 0) {
|
||||
const auto sent = socket.send(boost::asio::buffer(data, len),
|
||||
MSG_NOSIGNAL | (have_more ? MSG_MORE : 0), ec);
|
||||
if (ec) {
|
||||
shared_this->OnError(ec);
|
||||
return false;
|
||||
}
|
||||
data += sent;
|
||||
len -= sent;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[shared_this = shared_from_this(), data, len](SSLSocket &socket) mutable {
|
||||
boost::system::error_code ec;
|
||||
while (len > 0) {
|
||||
const auto sent = socket.write_some(boost::asio::buffer(data, len), ec);
|
||||
if (ec) {
|
||||
shared_this->OnError(ec);
|
||||
return false;
|
||||
}
|
||||
data += sent;
|
||||
len -= sent;
|
||||
}
|
||||
return true;
|
||||
}},
|
||||
socket_);
|
||||
}
|
||||
|
||||
bool IsConnected() const {
|
||||
return std::visit([this](const auto &socket) { return execution_active_ && socket.lowest_layer().is_open(); },
|
||||
socket_);
|
||||
}
|
||||
|
||||
private:
|
||||
explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &server_context, tcp::endpoint endpoint,
|
||||
const std::chrono::seconds inactivity_timeout_sec, std::string_view service_name)
|
||||
: socket_(CreateSocket(std::move(socket), server_context)),
|
||||
strand_{boost::asio::make_strand(GetExecutor())},
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
|
||||
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
|
||||
data_{data},
|
||||
endpoint_{endpoint},
|
||||
remote_endpoint_{GetRemoteEndpoint()},
|
||||
service_name_{service_name},
|
||||
timeout_seconds_(inactivity_timeout_sec),
|
||||
timeout_timer_(GetExecutor()) {
|
||||
ExecuteForSocket([](auto &&socket) {
|
||||
socket.lowest_layer().set_option(tcp::no_delay(true)); // enable PSH
|
||||
socket.lowest_layer().set_option(boost::asio::socket_base::keep_alive(true)); // enable SO_KEEPALIVE
|
||||
socket.lowest_layer().non_blocking(false);
|
||||
});
|
||||
timeout_timer_.expires_at(boost::asio::steady_timer::time_point::max());
|
||||
spdlog::info("Accepted a connection from {}:", service_name_, remote_endpoint_.address(), remote_endpoint_.port());
|
||||
}
|
||||
|
||||
void DoRead() {
|
||||
if (!IsConnected()) {
|
||||
return;
|
||||
}
|
||||
timeout_timer_.expires_after(timeout_seconds_);
|
||||
ExecuteForSocket([this](auto &&socket) {
|
||||
auto buffer = input_buffer_.write_end()->Allocate();
|
||||
socket.async_read_some(
|
||||
boost::asio::buffer(buffer.data, buffer.len),
|
||||
boost::asio::bind_executor(strand_, std::bind_front(&Session::OnRead, shared_from_this())));
|
||||
});
|
||||
}
|
||||
|
||||
bool IsWebsocketUpgrade(boost::beast::http::request_parser<boost::beast::http::string_body> &parser) {
|
||||
boost::system::error_code error_code_parsing;
|
||||
parser.put(boost::asio::buffer(input_buffer_.read_end()->data(), input_buffer_.read_end()->size()),
|
||||
error_code_parsing);
|
||||
if (error_code_parsing) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return boost::beast::websocket::is_upgrade(parser.get());
|
||||
}
|
||||
|
||||
void OnRead(const boost::system::error_code &ec, const size_t bytes_transferred) {
|
||||
if (ec) {
|
||||
return OnError(ec);
|
||||
}
|
||||
input_buffer_.write_end()->Written(bytes_transferred);
|
||||
|
||||
// Can be a websocket connection only on the first read, since it is not
|
||||
// expected from clients to upgrade from tcp to websocket
|
||||
if (!has_received_msg_) {
|
||||
has_received_msg_ = true;
|
||||
boost::beast::http::request_parser<boost::beast::http::string_body> parser;
|
||||
|
||||
if (IsWebsocketUpgrade(parser)) {
|
||||
spdlog::info("Switching {} to websocket connection", remote_endpoint_);
|
||||
if (std::holds_alternative<TCPSocket>(socket_)) {
|
||||
auto sock = std::get<TCPSocket>(std::move(socket_));
|
||||
WebsocketSession<TSession, TSessionData>::Create(std::move(sock), data_, endpoint_, service_name_)
|
||||
->DoAccept(parser.release());
|
||||
execution_active_ = false;
|
||||
return;
|
||||
}
|
||||
spdlog::error("Error while upgrading connection to websocket");
|
||||
DoShutdown();
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
session_.Execute();
|
||||
DoRead();
|
||||
} catch (const SessionClosedException &e) {
|
||||
spdlog::info("{} client {}:{} closed the connection.", service_name_, remote_endpoint_.address(),
|
||||
remote_endpoint_.port());
|
||||
DoShutdown();
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::error(
|
||||
"Exception was thrown while processing event in {} session "
|
||||
"associated with {}:{}",
|
||||
service_name_, remote_endpoint_.address(), remote_endpoint_.port());
|
||||
spdlog::debug("Exception message: {}", e.what());
|
||||
DoShutdown();
|
||||
}
|
||||
}
|
||||
|
||||
void OnError(const boost::system::error_code &ec) {
|
||||
if (ec == boost::asio::error::operation_aborted) {
|
||||
return;
|
||||
}
|
||||
execution_active_ = false;
|
||||
|
||||
if (ec == boost::asio::error::eof) {
|
||||
spdlog::info("Session closed by peer");
|
||||
} else {
|
||||
spdlog::error("Session error: {}", ec.message());
|
||||
}
|
||||
|
||||
DoShutdown();
|
||||
}
|
||||
|
||||
void DoShutdown() {
|
||||
if (!IsConnected()) {
|
||||
return;
|
||||
}
|
||||
execution_active_ = false;
|
||||
timeout_timer_.cancel();
|
||||
ExecuteForSocket([](auto &socket) {
|
||||
boost::system::error_code ec;
|
||||
auto &lowest_layer = socket.lowest_layer();
|
||||
lowest_layer.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
|
||||
if (ec) {
|
||||
spdlog::error("Session shutdown failed: {}", ec.what());
|
||||
}
|
||||
lowest_layer.close();
|
||||
});
|
||||
}
|
||||
|
||||
void DoHandshake() {
|
||||
if (!IsConnected()) {
|
||||
return;
|
||||
}
|
||||
if (auto *socket = std::get_if<SSLSocket>(&socket_); socket) {
|
||||
socket->async_handshake(
|
||||
boost::asio::ssl::stream_base::server,
|
||||
boost::asio::bind_executor(strand_, std::bind_front(&Session::OnHandshake, shared_from_this())));
|
||||
}
|
||||
}
|
||||
|
||||
void OnHandshake(const boost::system::error_code &ec) {
|
||||
if (ec) {
|
||||
return OnError(ec);
|
||||
}
|
||||
DoRead();
|
||||
}
|
||||
|
||||
void OnClose(const boost::system::error_code &ec) {
|
||||
if (ec) {
|
||||
return OnError(ec);
|
||||
}
|
||||
}
|
||||
|
||||
void OnTimeout() {
|
||||
if (!IsConnected()) {
|
||||
return;
|
||||
}
|
||||
// Check whether the deadline has passed. We compare the deadline against
|
||||
// the current time since a new asynchronous operation may have moved the
|
||||
// deadline before this actor had a chance to run.
|
||||
if (timeout_timer_.expiry() <= boost::asio::steady_timer::clock_type::now()) {
|
||||
// The deadline has passed. Stop the session. The other actors will
|
||||
// terminate as soon as possible.
|
||||
spdlog::info("Shutting down session after {} of inactivity", timeout_seconds_);
|
||||
DoShutdown();
|
||||
} else {
|
||||
// Put the actor back to sleep.
|
||||
timeout_timer_.async_wait(
|
||||
boost::asio::bind_executor(strand_, std::bind(&Session::OnTimeout, shared_from_this())));
|
||||
}
|
||||
}
|
||||
|
||||
std::variant<TCPSocket, SSLSocket> CreateSocket(tcp::socket &&socket, ServerContext &context) {
|
||||
if (context.use_ssl()) {
|
||||
ssl_context_.emplace(context.context_clone());
|
||||
return SSLSocket{std::move(socket), *ssl_context_};
|
||||
}
|
||||
|
||||
return TCPSocket{std::move(socket)};
|
||||
}
|
||||
|
||||
auto GetExecutor() {
|
||||
return std::visit(utils::Overloaded{[](auto &&socket) { return socket.get_executor(); }}, socket_);
|
||||
}
|
||||
|
||||
auto GetRemoteEndpoint() const {
|
||||
return std::visit(utils::Overloaded{[](const auto &socket) { return socket.lowest_layer().remote_endpoint(); }},
|
||||
socket_);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
decltype(auto) ExecuteForSocket(F &&fun) {
|
||||
return std::visit(utils::Overloaded{std::forward<F>(fun)}, socket_);
|
||||
}
|
||||
|
||||
std::variant<TCPSocket, SSLSocket> socket_;
|
||||
std::optional<std::reference_wrapper<boost::asio::ssl::context>> ssl_context_;
|
||||
boost::asio::strand<tcp::socket::executor_type> strand_;
|
||||
|
||||
communication::Buffer input_buffer_;
|
||||
OutputStream output_stream_;
|
||||
TSession session_;
|
||||
TSessionData *data_;
|
||||
tcp::endpoint endpoint_;
|
||||
tcp::endpoint remote_endpoint_;
|
||||
std::string_view service_name_;
|
||||
std::chrono::seconds timeout_seconds_;
|
||||
boost::asio::steady_timer timeout_timer_;
|
||||
bool execution_active_{false};
|
||||
bool has_received_msg_{false};
|
||||
};
|
||||
} // namespace memgraph::communication::v2
|
@ -11,8 +11,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
|
@ -11,8 +11,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include <spdlog/sinks/base_sink.h>
|
||||
|
@ -11,8 +11,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
@ -81,8 +81,8 @@
|
||||
#include "communication/bolt/v1/exceptions.hpp"
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "communication/init.hpp"
|
||||
#include "communication/server.hpp"
|
||||
#include "communication/session.hpp"
|
||||
#include "communication/v2/server.hpp"
|
||||
#include "communication/v2/session.hpp"
|
||||
#include "glue/communication.hpp"
|
||||
|
||||
#include "auth/auth.hpp"
|
||||
@ -842,13 +842,14 @@ class AuthChecker final : public memgraph::query::AuthChecker {
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
|
||||
};
|
||||
|
||||
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::InputStream,
|
||||
memgraph::communication::OutputStream> {
|
||||
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
|
||||
memgraph::communication::v2::OutputStream> {
|
||||
public:
|
||||
BoltSession(SessionData *data, const memgraph::io::network::Endpoint &endpoint,
|
||||
memgraph::communication::InputStream *input_stream, memgraph::communication::OutputStream *output_stream)
|
||||
: memgraph::communication::bolt::Session<memgraph::communication::InputStream,
|
||||
memgraph::communication::OutputStream>(input_stream, output_stream),
|
||||
BoltSession(SessionData *data, const memgraph::communication::v2::ServerEndpoint &endpoint,
|
||||
memgraph::communication::v2::InputStream *input_stream,
|
||||
memgraph::communication::v2::OutputStream *output_stream)
|
||||
: memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
|
||||
memgraph::communication::v2::OutputStream>(input_stream, output_stream),
|
||||
db_(data->db),
|
||||
interpreter_(data->interpreter_context),
|
||||
auth_(data->auth),
|
||||
@ -858,8 +859,8 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
endpoint_(endpoint) {
|
||||
}
|
||||
|
||||
using memgraph::communication::bolt::Session<memgraph::communication::InputStream,
|
||||
memgraph::communication::OutputStream>::TEncoder;
|
||||
using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
|
||||
memgraph::communication::v2::OutputStream>::TEncoder;
|
||||
|
||||
void BeginTransaction() override { interpreter_.BeginTransaction(); }
|
||||
|
||||
@ -877,7 +878,8 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
}
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::utils::license::global_license_checker.IsValidLicenseFast()) {
|
||||
audit_log_->Record(endpoint_.address, user_ ? *username : "", query, memgraph::storage::PropertyValue(params_pv));
|
||||
audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query,
|
||||
memgraph::storage::PropertyValue(params_pv));
|
||||
}
|
||||
#endif
|
||||
try {
|
||||
@ -996,10 +998,10 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
#ifdef MG_ENTERPRISE
|
||||
memgraph::audit::Log *audit_log_;
|
||||
#endif
|
||||
memgraph::io::network::Endpoint endpoint_;
|
||||
memgraph::communication::v2::ServerEndpoint endpoint_;
|
||||
};
|
||||
|
||||
using ServerT = memgraph::communication::Server<BoltSession, SessionData>;
|
||||
using ServerT = memgraph::communication::v2::Server<BoltSession, SessionData>;
|
||||
using memgraph::communication::ServerContext;
|
||||
|
||||
// Needed to correctly handle memgraph destruction from a signal handler.
|
||||
@ -1241,8 +1243,10 @@ int main(int argc, char **argv) {
|
||||
memgraph::utils::MessageWithLink("Using non-secure Bolt connection (without SSL).", "https://memgr.ph/ssl"));
|
||||
}
|
||||
|
||||
ServerT server({FLAGS_bolt_address, static_cast<uint16_t>(FLAGS_bolt_port)}, &session_data, &context,
|
||||
FLAGS_bolt_session_inactivity_timeout, service_name, FLAGS_bolt_num_workers);
|
||||
auto server_endpoint = memgraph::communication::v2::ServerEndpoint{
|
||||
boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast<uint16_t>(FLAGS_bolt_port)};
|
||||
ServerT server(server_endpoint, &session_data, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
|
||||
FLAGS_bolt_num_workers);
|
||||
|
||||
// Setup telemetry
|
||||
std::optional<memgraph::telemetry::Telemetry> telemetry;
|
||||
|
Loading…
Reference in New Issue
Block a user