diff --git a/src/communication/CMakeLists.txt b/src/communication/CMakeLists.txt index f782ede1c..b8e64e015 100644 --- a/src/communication/CMakeLists.txt +++ b/src/communication/CMakeLists.txt @@ -2,6 +2,9 @@ find_package(fmt REQUIRED) find_package(gflags REQUIRED) set(communication_src_files + websocket/server.cpp + websocket/listener.cpp + websocket/session.cpp bolt/v1/value.cpp buffer.cpp client.cpp @@ -9,8 +12,10 @@ set(communication_src_files helpers.cpp init.cpp) +find_package(Boost REQUIRED) + add_library(mg-communication STATIC ${communication_src_files}) -target_link_libraries(mg-communication Threads::Threads mg-utils mg-io fmt::fmt gflags) +target_link_libraries(mg-communication Boost::headers Threads::Threads mg-utils mg-io fmt::fmt gflags) find_package(OpenSSL REQUIRED) target_link_libraries(mg-communication ${OPENSSL_LIBRARIES}) diff --git a/src/communication/websocket/listener.cpp b/src/communication/websocket/listener.cpp new file mode 100644 index 000000000..d4888c849 --- /dev/null +++ b/src/communication/websocket/listener.cpp @@ -0,0 +1,82 @@ +// Copyright 2021 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "communication/websocket/listener.hpp" + +namespace communication::websocket { +namespace { +void LogError(boost::beast::error_code ec, const std::string_view what) { + spdlog::warn("Websocket listener failed on {}: {}", what, ec.message()); +} +} // namespace +void Listener::Run() { DoAccept(); } + +void Listener::WriteToAll(std::shared_ptr<std::string> message) { + auto sessions_ptr = sessions_.Lock(); + for (auto &session : *sessions_ptr) { + session->Write(message); + } +} + +Listener::Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint) : ioc_(ioc), acceptor_(ioc) { + boost::beast::error_code ec; + + // Open the acceptor + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + LogError(ec, "open"); + return; + } + + // Allow address reuse + acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec); + if (ec) { + LogError(ec, "set_option"); + return; + } + + // Bind to the server address + acceptor_.bind(endpoint, ec); + if (ec) { + LogError(ec, "bind"); + return; + } + + acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec); + if (ec) { + LogError(ec, "listen"); + return; + } + + spdlog::info("WebSocket server is listening on {}:{}", endpoint.address(), endpoint.port()); +} + +void Listener::DoAccept() { + acceptor_.async_accept( + ioc_, [shared_this = shared_from_this()](auto ec, auto socket) { shared_this->OnAccept(ec, std::move(socket)); }); +} + +void Listener::OnAccept(boost::beast::error_code ec, tcp::socket socket) { + if (ec) { + return LogError(ec, "accept"); + } + + { + auto sessions_ptr = sessions_.Lock(); + sessions_ptr->emplace_back(Session::Create(std::move(socket)))->Run(); + + // Clean disconnected clients + std::erase_if(*sessions_ptr, [](const auto &elem) { return !elem->IsConnected(); }); + } + + DoAccept(); +} +} // namespace communication::websocket diff --git a/src/communication/websocket/listener.hpp b/src/communication/websocket/listener.hpp new file mode 100644 index 000000000..21c3b434e --- /dev/null +++ b/src/communication/websocket/listener.hpp @@ -0,0 +1,51 @@ +// Copyright 2021 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <list> +#include <memory> + +#include <spdlog/spdlog.h> +#include <boost/asio/io_context.hpp> +#include <boost/asio/ip/tcp.hpp> +#include <boost/asio/strand.hpp> +#include <boost/beast/core.hpp> + +#include "communication/websocket/session.hpp" +#include "utils/spin_lock.hpp" +#include "utils/synchronized.hpp" + +namespace communication::websocket { +class Listener : public std::enable_shared_from_this<Listener> { + using tcp = boost::asio::ip::tcp; + + public: + template <typename... Args> + static std::shared_ptr<Listener> Create(Args &&...args) { + return std::shared_ptr<Listener>{new Listener(std::forward<Args>(args)...)}; + } + + // Start accepting incoming connections + void Run(); + void WriteToAll(std::shared_ptr<std::string> message); + + private: + Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint); + + void DoAccept(); + void OnAccept(boost::beast::error_code ec, tcp::socket socket); + + boost::asio::io_context &ioc_; + tcp::acceptor acceptor_; + utils::Synchronized<std::list<std::shared_ptr<Session>>, utils::SpinLock> sessions_; +}; +} // namespace communication::websocket diff --git a/src/communication/websocket/server.cpp b/src/communication/websocket/server.cpp new file mode 100644 index 000000000..e5dd550a1 --- /dev/null +++ b/src/communication/websocket/server.cpp @@ -0,0 +1,50 @@ +// Copyright 2021 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "communication/websocket/server.hpp" + +namespace communication::websocket { + +Server::~Server() { + MG_ASSERT(!background_thread_ || (ioc_.stopped() && !background_thread_->joinable()), + "Server wasn't shutdown properly"); +} + +void Server::Start() { + MG_ASSERT(!background_thread_, "The server was already started!"); + listener_->Run(); + background_thread_.emplace([this] { ioc_.run(); }); +} + +void Server::Shutdown() { ioc_.stop(); } + +void Server::AwaitShutdown() { + if (background_thread_ && background_thread_->joinable()) { + background_thread_->join(); + } +} + +bool Server::IsRunning() const { return background_thread_ && !ioc_.stopped(); } + +void Server::LoggingSink::sink_it_(const spdlog::details::log_msg &msg) { + const auto listener = listener_.lock(); + if (!listener) { + return; + } + using memory_buf_t = fmt::basic_memory_buffer<char, 250>; + memory_buf_t formatted; + base_sink<std::mutex>::formatter_->format(msg, formatted); + listener->WriteToAll(std::make_shared<std::string>(formatted.data(), formatted.size())); +} + +std::shared_ptr<Server::LoggingSink> Server::GetLoggingSink() { return std::make_shared<LoggingSink>(listener_); } + +} // namespace communication::websocket diff --git a/src/communication/websocket/server.hpp b/src/communication/websocket/server.hpp new file mode 100644 index 000000000..2caafa46c --- /dev/null +++ b/src/communication/websocket/server.hpp @@ -0,0 +1,66 @@ +// Copyright 2021 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <thread> + +#include <spdlog/sinks/base_sink.h> +#include <boost/asio/io_context.hpp> +#include <boost/asio/ip/tcp.hpp> + +#include "communication/websocket/listener.hpp" +#include "io/network/endpoint.hpp" + +namespace communication::websocket { + +class Server final { + using tcp = boost::asio::ip::tcp; + + public: + explicit Server(io::network::Endpoint endpoint) + : ioc_{}, + listener_{ + Listener::Create(ioc_, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port})} {} + + Server(const Server &) = delete; + Server(Server &&) = delete; + Server &operator=(const Server &) = delete; + Server &operator=(Server &&) = delete; + + ~Server(); + + void Start(); + void Shutdown(); + void AwaitShutdown(); + bool IsRunning() const; + + class LoggingSink : public spdlog::sinks::base_sink<std::mutex> { + public: + explicit LoggingSink(std::weak_ptr<Listener> listener) : listener_(listener) {} + + private: + void sink_it_(const spdlog::details::log_msg &msg) override; + + void flush_() override {} + + std::weak_ptr<Listener> listener_; + }; + + std::shared_ptr<LoggingSink> GetLoggingSink(); + + private: + boost::asio::io_context ioc_; + + std::shared_ptr<Listener> listener_; + std::optional<std::thread> background_thread_; +}; +} // namespace communication::websocket diff --git a/src/communication/websocket/session.cpp b/src/communication/websocket/session.cpp new file mode 100644 index 000000000..c6e7c1b08 --- /dev/null +++ b/src/communication/websocket/session.cpp @@ -0,0 +1,102 @@ +// Copyright 2021 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT + +#include "communication/websocket/session.hpp" + +#include <boost/asio/bind_executor.hpp> + +#include "utils/logging.hpp" + +namespace communication::websocket { +namespace { +void LogError(boost::beast::error_code ec, const std::string_view what) { + spdlog::warn("Websocket session failed on {}: {}", what, ec.message()); +} +} // namespace + +void Session::Run() { + ws_.set_option(boost::beast::websocket::stream_base::timeout::suggested(boost::beast::role_type::server)); + + ws_.set_option(boost::beast::websocket::stream_base::decorator( + [](boost::beast::websocket::response_type &res) { res.set(boost::beast::http::field::server, "Memgraph WS"); })); + + // Accept the websocket handshake + boost::beast::error_code ec; + ws_.accept(ec); + if (ec) { + return LogError(ec, "accept"); + } + connected_.store(true, std::memory_order_relaxed); + + // run on the strand + boost::asio::dispatch(strand_, [shared_this = shared_from_this()] { shared_this->DoRead(); }); +} + +void Session::Write(std::shared_ptr<std::string> message) { + if (!connected_.load(std::memory_order_relaxed)) { + return; + } + boost::asio::dispatch(strand_, [message = std::move(message), shared_this = shared_from_this()]() mutable { + shared_this->messages_.push_back(std::move(message)); + + if (shared_this->messages_.size() > 1) { + return; + } + shared_this->DoWrite(); + }); +} + +bool Session::IsConnected() const { return connected_.load(std::memory_order_relaxed); } + +void Session::DoWrite() { + auto next_message = messages_.front(); + ws_.async_write( + boost::asio::buffer(*next_message), + boost::asio::bind_executor(strand_, [message_string = std::move(next_message), shared_this = shared_from_this()]( + boost::beast::error_code ec, const size_t bytes_transferred) { + shared_this->OnWrite(ec, bytes_transferred); + })); +} + +void Session::OnWrite(boost::beast::error_code ec, size_t /*bytes_transferred*/) { + messages_.pop_front(); + + if (ec) { + return LogError(ec, "write"); + } + + if (!messages_.empty()) { + DoWrite(); + } +} + +void Session::DoRead() { + ws_.async_read( + buffer_, boost::asio::bind_executor(strand_, [shared_this = shared_from_this()](boost::beast::error_code ec, + const size_t bytes_transferred) { + shared_this->OnRead(ec, bytes_transferred); + })); +} + +void Session::OnRead(boost::beast::error_code ec, size_t /*bytes_transferred*/) { + if (ec == boost::beast::websocket::error::closed) { + messages_.clear(); + connected_.store(false, std::memory_order_relaxed); + return; + } + + buffer_.consume(buffer_.size()); + DoRead(); +} + +} // namespace communication::websocket diff --git a/src/communication/websocket/session.hpp b/src/communication/websocket/session.hpp new file mode 100644 index 000000000..5d50d37a8 --- /dev/null +++ b/src/communication/websocket/session.hpp @@ -0,0 +1,52 @@ +// Copyright 2021 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <deque> +#include <memory> + +#include <boost/asio/dispatch.hpp> +#include <boost/asio/ip/tcp.hpp> +#include <boost/asio/strand.hpp> +#include <boost/beast/core/tcp_stream.hpp> +#include <boost/beast/websocket.hpp> + +namespace communication::websocket { +class Session : public std::enable_shared_from_this<Session> { + using tcp = boost::asio::ip::tcp; + + public: + template <typename... Args> + static std::shared_ptr<Session> Create(Args &&...args) { + return std::shared_ptr<Session>{new Session{std::forward<Args>(args)...}}; + } + + void Run(); + void Write(std::shared_ptr<std::string> message); + bool IsConnected() const; + + private: + explicit Session(tcp::socket &&socket) + : ws_(std::move(socket)), strand_{boost::asio::make_strand(ws_.get_executor())} {} + + void DoWrite(); + void OnWrite(boost::beast::error_code ec, size_t bytest_transferred); + void DoRead(); + void OnRead(boost::beast::error_code ec, size_t bytest_transferred); + + boost::beast::websocket::stream<boost::beast::tcp_stream> ws_; + boost::beast::flat_buffer buffer_; + std::deque<std::shared_ptr<std::string>> messages_; + boost::asio::strand<decltype(ws_)::executor_type> strand_; + std::atomic<bool> connected_{false}; +}; +} // namespace communication::websocket diff --git a/src/memgraph.cpp b/src/memgraph.cpp index d6f3eebc0..2a4b5eae4 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -29,8 +29,11 @@ #include <gflags/gflags.h> #include <spdlog/common.h> #include <spdlog/sinks/daily_file_sink.h> +#include <spdlog/sinks/dist_sink.h> #include <spdlog/sinks/stdout_color_sinks.h> +#include "communication/websocket/server.hpp" + #include "communication/bolt/v1/constants.hpp" #include "helpers.hpp" #include "py/py.hpp" @@ -43,6 +46,7 @@ #include "query/procedure/module.hpp" #include "query/procedure/py_module.hpp" #include "requests/requests.hpp" +#include "spdlog/spdlog.h" #include "storage/v2/isolation_level.hpp" #include "storage/v2/storage.hpp" #include "storage/v2/view.hpp" @@ -1200,8 +1204,16 @@ int main(int argc, char **argv) { []() -> nlohmann::json { return query::plan::CallProcedure::GetAndResetCounters(); }); } + communication::websocket::Server websocket_server{{"0.0.0.0", 7444}}; + + { + auto sinks = spdlog::default_logger()->sinks(); + sinks.push_back(websocket_server.GetLoggingSink()); + spdlog::set_default_logger(std::make_shared<spdlog::logger>("memgraph_log", sinks.begin(), sinks.end())); + } + // Handler for regular termination signals - auto shutdown = [&server, &interpreter_context] { + auto shutdown = [&websocket_server, &server, &interpreter_context] { // Server needs to be shutdown first and then the database. This prevents // a race condition when a transaction is accepted during server shutdown. server.Shutdown(); @@ -1209,11 +1221,17 @@ int main(int argc, char **argv) { // connections we tell the execution engine to stop processing all pending // queries. query::Shutdown(&interpreter_context); + websocket_server.Shutdown(); }; + InitSignalHandlers(shutdown); MG_ASSERT(server.Start(), "Couldn't start the Bolt server!"); + websocket_server.Start(); + server.AwaitShutdown(); + websocket_server.AwaitShutdown(); + query::procedure::gModuleRegistry.UnloadAllModules(); Py_END_ALLOW_THREADS;