Websocket unit tests (#334)
This commit is contained in:
parent
06e6ead4d2
commit
bd2c30fddc
@ -19,15 +19,24 @@
|
||||
|
||||
namespace communication::websocket {
|
||||
|
||||
class SafeAuth {
|
||||
class AuthenticationInterface {
|
||||
public:
|
||||
virtual bool Authenticate(const std::string &username, const std::string &password) const = 0;
|
||||
|
||||
virtual bool HasUserPermission(const std::string &username, auth::Permission permission) const = 0;
|
||||
|
||||
virtual bool HasAnyUsers() const = 0;
|
||||
};
|
||||
|
||||
class SafeAuth : public AuthenticationInterface {
|
||||
public:
|
||||
explicit SafeAuth(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth) : auth_{auth} {}
|
||||
|
||||
bool Authenticate(const std::string &username, const std::string &password) const;
|
||||
bool Authenticate(const std::string &username, const std::string &password) const override;
|
||||
|
||||
bool HasUserPermission(const std::string &username, auth::Permission permission) const;
|
||||
bool HasUserPermission(const std::string &username, auth::Permission permission) const override;
|
||||
|
||||
bool HasAnyUsers() const;
|
||||
bool HasAnyUsers() const override;
|
||||
|
||||
private:
|
||||
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_;
|
||||
|
@ -27,7 +27,10 @@ void Listener::WriteToAll(std::shared_ptr<std::string> message) {
|
||||
}
|
||||
}
|
||||
|
||||
Listener::Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::endpoint endpoint, SafeAuth auth)
|
||||
boost::asio::ip::tcp::endpoint Listener::GetEndpoint() const { return acceptor_.local_endpoint(); };
|
||||
|
||||
Listener::Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::endpoint endpoint,
|
||||
AuthenticationInterface &auth)
|
||||
: ioc_(ioc), context_(context), acceptor_(ioc), auth_(auth) {
|
||||
boost::beast::error_code ec;
|
||||
|
||||
|
@ -40,9 +40,10 @@ class Listener : public std::enable_shared_from_this<Listener> {
|
||||
// Start accepting incoming connections
|
||||
void Run();
|
||||
void WriteToAll(std::shared_ptr<std::string> message);
|
||||
tcp::endpoint GetEndpoint() const;
|
||||
|
||||
private:
|
||||
Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::endpoint endpoint, SafeAuth auth);
|
||||
Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::endpoint endpoint, AuthenticationInterface &auth);
|
||||
|
||||
void DoAccept();
|
||||
void OnAccept(boost::beast::error_code ec, tcp::socket socket);
|
||||
@ -51,6 +52,6 @@ class Listener : public std::enable_shared_from_this<Listener> {
|
||||
ServerContext *context_;
|
||||
tcp::acceptor acceptor_;
|
||||
utils::Synchronized<std::list<std::shared_ptr<Session>>, utils::SpinLock> sessions_;
|
||||
SafeAuth auth_;
|
||||
AuthenticationInterface &auth_;
|
||||
};
|
||||
} // namespace communication::websocket
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2021 Memgraph Ltd.
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -36,6 +36,8 @@ void Server::AwaitShutdown() {
|
||||
|
||||
bool Server::IsRunning() const { return background_thread_ && !ioc_.stopped(); }
|
||||
|
||||
boost::asio::ip::tcp::endpoint Server::GetEndpoint() const { return listener_->GetEndpoint(); };
|
||||
|
||||
namespace {
|
||||
class QuoteEscapeFormatter : public spdlog::custom_flag_formatter {
|
||||
public:
|
||||
|
@ -28,7 +28,7 @@ class Server final {
|
||||
using tcp = boost::asio::ip::tcp;
|
||||
|
||||
public:
|
||||
explicit Server(io::network::Endpoint endpoint, ServerContext *context, SafeAuth auth)
|
||||
explicit Server(io::network::Endpoint endpoint, ServerContext *context, AuthenticationInterface &auth)
|
||||
: listener_{Listener::Create(
|
||||
ioc_, context, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port}, auth)} {}
|
||||
|
||||
@ -43,6 +43,7 @@ class Server final {
|
||||
void Shutdown();
|
||||
void AwaitShutdown();
|
||||
bool IsRunning() const;
|
||||
tcp::endpoint GetEndpoint() const;
|
||||
|
||||
class LoggingSink : public spdlog::sinks::base_sink<std::mutex> {
|
||||
public:
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <json/json.hpp>
|
||||
|
||||
#include "communication/context.hpp"
|
||||
#include "communication/websocket/auth.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
namespace communication::websocket {
|
||||
@ -42,7 +43,7 @@ std::variant<Session::PlainWebSocket, Session::SSLWebSocket> Session::CreateWebS
|
||||
return Session::PlainWebSocket{std::move(socket)};
|
||||
}
|
||||
|
||||
Session::Session(tcp::socket &&socket, ServerContext &context, SafeAuth auth)
|
||||
Session::Session(tcp::socket &&socket, ServerContext &context, AuthenticationInterface &auth)
|
||||
: ws_(CreateWebSocket(std::move(socket), context)), strand_{boost::asio::make_strand(GetExecutor())}, auth_{auth} {}
|
||||
|
||||
bool Session::Run() {
|
||||
|
@ -50,7 +50,7 @@ class Session : public std::enable_shared_from_this<Session> {
|
||||
using PlainWebSocket = boost::beast::websocket::stream<boost::beast::tcp_stream>;
|
||||
using SSLWebSocket = boost::beast::websocket::stream<boost::beast::ssl_stream<boost::beast::tcp_stream>>;
|
||||
|
||||
explicit Session(tcp::socket &&socket, ServerContext &context, SafeAuth auth);
|
||||
explicit Session(tcp::socket &&socket, ServerContext &context, AuthenticationInterface &auth);
|
||||
|
||||
void DoWrite();
|
||||
void OnWrite(boost::beast::error_code ec, size_t bytest_transferred);
|
||||
@ -61,10 +61,10 @@ class Session : public std::enable_shared_from_this<Session> {
|
||||
void DoClose();
|
||||
void OnClose(boost::beast::error_code ec);
|
||||
|
||||
utils::BasicResult<std::string> Authorize(const nlohmann::json &creds);
|
||||
|
||||
bool IsAuthenticated() const;
|
||||
|
||||
utils::BasicResult<std::string> Authorize(const nlohmann::json &creds);
|
||||
|
||||
void DoShutdown();
|
||||
|
||||
auto GetExecutor() {
|
||||
@ -86,6 +86,6 @@ class Session : public std::enable_shared_from_this<Session> {
|
||||
std::atomic<bool> connected_{false};
|
||||
bool authenticated_{false};
|
||||
bool close_{false};
|
||||
SafeAuth auth_;
|
||||
AuthenticationInterface &auth_;
|
||||
};
|
||||
} // namespace communication::websocket
|
||||
|
@ -1215,8 +1215,8 @@ int main(int argc, char **argv) {
|
||||
[]() -> nlohmann::json { return query::plan::CallProcedure::GetAndResetCounters(); });
|
||||
}
|
||||
|
||||
communication::websocket::Server websocket_server{
|
||||
{"0.0.0.0", 7444}, &context, communication::websocket::SafeAuth{&auth}};
|
||||
communication::websocket::SafeAuth websocket_auth{&auth};
|
||||
communication::websocket::Server websocket_server{{"0.0.0.0", 7444}, &context, websocket_auth};
|
||||
AddLoggerSink(websocket_server.GetLoggingSink());
|
||||
|
||||
// Handler for regular termination signals
|
||||
|
@ -368,3 +368,8 @@ add_custom_target(test_lcp ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/test_lcp)
|
||||
add_test(test_lcp ${CMAKE_CURRENT_BINARY_DIR}/test_lcp)
|
||||
add_dependencies(memgraph__unit test_lcp)
|
||||
|
||||
# Test websocket
|
||||
find_package(Boost REQUIRED)
|
||||
|
||||
add_unit_test(websocket.cpp)
|
||||
target_link_libraries(${test_prefix}websocket mg-communication Boost::headers)
|
||||
|
305
tests/unit/websocket.cpp
Normal file
305
tests/unit/websocket.cpp
Normal file
@ -0,0 +1,305 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt."""
|
||||
|
||||
#define BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <thread>
|
||||
|
||||
#include <fmt/core.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <boost/asio/connect.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/beast/core.hpp>
|
||||
#include <boost/beast/core/buffers_to_string.hpp>
|
||||
#include <boost/beast/websocket.hpp>
|
||||
|
||||
#include "communication/websocket/auth.hpp"
|
||||
#include "communication/websocket/server.hpp"
|
||||
|
||||
namespace beast = boost::beast;
|
||||
namespace http = beast::http;
|
||||
namespace websocket = beast::websocket;
|
||||
namespace net = boost::asio;
|
||||
using tcp = boost::asio::ip::tcp;
|
||||
|
||||
constexpr auto kResponseSuccess{"success"};
|
||||
constexpr auto kResponseMessage{"message"};
|
||||
|
||||
struct MockAuth : public communication::websocket::AuthenticationInterface {
|
||||
MockAuth() = default;
|
||||
|
||||
bool Authenticate(const std::string & /*username*/, const std::string & /*password*/) const override {
|
||||
return authentication;
|
||||
}
|
||||
|
||||
bool HasUserPermission(const std::string & /*username*/, auth::Permission /*permission*/) const override {
|
||||
return authorization;
|
||||
}
|
||||
|
||||
bool HasAnyUsers() const override { return has_any_users; }
|
||||
|
||||
bool authentication{true};
|
||||
bool authorization{true};
|
||||
bool has_any_users{true};
|
||||
};
|
||||
|
||||
class WebSocketServerTest : public ::testing::Test {
|
||||
public:
|
||||
protected:
|
||||
WebSocketServerTest() : websocket_server{{"0.0.0.0", 0}, &context, auth} {
|
||||
EXPECT_NO_THROW(websocket_server.Start());
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
EXPECT_NO_THROW(websocket_server.Shutdown());
|
||||
EXPECT_NO_THROW(websocket_server.AwaitShutdown());
|
||||
}
|
||||
|
||||
std::string ServerPort() const { return std::to_string(websocket_server.GetEndpoint().port()); }
|
||||
|
||||
std::string ServerAddress() const { return websocket_server.GetEndpoint().address().to_string(); }
|
||||
|
||||
MockAuth auth;
|
||||
communication::ServerContext context{};
|
||||
communication::websocket::Server websocket_server;
|
||||
};
|
||||
|
||||
class Client {
|
||||
public:
|
||||
~Client() { ws_.close(websocket::close_code::normal); }
|
||||
|
||||
void Connect(const std::string &host, const std::string &port) {
|
||||
tcp::resolver resolver{ioc_};
|
||||
auto endpoint_ = resolver.resolve(host, port);
|
||||
auto ep = net::connect(ws_.next_layer(), endpoint_);
|
||||
const auto server = fmt::format("{}:{}", host, ep.port());
|
||||
ws_.set_option(websocket::stream_base::decorator([](websocket::request_type &req) {
|
||||
req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-coro");
|
||||
}));
|
||||
|
||||
// Perform the websocket handshake
|
||||
ws_.handshake(host, "/");
|
||||
}
|
||||
|
||||
void Write(const std::string &msg) { ws_.write(net::buffer(msg)); }
|
||||
|
||||
std::string Read() {
|
||||
ws_.read(buffer_);
|
||||
const std::string response = beast::buffers_to_string(buffer_.data());
|
||||
buffer_.consume(buffer_.size());
|
||||
return response;
|
||||
}
|
||||
|
||||
private:
|
||||
net::io_context ioc_{};
|
||||
websocket::stream<tcp::socket> ws_{ioc_};
|
||||
beast::flat_buffer buffer_;
|
||||
};
|
||||
|
||||
TEST(WebSocketServer, WebsocketWorkflow) {
|
||||
/**
|
||||
* Notice how there is no port management for the clients
|
||||
* and the servers, that is because when using "0.0.0.0" as address and
|
||||
* and 0 as port number we delegate port assignment to the OS
|
||||
* and it is the keeper of all available port numbers and
|
||||
* assigns them automatically.
|
||||
*/
|
||||
MockAuth auth{};
|
||||
communication::ServerContext context{};
|
||||
communication::websocket::Server websocket_server({"0.0.0.0", 0}, &context, auth);
|
||||
const auto port = websocket_server.GetEndpoint().port();
|
||||
|
||||
SCOPED_TRACE(fmt::format("Checking port number different then 0: {}", port));
|
||||
EXPECT_NE(port, 0);
|
||||
EXPECT_NO_THROW(websocket_server.Start());
|
||||
EXPECT_TRUE(websocket_server.IsRunning());
|
||||
|
||||
EXPECT_NO_THROW(websocket_server.Shutdown());
|
||||
EXPECT_FALSE(websocket_server.IsRunning());
|
||||
|
||||
EXPECT_NO_THROW(websocket_server.AwaitShutdown());
|
||||
EXPECT_FALSE(websocket_server.IsRunning());
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, WebsocketConnection) {
|
||||
{
|
||||
auto client = Client{};
|
||||
EXPECT_NO_THROW(client.Connect("0.0.0.0", ServerPort()));
|
||||
}
|
||||
|
||||
websocket_server.Shutdown();
|
||||
websocket_server.AwaitShutdown();
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, WebsocketLogging) {
|
||||
auth.has_any_users = false;
|
||||
// Set up the websocket logger as one of the defaults for spdlog
|
||||
{
|
||||
auto default_logger = spdlog::default_logger();
|
||||
auto sinks = default_logger->sinks();
|
||||
sinks.push_back(websocket_server.GetLoggingSink());
|
||||
|
||||
auto logger = std::make_shared<spdlog::logger>("memgraph_log", sinks.begin(), sinks.end());
|
||||
logger->set_level(default_logger->level());
|
||||
logger->flush_on(spdlog::level::trace);
|
||||
spdlog::set_default_logger(std::move(logger));
|
||||
}
|
||||
{
|
||||
auto client = Client();
|
||||
client.Connect(ServerAddress(), ServerPort());
|
||||
|
||||
auto log_message = [](spdlog::level::level_enum level, std::string_view message) {
|
||||
spdlog::log(level, message);
|
||||
spdlog::default_logger()->flush();
|
||||
};
|
||||
auto log_and_check = [log_message, &client](spdlog::level::level_enum level, std::string_view message,
|
||||
std::string_view log_level_received) {
|
||||
std::thread(log_message, level, message).detach();
|
||||
const auto received_message = client.Read();
|
||||
EXPECT_EQ(received_message, fmt::format("{{\"event\": \"log\", \"level\": \"{}\", \"message\": \"{}\"}}\n",
|
||||
log_level_received, message));
|
||||
};
|
||||
|
||||
log_and_check(spdlog::level::err, "Sending error message!", "error");
|
||||
log_and_check(spdlog::level::warn, "Sending warn message!", "warning");
|
||||
log_and_check(spdlog::level::info, "Sending info message!", "info");
|
||||
log_and_check(spdlog::level::trace, "Sending trace message!", "trace");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, WebsocketAuthenticationParsingError) {
|
||||
constexpr auto auth_fail = "Cannot parse JSON for WebSocket authentication";
|
||||
|
||||
{
|
||||
SCOPED_TRACE("Checking handling of first request parsing error.");
|
||||
auto client = Client();
|
||||
EXPECT_NO_THROW(client.Connect(ServerAddress(), ServerPort()));
|
||||
EXPECT_NO_THROW(client.Write("Test"));
|
||||
const auto response = nlohmann::json::parse(client.Read());
|
||||
const auto message_header = response[kResponseMessage].get<std::string>();
|
||||
const auto message_first_part = message_header.substr(0, message_header.find(": "));
|
||||
|
||||
EXPECT_FALSE(response[kResponseSuccess]);
|
||||
EXPECT_EQ(message_first_part, auth_fail);
|
||||
}
|
||||
{
|
||||
SCOPED_TRACE("Checking handling of JSON parsing error.");
|
||||
auto client = Client();
|
||||
EXPECT_NO_THROW(client.Connect(ServerAddress(), ServerPort()));
|
||||
const std::string json_without_comma = R"({"username": "user" "password": "123"})";
|
||||
EXPECT_NO_THROW(client.Write(json_without_comma));
|
||||
const auto response = nlohmann::json::parse(client.Read());
|
||||
const auto message_header = response[kResponseMessage].get<std::string>();
|
||||
const auto message_first_part = message_header.substr(0, message_header.find(": "));
|
||||
|
||||
EXPECT_FALSE(response[kResponseSuccess]);
|
||||
EXPECT_EQ(message_first_part, auth_fail);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, WebsocketAuthenticationWhenAuthPasses) {
|
||||
constexpr auto auth_success = R"({"message":"User has been successfully authenticated!","success":true})";
|
||||
|
||||
{
|
||||
SCOPED_TRACE("Checking successful authentication response.");
|
||||
auto client = Client();
|
||||
EXPECT_NO_THROW(client.Connect(ServerAddress(), ServerPort()));
|
||||
EXPECT_NO_THROW(client.Write(R"({"username": "user", "password": "123"})"));
|
||||
const auto response = client.Read();
|
||||
|
||||
EXPECT_EQ(response, auth_success);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, WebsocketAuthenticationWithMultipleAttempts) {
|
||||
constexpr auto auth_success = R"({"message":"User has been successfully authenticated!","success":true})";
|
||||
constexpr auto auth_fail = "Cannot parse JSON for WebSocket authentication";
|
||||
|
||||
{
|
||||
SCOPED_TRACE("Checking multiple authentication tries from same client");
|
||||
auto client = Client();
|
||||
EXPECT_NO_THROW(client.Connect(ServerAddress(), ServerPort()));
|
||||
EXPECT_NO_THROW(client.Write(R"({"username": "user" "password": "123"})"));
|
||||
|
||||
{
|
||||
const auto response = nlohmann::json::parse(client.Read());
|
||||
const auto message_header = response[kResponseMessage].get<std::string>();
|
||||
const auto message_first_part = message_header.substr(0, message_header.find(": "));
|
||||
|
||||
EXPECT_FALSE(response[kResponseSuccess]);
|
||||
EXPECT_EQ(message_first_part, auth_fail);
|
||||
}
|
||||
{
|
||||
EXPECT_NO_THROW(client.Connect(ServerAddress(), ServerPort()));
|
||||
EXPECT_NO_THROW(client.Write(R"({"username": "user", "password": "123"})"));
|
||||
const auto response = client.Read();
|
||||
EXPECT_EQ(response, auth_success);
|
||||
}
|
||||
}
|
||||
{
|
||||
SCOPED_TRACE("Checking multiple authentication tries from different clients");
|
||||
auto client1 = Client();
|
||||
auto client2 = Client();
|
||||
|
||||
EXPECT_NO_THROW(client1.Connect(ServerAddress(), ServerPort()));
|
||||
EXPECT_NO_THROW(client2.Connect(ServerAddress(), ServerPort()));
|
||||
|
||||
EXPECT_NO_THROW(client1.Write(R"({"username": "user" "password": "123"})"));
|
||||
EXPECT_NO_THROW(client2.Write(R"({"username": "user", "password": "123"})"));
|
||||
|
||||
{
|
||||
const auto response = nlohmann::json::parse(client1.Read());
|
||||
const auto message_header = response[kResponseMessage].get<std::string>();
|
||||
const auto message_first_part = message_header.substr(0, message_header.find(": "));
|
||||
|
||||
EXPECT_FALSE(response[kResponseSuccess]);
|
||||
EXPECT_EQ(message_first_part, auth_fail);
|
||||
}
|
||||
{
|
||||
const auto response = client2.Read();
|
||||
EXPECT_EQ(response, auth_success);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketServerTest, WebsocketAuthenticationFails) {
|
||||
auth.authentication = false;
|
||||
|
||||
constexpr auto auth_fail = R"({"message":"Authentication failed!","success":false})";
|
||||
{
|
||||
auto client = Client();
|
||||
EXPECT_NO_THROW(client.Connect(ServerAddress(), ServerPort()));
|
||||
EXPECT_NO_THROW(client.Write(R"({"username": "user", "password": "123"})"));
|
||||
|
||||
const auto response = client.Read();
|
||||
EXPECT_EQ(response, auth_fail);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
TEST_F(WebSocketServerTest, WebsocketAuthorizationFails) {
|
||||
auth.authorization = false;
|
||||
constexpr auto auth_fail = R"({"message":"Authorization failed!","success":false})";
|
||||
|
||||
{
|
||||
auto client = Client();
|
||||
EXPECT_NO_THROW(client.Connect(ServerAddress(), ServerPort()));
|
||||
EXPECT_NO_THROW(client.Write(R"({"username": "user", "password": "123"})"));
|
||||
|
||||
const auto response = client.Read();
|
||||
EXPECT_EQ(response, auth_fail);
|
||||
}
|
||||
}
|
||||
#endif
|
Loading…
Reference in New Issue
Block a user