diff --git a/src/auth/CMakeLists.txt b/src/auth/CMakeLists.txt index f70d60d36..5e38fe927 100644 --- a/src/auth/CMakeLists.txt +++ b/src/auth/CMakeLists.txt @@ -11,7 +11,7 @@ find_package(gflags REQUIRED) add_library(mg-auth STATIC ${auth_src_files}) target_link_libraries(mg-auth json libbcrypt gflags fmt::fmt) -target_link_libraries(mg-auth mg-utils mg-kvstore) +target_link_libraries(mg-auth mg-utils mg-kvstore mg-license ) target_link_libraries(mg-auth ${Seccomp_LIBRARIES}) target_include_directories(mg-auth SYSTEM PRIVATE ${Seccomp_INCLUDE_DIRS}) diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 5eaa042a6..4e706c8b6 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -30,6 +30,14 @@ DEFINE_string(auth_password_strength_regex, default_password_regex.data(), namespace auth { +// Constant list of all available permissions. +constexpr std::array kPermissionsAll = { + Permission::MATCH, Permission::CREATE, Permission::MERGE, Permission::DELETE, + Permission::SET, Permission::REMOVE, Permission::INDEX, Permission::STATS, + Permission::CONSTRAINT, Permission::DUMP, Permission::AUTH, Permission::REPLICATION, + Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, Permission::TRIGGER, + Permission::CONFIG, Permission::STREAM, Permission::WEBSOCKET}; + std::string PermissionToString(Permission permission) { switch (permission) { case Permission::MATCH: @@ -72,6 +80,8 @@ std::string PermissionToString(Permission permission) { return "MODULE_READ"; case Permission::MODULE_WRITE: return "MODULE_WRITE"; + case Permission::WEBSOCKET: + return "WEBSOCKET"; } } diff --git a/src/auth/models.hpp b/src/auth/models.hpp index e4affc236..d23dceb2c 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -38,7 +38,8 @@ enum class Permission : uint64_t { AUTH = 1U << 16U, STREAM = 1U << 17U, MODULE_READ = 1U << 18U, - MODULE_WRITE = 1U << 19U + MODULE_WRITE = 1U << 19U, + WEBSOCKET = 1U << 20U }; // clang-format on @@ -48,7 +49,8 @@ const std::vector kPermissionsAll = { Permission::SET, Permission::REMOVE, Permission::INDEX, Permission::STATS, Permission::CONSTRAINT, Permission::DUMP, Permission::AUTH, Permission::REPLICATION, Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, Permission::TRIGGER, - Permission::CONFIG, Permission::STREAM, Permission::MODULE_READ, Permission::MODULE_WRITE}; + Permission::CONFIG, Permission::STREAM, Permission::MODULE_READ, Permission::MODULE_WRITE, + Permission::WEBSOCKET}; // Function that converts a permission to its string representation. std::string PermissionToString(Permission permission); diff --git a/src/communication/CMakeLists.txt b/src/communication/CMakeLists.txt index b8e64e015..ad316aad0 100644 --- a/src/communication/CMakeLists.txt +++ b/src/communication/CMakeLists.txt @@ -2,6 +2,7 @@ find_package(fmt REQUIRED) find_package(gflags REQUIRED) set(communication_src_files + websocket/auth.cpp websocket/server.cpp websocket/listener.cpp websocket/session.cpp @@ -15,7 +16,7 @@ set(communication_src_files find_package(Boost REQUIRED) add_library(mg-communication STATIC ${communication_src_files}) -target_link_libraries(mg-communication Boost::headers Threads::Threads mg-utils mg-io fmt::fmt gflags) +target_link_libraries(mg-communication Boost::headers Threads::Threads mg-utils mg-io mg-auth fmt::fmt gflags) find_package(OpenSSL REQUIRED) target_link_libraries(mg-communication ${OPENSSL_LIBRARIES}) diff --git a/src/communication/websocket/auth.cpp b/src/communication/websocket/auth.cpp new file mode 100644 index 000000000..cee3cad2d --- /dev/null +++ b/src/communication/websocket/auth.cpp @@ -0,0 +1,30 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "communication/websocket/auth.hpp" + +#include + +namespace communication::websocket { + +bool SafeAuth::Authenticate(const std::string &username, const std::string &password) const { + return auth_->Lock()->Authenticate(username, password).has_value(); +} + +bool SafeAuth::HasUserPermission(const std::string &username, const auth::Permission permission) const { + if (const auto user = auth_->Lock()->GetUser(username); user) { + return user->GetPermissions().Has(permission) == auth::PermissionLevel::GRANT; + } + return false; +} + +bool SafeAuth::HasAnyUsers() const { return auth_->ReadLock()->HasUsers(); } +} // namespace communication::websocket \ No newline at end of file diff --git a/src/communication/websocket/auth.hpp b/src/communication/websocket/auth.hpp new file mode 100644 index 000000000..8c1dade44 --- /dev/null +++ b/src/communication/websocket/auth.hpp @@ -0,0 +1,35 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include + +#include "auth/auth.hpp" +#include "utils/spin_lock.hpp" +#include "utils/synchronized.hpp" + +namespace communication::websocket { + +class SafeAuth { + public: + explicit SafeAuth(utils::Synchronized *auth) : auth_{auth} {} + + bool Authenticate(const std::string &username, const std::string &password) const; + + bool HasUserPermission(const std::string &username, auth::Permission permission) const; + + bool HasAnyUsers() const; + + private: + utils::Synchronized *auth_; +}; +} // namespace communication::websocket \ No newline at end of file diff --git a/src/communication/websocket/listener.cpp b/src/communication/websocket/listener.cpp index d4888c849..388621b19 100644 --- a/src/communication/websocket/listener.cpp +++ b/src/communication/websocket/listener.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,6 +17,7 @@ void LogError(boost::beast::error_code ec, const std::string_view what) { spdlog::warn("Websocket listener failed on {}: {}", what, ec.message()); } } // namespace + void Listener::Run() { DoAccept(); } void Listener::WriteToAll(std::shared_ptr message) { @@ -26,7 +27,8 @@ void Listener::WriteToAll(std::shared_ptr message) { } } -Listener::Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint) : ioc_(ioc), acceptor_(ioc) { +Listener::Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint, SafeAuth auth) + : ioc_(ioc), acceptor_(ioc), auth_(auth) { boost::beast::error_code ec; // Open the acceptor @@ -71,7 +73,7 @@ void Listener::OnAccept(boost::beast::error_code ec, tcp::socket socket) { { auto sessions_ptr = sessions_.Lock(); - sessions_ptr->emplace_back(Session::Create(std::move(socket)))->Run(); + sessions_ptr->emplace_back(Session::Create(std::move(socket), auth_))->Run(); // Clean disconnected clients std::erase_if(*sessions_ptr, [](const auto &elem) { return !elem->IsConnected(); }); diff --git a/src/communication/websocket/listener.hpp b/src/communication/websocket/listener.hpp index 92e80bd8d..e4b15bddf 100644 --- a/src/communication/websocket/listener.hpp +++ b/src/communication/websocket/listener.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -41,7 +41,7 @@ class Listener : public std::enable_shared_from_this { void WriteToAll(std::shared_ptr message); private: - Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint); + Listener(boost::asio::io_context &ioc, tcp::endpoint endpoint, SafeAuth auth); void DoAccept(); void OnAccept(boost::beast::error_code ec, tcp::socket socket); @@ -49,5 +49,6 @@ class Listener : public std::enable_shared_from_this { boost::asio::io_context &ioc_; tcp::acceptor acceptor_; utils::Synchronized>, utils::SpinLock> sessions_; + SafeAuth auth_; }; } // namespace communication::websocket diff --git a/src/communication/websocket/server.hpp b/src/communication/websocket/server.hpp index bc5a2ad04..81713cbe3 100644 --- a/src/communication/websocket/server.hpp +++ b/src/communication/websocket/server.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -28,10 +28,10 @@ class Server final { using tcp = boost::asio::ip::tcp; public: - explicit Server(io::network::Endpoint endpoint) + explicit Server(io::network::Endpoint endpoint, SafeAuth auth) : ioc_{}, - listener_{ - Listener::Create(ioc_, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port})} {} + listener_{Listener::Create(ioc_, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port}, + auth)} {} Server(const Server &) = delete; Server(Server &&) = delete; diff --git a/src/communication/websocket/session.cpp b/src/communication/websocket/session.cpp index ca9aadec3..fc56688eb 100644 --- a/src/communication/websocket/session.cpp +++ b/src/communication/websocket/session.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,8 +11,17 @@ #include "communication/websocket/session.hpp" -#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "communication/context.hpp" #include "utils/logging.hpp" namespace communication::websocket { @@ -41,12 +50,14 @@ void Session::Run() { } void Session::Write(std::shared_ptr message) { - if (!connected_.load(std::memory_order_relaxed)) { - return; - } boost::asio::dispatch(strand_, [message = std::move(message), shared_this = shared_from_this()]() mutable { + if (!shared_this->connected_.load(std::memory_order_relaxed)) { + return; + } + if (!shared_this->authenticated_) { + return; + } shared_this->messages_.push_back(std::move(message)); - if (shared_this->messages_.size() > 1) { return; } @@ -72,7 +83,10 @@ void Session::OnWrite(boost::beast::error_code ec, size_t /*bytes_transferred*/) if (ec) { return LogError(ec, "write"); } - + if (close_) { + DoClose(); + return; + } if (!messages_.empty()) { DoWrite(); } @@ -86,14 +100,76 @@ void Session::DoRead() { })); } -void Session::OnRead(boost::beast::error_code ec, size_t /*bytes_transferred*/) { +void Session::DoClose() { + ws_.async_close(boost::beast::websocket::close_code::normal, + boost::asio::bind_executor(strand_, [shared_this = shared_from_this()](boost::beast::error_code ec) { + shared_this->OnClose(ec); + })); +} + +void Session::OnClose(boost::beast::error_code ec) { + if (ec) { + return LogError(ec, "close"); + } + connected_.store(false, std::memory_order_relaxed); +} + +utils::BasicResult Session::Authorize(const nlohmann::json &creds) { + if (!auth_.Authenticate(creds.at("username").get(), creds.at("password").get())) { + return {"Authentication failed!"}; + } +#ifdef MG_ENTERPRISE + if (auth_.HasUserPermission(creds.at("username").get(), auth::Permission::WEBSOCKET)) { + return {"Authorization failed!"}; + } +#endif + return {}; +} + +void Session::OnRead(const boost::beast::error_code ec, const size_t /*bytes_transferred*/) { if (ec == boost::beast::websocket::error::closed) { messages_.clear(); connected_.store(false, std::memory_order_relaxed); return; } - buffer_.consume(buffer_.size()); + if (!authenticated_ && auth_.HasAnyUsers()) { + auto response = nlohmann::json(); + auto auth_failed = [this, &response](const std::string &message) { + response["success"] = false; + response["message"] = message; + MG_ASSERT(messages_.empty()); + messages_.push_back(make_shared(response.dump())); + close_ = true; + DoWrite(); + }; + try { + const auto creds = nlohmann::json::parse(boost::beast::buffers_to_string(buffer_.data())); + buffer_.consume(buffer_.size()); + + if (const auto result = Authorize(creds); result.HasError()) { + std::invoke(auth_failed, result.GetError()); + return; + } + response["success"] = true; + response["message"] = "User has been successfully authenticated!"; + MG_ASSERT(messages_.empty()); + messages_.push_back(make_shared(response.dump())); + DoWrite(); + authenticated_ = true; + } catch (const nlohmann::json::out_of_range &out_of_range) { + const auto err_msg = fmt::format("Invalid JSON for authentication received: {}!", out_of_range.what()); + spdlog::error(err_msg); + std::invoke(auth_failed, err_msg); + return; + } catch (const nlohmann::json::parse_error &parse_error) { + const auto err_msg = fmt::format("Cannot parse JSON for WebSocket authentication: {}!", parse_error.what()); + spdlog::error(err_msg); + std::invoke(auth_failed, err_msg); + return; + } + } + DoRead(); } diff --git a/src/communication/websocket/session.hpp b/src/communication/websocket/session.hpp index d588eae1c..de1869c21 100644 --- a/src/communication/websocket/session.hpp +++ b/src/communication/websocket/session.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -21,6 +21,11 @@ #include #include #include +#include + +#include "communication/websocket/auth.hpp" +#include "utils/result.hpp" +#include "utils/synchronized.hpp" namespace communication::websocket { class Session : public std::enable_shared_from_this { @@ -37,18 +42,27 @@ class Session : public std::enable_shared_from_this { bool IsConnected() const; private: - explicit Session(tcp::socket &&socket) - : ws_(std::move(socket)), strand_{boost::asio::make_strand(ws_.get_executor())} {} + explicit Session(tcp::socket &&socket, SafeAuth auth) + : ws_(std::move(socket)), strand_{boost::asio::make_strand(ws_.get_executor())}, auth_(auth) {} void DoWrite(); void OnWrite(boost::beast::error_code ec, size_t bytest_transferred); + void DoRead(); void OnRead(boost::beast::error_code ec, size_t bytest_transferred); + void DoClose(); + void OnClose(boost::beast::error_code ec); + + utils::BasicResult Authorize(const nlohmann::json &creds); + boost::beast::websocket::stream ws_; boost::beast::flat_buffer buffer_; std::deque> messages_; boost::asio::strand strand_; std::atomic connected_{false}; + bool authenticated_{false}; + bool close_{false}; + SafeAuth auth_; }; } // namespace communication::websocket diff --git a/src/glue/auth.cpp b/src/glue/auth.cpp index 990877cd2..5e8c203a0 100644 --- a/src/glue/auth.cpp +++ b/src/glue/auth.cpp @@ -55,6 +55,8 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) { return auth::Permission::MODULE_READ; case query::AuthQuery::Privilege::MODULE_WRITE: return auth::Permission::MODULE_WRITE; + case query::AuthQuery::Privilege::WEBSOCKET: + return auth::Permission::WEBSOCKET; } } } // namespace glue diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 2a4b5eae4..ba992faf6 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -32,6 +32,7 @@ #include #include +#include "communication/websocket/auth.hpp" #include "communication/websocket/server.hpp" #include "communication/bolt/v1/constants.hpp" @@ -1204,7 +1205,7 @@ int main(int argc, char **argv) { []() -> nlohmann::json { return query::plan::CallProcedure::GetAndResetCounters(); }); } - communication::websocket::Server websocket_server{{"0.0.0.0", 7444}}; + communication::websocket::Server websocket_server{{"0.0.0.0", 7444}, communication::websocket::SafeAuth{&auth}}; { auto sinks = spdlog::default_logger()->sinks(); diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index f50aba5d4..7fd575cf8 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2247,7 +2247,8 @@ cpp<# (:serialize)) (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint - dump replication durability read_file free_memory trigger config stream module_read module_write) + dump replication durability read_file free_memory trigger config stream module_read module_write + websocket) (:serialize)) #>cpp AuthQuery() = default; @@ -2288,7 +2289,11 @@ const std::vector kPrivilegesAll = { AuthQuery::Privilege::DURABILITY, AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, +<<<<<<< HEAD AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE}; +======= + AuthQuery::Privilege::WEBSOCKET}; +>>>>>>> e15495b7 (Add websocket authentication (#322)) cpp<# (lcp:define-class info-query (query) diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 6ac33cd25..af7abef35 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1332,6 +1332,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext if (ctx->STREAM()) return AuthQuery::Privilege::STREAM; if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; + if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; LOG_FATAL("Should not get here - unknown privilege!"); } diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 02de4493c..e4c13c342 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -241,6 +241,7 @@ privilege : CREATE | STREAM | MODULE_READ | MODULE_WRITE + | WEBSOCKET ; privilegeList : privilege ( ',' privilege )* ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index a45651c66..7e95c295f 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -112,3 +112,4 @@ UPDATE : U P D A T E ; USER : U S E R ; USERS : U S E R S ; VERSION : V E R S I O N ; +WEBSOCKET : W E B S O C K E T ; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index 39b765971..f69107c03 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -203,7 +203,8 @@ const trie::Trie kKeywords = {"union", "kafka", "pulsar", "service_url", - "version"}; + "version", + "websocket"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset kUnescapedNameAllowedStarts( diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 9af86aa3a..e1aff9a19 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -2172,6 +2172,8 @@ TEST_P(CypherMainVisitorTest, GrantPrivilege) { {AuthQuery::Privilege::CONFIG}); check_auth_query(&ast_generator, "GRANT STREAM TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::STREAM}); + check_auth_query(&ast_generator, "GRANT WEBSOCKET TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::WEBSOCKET}); } TEST_P(CypherMainVisitorTest, DenyPrivilege) { @@ -2206,6 +2208,8 @@ TEST_P(CypherMainVisitorTest, DenyPrivilege) { {AuthQuery::Privilege::CONSTRAINT}); check_auth_query(&ast_generator, "DENY DUMP TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::DUMP}); + check_auth_query(&ast_generator, "DENY WEBSOCKET TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::WEBSOCKET}); } TEST_P(CypherMainVisitorTest, RevokePrivilege) { @@ -2242,6 +2246,8 @@ TEST_P(CypherMainVisitorTest, RevokePrivilege) { {}, {AuthQuery::Privilege::CONSTRAINT}); check_auth_query(&ast_generator, "REVOKE DUMP FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::DUMP}); + check_auth_query(&ast_generator, "REVOKE WEBSOCKET FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::WEBSOCKET}); } TEST_P(CypherMainVisitorTest, ShowPrivileges) {