// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // License, and you may not use this file except in compliance with the Business Source License. // // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "utils/logging.hpp" namespace beast = boost::beast; namespace http = beast::http; namespace websocket = beast::websocket; namespace net = boost::asio; using tcp = boost::asio::ip::tcp; namespace ssl = boost::asio::ssl; constexpr std::array kSupportedLogLevels{"debug", "trace", "info", "warning", "error", "critical"}; struct Credentials { std::string_view username; std::string_view passsword; }; inline void Fail(beast::error_code ec, char const *what) { std::cerr << what << ": " << ec.message() << "\n"; } inline std::string GetAuthenticationJSON(const Credentials &creds) { nlohmann::json json_creds; json_creds["username"] = creds.username; json_creds["password"] = creds.passsword; return json_creds.dump(); } template class Session : public std::enable_shared_from_this> { using std::enable_shared_from_this>::shared_from_this; public: explicit Session(net::io_context &ioc, ssl::context &ctx, std::vector &expected_messages) requires(ssl) : resolver_(net::make_strand(ioc)), ws_(net::make_strand(ioc), ctx), received_messages_{expected_messages} {} explicit Session(net::io_context &ioc, std::vector &expected_messages) requires(!ssl) : resolver_(net::make_strand(ioc)), ws_(net::make_strand(ioc)), received_messages_{expected_messages} {} template explicit Session(Credentials creds, Args &&...args) : Session(std::forward(args)...) { creds_.emplace(creds); } void Run(std::string_view host, std::string_view port) { host_ = host; resolver_.async_resolve(host, port, beast::bind_front_handler(&Session::OnResolve, shared_from_this())); } void OnResolve(beast::error_code ec, tcp::resolver::results_type results) { if (ec) { return Fail(ec, "resolve"); } beast::get_lowest_layer(ws_).expires_after(std::chrono::seconds(30)); beast::get_lowest_layer(ws_).async_connect(results, beast::bind_front_handler(&Session::OnConnect, shared_from_this())); } void OnConnect(beast::error_code ec, tcp::resolver::results_type::endpoint_type ep) { if (ec) { return Fail(ec, "connect"); } host_ = fmt::format("{}:{}", host_, ep.port()); beast::get_lowest_layer(ws_).expires_after(std::chrono::seconds(30)); if constexpr (ssl) { if (!SSL_set_tlsext_host_name(ws_.next_layer().native_handle(), host_.c_str())) { ec = beast::error_code(static_cast(::ERR_get_error()), net::error::get_ssl_category()); return Fail(ec, "connect"); } ws_.next_layer().async_handshake(ssl::stream_base::client, beast::bind_front_handler(&Session::OnSSLHandshake, shared_from_this())); } else { ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); ws_.set_option(websocket::stream_base::decorator([](websocket::request_type &req) { req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-async"); })); host_ = fmt::format("{}:{}", host_, ep.port()); ws_.async_handshake(host_, "/", beast::bind_front_handler(&Session::OnHandshake, shared_from_this())); } } void OnHandshake(beast::error_code ec) { if (ec) { return Fail(ec, "handshake"); } if (creds_) { ws_.async_write(net::buffer(GetAuthenticationJSON(*creds_)), beast::bind_front_handler(&Session::OnWrite, shared_from_this())); } else { ws_.async_read(buffer_, beast::bind_front_handler(&Session::OnRead, shared_from_this())); } } void OnSSLHandshake(beast::error_code ec) { if (ec) { return Fail(ec, "ssl_handshake"); } beast::get_lowest_layer(ws_).expires_never(); ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); ws_.set_option(websocket::stream_base::decorator([](websocket::request_type &req) { req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-async-ssl"); })); ws_.async_handshake(host_, "/", beast::bind_front_handler(&Session::OnHandshake, shared_from_this())); } void OnWrite(beast::error_code ec, std::size_t bytes_transferred) { boost::ignore_unused(bytes_transferred); if (ec) { return Fail(ec, "write"); } ws_.async_read(buffer_, beast::bind_front_handler(&Session::OnRead, shared_from_this())); } void OnRead(beast::error_code ec, std::size_t bytes_transferred) { boost::ignore_unused(bytes_transferred); if (ec) { return Fail(ec, "read"); } received_messages_.push_back(boost::beast::buffers_to_string(buffer_.data())); buffer_.clear(); ws_.async_read(buffer_, beast::bind_front_handler(&Session::OnRead, shared_from_this())); } void OnClose(beast::error_code ec) { if (ec) { return Fail(ec, "close"); } } private: using InternalStream = std::conditional_t, beast::tcp_stream>; tcp::resolver resolver_; websocket::stream ws_; beast::flat_buffer buffer_; std::string host_; std::vector &received_messages_; std::optional creds_{std::nullopt}; }; std::unique_ptr GetBoltClient(const uint16_t bolt_port, const bool use_ssl) { auto client = mg::Client::Connect({.host = "127.0.0.1", .port = bolt_port, .use_ssl = use_ssl}); MG_ASSERT(client, "Failed to connect!"); return client; } inline void CleanDatabase(std::unique_ptr &client) { MG_ASSERT(client->Execute("MATCH (n) DETACH DELETE n;")); client->DiscardAll(); } inline void AddUser(std::unique_ptr &client) { MG_ASSERT(client->Execute("CREATE USER test IDENTIFIED BY 'testing';")); client->DiscardAll(); } inline void AddVertex(std::unique_ptr &client) { MG_ASSERT(client->Execute("CREATE ();")); client->DiscardAll(); } inline void AddConnectedVertices(std::unique_ptr &client) { MG_ASSERT(client->Execute("CREATE ()-[:TO]->();")); client->DiscardAll(); } inline void RunQueries(std::unique_ptr &mg_client) { CleanDatabase(mg_client); AddVertex(mg_client); AddVertex(mg_client); AddVertex(mg_client); AddConnectedVertices(mg_client); CleanDatabase(mg_client); } inline void AssertAuthMessage(auto &json_message, const bool success = true) { MG_ASSERT(json_message.at("message").is_string(), "Event is not a string!"); MG_ASSERT(json_message.at("success").is_boolean(), "Success is not a boolean!"); MG_ASSERT(json_message.at("success").template get() == success, "Success does not match expected!"); } inline void AssertLogMessage(const std::string &log_message) { const auto json_message = nlohmann::json::parse(log_message); if (json_message.contains("success")) { spdlog::info("Received auth message: {}", json_message.dump()); AssertAuthMessage(json_message); return; } MG_ASSERT(json_message.at("event").is_string(), "Event is not a string!"); MG_ASSERT(json_message.at("event").get() == "log", "Event is not equal to `log`!"); MG_ASSERT(json_message.at("level").is_string(), "Level is not a string!"); MG_ASSERT(std::ranges::count(kSupportedLogLevels, json_message.at("level")) == 1); MG_ASSERT(json_message.at("message").is_string(), "Message is not a string!"); } template void TestWebsocketWithoutAnyUsers(std::unique_ptr &mg_client) { spdlog::info("Starting websocket connection without any users."); auto websocket_client = TWebsocketClient(); websocket_client.Connect("127.0.0.1", "7444"); RunQueries(mg_client); std::this_thread::sleep_for(std::chrono::seconds(1)); websocket_client.Close(); websocket_client.AwaitClose(); const auto received_messages = websocket_client.GetReceivedMessages(); spdlog::info("Received {} messages.", received_messages.size()); MG_ASSERT(!received_messages.empty(), "There are no received messages!"); std::ranges::for_each(received_messages, AssertLogMessage); spdlog::info("Finishing websocket connection without any users."); } template void TestWebsocketWithAuthentication(std::unique_ptr &mg_client) { spdlog::info("Starting websocket connection with users."); AddUser(mg_client); std::this_thread::sleep_for(std::chrono::seconds(1)); auto websocket_client = TWebsocketClient({"test", "testing"}); websocket_client.Connect("127.0.0.1", "7444"); RunQueries(mg_client); std::this_thread::sleep_for(std::chrono::seconds(1)); websocket_client.Close(); websocket_client.AwaitClose(); const auto received_messages = websocket_client.GetReceivedMessages(); spdlog::info("Received {} messages.", received_messages.size()); MG_ASSERT(!received_messages.empty(), "There are no received messages!"); std::ranges::for_each(received_messages, AssertLogMessage); spdlog::info("Finishing websocket connection with users."); } template void TestWebsocketWithoutBeingAuthorized(std::unique_ptr &mg_client) { spdlog::info("Starting websocket connection with users but without being authenticated."); std::this_thread::sleep_for(std::chrono::seconds(1)); auto websocket_client = TWebsocketClient({"wrong", "credentials"}); websocket_client.Connect("127.0.0.1", "7444"); RunQueries(mg_client); std::this_thread::sleep_for(std::chrono::seconds(1)); websocket_client.Close(); websocket_client.AwaitClose(); const auto received_messages = websocket_client.GetReceivedMessages(); spdlog::info("Received {} messages.", received_messages.size()); MG_ASSERT(received_messages.size() == 1, "There must be only one message received!"); if (!received_messages.empty()) { auto json_message = nlohmann::json::parse(received_messages[0]); AssertAuthMessage(json_message, false); } spdlog::info("Finishing websocket connection with users but without being authenticated."); } template void RunTestCases(std::unique_ptr &mg_client) { TestWebsocketWithoutAnyUsers(mg_client); TestWebsocketWithAuthentication(mg_client); TestWebsocketWithoutBeingAuthorized(mg_client); }