WSS support (#327)

This commit is contained in:
Antonio Andelic 2022-01-27 09:51:00 +01:00 committed by Antonio Andelic
parent 728b37080d
commit 06e6ead4d2
8 changed files with 176 additions and 90 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2021 Memgraph Ltd.
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -73,11 +73,9 @@ SSL_CTX *ClientContext::context() { return ctx_; }
bool ClientContext::use_ssl() { return use_ssl_; }
ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {}
ServerContext::ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file,
bool verify_peer)
: use_ssl_(true),
:
#if OPENSSL_VERSION_NUMBER < 0x10100000L
ctx_(SSL_CTX_new(SSLv23_server_method()))
#else
@ -110,43 +108,47 @@ ServerContext::ServerContext(const std::string &key_file, const std::string &cer
SSL_CTX_set_client_CA_list(ctx_, ca_names);
// Enable verification of the client certificate.
// NOLINTNEXTLINE(hicpp-signed-bitwise)
SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
}
}
}
ServerContext::ServerContext(ServerContext &&other) noexcept : use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
other.use_ssl_ = false;
other.ctx_ = nullptr;
}
ServerContext::ServerContext(ServerContext &&other) noexcept : ctx_(other.ctx_) { other.ctx_ = nullptr; }
ServerContext &ServerContext::operator=(ServerContext &&other) noexcept {
if (this == &other) return *this;
// destroy my objects
if (use_ssl_) {
if (ctx_) {
SSL_CTX_free(ctx_);
}
// move other objects to self
use_ssl_ = other.use_ssl_;
ctx_ = other.ctx_;
// reset other objects
other.use_ssl_ = false;
other.ctx_ = nullptr;
return *this;
}
ServerContext::~ServerContext() {
if (use_ssl_) {
if (ctx_) {
SSL_CTX_free(ctx_);
}
}
SSL_CTX *ServerContext::context() { return ctx_; }
SSL_CTX *ServerContext::context() {
MG_ASSERT(ctx_);
return ctx_;
}
SSL_CTX *ServerContext::context_clone() {
MG_ASSERT(ctx_);
SSL_CTX_up_ref(ctx_);
return ctx_;
}
bool ServerContext::use_ssl() { return use_ssl_; }
bool ServerContext::use_ssl() { return ctx_ != nullptr; }
} // namespace communication

View File

@ -1,4 +1,4 @@
// Copyright 2021 Memgraph Ltd.
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -69,11 +69,7 @@ class ClientContext final {
*/
class ServerContext final {
public:
/**
* This constructor constructs a ServerContext that doesn't use SSL.
*/
ServerContext();
ServerContext() = default;
/**
* This constructor constructs a ServerContext that uses SSL. The parameters
* `key_file` and `cert_file` can't be "" because when setting up a server it
@ -95,16 +91,15 @@ class ServerContext final {
ServerContext(ServerContext &&other) noexcept;
ServerContext &operator=(ServerContext &&other) noexcept;
// Destructor that handles ownership of the SSL object.
~ServerContext();
SSL_CTX *context();
SSL_CTX *context_clone();
bool use_ssl();
private:
bool use_ssl_;
SSL_CTX *ctx_;
SSL_CTX *ctx_{nullptr};
};
} // namespace communication

View File

@ -27,8 +27,8 @@ void Listener::WriteToAll(std::shared_ptr<std::string> message) {
}
}
Listener::Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint, SafeAuth auth)
: ioc_(ioc), acceptor_(ioc), auth_(auth) {
Listener::Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::endpoint endpoint, SafeAuth auth)
: ioc_(ioc), context_(context), acceptor_(ioc), auth_(auth) {
boost::beast::error_code ec;
// Open the acceptor
@ -71,12 +71,15 @@ void Listener::OnAccept(boost::beast::error_code ec, tcp::socket socket) {
return LogError(ec, "accept");
}
{
auto session = Session::Create(std::move(socket), *context_, auth_);
if (session->Run()) {
auto sessions_ptr = sessions_.Lock();
sessions_ptr->emplace_back(Session::Create(std::move(socket), auth_))->Run();
// Clean disconnected clients
std::erase_if(*sessions_ptr, [](const auto &elem) { return !elem->IsConnected(); });
sessions_ptr->emplace_back(std::move(session));
}
DoAccept();

View File

@ -22,6 +22,7 @@
#include <boost/asio/strand.hpp>
#include <boost/beast/core.hpp>
#include "communication/context.hpp"
#include "communication/websocket/session.hpp"
#include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
@ -41,12 +42,13 @@ class Listener : public std::enable_shared_from_this<Listener> {
void WriteToAll(std::shared_ptr<std::string> message);
private:
Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint, SafeAuth auth);
Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::endpoint endpoint, SafeAuth auth);
void DoAccept();
void OnAccept(boost::beast::error_code ec, tcp::socket socket);
boost::asio::io_context &ioc_;
ServerContext *context_;
tcp::acceptor acceptor_;
utils::Synchronized<std::list<std::shared_ptr<Session>>, utils::SpinLock> sessions_;
SafeAuth auth_;

View File

@ -28,10 +28,9 @@ class Server final {
using tcp = boost::asio::ip::tcp;
public:
explicit Server(io::network::Endpoint endpoint, SafeAuth auth)
: ioc_{},
listener_{Listener::Create(ioc_, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port},
auth)} {}
explicit Server(io::network::Endpoint endpoint, ServerContext *context, SafeAuth auth)
: listener_{Listener::Create(
ioc_, context, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port}, auth)} {}
Server(const Server &) = delete;
Server(Server &&) = delete;

View File

@ -19,6 +19,7 @@
#include <spdlog/spdlog.h>
#include <boost/asio/bind_executor.hpp>
#include <boost/beast/core/buffers_to_string.hpp>
#include <boost/beast/core/stream_traits.hpp>
#include <json/json.hpp>
#include "communication/context.hpp"
@ -26,27 +27,63 @@
namespace communication::websocket {
namespace {
void LogError(boost::beast::error_code ec, const std::string_view what) {
void LogError(const boost::beast::error_code ec, const std::string_view what) {
spdlog::warn("Websocket session failed on {}: {}", what, ec.message());
}
} // namespace
void Session::Run() {
ws_.set_option(boost::beast::websocket::stream_base::timeout::suggested(boost::beast::role_type::server));
ws_.set_option(boost::beast::websocket::stream_base::decorator(
[](boost::beast::websocket::response_type &res) { res.set(boost::beast::http::field::server, "Memgraph WS"); }));
// Accept the websocket handshake
boost::beast::error_code ec;
ws_.accept(ec);
if (ec) {
return LogError(ec, "accept");
std::variant<Session::PlainWebSocket, Session::SSLWebSocket> Session::CreateWebSocket(tcp::socket &&socket,
ServerContext &context) {
if (context.use_ssl()) {
ssl_context_.emplace(context.context_clone());
return Session::SSLWebSocket{std::move(socket), *ssl_context_};
}
return Session::PlainWebSocket{std::move(socket)};
}
Session::Session(tcp::socket &&socket, ServerContext &context, SafeAuth auth)
: ws_(CreateWebSocket(std::move(socket), context)), strand_{boost::asio::make_strand(GetExecutor())}, auth_{auth} {}
bool Session::Run() {
ExecuteForWebsocket([](auto &&ws) {
ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(boost::beast::role_type::server));
ws.set_option(boost::beast::websocket::stream_base::decorator([](boost::beast::websocket::response_type &res) {
res.set(boost::beast::http::field::server, "Memgraph WS");
}));
});
if (auto *ssl_ws = std::get_if<SSLWebSocket>(&ws_); ssl_ws != nullptr) {
try {
boost::beast::get_lowest_layer(*ssl_ws).expires_after(std::chrono::seconds(30));
ssl_ws->next_layer().handshake(boost::asio::ssl::stream_base::server);
} catch (const boost::system::system_error &e) {
spdlog::warn("Failed on SSL handshake: {}", e.what());
return false;
}
}
auto result = ExecuteForWebsocket([](auto &&ws) -> bool {
// Accept the websocket handshake
boost::beast::error_code ec;
ws.accept(ec);
if (ec) {
LogError(ec, "accept");
return false;
}
return true;
});
if (!result) {
return false;
}
connected_.store(true, std::memory_order_relaxed);
// run on the strand
boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoRead(); });
return true;
}
void Session::Write(std::shared_ptr<std::string> message) {
@ -54,7 +91,7 @@ void Session::Write(std::shared_ptr<std::string> message) {
if (!shared_this->connected_.load(std::memory_order_relaxed)) {
return;
}
if (!shared_this->authenticated_) {
if (!shared_this->IsAuthenticated()) {
return;
}
shared_this->messages_.push_back(std::move(message));
@ -68,13 +105,15 @@ void Session::Write(std::shared_ptr<std::string> message) {
bool Session::IsConnected() const { return connected_.load(std::memory_order_relaxed); }
void Session::DoWrite() {
auto next_message = messages_.front();
ws_.async_write(
boost::asio::buffer(*next_message),
boost::asio::bind_executor(strand_, [message_string = std::move(next_message), shared_this = shared_from_this()](
boost::beast::error_code ec, const size_t bytes_transferred) {
shared_this->OnWrite(ec, bytes_transferred);
}));
ExecuteForWebsocket([this](auto &&ws) {
auto next_message = messages_.front();
ws.async_write(boost::asio::buffer(*next_message),
boost::asio::bind_executor(
strand_, [message_string = std::move(next_message), shared_this = shared_from_this()](
boost::beast::error_code ec, const size_t bytes_transferred) {
shared_this->OnWrite(ec, bytes_transferred);
}));
});
}
void Session::OnWrite(boost::beast::error_code ec, size_t /*bytes_transferred*/) {
@ -84,7 +123,7 @@ void Session::OnWrite(boost::beast::error_code ec, size_t /*bytes_transferred*/)
return LogError(ec, "write");
}
if (close_) {
DoClose();
DoShutdown();
return;
}
if (!messages_.empty()) {
@ -93,18 +132,19 @@ void Session::OnWrite(boost::beast::error_code ec, size_t /*bytes_transferred*/)
}
void Session::DoRead() {
ws_.async_read(
buffer_, boost::asio::bind_executor(strand_, [shared_this = shared_from_this()](boost::beast::error_code ec,
const size_t bytes_transferred) {
shared_this->OnRead(ec, bytes_transferred);
}));
ExecuteForWebsocket([this](auto &&ws) {
ws.async_read(buffer_, boost::asio::bind_executor(strand_, std::bind_front(&Session::OnRead, shared_from_this())));
});
;
}
void Session::DoClose() {
ws_.async_close(boost::beast::websocket::close_code::normal,
boost::asio::bind_executor(strand_, [shared_this = shared_from_this()](boost::beast::error_code ec) {
shared_this->OnClose(ec);
}));
ExecuteForWebsocket([this](auto &&ws) mutable {
ws.async_close(boost::beast::websocket::close_code::normal,
boost::asio::bind_executor(strand_, [shared_this = shared_from_this()](boost::beast::error_code ec) {
shared_this->OnClose(ec);
}));
});
}
void Session::OnClose(boost::beast::error_code ec) {
@ -128,12 +168,11 @@ utils::BasicResult<std::string> Session::Authorize(const nlohmann::json &creds)
void Session::OnRead(const boost::beast::error_code ec, const size_t /*bytes_transferred*/) {
if (ec == boost::beast::websocket::error::closed) {
messages_.clear();
connected_.store(false, std::memory_order_relaxed);
DoShutdown();
return;
}
if (!authenticated_ && auth_.HasAnyUsers()) {
if (!IsAuthenticated()) {
auto response = nlohmann::json();
auto auth_failed = [this, &response](const std::string &message) {
response["success"] = false;
@ -173,4 +212,21 @@ void Session::OnRead(const boost::beast::error_code ec, const size_t /*bytes_tra
DoRead();
}
bool Session::IsAuthenticated() const { return authenticated_ || !auth_.HasAnyUsers(); }
void Session::DoShutdown() {
std::visit(utils::Overloaded{[this](SSLWebSocket &ssl_ws) {
boost::beast::get_lowest_layer(ssl_ws).expires_after(std::chrono::seconds(30));
ssl_ws.next_layer().async_shutdown(
[shared_this = shared_from_this()](boost::beast::error_code ec) {
if (ec) {
LogError(ec, "shutdown");
}
shared_this->DoClose();
});
},
[this](auto && /* plain_ws */) { DoClose(); }},
ws_);
}
} // namespace communication::websocket

View File

@ -15,17 +15,22 @@
#include <deque>
#include <memory>
#include <optional>
#include <variant>
#include <boost/asio/dispatch.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/strand.hpp>
#include <boost/beast/core/tcp_stream.hpp>
#include <boost/beast/ssl.hpp>
#include <boost/beast/websocket.hpp>
#include <json/json.hpp>
#include "communication/context.hpp"
#include "communication/websocket/auth.hpp"
#include "utils/result.hpp"
#include "utils/synchronized.hpp"
#include "utils/variant_helpers.hpp"
namespace communication::websocket {
class Session : public std::enable_shared_from_this<Session> {
@ -37,13 +42,15 @@ class Session : public std::enable_shared_from_this<Session> {
return std::shared_ptr<Session>{new Session{std::forward<Args>(args)...}};
}
void Run();
bool Run();
void Write(std::shared_ptr<std::string> message);
bool IsConnected() const;
private:
explicit Session(tcp::socket &&socket, SafeAuth auth)
: ws_(std::move(socket)), strand_{boost::asio::make_strand(ws_.get_executor())}, auth_(auth) {}
using PlainWebSocket = boost::beast::websocket::stream<boost::beast::tcp_stream>;
using SSLWebSocket = boost::beast::websocket::stream<boost::beast::ssl_stream<boost::beast::tcp_stream>>;
explicit Session(tcp::socket &&socket, ServerContext &context, SafeAuth auth);
void DoWrite();
void OnWrite(boost::beast::error_code ec, size_t bytest_transferred);
@ -56,10 +63,26 @@ class Session : public std::enable_shared_from_this<Session> {
utils::BasicResult<std::string> Authorize(const nlohmann::json &creds);
boost::beast::websocket::stream<boost::beast::tcp_stream> ws_;
bool IsAuthenticated() const;
void DoShutdown();
auto GetExecutor() {
return std::visit(utils::Overloaded{[](auto &&ws) { return ws.get_executor(); }}, ws_);
}
template <typename F>
decltype(auto) ExecuteForWebsocket(F &&fn) {
return std::visit(utils::Overloaded{std::forward<F>(fn)}, ws_);
}
std::variant<PlainWebSocket, SSLWebSocket> CreateWebSocket(tcp::socket &&socket, ServerContext &context);
std::optional<boost::asio::ssl::context> ssl_context_;
std::variant<PlainWebSocket, SSLWebSocket> ws_;
boost::beast::flat_buffer buffer_;
std::deque<std::shared_ptr<std::string>> messages_;
boost::asio::strand<decltype(ws_)::executor_type> strand_;
boost::asio::strand<PlainWebSocket::executor_type> strand_;
std::atomic<bool> connected_{false};
bool authenticated_{false};
bool close_{false};

View File

@ -349,39 +349,49 @@ DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
});
namespace {
void ParseLogLevel() {
spdlog::level::level_enum ParseLogLevel() {
const auto log_level = StringToEnum<spdlog::level::level_enum>(FLAGS_log_level, log_level_mappings);
MG_ASSERT(log_level, "Invalid log level");
spdlog::set_level(*log_level);
return *log_level;
}
// 5 weeks * 7 days
constexpr auto log_retention_count = 35;
void CreateLoggerFromSink(const auto &sinks, const auto log_level) {
auto logger = std::make_shared<spdlog::logger>("memgraph_log", sinks.begin(), sinks.end());
logger->set_level(log_level);
logger->flush_on(spdlog::level::trace);
spdlog::set_default_logger(std::move(logger));
}
void ConfigureLogging() {
std::vector<spdlog::sink_ptr> loggers;
void InitializeLogger() {
std::vector<spdlog::sink_ptr> sinks;
if (FLAGS_also_log_to_stderr) {
loggers.emplace_back(std::make_shared<spdlog::sinks::stderr_color_sink_mt>());
sinks.emplace_back(std::make_shared<spdlog::sinks::stderr_color_sink_mt>());
}
if (!FLAGS_log_file.empty()) {
// get local time
time_t current_time;
time_t current_time{0};
struct tm *local_time{nullptr};
time(&current_time);
local_time = localtime(&current_time);
loggers.emplace_back(std::make_shared<spdlog::sinks::daily_file_sink_mt>(
sinks.emplace_back(std::make_shared<spdlog::sinks::daily_file_sink_mt>(
FLAGS_log_file, local_time->tm_hour, local_time->tm_min, false, log_retention_count));
}
spdlog::set_default_logger(std::make_shared<spdlog::logger>("memgraph_log", loggers.begin(), loggers.end()));
spdlog::flush_on(spdlog::level::trace);
ParseLogLevel();
CreateLoggerFromSink(sinks, ParseLogLevel());
}
void AddLoggerSink(spdlog::sink_ptr new_sink) {
auto default_logger = spdlog::default_logger();
auto sinks = default_logger->sinks();
sinks.push_back(new_sink);
CreateLoggerFromSink(sinks, default_logger->level());
}
} // namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -992,7 +1002,7 @@ int main(int argc, char **argv) {
LoadConfig("memgraph");
gflags::ParseCommandLineFlags(&argc, &argv, true);
ConfigureLogging();
InitializeLogger();
// Unhandled exception handler init.
std::set_terminate(&utils::TerminateHandler);
@ -1205,13 +1215,9 @@ int main(int argc, char **argv) {
[]() -> nlohmann::json { return query::plan::CallProcedure::GetAndResetCounters(); });
}
communication::websocket::Server websocket_server{{"0.0.0.0", 7444}, communication::websocket::SafeAuth{&auth}};
{
auto sinks = spdlog::default_logger()->sinks();
sinks.push_back(websocket_server.GetLoggingSink());
spdlog::set_default_logger(std::make_shared<spdlog::logger>("memgraph_log", sinks.begin(), sinks.end()));
}
communication::websocket::Server websocket_server{
{"0.0.0.0", 7444}, &context, communication::websocket::SafeAuth{&auth}};
AddLoggerSink(websocket_server.GetLoggingSink());
// Handler for regular termination signals
auto shutdown = [&websocket_server, &server, &interpreter_context] {