From 06e6ead4d281ff94906bcb3fd05151e2d159145c Mon Sep 17 00:00:00 2001 From: Antonio Andelic <antonio2368@users.noreply.github.com> Date: Thu, 27 Jan 2022 09:51:00 +0100 Subject: [PATCH] WSS support (#327) --- src/communication/context.cpp | 30 +++--- src/communication/context.hpp | 13 +-- src/communication/websocket/listener.cpp | 11 +- src/communication/websocket/listener.hpp | 4 +- src/communication/websocket/server.hpp | 7 +- src/communication/websocket/session.cpp | 122 +++++++++++++++++------ src/communication/websocket/session.hpp | 33 +++++- src/memgraph.cpp | 46 +++++---- 8 files changed, 176 insertions(+), 90 deletions(-) diff --git a/src/communication/context.cpp b/src/communication/context.cpp index 88534b9a5..a272723fd 100644 --- a/src/communication/context.cpp +++ b/src/communication/context.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -73,11 +73,9 @@ SSL_CTX *ClientContext::context() { return ctx_; } bool ClientContext::use_ssl() { return use_ssl_; } -ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {} - ServerContext::ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file, bool verify_peer) - : use_ssl_(true), + : #if OPENSSL_VERSION_NUMBER < 0x10100000L ctx_(SSL_CTX_new(SSLv23_server_method())) #else @@ -110,43 +108,47 @@ ServerContext::ServerContext(const std::string &key_file, const std::string &cer SSL_CTX_set_client_CA_list(ctx_, ca_names); // Enable verification of the client certificate. + // NOLINTNEXTLINE(hicpp-signed-bitwise) SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); } } } -ServerContext::ServerContext(ServerContext &&other) noexcept : use_ssl_(other.use_ssl_), ctx_(other.ctx_) { - other.use_ssl_ = false; - other.ctx_ = nullptr; -} +ServerContext::ServerContext(ServerContext &&other) noexcept : ctx_(other.ctx_) { other.ctx_ = nullptr; } ServerContext &ServerContext::operator=(ServerContext &&other) noexcept { if (this == &other) return *this; // destroy my objects - if (use_ssl_) { + if (ctx_) { SSL_CTX_free(ctx_); } // move other objects to self - use_ssl_ = other.use_ssl_; ctx_ = other.ctx_; // reset other objects - other.use_ssl_ = false; other.ctx_ = nullptr; return *this; } ServerContext::~ServerContext() { - if (use_ssl_) { + if (ctx_) { SSL_CTX_free(ctx_); } } -SSL_CTX *ServerContext::context() { return ctx_; } +SSL_CTX *ServerContext::context() { + MG_ASSERT(ctx_); + return ctx_; +} +SSL_CTX *ServerContext::context_clone() { + MG_ASSERT(ctx_); + SSL_CTX_up_ref(ctx_); + return ctx_; +} -bool ServerContext::use_ssl() { return use_ssl_; } +bool ServerContext::use_ssl() { return ctx_ != nullptr; } } // namespace communication diff --git a/src/communication/context.hpp b/src/communication/context.hpp index d9b8f186a..05f7866ff 100644 --- a/src/communication/context.hpp +++ b/src/communication/context.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -69,11 +69,7 @@ class ClientContext final { */ class ServerContext final { public: - /** - * This constructor constructs a ServerContext that doesn't use SSL. - */ - ServerContext(); - + ServerContext() = default; /** * This constructor constructs a ServerContext that uses SSL. The parameters * `key_file` and `cert_file` can't be "" because when setting up a server it @@ -95,16 +91,15 @@ class ServerContext final { ServerContext(ServerContext &&other) noexcept; ServerContext &operator=(ServerContext &&other) noexcept; - // Destructor that handles ownership of the SSL object. ~ServerContext(); SSL_CTX *context(); + SSL_CTX *context_clone(); bool use_ssl(); private: - bool use_ssl_; - SSL_CTX *ctx_; + SSL_CTX *ctx_{nullptr}; }; } // namespace communication diff --git a/src/communication/websocket/listener.cpp b/src/communication/websocket/listener.cpp index 388621b19..167e59a02 100644 --- a/src/communication/websocket/listener.cpp +++ b/src/communication/websocket/listener.cpp @@ -27,8 +27,8 @@ void Listener::WriteToAll(std::shared_ptr<std::string> message) { } } -Listener::Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint, SafeAuth auth) - : ioc_(ioc), acceptor_(ioc), auth_(auth) { +Listener::Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::endpoint endpoint, SafeAuth auth) + : ioc_(ioc), context_(context), acceptor_(ioc), auth_(auth) { boost::beast::error_code ec; // Open the acceptor @@ -71,12 +71,15 @@ void Listener::OnAccept(boost::beast::error_code ec, tcp::socket socket) { return LogError(ec, "accept"); } - { + auto session = Session::Create(std::move(socket), *context_, auth_); + + if (session->Run()) { auto sessions_ptr = sessions_.Lock(); - sessions_ptr->emplace_back(Session::Create(std::move(socket), auth_))->Run(); // Clean disconnected clients std::erase_if(*sessions_ptr, [](const auto &elem) { return !elem->IsConnected(); }); + + sessions_ptr->emplace_back(std::move(session)); } DoAccept(); diff --git a/src/communication/websocket/listener.hpp b/src/communication/websocket/listener.hpp index e4b15bddf..c41c38a05 100644 --- a/src/communication/websocket/listener.hpp +++ b/src/communication/websocket/listener.hpp @@ -22,6 +22,7 @@ #include <boost/asio/strand.hpp> #include <boost/beast/core.hpp> +#include "communication/context.hpp" #include "communication/websocket/session.hpp" #include "utils/spin_lock.hpp" #include "utils/synchronized.hpp" @@ -41,12 +42,13 @@ class Listener : public std::enable_shared_from_this<Listener> { void WriteToAll(std::shared_ptr<std::string> message); private: - Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint, SafeAuth auth); + Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::endpoint endpoint, SafeAuth auth); void DoAccept(); void OnAccept(boost::beast::error_code ec, tcp::socket socket); boost::asio::io_context &ioc_; + ServerContext *context_; tcp::acceptor acceptor_; utils::Synchronized<std::list<std::shared_ptr<Session>>, utils::SpinLock> sessions_; SafeAuth auth_; diff --git a/src/communication/websocket/server.hpp b/src/communication/websocket/server.hpp index 81713cbe3..5be8bf29e 100644 --- a/src/communication/websocket/server.hpp +++ b/src/communication/websocket/server.hpp @@ -28,10 +28,9 @@ class Server final { using tcp = boost::asio::ip::tcp; public: - explicit Server(io::network::Endpoint endpoint, SafeAuth auth) - : ioc_{}, - listener_{Listener::Create(ioc_, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port}, - auth)} {} + explicit Server(io::network::Endpoint endpoint, ServerContext *context, SafeAuth auth) + : listener_{Listener::Create( + ioc_, context, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port}, auth)} {} Server(const Server &) = delete; Server(Server &&) = delete; diff --git a/src/communication/websocket/session.cpp b/src/communication/websocket/session.cpp index 376668d62..aa86955ae 100644 --- a/src/communication/websocket/session.cpp +++ b/src/communication/websocket/session.cpp @@ -19,6 +19,7 @@ #include <spdlog/spdlog.h> #include <boost/asio/bind_executor.hpp> #include <boost/beast/core/buffers_to_string.hpp> +#include <boost/beast/core/stream_traits.hpp> #include <json/json.hpp> #include "communication/context.hpp" @@ -26,27 +27,63 @@ namespace communication::websocket { namespace { -void LogError(boost::beast::error_code ec, const std::string_view what) { +void LogError(const boost::beast::error_code ec, const std::string_view what) { spdlog::warn("Websocket session failed on {}: {}", what, ec.message()); } } // namespace -void Session::Run() { - ws_.set_option(boost::beast::websocket::stream_base::timeout::suggested(boost::beast::role_type::server)); - - ws_.set_option(boost::beast::websocket::stream_base::decorator( - [](boost::beast::websocket::response_type &res) { res.set(boost::beast::http::field::server, "Memgraph WS"); })); - - // Accept the websocket handshake - boost::beast::error_code ec; - ws_.accept(ec); - if (ec) { - return LogError(ec, "accept"); +std::variant<Session::PlainWebSocket, Session::SSLWebSocket> Session::CreateWebSocket(tcp::socket &&socket, + ServerContext &context) { + if (context.use_ssl()) { + ssl_context_.emplace(context.context_clone()); + return Session::SSLWebSocket{std::move(socket), *ssl_context_}; } + + return Session::PlainWebSocket{std::move(socket)}; +} + +Session::Session(tcp::socket &&socket, ServerContext &context, SafeAuth auth) + : ws_(CreateWebSocket(std::move(socket), context)), strand_{boost::asio::make_strand(GetExecutor())}, auth_{auth} {} + +bool Session::Run() { + ExecuteForWebsocket([](auto &&ws) { + ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(boost::beast::role_type::server)); + + ws.set_option(boost::beast::websocket::stream_base::decorator([](boost::beast::websocket::response_type &res) { + res.set(boost::beast::http::field::server, "Memgraph WS"); + })); + }); + + if (auto *ssl_ws = std::get_if<SSLWebSocket>(&ws_); ssl_ws != nullptr) { + try { + boost::beast::get_lowest_layer(*ssl_ws).expires_after(std::chrono::seconds(30)); + ssl_ws->next_layer().handshake(boost::asio::ssl::stream_base::server); + } catch (const boost::system::system_error &e) { + spdlog::warn("Failed on SSL handshake: {}", e.what()); + return false; + } + } + + auto result = ExecuteForWebsocket([](auto &&ws) -> bool { + // Accept the websocket handshake + boost::beast::error_code ec; + ws.accept(ec); + if (ec) { + LogError(ec, "accept"); + return false; + } + return true; + }); + + if (!result) { + return false; + } + connected_.store(true, std::memory_order_relaxed); // run on the strand boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoRead(); }); + return true; } void Session::Write(std::shared_ptr<std::string> message) { @@ -54,7 +91,7 @@ void Session::Write(std::shared_ptr<std::string> message) { if (!shared_this->connected_.load(std::memory_order_relaxed)) { return; } - if (!shared_this->authenticated_) { + if (!shared_this->IsAuthenticated()) { return; } shared_this->messages_.push_back(std::move(message)); @@ -68,13 +105,15 @@ void Session::Write(std::shared_ptr<std::string> message) { bool Session::IsConnected() const { return connected_.load(std::memory_order_relaxed); } void Session::DoWrite() { - auto next_message = messages_.front(); - ws_.async_write( - boost::asio::buffer(*next_message), - boost::asio::bind_executor(strand_, [message_string = std::move(next_message), shared_this = shared_from_this()]( - boost::beast::error_code ec, const size_t bytes_transferred) { - shared_this->OnWrite(ec, bytes_transferred); - })); + ExecuteForWebsocket([this](auto &&ws) { + auto next_message = messages_.front(); + ws.async_write(boost::asio::buffer(*next_message), + boost::asio::bind_executor( + strand_, [message_string = std::move(next_message), shared_this = shared_from_this()]( + boost::beast::error_code ec, const size_t bytes_transferred) { + shared_this->OnWrite(ec, bytes_transferred); + })); + }); } void Session::OnWrite(boost::beast::error_code ec, size_t /*bytes_transferred*/) { @@ -84,7 +123,7 @@ void Session::OnWrite(boost::beast::error_code ec, size_t /*bytes_transferred*/) return LogError(ec, "write"); } if (close_) { - DoClose(); + DoShutdown(); return; } if (!messages_.empty()) { @@ -93,18 +132,19 @@ void Session::OnWrite(boost::beast::error_code ec, size_t /*bytes_transferred*/) } void Session::DoRead() { - ws_.async_read( - buffer_, boost::asio::bind_executor(strand_, [shared_this = shared_from_this()](boost::beast::error_code ec, - const size_t bytes_transferred) { - shared_this->OnRead(ec, bytes_transferred); - })); + ExecuteForWebsocket([this](auto &&ws) { + ws.async_read(buffer_, boost::asio::bind_executor(strand_, std::bind_front(&Session::OnRead, shared_from_this()))); + }); + ; } void Session::DoClose() { - ws_.async_close(boost::beast::websocket::close_code::normal, - boost::asio::bind_executor(strand_, [shared_this = shared_from_this()](boost::beast::error_code ec) { - shared_this->OnClose(ec); - })); + ExecuteForWebsocket([this](auto &&ws) mutable { + ws.async_close(boost::beast::websocket::close_code::normal, + boost::asio::bind_executor(strand_, [shared_this = shared_from_this()](boost::beast::error_code ec) { + shared_this->OnClose(ec); + })); + }); } void Session::OnClose(boost::beast::error_code ec) { @@ -128,12 +168,11 @@ utils::BasicResult<std::string> Session::Authorize(const nlohmann::json &creds) void Session::OnRead(const boost::beast::error_code ec, const size_t /*bytes_transferred*/) { if (ec == boost::beast::websocket::error::closed) { - messages_.clear(); - connected_.store(false, std::memory_order_relaxed); + DoShutdown(); return; } - if (!authenticated_ && auth_.HasAnyUsers()) { + if (!IsAuthenticated()) { auto response = nlohmann::json(); auto auth_failed = [this, &response](const std::string &message) { response["success"] = false; @@ -173,4 +212,21 @@ void Session::OnRead(const boost::beast::error_code ec, const size_t /*bytes_tra DoRead(); } +bool Session::IsAuthenticated() const { return authenticated_ || !auth_.HasAnyUsers(); } + +void Session::DoShutdown() { + std::visit(utils::Overloaded{[this](SSLWebSocket &ssl_ws) { + boost::beast::get_lowest_layer(ssl_ws).expires_after(std::chrono::seconds(30)); + ssl_ws.next_layer().async_shutdown( + [shared_this = shared_from_this()](boost::beast::error_code ec) { + if (ec) { + LogError(ec, "shutdown"); + } + shared_this->DoClose(); + }); + }, + [this](auto && /* plain_ws */) { DoClose(); }}, + ws_); +} + } // namespace communication::websocket diff --git a/src/communication/websocket/session.hpp b/src/communication/websocket/session.hpp index de1869c21..899daf366 100644 --- a/src/communication/websocket/session.hpp +++ b/src/communication/websocket/session.hpp @@ -15,17 +15,22 @@ #include <deque> #include <memory> +#include <optional> +#include <variant> #include <boost/asio/dispatch.hpp> #include <boost/asio/ip/tcp.hpp> #include <boost/asio/strand.hpp> #include <boost/beast/core/tcp_stream.hpp> +#include <boost/beast/ssl.hpp> #include <boost/beast/websocket.hpp> #include <json/json.hpp> +#include "communication/context.hpp" #include "communication/websocket/auth.hpp" #include "utils/result.hpp" #include "utils/synchronized.hpp" +#include "utils/variant_helpers.hpp" namespace communication::websocket { class Session : public std::enable_shared_from_this<Session> { @@ -37,13 +42,15 @@ class Session : public std::enable_shared_from_this<Session> { return std::shared_ptr<Session>{new Session{std::forward<Args>(args)...}}; } - void Run(); + bool Run(); void Write(std::shared_ptr<std::string> message); bool IsConnected() const; private: - explicit Session(tcp::socket &&socket, SafeAuth auth) - : ws_(std::move(socket)), strand_{boost::asio::make_strand(ws_.get_executor())}, auth_(auth) {} + using PlainWebSocket = boost::beast::websocket::stream<boost::beast::tcp_stream>; + using SSLWebSocket = boost::beast::websocket::stream<boost::beast::ssl_stream<boost::beast::tcp_stream>>; + + explicit Session(tcp::socket &&socket, ServerContext &context, SafeAuth auth); void DoWrite(); void OnWrite(boost::beast::error_code ec, size_t bytest_transferred); @@ -56,10 +63,26 @@ class Session : public std::enable_shared_from_this<Session> { utils::BasicResult<std::string> Authorize(const nlohmann::json &creds); - boost::beast::websocket::stream<boost::beast::tcp_stream> ws_; + bool IsAuthenticated() const; + + void DoShutdown(); + + auto GetExecutor() { + return std::visit(utils::Overloaded{[](auto &&ws) { return ws.get_executor(); }}, ws_); + } + + template <typename F> + decltype(auto) ExecuteForWebsocket(F &&fn) { + return std::visit(utils::Overloaded{std::forward<F>(fn)}, ws_); + } + + std::variant<PlainWebSocket, SSLWebSocket> CreateWebSocket(tcp::socket &&socket, ServerContext &context); + + std::optional<boost::asio::ssl::context> ssl_context_; + std::variant<PlainWebSocket, SSLWebSocket> ws_; boost::beast::flat_buffer buffer_; std::deque<std::shared_ptr<std::string>> messages_; - boost::asio::strand<decltype(ws_)::executor_type> strand_; + boost::asio::strand<PlainWebSocket::executor_type> strand_; std::atomic<bool> connected_{false}; bool authenticated_{false}; bool close_{false}; diff --git a/src/memgraph.cpp b/src/memgraph.cpp index ba992faf6..32a9741ea 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -349,39 +349,49 @@ DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), { }); namespace { -void ParseLogLevel() { +spdlog::level::level_enum ParseLogLevel() { const auto log_level = StringToEnum<spdlog::level::level_enum>(FLAGS_log_level, log_level_mappings); MG_ASSERT(log_level, "Invalid log level"); - spdlog::set_level(*log_level); + return *log_level; } // 5 weeks * 7 days constexpr auto log_retention_count = 35; +void CreateLoggerFromSink(const auto &sinks, const auto log_level) { + auto logger = std::make_shared<spdlog::logger>("memgraph_log", sinks.begin(), sinks.end()); + logger->set_level(log_level); + logger->flush_on(spdlog::level::trace); + spdlog::set_default_logger(std::move(logger)); +} -void ConfigureLogging() { - std::vector<spdlog::sink_ptr> loggers; +void InitializeLogger() { + std::vector<spdlog::sink_ptr> sinks; if (FLAGS_also_log_to_stderr) { - loggers.emplace_back(std::make_shared<spdlog::sinks::stderr_color_sink_mt>()); + sinks.emplace_back(std::make_shared<spdlog::sinks::stderr_color_sink_mt>()); } if (!FLAGS_log_file.empty()) { // get local time - time_t current_time; + time_t current_time{0}; struct tm *local_time{nullptr}; time(¤t_time); local_time = localtime(¤t_time); - loggers.emplace_back(std::make_shared<spdlog::sinks::daily_file_sink_mt>( + sinks.emplace_back(std::make_shared<spdlog::sinks::daily_file_sink_mt>( FLAGS_log_file, local_time->tm_hour, local_time->tm_min, false, log_retention_count)); } - - spdlog::set_default_logger(std::make_shared<spdlog::logger>("memgraph_log", loggers.begin(), loggers.end())); - - spdlog::flush_on(spdlog::level::trace); - ParseLogLevel(); + CreateLoggerFromSink(sinks, ParseLogLevel()); } + +void AddLoggerSink(spdlog::sink_ptr new_sink) { + auto default_logger = spdlog::default_logger(); + auto sinks = default_logger->sinks(); + sinks.push_back(new_sink); + CreateLoggerFromSink(sinks, default_logger->level()); +} + } // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -992,7 +1002,7 @@ int main(int argc, char **argv) { LoadConfig("memgraph"); gflags::ParseCommandLineFlags(&argc, &argv, true); - ConfigureLogging(); + InitializeLogger(); // Unhandled exception handler init. std::set_terminate(&utils::TerminateHandler); @@ -1205,13 +1215,9 @@ int main(int argc, char **argv) { []() -> nlohmann::json { return query::plan::CallProcedure::GetAndResetCounters(); }); } - communication::websocket::Server websocket_server{{"0.0.0.0", 7444}, communication::websocket::SafeAuth{&auth}}; - - { - auto sinks = spdlog::default_logger()->sinks(); - sinks.push_back(websocket_server.GetLoggingSink()); - spdlog::set_default_logger(std::make_shared<spdlog::logger>("memgraph_log", sinks.begin(), sinks.end())); - } + communication::websocket::Server websocket_server{ + {"0.0.0.0", 7444}, &context, communication::websocket::SafeAuth{&auth}}; + AddLoggerSink(websocket_server.GetLoggingSink()); // Handler for regular termination signals auto shutdown = [&websocket_server, &server, &interpreter_context] {