Add multi-tenancy v1 (#952)

* Decouple BoltSession and communication::bolt::Session
* Add CREATE/USE/DROP DATABASE
* Add SHOW DATABASES
* Cover WebSocket session
* Simple session safety implemented via RWLock
* Storage symlinks for backward. compatibility
* Extend the audit log with the DB info
* Add auth part
* Add tenant recovery
This commit is contained in:
andrejtonev 2023-08-01 18:49:11 +02:00 committed by GitHub
parent fd819cd099
commit e8850549d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
114 changed files with 5927 additions and 1015 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -116,12 +116,12 @@ Log::~Log() {
}
void Log::Record(const std::string &address, const std::string &username, const std::string &query,
const storage::PropertyValue &params) {
const storage::PropertyValue &params, const std::string &db) {
if (!started_.load(std::memory_order_relaxed)) return;
auto timestamp =
std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch())
.count();
buffer_->emplace(Item{timestamp, address, username, query, params});
buffer_->emplace(Item{timestamp, address, username, query, params, db});
}
void Log::ReopenLog() {
@ -136,8 +136,8 @@ void Log::Flush() {
for (uint64_t i = 0; i < buffer_size_; ++i) {
auto item = buffer_->pop();
if (!item) break;
log_.Write(fmt::format("{}.{:06d},{},{},{},{}\n", item->timestamp / 1000000, item->timestamp % 1000000,
item->address, item->username, utils::Escape(item->query),
log_.Write(fmt::format("{}.{:06d},{},{},{},{},{}\n", item->timestamp / 1000000, item->timestamp % 1000000,
item->address, item->username, item->db, utils::Escape(item->query),
utils::Escape(PropertyValueToJson(item->params).dump())));
}
log_.Sync();

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -32,6 +32,7 @@ class Log {
std::string username;
std::string query;
storage::PropertyValue params;
std::string db;
};
public:
@ -51,7 +52,7 @@ class Log {
/// Adds an entry to the audit log. Thread-safe.
void Record(const std::string &address, const std::string &username, const std::string &query,
const storage::PropertyValue &params);
const storage::PropertyValue &params, const std::string &db);
/// Reopens the log file. Used for log file rotation. Thread-safe.
void ReopenLog();

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -314,4 +314,57 @@ std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig)
return ret;
}
#ifdef MG_ENTERPRISE
bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name) {
auto user = GetUser(name);
if (user) {
if (db == kAllDatabases) {
user->db_access().GrantAll();
} else {
user->db_access().Add(db);
}
SaveUser(*user);
return true;
}
return false;
}
bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name) {
auto user = GetUser(name);
if (user) {
if (db == kAllDatabases) {
user->db_access().DenyAll();
} else {
user->db_access().Remove(db);
}
SaveUser(*user);
return true;
}
return false;
}
void Auth::DeleteDatabase(const std::string &db) {
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
auto user = GetUser(username);
if (user) {
user->db_access().Delete(db);
SaveUser(*user);
}
}
}
bool Auth::SetMainDatabase(const std::string &db, const std::string &name) {
auto user = GetUser(name);
if (user) {
if (!user->db_access().SetDefault(db)) {
throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name);
}
SaveUser(*user);
return true;
}
return false;
}
#endif
} // namespace memgraph::auth

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -19,6 +19,9 @@
#include "utils/settings.hpp"
namespace memgraph::auth {
static const constexpr char *const kAllDatabases = "*";
/**
* This class serves as the main Authentication/Authorization storage.
* It provides functions for managing Users, Roles, Permissions and FineGrainedAccessPermissions.
@ -155,6 +158,46 @@ class Auth final {
*/
std::vector<User> AllUsersForRole(const std::string &rolename) const;
#ifdef MG_ENTERPRISE
/**
* @brief Revoke access to individual database for a user.
*
* @param db name of the database to revoke
* @param name user's username
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool RevokeDatabaseFromUser(const std::string &db, const std::string &name);
/**
* @brief Grant access to individual database for a user.
*
* @param db name of the database to revoke
* @param name user's username
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool GrantDatabaseToUser(const std::string &db, const std::string &name);
/**
* @brief Delete a database from all users.
*
* @param db name of the database to delete
* @throw AuthException if unable to read data
*/
void DeleteDatabase(const std::string &db);
/**
* @brief Set main database for an individual user.
*
* @param db name of the database to revoke
* @param name user's username
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool SetMainDatabase(const std::string &db, const std::string &name);
#endif
private:
// Even though the `kvstore::KVStore` class is guaranteed to be thread-safe,
// Auth is not thread-safe because modifying users and roles might require

View File

@ -15,8 +15,10 @@
#include "auth/crypto.hpp"
#include "auth/exceptions.hpp"
#include "dbms/constants.hpp"
#include "license/license.hpp"
#include "query/constants.hpp"
#include "spdlog/spdlog.h"
#include "utils/cast.hpp"
#include "utils/logging.hpp"
#include "utils/settings.hpp"
@ -35,18 +37,31 @@ namespace memgraph::auth {
namespace {
// Constant list of all available permissions.
const std::vector<Permission> kPermissionsAll = {Permission::MATCH, Permission::CREATE,
Permission::MERGE, Permission::DELETE,
Permission::SET, Permission::REMOVE,
Permission::INDEX, Permission::STATS,
Permission::CONSTRAINT, Permission::DUMP,
Permission::AUTH, Permission::REPLICATION,
Permission::DURABILITY, Permission::READ_FILE,
Permission::FREE_MEMORY, Permission::TRIGGER,
Permission::CONFIG, Permission::STREAM,
Permission::MODULE_READ, Permission::MODULE_WRITE,
Permission::WEBSOCKET, Permission::TRANSACTION_MANAGEMENT,
Permission::STORAGE_MODE};
const std::vector<Permission> kPermissionsAll = {Permission::MATCH,
Permission::CREATE,
Permission::MERGE,
Permission::DELETE,
Permission::SET,
Permission::REMOVE,
Permission::INDEX,
Permission::STATS,
Permission::CONSTRAINT,
Permission::DUMP,
Permission::AUTH,
Permission::REPLICATION,
Permission::DURABILITY,
Permission::READ_FILE,
Permission::FREE_MEMORY,
Permission::TRIGGER,
Permission::CONFIG,
Permission::STREAM,
Permission::MODULE_READ,
Permission::MODULE_WRITE,
Permission::WEBSOCKET,
Permission::TRANSACTION_MANAGEMENT,
Permission::STORAGE_MODE,
Permission::MULTI_DATABASE_EDIT,
Permission::MULTI_DATABASE_USE};
} // namespace
@ -98,6 +113,10 @@ std::string PermissionToString(Permission permission) {
return "TRANSACTION_MANAGEMENT";
case Permission::STORAGE_MODE:
return "STORAGE_MODE";
case Permission::MULTI_DATABASE_EDIT:
return "MULTI_DATABASE_EDIT";
case Permission::MULTI_DATABASE_USE:
return "MULTI_DATABASE_USE";
}
}
@ -464,6 +483,82 @@ bool operator==(const Role &first, const Role &second) {
return first.rolename_ == second.rolename_ && first.permissions_ == second.permissions_;
}
#ifdef MG_ENTERPRISE
void Databases::Add(const std::string &db) {
if (allow_all_) {
grants_dbs_.clear();
allow_all_ = false;
}
grants_dbs_.emplace(db);
denies_dbs_.erase(db);
}
void Databases::Remove(const std::string &db) {
denies_dbs_.emplace(db);
grants_dbs_.erase(db);
}
void Databases::Delete(const std::string &db) {
denies_dbs_.erase(db);
if (!allow_all_) {
grants_dbs_.erase(db);
}
// Reset if default deleted
if (default_db_ == db) {
default_db_ = "";
}
}
void Databases::GrantAll() {
allow_all_ = true;
grants_dbs_.clear();
denies_dbs_.clear();
}
void Databases::DenyAll() {
allow_all_ = false;
grants_dbs_.clear();
denies_dbs_.clear();
}
bool Databases::SetDefault(const std::string &db) {
if (!Contains(db)) return false;
default_db_ = db;
return true;
}
[[nodiscard]] bool Databases::Contains(const std::string &db) const {
return !denies_dbs_.contains(db) && (allow_all_ || grants_dbs_.contains(db));
}
const std::string &Databases::GetDefault() const {
if (!Contains(default_db_)) {
throw AuthException("No access to the set default database \"{}\".", default_db_);
}
return default_db_;
}
nlohmann::json Databases::Serialize() const {
nlohmann::json data = nlohmann::json::object();
data["grants"] = grants_dbs_;
data["denies"] = denies_dbs_;
data["allow_all"] = allow_all_;
data["default"] = default_db_;
return data;
}
Databases Databases::Deserialize(const nlohmann::json &data) {
if (!data.is_object()) {
throw AuthException("Couldn't load database data!");
}
if (!data["grants"].is_structured() || !data["denies"].is_structured() || !data["allow_all"].is_boolean() ||
!data["default"].is_string()) {
throw AuthException("Couldn't load database data!");
}
return {data["allow_all"], data["grants"], data["denies"], data["default"]};
}
#endif
User::User() {}
User::User(const std::string &username) : username_(utils::ToLowerCase(username)) {}
@ -472,11 +567,12 @@ User::User(const std::string &username, const std::string &password_hash, const
#ifdef MG_ENTERPRISE
User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions,
FineGrainedAccessHandler fine_grained_access_handler)
FineGrainedAccessHandler fine_grained_access_handler, Databases db_access)
: username_(utils::ToLowerCase(username)),
password_hash_(password_hash),
permissions_(permissions),
fine_grained_access_handler_(std::move(fine_grained_access_handler)) {}
fine_grained_access_handler_(std::move(fine_grained_access_handler)),
database_access_(db_access) {}
#endif
bool User::CheckPassword(const std::string &password) {
@ -576,8 +672,10 @@ nlohmann::json User::Serialize() const {
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
data["fine_grained_access_handler"] = fine_grained_access_handler_.Serialize();
data["databases"] = database_access_.Serialize();
} else {
data["fine_grained_access_handler"] = {};
data["databases"] = {};
}
#endif
// The role shouldn't be serialized here, it is stored as a foreign key.
@ -594,11 +692,20 @@ User User::Deserialize(const nlohmann::json &data) {
auto permissions = Permissions::Deserialize(data["permissions"]);
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
Databases db_access;
if (data["databases"].is_structured()) {
db_access = Databases::Deserialize(data["databases"]);
} else {
// Back-compatibility
spdlog::warn("User without specified database access. Given access to the default database.");
db_access.Add(dbms::kDefaultDB);
db_access.SetDefault(dbms::kDefaultDB);
}
if (!data["fine_grained_access_handler"].is_object()) {
throw AuthException("Couldn't load user data!");
}
auto fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data["fine_grained_access_handler"]);
return {data["username"], data["password_hash"], permissions, fine_grained_access_handler};
return {data["username"], data["password_hash"], permissions, fine_grained_access_handler, db_access};
}
#endif
return {data["username"], data["password_hash"], permissions};

View File

@ -9,10 +9,13 @@
#pragma once
#include <optional>
#include <set>
#include <string>
#include <unordered_map>
#include <json/json.hpp>
#include "dbms/constants.hpp"
#include "utils/logging.hpp"
namespace memgraph::auth {
// These permissions must have values that are applicable for usage in a
@ -41,7 +44,9 @@ enum class Permission : uint64_t {
MODULE_WRITE = 1U << 19U,
WEBSOCKET = 1U << 20U,
TRANSACTION_MANAGEMENT = 1U << 21U,
STORAGE_MODE = 1U << 22U
STORAGE_MODE = 1U << 22U,
MULTI_DATABASE_EDIT = 1U << 23U,
MULTI_DATABASE_USE = 1U << 24U,
};
// clang-format on
@ -237,6 +242,85 @@ class Role final {
bool operator==(const Role &first, const Role &second);
#ifdef MG_ENTERPRISE
class Databases final {
public:
Databases() : grants_dbs_({dbms::kDefaultDB}), allow_all_(false), default_db_(dbms::kDefaultDB) {}
Databases(const Databases &) = default;
Databases &operator=(const Databases &) = default;
Databases(Databases &&) noexcept = default;
Databases &operator=(Databases &&) noexcept = default;
~Databases() = default;
/**
* @brief Add database to the list of granted access. @note allow_all_ will be false after execution
*
* @param db name of the database to grant access to
*/
void Add(const std::string &db);
/**
* @brief Remove database to the list of granted access.
* @note if allow_all_ is set, the flag will remain set and the
* database will be added to the set of denied databases.
*
* @param db name of the database to grant access to
*/
void Remove(const std::string &db);
/**
* @brief Called when database is dropped. Removes it from granted (if allow_all is false) and denied set.
* @note allow_all_ is not changed
*
* @param db name of the database to grant access to
*/
void Delete(const std::string &db);
/**
* @brief Set allow_all_ to true and clears grants and denied sets.
*/
void GrantAll();
/**
* @brief Set allow_all_ to false and clears grants and denied sets.
*/
void DenyAll();
/**
* @brief Set the default database.
*/
bool SetDefault(const std::string &db);
/**
* @brief Checks if access is grated to the database.
*
* @param db name of the database
* @return true if allow_all and not denied or granted
*/
bool Contains(const std::string &db) const;
bool GetAllowAll() const { return allow_all_; }
const std::set<std::string> &GetGrants() const { return grants_dbs_; }
const std::set<std::string> &GetDenies() const { return denies_dbs_; }
const std::string &GetDefault() const;
nlohmann::json Serialize() const;
/// @throw AuthException if unable to deserialize.
static Databases Deserialize(const nlohmann::json &data);
private:
Databases(bool allow_all, std::set<std::string> grant, std::set<std::string> deny,
const std::string &default_db = dbms::kDefaultDB)
: grants_dbs_(grant), denies_dbs_(deny), allow_all_(allow_all), default_db_(default_db) {}
std::set<std::string> grants_dbs_; //!< set of databases with granted access
std::set<std::string> denies_dbs_; //!< set of databases with denied access
bool allow_all_; //!< flag to allow access to everything (denied overrides this)
std::string default_db_; //!< user's default database
};
#endif
// TODO (mferencevic): Implement password expiry.
class User final {
public:
@ -246,7 +330,7 @@ class User final {
User(const std::string &username, const std::string &password_hash, const Permissions &permissions);
#ifdef MG_ENTERPRISE
User(const std::string &username, const std::string &password_hash, const Permissions &permissions,
FineGrainedAccessHandler fine_grained_access_handler);
FineGrainedAccessHandler fine_grained_access_handler, Databases db_access = {});
#endif
User(const User &) = default;
User &operator=(const User &) = default;
@ -279,6 +363,11 @@ class User final {
const Role *role() const;
#ifdef MG_ENTERPRISE
Databases &db_access() { return database_access_; }
const Databases &db_access() const { return database_access_; }
#endif
nlohmann::json Serialize() const;
/// @throw AuthException if unable to deserialize.
@ -292,6 +381,7 @@ class User final {
Permissions permissions_;
#ifdef MG_ENTERPRISE
FineGrainedAccessHandler fine_grained_access_handler_;
Databases database_access_;
#endif
std::optional<Role> role_;
};

View File

@ -11,6 +11,8 @@
#pragma once
#include <concepts>
#include <cstddef>
#include <optional>
#include <thread>
@ -24,8 +26,12 @@
#include "communication/bolt/v1/states/executing.hpp"
#include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/states/init.hpp"
#include "communication/bolt/v1/value.hpp"
#include "dbms/constants.hpp"
#include "dbms/global.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/uuid.hpp"
namespace memgraph::communication::bolt {
@ -48,14 +54,26 @@ class SessionException : public utils::BasicException {
* @tparam TOutputStream type of output stream that will be used
*/
template <typename TInputStream, typename TOutputStream>
class Session {
class Session : public dbms::SessionInterface {
public:
using TEncoder = Encoder<ChunkedEncoderBuffer<TOutputStream>>;
/**
* @brief Construct a new Session object
*
* @param input_stream stream to read from
* @param output_stream stream to write to
* @param impl a default high-level implementation to use (has to be defined)
*/
Session(TInputStream *input_stream, TOutputStream *output_stream)
: input_stream_(*input_stream), output_stream_(*output_stream) {}
: input_stream_(*input_stream), output_stream_(*output_stream), session_uuid_(utils::GenerateUUID()) {}
virtual ~Session() {}
virtual ~Session() = default;
Session(const Session &) = delete;
Session &operator=(const Session &) = delete;
Session(Session &&) noexcept = delete;
Session &operator=(Session &&) noexcept = delete;
/**
* Process the given `query` with `params`.
@ -66,6 +84,8 @@ class Session {
const std::string &query, const std::map<std::string, Value> &params,
const std::map<std::string, memgraph::communication::bolt::Value> &extra) = 0;
virtual void Configure(const std::map<std::string, memgraph::communication::bolt::Value> &run_time_info) = 0;
/**
* Put results of the processed query in the `encoder`.
*
@ -86,7 +106,7 @@ class Session {
*/
virtual std::map<std::string, Value> Discard(std::optional<int> n, std::optional<int> qid) = 0;
virtual void BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &) = 0;
virtual void BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &params) = 0;
virtual void CommitTransaction() = 0;
virtual void RollbackTransaction() = 0;
@ -99,7 +119,6 @@ class Session {
/** Return the name of the server that should be used for the Bolt INIT
* message. */
virtual std::optional<std::string> GetServerNameForInit() = 0;
/**
* Executes the session after data has been read into the buffer.
* Goes through the bolt states in order to execute commands from the client.
@ -161,8 +180,7 @@ class Session {
}
}
// TODO: Rethink if there is a way to hide some members. At the momement all
// of them are public.
// TODO: Rethink if there is a way to hide some members. At the momement all of them are public.
TInputStream &input_stream_;
TOutputStream &output_stream_;
@ -182,6 +200,9 @@ class Session {
Version version_;
std::string GetDatabaseName() const override = 0;
std::string UUID() const final { return session_uuid_; }
private:
void ClientFailureInvalidData() {
// Set the state to Close.
@ -197,6 +218,8 @@ class Session {
// of the session to trigger session cleanup and socket close.
throw SessionException("Something went wrong during session execution!");
}
const std::string session_uuid_; //!< unique identifier of the session (auto generated)
};
} // namespace memgraph::communication::bolt

View File

@ -11,6 +11,7 @@
#pragma once
#include <exception>
#include <map>
#include <optional>
#include <string>
@ -207,7 +208,7 @@ State HandleRunV1(TSession &session, const State state, const Marker marker) {
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
spdlog::debug("[Run] '{}'", query.ValueString());
spdlog::debug("[Run - {}] '{}'", session.GetDatabaseName(), query.ValueString());
try {
// Interpret can throw.
@ -265,7 +266,13 @@ State HandleRunV4(TSession &session, const State state, const Marker marker) {
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
spdlog::debug("[Run] '{}'", query.ValueString());
try {
session.Configure(extra.ValueMap());
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
spdlog::debug("[Run - {}] '{}'", session.GetDatabaseName(), query.ValueString());
try {
// Interpret can throw.
@ -381,6 +388,7 @@ State HandleBegin(TSession &session, const State state, const Marker marker) {
}
try {
session.Configure(extra.ValueMap());
session.BeginTransaction(extra.ValueMap());
} catch (const std::exception &e) {
return HandleFailure(session, e);
@ -489,7 +497,7 @@ State HandleRoute(TSession &session, const Marker marker) {
template <typename TSession>
State HandleLogOff() {
// Not arguments sent, the user just needs to reauthenticate
// No arguments sent, the user just needs to reauthenticate
return State::Init;
}
} // namespace memgraph::communication::bolt

View File

@ -18,6 +18,7 @@
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/value.hpp"
#include "communication/exceptions.hpp"
#include "spdlog/spdlog.h"
#include "utils/likely.hpp"
#include "utils/logging.hpp"
@ -248,8 +249,9 @@ State StateInitRunV5(TSession &session, Marker marker, Signature signature) {
}
// Stay in Init
return State::Init;
}
} else if (signature == Signature::LogOn) {
if (signature == Signature::LogOn) {
if (marker != Marker::TinyStruct1) [[unlikely]] {
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
spdlog::trace(
@ -273,11 +275,10 @@ State StateInitRunV5(TSession &session, Marker marker, Signature signature) {
return State::Close;
}
return State::Idle;
} else [[unlikely]] {
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
return State::Close;
}
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
return State::Close;
}
} // namespace details

View File

@ -27,11 +27,11 @@
namespace memgraph::communication::http {
template <class TRequestHandler, typename TSessionData>
class Listener final : public std::enable_shared_from_this<Listener<TRequestHandler, TSessionData>> {
template <class TRequestHandler, typename TSessionContext>
class Listener final : public std::enable_shared_from_this<Listener<TRequestHandler, TSessionContext>> {
using tcp = boost::asio::ip::tcp;
using SessionHandler = Session<TRequestHandler, TSessionData>;
using std::enable_shared_from_this<Listener<TRequestHandler, TSessionData>>::shared_from_this;
using SessionHandler = Session<TRequestHandler, TSessionContext>;
using std::enable_shared_from_this<Listener<TRequestHandler, TSessionContext>>::shared_from_this;
public:
Listener(const Listener &) = delete;
@ -50,8 +50,9 @@ class Listener final : public std::enable_shared_from_this<Listener<TRequestHand
tcp::endpoint GetEndpoint() const { return acceptor_.local_endpoint(); }
private:
Listener(boost::asio::io_context &ioc, TSessionData *data, ServerContext *context, tcp::endpoint endpoint)
: ioc_(ioc), data_(data), context_(context), acceptor_(ioc) {
Listener(boost::asio::io_context &ioc, TSessionContext *session_context, ServerContext *context,
tcp::endpoint endpoint)
: ioc_(ioc), session_context_(session_context), context_(context), acceptor_(ioc) {
boost::beast::error_code ec;
// Open the acceptor
@ -95,13 +96,13 @@ class Listener final : public std::enable_shared_from_this<Listener<TRequestHand
return LogError(ec, "accept");
}
SessionHandler::Create(std::move(socket), data_, *context_)->Run();
SessionHandler::Create(std::move(socket), session_context_, *context_)->Run();
DoAccept();
}
boost::asio::io_context &ioc_;
TSessionData *data_;
TSessionContext *session_context_;
ServerContext *context_;
tcp::acceptor acceptor_;
};

View File

@ -21,14 +21,15 @@
namespace memgraph::communication::http {
template <class TRequestHandler, typename TSessionData>
template <class TRequestHandler, typename TSessionContext>
class Server final {
using tcp = boost::asio::ip::tcp;
public:
explicit Server(io::network::Endpoint endpoint, TSessionData *data, ServerContext *context)
: listener_{Listener<TRequestHandler, TSessionData>::Create(
ioc_, data, context, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port})} {}
explicit Server(io::network::Endpoint endpoint, TSessionContext *session_context, ServerContext *context)
: listener_{Listener<TRequestHandler, TSessionContext>::Create(
ioc_, session_context, context,
tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port})} {}
Server(const Server &) = delete;
Server(Server &&) = delete;
@ -59,7 +60,7 @@ class Server final {
private:
boost::asio::io_context ioc_;
std::shared_ptr<Listener<TRequestHandler, TSessionData>> listener_;
std::shared_ptr<Listener<TRequestHandler, TSessionContext>> listener_;
std::optional<std::thread> background_thread_;
};
} // namespace memgraph::communication::http

View File

@ -42,10 +42,10 @@ inline void LogError(boost::beast::error_code ec, const std::string_view what) {
spdlog::warn("HTTP session failed on {}: {}", what, ec.message());
}
template <class TRequestHandler, typename TSessionData>
class Session : public std::enable_shared_from_this<Session<TRequestHandler, TSessionData>> {
template <class TRequestHandler, typename TSessionContext>
class Session : public std::enable_shared_from_this<Session<TRequestHandler, TSessionContext>> {
using tcp = boost::asio::ip::tcp;
using std::enable_shared_from_this<Session<TRequestHandler, TSessionData>>::shared_from_this;
using std::enable_shared_from_this<Session<TRequestHandler, TSessionContext>>::shared_from_this;
public:
template <typename... Args>
@ -72,7 +72,7 @@ class Session : public std::enable_shared_from_this<Session<TRequestHandler, TSe
using PlainSocket = boost::beast::tcp_stream;
using SSLSocket = boost::beast::ssl_stream<boost::beast::tcp_stream>;
explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &context)
explicit Session(tcp::socket &&socket, TSessionContext *data, ServerContext &context)
: stream_(CreateSocket(std::move(socket), context)),
handler_(data),
strand_{boost::asio::make_strand(GetExecutor())} {}

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -39,7 +39,7 @@ namespace memgraph::communication {
* second, checks all sessions for expiration and shuts them down if they have
* expired.
*/
template <class TSession, class TSessionData>
template <class TSession, class TSessionContext>
class Listener final {
private:
// The maximum number of events handled per execution thread is 1. This is
@ -48,10 +48,10 @@ class Listener final {
// can take a long time.
static const int kMaxEvents = 1;
using SessionHandler = Session<TSession, TSessionData>;
using SessionHandler = Session<TSession, TSessionContext>;
public:
Listener(TSessionData *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name,
Listener(TSessionContext *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count)
: data_(data),
alive_(false),
@ -259,7 +259,7 @@ class Listener final {
io::network::Epoll epoll_;
TSessionData *data_;
TSessionContext *data_;
utils::SpinLock lock_;
std::vector<std::unique_ptr<SessionHandler>> sessions_;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -46,10 +46,10 @@ namespace memgraph::communication {
* @tparam TSession the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
* @tparam TSessionData the class with objects that will be forwarded to the
* @tparam TSessionContext the class with objects that will be forwarded to the
* session
*/
template <typename TSession, typename TSessionData>
template <typename TSession, typename TSessionContext>
class Server final {
public:
using Socket = io::network::Socket;
@ -58,12 +58,12 @@ class Server final {
* Constructs and binds server to endpoint, operates on session data and
* invokes workers_count workers
*/
Server(const io::network::Endpoint &endpoint, TSessionData *session_data, ServerContext *context,
Server(const io::network::Endpoint &endpoint, TSessionContext *session_context, ServerContext *context,
int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count = std::thread::hardware_concurrency())
: alive_(false),
endpoint_(endpoint),
listener_(session_data, context, inactivity_timeout_sec, service_name, workers_count),
listener_(session_context, context, inactivity_timeout_sec, service_name, workers_count),
service_name_(service_name) {}
~Server() {
@ -156,7 +156,7 @@ class Server final {
Socket socket_;
io::network::Endpoint endpoint_;
Listener<TSession, TSessionData> listener_;
Listener<TSession, TSessionContext> listener_;
const std::string service_name_;
};

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -69,10 +69,10 @@ class OutputStream final {
* sessions. It handles socket ownership, inactivity timeout and protocol
* wrapping.
*/
template <class TSession, class TSessionData>
template <class TSession, class TSessionContext>
class Session final {
public:
Session(io::network::Socket &&socket, TSessionData *data, ServerContext *context, int inactivity_timeout_sec)
Session(io::network::Socket &&socket, TSessionContext *data, ServerContext *context, int inactivity_timeout_sec)
: socket_(std::move(socket)),
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
session_(data, socket_.endpoint(), input_buffer_.read_end(), &output_stream_),

View File

@ -36,11 +36,11 @@
namespace memgraph::communication::v2 {
template <class TSession, class TSessionData>
class Listener final : public std::enable_shared_from_this<Listener<TSession, TSessionData>> {
template <class TSession, class TSessionContext>
class Listener final : public std::enable_shared_from_this<Listener<TSession, TSessionContext>> {
using tcp = boost::asio::ip::tcp;
using SessionHandler = Session<TSession, TSessionData>;
using std::enable_shared_from_this<Listener<TSession, TSessionData>>::shared_from_this;
using SessionHandler = Session<TSession, TSessionContext>;
using std::enable_shared_from_this<Listener<TSession, TSessionContext>>::shared_from_this;
public:
Listener(const Listener &) = delete;
@ -59,10 +59,10 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
bool IsRunning() const noexcept { return alive_.load(std::memory_order_relaxed); }
private:
Listener(boost::asio::io_context &io_context, TSessionData *data, ServerContext *server_context,
Listener(boost::asio::io_context &io_context, TSessionContext *session_context, ServerContext *server_context,
tcp::endpoint &endpoint, const std::string_view service_name, const uint64_t inactivity_timeout_sec)
: io_context_(io_context),
data_(data),
session_context_(session_context),
server_context_(server_context),
acceptor_(io_context_),
endpoint_{endpoint},
@ -111,8 +111,8 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
return OnError(ec, "accept");
}
auto session = SessionHandler::Create(std::move(socket), data_, *server_context_, endpoint_, inactivity_timeout_,
service_name_);
auto session = SessionHandler::Create(std::move(socket), session_context_, *server_context_, endpoint_,
inactivity_timeout_, service_name_);
session->Start();
DoAccept();
}
@ -123,7 +123,7 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
}
boost::asio::io_context &io_context_;
TSessionData *data_;
TSessionContext *session_context_;
ServerContext *server_context_;
tcp::acceptor acceptor_;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -60,27 +60,27 @@ using ServerEndpoint = boost::asio::ip::tcp::endpoint;
* @tparam TSession the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
* @tparam TSessionData the class with objects that will be forwarded to the
* @tparam TSessionContext the class with objects that will be forwarded to the
* session
*/
template <typename TSession, typename TSessionData>
template <typename TSession, typename TSessionContext>
class Server final {
using ServerHandler = Server<TSession, TSessionData>;
using ServerHandler = Server<TSession, TSessionContext>;
public:
/**
* Constructs and binds server to endpoint, operates on session data and
* invokes workers_count workers
*/
Server(ServerEndpoint &endpoint, TSessionData *session_data, ServerContext *server_context,
Server(ServerEndpoint &endpoint, TSessionContext *session_context, ServerContext *server_context,
const int inactivity_timeout_sec, const std::string_view service_name,
size_t workers_count = std::thread::hardware_concurrency())
: endpoint_{endpoint},
service_name_{service_name},
context_thread_pool_{workers_count},
listener_{Listener<TSession, TSessionData>::Create(context_thread_pool_.GetIOContext(), session_data,
server_context, endpoint_, service_name_,
inactivity_timeout_sec)} {}
listener_{Listener<TSession, TSessionContext>::Create(context_thread_pool_.GetIOContext(), session_context,
server_context, endpoint_, service_name_,
inactivity_timeout_sec)} {}
~Server() { MG_ASSERT(!IsRunning(), "Server wasn't shutdown properly"); }
@ -122,7 +122,7 @@ class Server final {
std::string service_name_;
IOContextThreadPool context_thread_pool_;
std::shared_ptr<Listener<TSession, TSessionData>> listener_;
std::shared_ptr<Listener<TSession, TSessionContext>> listener_;
};
} // namespace memgraph::communication::v2

View File

@ -16,10 +16,12 @@
#include <cstdint>
#include <cstring>
#include <deque>
#include <exception>
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <variant>
@ -41,9 +43,11 @@
#include <boost/beast/websocket/rfc6455.hpp>
#include <boost/system/detail/error_code.hpp>
#include "communication/bolt/v1/session.hpp"
#include "communication/buffer.hpp"
#include "communication/context.hpp"
#include "communication/exceptions.hpp"
#include "dbms/global.hpp"
#include "utils/event_counter.hpp"
#include "utils/logging.hpp"
#include "utils/on_scope_exit.hpp"
@ -95,10 +99,10 @@ class OutputStream final {
* Websocket Sessions. It handles socket ownership, inactivity timeout and protocol
* wrapping.
*/
template <typename TSession, typename TSessionData>
class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TSession, TSessionData>> {
template <typename TSession, typename TSessionContext>
class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TSession, TSessionContext>> {
using WebSocket = boost::beast::websocket::stream<boost::beast::tcp_stream>;
using std::enable_shared_from_this<WebsocketSession<TSession, TSessionData>>::shared_from_this;
using std::enable_shared_from_this<WebsocketSession<TSession, TSessionContext>>::shared_from_this;
public:
template <typename... Args>
@ -106,6 +110,17 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
return std::shared_ptr<WebsocketSession>(new WebsocketSession(std::forward<Args>(args)...));
}
#ifdef MG_ENTERPRISE
~WebsocketSession() { session_context_->Delete(session_); }
#else
~WebsocketSession() = default;
#endif
WebsocketSession(const WebsocketSession &) = delete;
WebsocketSession &operator=(const WebsocketSession &) = delete;
WebsocketSession(WebsocketSession &&) noexcept = delete;
WebsocketSession &operator=(WebsocketSession &&) noexcept = delete;
// Start the asynchronous accept operation
template <class Body, class Allocator>
void DoAccept(boost::beast::http::request<Body, boost::beast::http::basic_fields<Allocator>> req) {
@ -151,15 +166,20 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
private:
// Take ownership of the socket
explicit WebsocketSession(tcp::socket &&socket, TSessionData *data, tcp::endpoint endpoint,
explicit WebsocketSession(tcp::socket &&socket, TSessionContext *session_context, tcp::endpoint endpoint,
std::string_view service_name)
: ws_(std::move(socket)),
strand_{boost::asio::make_strand(ws_.get_executor())},
output_stream_([this](const uint8_t *data, size_t len, bool /*have_more*/) { return Write(data, len); }),
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
session_{*session_context, endpoint, input_buffer_.read_end(), &output_stream_},
session_context_{session_context},
endpoint_{endpoint},
remote_endpoint_{ws_.next_layer().socket().remote_endpoint()},
service_name_{service_name} {}
service_name_{service_name} {
#ifdef MG_ENTERPRISE
session_context_->Register(session_);
#endif
}
void OnAccept(boost::beast::error_code ec) {
if (ec) {
@ -242,6 +262,7 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
communication::Buffer input_buffer_;
OutputStream output_stream_;
TSession session_;
TSessionContext *session_context_;
tcp::endpoint endpoint_;
tcp::endpoint remote_endpoint_;
std::string_view service_name_;
@ -253,11 +274,11 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
* Sessions. It handles socket ownership, inactivity timeout and protocol
* wrapping.
*/
template <typename TSession, typename TSessionData>
class Session final : public std::enable_shared_from_this<Session<TSession, TSessionData>> {
template <typename TSession, typename TSessionContext>
class Session final : public std::enable_shared_from_this<Session<TSession, TSessionContext>> {
using TCPSocket = tcp::socket;
using SSLSocket = boost::asio::ssl::stream<TCPSocket>;
using std::enable_shared_from_this<Session<TSession, TSessionData>>::shared_from_this;
using std::enable_shared_from_this<Session<TSession, TSessionContext>>::shared_from_this;
public:
template <typename... Args>
@ -265,11 +286,16 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
return std::shared_ptr<Session>(new Session(std::forward<Args>(args)...));
}
#ifdef MG_ENTERPRISE
~Session() { session_context_->Delete(session_); }
#else
~Session() = default;
#endif
Session(const Session &) = delete;
Session(Session &&) = delete;
Session &operator=(const Session &) = delete;
Session &operator=(Session &&) = delete;
~Session() = default;
bool Start() {
if (execution_active_) {
@ -334,18 +360,23 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
}
private:
explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &server_context, tcp::endpoint endpoint,
const std::chrono::seconds inactivity_timeout_sec, std::string_view service_name)
explicit Session(tcp::socket &&socket, TSessionContext *session_context, ServerContext &server_context,
tcp::endpoint endpoint, const std::chrono::seconds inactivity_timeout_sec,
std::string_view service_name)
: socket_(CreateSocket(std::move(socket), server_context)),
strand_{boost::asio::make_strand(GetExecutor())},
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
data_{data},
session_{*session_context, endpoint, input_buffer_.read_end(), &output_stream_},
session_context_{session_context},
endpoint_{endpoint},
remote_endpoint_{GetRemoteEndpoint()},
service_name_{service_name},
timeout_seconds_(inactivity_timeout_sec),
timeout_timer_(GetExecutor()) {
#ifdef MG_ENTERPRISE
// TODO Try to remove Register (see comment at SessionInterface declaration)
session_context_->Register(session_);
#endif
ExecuteForSocket([](auto &&socket) {
socket.lowest_layer().set_option(tcp::no_delay(true)); // enable PSH
socket.lowest_layer().set_option(boost::asio::socket_base::keep_alive(true)); // enable SO_KEEPALIVE
@ -396,7 +427,8 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
spdlog::info("Switching {} to websocket connection", remote_endpoint_);
if (std::holds_alternative<TCPSocket>(socket_)) {
auto sock = std::get<TCPSocket>(std::move(socket_));
WebsocketSession<TSession, TSessionData>::Create(std::move(sock), data_, endpoint_, service_name_)
WebsocketSession<TSession, TSessionContext>::Create(std::move(sock), session_context_, endpoint_,
service_name_)
->DoAccept(parser.release());
execution_active_ = false;
return;
@ -535,7 +567,7 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
communication::Buffer input_buffer_;
OutputStream output_stream_;
TSession session_;
TSessionData *data_;
TSessionContext *session_context_;
tcp::endpoint endpoint_;
tcp::endpoint remote_endpoint_;
std::string_view service_name_;

18
src/dbms/constants.hpp Normal file
View File

@ -0,0 +1,18 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
namespace memgraph::dbms {
constexpr static const char *kDefaultDB = "memgraph"; //!< Name of the default database
} // namespace memgraph::dbms

110
src/dbms/global.hpp Normal file
View File

@ -0,0 +1,110 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <concepts>
#include <cstdint>
#include <string>
#include "utils/exceptions.hpp"
namespace memgraph::dbms {
enum class DeleteError : uint8_t {
DEFAULT_DB,
USING,
NON_EXISTENT,
FAIL,
DISK_FAIL,
};
enum class NewError : uint8_t {
NO_CONFIGS,
EXISTS,
DEFUNCT,
GENERIC,
};
enum class SetForResult : uint8_t {
SUCCESS,
ALREADY_SET,
FAIL,
};
/**
* UnknownSession Exception
*
* Used to indicate that an unknown session was used.
*/
class UnknownSessionException : public utils::BasicException {
public:
using utils::BasicException::BasicException;
};
/**
* UnknownDatabase Exception
*
* Used to indicate that an unknown database was used.
*/
class UnknownDatabaseException : public utils::BasicException {
public:
using utils::BasicException::BasicException;
};
/**
* @brief Session interface used by the DBMS to handle the the active sessions.
* @todo Try to remove this dependency from SessionContextHandler. OnDelete could be removed, as it only does an assert.
* OnChange could be removed if SetFor returned the pointer and the called then handled the OnChange execution.
* However, the interface is very useful to decouple the interpreter's query execution and the sessions themselves.
*/
class SessionInterface {
public:
SessionInterface() = default;
virtual ~SessionInterface() = default;
SessionInterface(const SessionInterface &) = default;
SessionInterface &operator=(const SessionInterface &) = default;
SessionInterface(SessionInterface &&) noexcept = default;
SessionInterface &operator=(SessionInterface &&) noexcept = default;
/**
* @brief Return the unique string identifying the session.
*
* @return std::string
*/
virtual std::string UUID() const = 0;
/**
* @brief Return the currently active database.
*
* @return std::string
*/
virtual std::string GetDatabaseName() const = 0;
#ifdef MG_ENTERPRISE
/**
* @brief Gets called on database change.
*
* @return SetForResult enum (SUCCESS, ALREADY_SET or FAIL)
*/
virtual dbms::SetForResult OnChange(const std::string &) = 0;
/**
* @brief Callback that gets called on database delete (drop).
*
* @return true on success
*/
virtual bool OnDelete(const std::string &) = 0;
#endif
};
} // namespace memgraph::dbms

142
src/dbms/handler.hpp Normal file
View File

@ -0,0 +1,142 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <filesystem>
#include <memory>
#include <optional>
#include <string_view>
#include <unordered_map>
#include "global.hpp"
#include "utils/result.hpp"
#include "utils/sync_ptr.hpp"
namespace memgraph::dbms {
/**
* @brief Generic multi-database content handler.
*
* @tparam TContext
* @tparam TConfig
*/
template <typename TContext, typename TConfig>
class Handler {
public:
using NewResult = utils::BasicResult<NewError, std::shared_ptr<TContext>>;
/**
* @brief Empty Handler constructor.
*
*/
Handler() {}
/**
* @brief Generate a new context and corresponding configuration.
*
* @tparam T1 Variadic template of context constructor arguments
* @tparam T2 Variadic template of config constructor arguments
* @param name Name associated with the new context/config pair
* @param args1 Arguments passed (as a tuple) to the context constructor
* @param args2 Arguments passed (as a tuple) to the config constructor
* @return NewResult
*/
template <typename... T1, typename... T2>
NewResult New(std::string name, std::tuple<T1...> args1, std::tuple<T2...> args2) {
return New_(name, args1, args2, std::make_index_sequence<sizeof...(T1)>{},
std::make_index_sequence<sizeof...(T2)>{});
}
/**
* @brief Get pointer to context.
*
* @param name Name associated with the wanted context
* @return std::optional<std::shared_ptr<TContext>>
*/
std::optional<std::shared_ptr<TContext>> Get(const std::string &name) {
if (auto search = items_.find(name); search != items_.end()) {
return search->second.get();
}
return {};
}
/**
* @brief Get the config.
*
* @param name Name associated with the wanted config
* @return std::optional<TConfig>
*/
std::optional<TConfig> GetConfig(const std::string &name) const {
if (auto search = items_.find(name); search != items_.end()) {
return search->second.config();
}
return {};
}
/**
* @brief Delete the context/config pair associated with the name.
*
* @param name Name associated with the context/config pair to delete
* @return true on success
*/
bool Delete(const std::string &name) {
if (auto itr = items_.find(name); itr != items_.end()) {
itr->second.DestroyAndSync();
items_.erase(itr);
return true;
}
return false;
}
/**
* @brief Check if a name is already used.
*
* @param name Name to check
* @return true if a context/config pair is already associated with the name
*/
bool Has(const std::string &name) const { return items_.find(name) != items_.end(); }
auto begin() { return items_.begin(); }
auto end() { return items_.end(); }
auto begin() const { return items_.begin(); }
auto end() const { return items_.end(); }
auto cbegin() const { return items_.cbegin(); }
auto cend() const { return items_.cend(); }
private:
/**
* @brief Lower level handler that hides some ugly code.
*
* @tparam T1 Variadic template of context constructor arguments
* @tparam T2 Variadic template of config constructor arguments
* @tparam I1 List of indexes associated with the first tuple
* @tparam I2 List of indexes associated with the second tuple
*/
template <typename... T1, typename... T2, std::size_t... I1, std::size_t... I2>
NewResult New_(std::string name, std::tuple<T1...> &args1, std::tuple<T2...> &args2,
std::integer_sequence<std::size_t, I1...> /*not-used*/,
std::integer_sequence<std::size_t, I2...> /*not-used*/) {
// Make sure the emplace will succeed, since we don't want to create temporary objects that could break something
if (!Has(name)) {
auto [itr, _] = items_.emplace(std::piecewise_construct, std::forward_as_tuple(name),
std::forward_as_tuple(TConfig{std::forward<T1>(std::get<I1>(args1))...},
std::forward<T2>(std::get<I2>(args2))...));
return itr->second.get();
}
spdlog::info("Item with name \"{}\" already exists.", name);
return NewError::EXISTS;
}
std::unordered_map<std::string, utils::SyncPtr<TContext, TConfig>> items_; //!< map to all active items
};
} // namespace memgraph::dbms

106
src/dbms/interp_handler.hpp Normal file
View File

@ -0,0 +1,106 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#ifdef MG_ENTERPRISE
#include "global.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
#include "storage/v2/storage.hpp"
#include "handler.hpp"
namespace memgraph::dbms {
/**
* @brief Simple class that adds useful information to the query's InterpreterContext
*
* @tparam T Multi-database handler type
*/
template <typename T>
class ExpandedInterpContext : public query::InterpreterContext {
public:
template <typename... TArgs>
explicit ExpandedInterpContext(T &ref, TArgs &&...args)
: query::InterpreterContext(std::forward<TArgs>(args)...), sc_handler_(ref) {}
T &sc_handler_; //!< Multi-database/SessionContext handler (used in some queries)
};
/**
* @brief Simple structure that expands on the query's InterpreterConfig
*
*/
struct ExpandedInterpConfig {
storage::Config storage_config; //!< Storage configuration
query::InterpreterConfig interp_config; //!< Interpreter configuration
};
/**
* @brief Multi-database interpreter context handler
*
* @tparam TSCHandler High-level multi-database/SessionContext handler type
*/
template <typename TSCHandler>
class InterpContextHandler : public Handler<ExpandedInterpContext<TSCHandler>, ExpandedInterpConfig> {
public:
using InterpContextT = ExpandedInterpContext<TSCHandler>;
using HandlerT = Handler<InterpContextT, ExpandedInterpConfig>;
/**
* @brief Generate a new interpreter context associated with the passed name.
*
* @param name Name associating the new interpreter context
* @param sc_handler Multi-database/SessionContext handler used (some queries might use it)
* @param db Storage associated with the interpreter context
* @param config Interpreter's configuration
* @param dir Directory used by the interpreter
* @param auth_handler AuthQueryHandler used
* @param auth_checker AuthChecker used
* @return HandlerT::NewResult
*/
typename HandlerT::NewResult New(const std::string &name, TSCHandler &sc_handler, storage::Config storage_config,
const query::InterpreterConfig &interpreter_config,
query::AuthQueryHandler &auth_handler, query::AuthChecker &auth_checker) {
// Check if compatible with the existing interpreters
if (std::any_of(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) {
const auto &config = elem.second.config().storage_config;
return config.durability.storage_directory == storage_config.durability.storage_directory;
})) {
spdlog::info("Tried to generate a new context using claimed directory and/or storage.");
return NewError::EXISTS;
}
const auto dir = storage_config.durability.storage_directory;
storage_config.name = name; // Set storage id via config
return HandlerT::New(
name, std::forward_as_tuple(storage_config, interpreter_config),
std::forward_as_tuple(sc_handler, storage_config, interpreter_config, dir, &auth_handler, &auth_checker));
}
/**
* @brief All currently active storage.
*
* @return std::vector<std::string>
*/
std::vector<std::string> All() const {
std::vector<std::string> res;
res.reserve(std::distance(HandlerT::cbegin(), HandlerT::cend()));
std::for_each(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) { res.push_back(elem.first); });
return res;
}
};
} // namespace memgraph::dbms
#endif

View File

@ -0,0 +1,61 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/auth.hpp"
#include "query/interpreter.hpp"
#include "storage/v2/storage.hpp"
#include "utils/synchronized.hpp"
#if MG_ENTERPRISE
#include "audit/log.hpp"
#endif
namespace memgraph::dbms {
/**
* @brief Structure encapsulating storage and interpreter context.
*
* @note Each session contains a copy.
*/
struct SessionContext {
// Explicit constructor here to ensure that pointers to all objects are
// supplied.
SessionContext(std::shared_ptr<memgraph::query::InterpreterContext> interpreter_context, std::string run,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
#ifdef MG_ENTERPRISE
,
memgraph::audit::Log *audit_log
#endif
)
: interpreter_context(interpreter_context),
run_id(run),
auth(auth)
#ifdef MG_ENTERPRISE
,
audit_log(audit_log)
#endif
{
}
std::shared_ptr<memgraph::query::InterpreterContext> interpreter_context;
std::string run_id;
// std::shared_ptr<AuthContext> auth_context;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log;
#endif
};
} // namespace memgraph::dbms

View File

@ -0,0 +1,603 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <algorithm>
#include <concepts>
#include <cstdint>
#include <filesystem>
#include <memory>
#include <mutex>
#include <optional>
#include <ostream>
#include <stdexcept>
#include <system_error>
#include <unordered_map>
#include "constants.hpp"
#include "global.hpp"
#include "interp_handler.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
#include "session_context.hpp"
#include "spdlog/spdlog.h"
#include "storage/v2/durability/durability.hpp"
#include "storage/v2/durability/paths.hpp"
#include "utils/exceptions.hpp"
#include "utils/file.hpp"
#include "utils/logging.hpp"
#include "utils/result.hpp"
#include "utils/rw_lock.hpp"
#include "utils/synchronized.hpp"
#include "utils/uuid.hpp"
#include "handler.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
using DeleteResult = utils::BasicResult<DeleteError>;
/**
* @brief Multi-database session contexts handler.
*/
class SessionContextHandler {
public:
using StorageT = storage::Storage;
using StorageConfigT = storage::Config;
using LockT = utils::RWLock;
using NewResultT = utils::BasicResult<NewError, SessionContext>;
struct Config {
StorageConfigT storage_config; //!< Storage configuration
query::InterpreterConfig interp_config; //!< Interpreter context configuration
std::function<void(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *,
std::unique_ptr<query::AuthQueryHandler> &, std::unique_ptr<query::AuthChecker> &)>
glue_auth;
};
struct Statistics {
uint64_t num_vertex; //!< Sum of vertexes in every database
uint64_t num_edges; //!< Sum of edges in every database
uint64_t num_databases; //! number of isolated databases
};
/**
* @brief Initialize the handler.
*
* @param audit_log pointer to the audit logger (ENTERPRISE only)
* @param configs storage and interpreter configurations
* @param recovery_on_startup restore databases (and its content) and authentication data
*/
SessionContextHandler(memgraph::audit::Log &audit_log, Config configs, bool recovery_on_startup, bool delete_on_drop)
: lock_{utils::RWLock::Priority::READ},
default_configs_(configs),
run_id_{utils::GenerateUUID()},
audit_log_(&audit_log),
delete_on_drop_(delete_on_drop) {
const auto &root = configs.storage_config.durability.storage_directory;
utils::EnsureDirOrDie(root);
// Verify that the user that started the process is the same user that is
// the owner of the storage directory.
storage::durability::VerifyStorageDirectoryOwnerAndProcessUserOrDie(root);
// Create the lock file and open a handle to it. This will crash the
// database if it can't open the file for writing or if any other process is
// holding the file opened.
lock_file_path_ = root / ".lock";
lock_file_handle_.Open(lock_file_path_, utils::OutputFile::Mode::OVERWRITE_EXISTING);
MG_ASSERT(lock_file_handle_.AcquireLock(),
"Couldn't acquire lock on the storage directory {}"
"!\nAnother Memgraph process is currently running with the same "
"storage directory, please stop it first before starting this "
"process!",
root);
// TODO: Figure out if this is needed/wanted
// Clear auth database since we are not recovering
// if (!recovery_on_startup) {
// const auto &auth_dir = root / "auth";
// // Backup if auth present
// if (utils::DirExists(auth_dir)) {
// auto backup_dir = root / storage::durability::kBackupDirectory;
// std::error_code error_code;
// utils::EnsureDirOrDie(backup_dir);
// std::error_code ec;
// const auto now = std::chrono::system_clock::now();
// std::ostringstream os;
// os << now.time_since_epoch().count();
// std::filesystem::rename(auth_dir, backup_dir / ("auth-" + os.str()), ec);
// MG_ASSERT(!ec, "Couldn't backup auth directory because of: {}", ec.message());
// spdlog::warn(
// "Since Memgraph was not supposed to recover on startup the authentication files will be "
// "overwritten. To prevent important data loss, Memgraph has stored those files into .backup directory "
// "inside the storage directory.");
// }
// // Clear
// if (std::filesystem::exists(auth_dir)) {
// std::filesystem::remove_all(auth_dir);
// }
// }
// Lazy initialization of auth_
auth_ = std::make_unique<utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock>>(root / "auth");
configs.glue_auth(auth_.get(), auth_handler_, auth_checker_);
// TODO: Decouple storage config from dbms config
// TODO: Save individual db configs inside the kvstore and restore from there
storage::UpdatePaths(default_configs_->storage_config,
default_configs_->storage_config.durability.storage_directory / "databases");
const auto &db_dir = default_configs_->storage_config.durability.storage_directory;
const auto durability_dir = db_dir / ".durability";
utils::EnsureDirOrDie(db_dir);
utils::EnsureDirOrDie(durability_dir);
durability_ = std::make_unique<kvstore::KVStore>(durability_dir);
// Generate the default database
MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB.");
// Recover previous databases
if (recovery_on_startup) {
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue; // Already set
spdlog::info("Restoring database {}.", name);
MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name);
spdlog::info("Database {} restored.", name);
}
} else { // Clear databases from the durability list and auth
auto locked_auth = auth_->Lock();
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue;
locked_auth->DeleteDatabase(name);
durability_->Delete(name);
}
}
}
void Shutdown() {
for (auto &ic : interp_handler_) memgraph::query::Shutdown(ic.second.get().get());
}
/**
* @brief Create a new SessionContext associated with the "name" database
*
* @param name name of the database
* @return NewResultT context on success, error on failure
*/
NewResultT New(const std::string &name) {
std::lock_guard<LockT> wr(lock_);
return New_(name, name);
}
/**
* @brief Get the context associated with the "name" database
*
* @param name
* @return SessionContext
* @throw UnknownDatabaseException if getting unknown database
*/
SessionContext Get(const std::string &name) {
std::shared_lock<LockT> rd(lock_);
return Get_(name);
}
/**
* @brief Set the undelying database for a particular session.
*
* @param uuid unique session identifier
* @param db_name unique database name
* @return SetForResult enum
* @throws UnknownDatabaseException, UnknownSessionException or anything OnChange throws
*/
SetForResult SetFor(const std::string &uuid, const std::string &db_name) {
std::shared_lock<LockT> rd(lock_);
(void)Get_(
db_name); // throws if db doesn't exist (TODO: Better to pass it via OnChange - but injecting dependency)
try {
auto &s = sessions_.at(uuid);
return s.OnChange(db_name);
} catch (std::out_of_range &) {
throw UnknownSessionException("Unknown session \"{}\"", uuid);
}
}
/**
* @brief Set the undelying database from a session itself. SessionContext handler.
*
* @param db_name unique database name
* @param handler function that gets called in place with the appropriate SessionContext
* @return SetForResult enum
*/
template <typename THandler>
requires std::invocable<THandler, SessionContext> SetForResult SetInPlace(const std::string &db_name,
THandler handler) {
std::shared_lock<LockT> rd(lock_);
return handler(Get_(db_name));
}
/**
* @brief Call void handler under a shared lock.
*
* @param handler function that gets called in place
*/
template <typename THandler>
requires std::invocable<THandler>
void CallInPlace(THandler handler) {
std::shared_lock<LockT> rd(lock_);
handler();
}
/**
* @brief Register an active session (used to handle callbacks).
*
* @param session
* @return true on success
*/
bool Register(SessionInterface &session) {
std::lock_guard<LockT> wr(lock_);
auto [_, success] = sessions_.emplace(session.UUID(), session);
return success;
}
/**
* @brief Delete a session.
*
* @param session
*/
bool Delete(const SessionInterface &session) {
std::lock_guard<LockT> wr(lock_);
return sessions_.erase(session.UUID()) > 0;
}
/**
* @brief Delete database.
*
* @param db_name database name
* @return DeleteResult error on failure
*/
DeleteResult Delete(const std::string &db_name) {
std::lock_guard<LockT> wr(lock_);
if (db_name == kDefaultDB) {
// MSG cannot delete the default db
return DeleteError::DEFAULT_DB;
}
// Check if db exists
try {
auto sc = Get_(db_name);
// Check if a session is using the db
if (!sc.interpreter_context->interpreters->empty()) {
return DeleteError::USING;
}
} catch (UnknownDatabaseException &) {
return DeleteError::NON_EXISTENT;
}
// High level handlers
for (auto &[_, s] : sessions_) {
if (!s.OnDelete(db_name)) {
spdlog::error("Partial failure while deleting database \"{}\".", db_name);
defunct_dbs_.emplace(db_name);
return DeleteError::FAIL;
}
}
// Low level handlers
const auto storage_path = StorageDir_(db_name);
MG_ASSERT(storage_path, "Missing storage for {}", db_name);
if (!interp_handler_.Delete(db_name)) {
spdlog::error("Partial failure while deleting database \"{}\".", db_name);
defunct_dbs_.emplace(db_name);
return DeleteError::FAIL;
}
// Remove from auth
auth_->Lock()->DeleteDatabase(db_name);
// Remove from durability list
if (durability_) durability_->Delete(db_name);
// Delete disk storage
if (delete_on_drop_) {
std::error_code ec;
(void)std::filesystem::remove_all(*storage_path, ec);
if (ec) {
spdlog::error("Failed to clean disk while deleting database \"{}\".", db_name);
defunct_dbs_.emplace(db_name);
return DeleteError::DISK_FAIL;
}
}
// Delete from defunct_dbs_ (in case a second delete call was successful)
defunct_dbs_.erase(db_name);
return {}; // Success
}
/**
* @brief Set the default configurations.
*
* @param configs storage, interpreter and authorization configurations
*/
void SetDefaultConfigs(Config configs) {
std::lock_guard<LockT> wr(lock_);
default_configs_ = configs;
}
/**
* @brief Get the default configurations.
*
* @return std::optional<Config>
*/
std::optional<Config> GetDefaultConfigs() const {
std::shared_lock<LockT> rd(lock_);
return default_configs_;
}
/**
* @brief Return all active databases.
*
* @return std::vector<std::string>
*/
std::vector<std::string> All() const {
std::shared_lock<LockT> rd(lock_);
return interp_handler_.All();
}
/**
* @brief Return the number of vertex across all databases.
*
* @return uint64_t
*/
Statistics Info() const {
// TODO: Handle overflow
uint64_t nv = 0;
uint64_t ne = 0;
std::shared_lock<LockT> rd(lock_);
const uint64_t ndb = std::distance(interp_handler_.cbegin(), interp_handler_.cend());
for (const auto &ic : interp_handler_) {
const auto &info = ic.second.get()->db->GetInfo();
nv += info.vertex_count;
ne += info.edge_count;
}
return {nv, ne, ndb};
}
/**
* @brief Return the currently active database for a particular session.
*
* @param uuid session's unique identifier
* @return std::string name of the database
* @throw
*/
std::string Current(const std::string &uuid) const {
std::shared_lock<LockT> rd(lock_);
return sessions_.at(uuid).GetDatabaseName();
}
/**
* @brief Restore triggers for all currently defined databases.
* @note: Triggers can execute query procedures, so we need to reload the modules first and then the triggers
*/
void RestoreTriggers() {
std::lock_guard<LockT> wr(lock_);
for (auto &ic_itr : interp_handler_) {
auto ic = ic_itr.second.get();
spdlog::debug("Restoring trigger for database \"{}\"", ic->db->id());
auto storage_accessor = ic->db->Access();
auto dba = memgraph::query::DbAccessor{storage_accessor.get()};
ic->trigger_store.RestoreTriggers(&ic->ast_cache, &dba, ic->config.query, ic->auth_checker);
}
}
/**
* @brief Restore streams of all currently defined databases.
* @note: Stream transformations are using modules, they have to be restored after the query modules are loaded.
*/
void RestoreStreams() {
std::lock_guard<LockT> wr(lock_);
for (auto &ic_itr : interp_handler_) {
auto ic = ic_itr.second.get();
spdlog::debug("Restoring streams for database \"{}\"", ic->db->id());
ic->streams.RestoreStreams();
}
}
private:
std::optional<std::filesystem::path> StorageDir_(const std::string &name) const {
const auto conf = interp_handler_.GetConfig(name);
if (conf) {
return conf->storage_config.durability.storage_directory;
}
spdlog::debug("Failed to find storage dir for database \"{}\"", name);
return {};
}
/**
* @brief Create a new SessionContext associated with the "name" database
*
* @param name name of the database
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name) { return New_(name, name); }
/**
* @brief Create a new SessionContext associated with the "name" database
*
* @param name name of the database
* @param storage_subdir undelying RocksDB directory
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name, std::filesystem::path storage_subdir) {
if (default_configs_) {
auto storage = default_configs_->storage_config;
storage::UpdatePaths(storage, storage.durability.storage_directory / storage_subdir);
return New_(name, storage, default_configs_->interp_config);
}
spdlog::info("Trying to generate session context without any configurations.");
return NewError::NO_CONFIGS;
}
/**
* @brief Create a new SessionContext associated with the "name" database
*
* @param name name of the database
* @param storage_config storage configuration
* @param inter_config interpreter configuration
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name, StorageConfigT &storage_config, query::InterpreterConfig &inter_config/*,
const std::string &ah_flags*/) {
MG_ASSERT(auth_handler_, "No high level AuthQueryHandler has been supplied.");
MG_ASSERT(auth_checker_, "No high level AuthChecker has been supplied.");
if (defunct_dbs_.contains(name)) {
spdlog::warn("Failed to generate database due to the unknown state of the previously defunct database \"{}\".",
name);
return NewError::DEFUNCT;
}
auto new_interp = interp_handler_.New(name, *this, storage_config, inter_config, *auth_handler_, *auth_checker_);
if (new_interp.HasValue()) {
// Success
if (durability_) durability_->Put(name, "ok");
return SessionContext{new_interp.GetValue(), run_id_, auth_.get(), audit_log_};
}
return new_interp.GetError();
}
/**
* @brief Create a new SessionContext associated with the default database
*
* @return NewResultT context on success, error on failure
*/
NewResultT NewDefault_() {
// Create the default DB in the root (this is how it was done pre multi-tenancy)
auto res = New_(kDefaultDB, "..");
if (res.HasValue()) {
// For back-compatibility...
// Recreate the dbms layout for the default db and symlink to the root
const auto dir = StorageDir_(kDefaultDB);
MG_ASSERT(dir, "Failed to find storage path.");
const auto main_dir = *dir / "databases" / kDefaultDB;
if (!std::filesystem::exists(main_dir)) {
std::filesystem::create_directory(main_dir);
}
// Force link on-disk directories
const auto conf = interp_handler_.GetConfig(kDefaultDB);
MG_ASSERT(conf, "No configuration for the default database.");
const auto &tmp_conf = conf->storage_config.disk;
std::vector<std::filesystem::path> to_link{
tmp_conf.main_storage_directory, tmp_conf.label_index_directory,
tmp_conf.label_property_index_directory, tmp_conf.unique_constraints_directory,
tmp_conf.name_id_mapper_directory, tmp_conf.id_name_mapper_directory,
tmp_conf.durability_directory, tmp_conf.wal_directory,
};
// Add in-memory paths
// Some directories are redundant (skip those)
const std::vector<std::string> skip{".lock", "audit_log", "auth", "databases", "internal_modules", "settings"};
for (auto const &item : std::filesystem::directory_iterator{*dir}) {
const auto dir_name = std::filesystem::relative(item.path(), item.path().parent_path());
if (std::find(skip.begin(), skip.end(), dir_name) != skip.end()) continue;
to_link.push_back(item.path());
}
// Symlink to root dir
for (auto const &item : to_link) {
const auto dir_name = std::filesystem::relative(item, item.parent_path());
const auto link = main_dir / dir_name;
const auto to = std::filesystem::relative(item, main_dir);
if (!std::filesystem::is_symlink(link) && !std::filesystem::exists(link)) {
std::filesystem::create_directory_symlink(to, link);
} else { // Check existing link
std::error_code ec;
const auto test_link = std::filesystem::read_symlink(link, ec);
if (ec || test_link != to) {
MG_ASSERT(false,
"Memgraph storage directory incompatible with new version.\n"
"Please use a clean directory or remove \"{}\" and try again.",
link.string());
}
}
}
}
return res;
}
/**
* @brief Get the context associated with the "name" database
*
* @param name
* @return SessionContext
* @throw UnknownDatabaseException if trying to get unknown database
*/
SessionContext Get_(const std::string &name) {
auto interp = interp_handler_.Get(name);
if (interp) {
return SessionContext{*interp, run_id_, auth_.get(), audit_log_};
}
throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name);
}
// Should storage objects ever be deleted?
mutable LockT lock_; //!< protective lock
std::filesystem::path lock_file_path_; //!< Lock file protecting the main storage
utils::OutputFile lock_file_handle_; //!< Handler the lock (crash if already open)
InterpContextHandler<SessionContextHandler> interp_handler_; //!< multi-tenancy interpreter handler
// AuthContextHandler auth_handler_; //!< multi-tenancy authorization handler (currently we use a single global
// auth)
std::unique_ptr<utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock>> auth_;
std::unique_ptr<query::AuthQueryHandler> auth_handler_;
std::unique_ptr<query::AuthChecker> auth_checker_;
std::optional<Config> default_configs_; //!< default storage and interpreter configurations
const std::string run_id_; //!< run's unique identifier (auto generated)
memgraph::audit::Log *audit_log_; //!< pointer to the audit logger
std::unordered_map<std::string, SessionInterface &> sessions_; //!< map of active/registered sessions
std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation)
std::set<std::string> defunct_dbs_; //!< Databases that are in an unknown state due to various failures
bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory
public:
static SessionContextHandler &ExtractSCH(query::InterpreterContext *interpreter_context) {
return static_cast<typename decltype(interp_handler_)::InterpContextT *>(interpreter_context)->sc_handler_;
}
};
#else
/**
* @brief Initialize the handler.
*
* @param auth pointer to the authenticator
* @param configs storage and interpreter configurations
*/
static inline SessionContext Init(storage::Config &storage_config, query::InterpreterConfig &interp_config,
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth,
query::AuthQueryHandler *auth_handler, query::AuthChecker *auth_checker) {
MG_ASSERT(auth, "Passed a nullptr auth");
MG_ASSERT(auth_handler, "Passed a nullptr auth_handler");
MG_ASSERT(auth_checker, "Passed a nullptr auth_checker");
storage_config.name = kDefaultDB;
auto interp_context = std::make_shared<query::InterpreterContext>(
storage_config, interp_config, storage_config.durability.storage_directory, auth_handler, auth_checker);
MG_ASSERT(interp_context, "Failed to construct main interpret context.");
return SessionContext{interp_context, utils::GenerateUUID(), auth};
}
#endif
} // namespace memgraph::dbms

View File

@ -62,6 +62,10 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) {
return auth::Permission::STORAGE_MODE;
case query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT:
return auth::Permission::TRANSACTION_MANAGEMENT;
case query::AuthQuery::Privilege::MULTI_DATABASE_EDIT:
return auth::Permission::MULTI_DATABASE_EDIT;
case query::AuthQuery::Privilege::MULTI_DATABASE_USE:
return auth::Permission::MULTI_DATABASE_USE;
}
}

View File

@ -71,7 +71,8 @@ AuthChecker::AuthChecker(
: auth_(auth) {}
bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) const {
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::string &db_name) const {
std::optional<memgraph::auth::User> maybe_user;
{
auto locked_auth = auth_->ReadLock();
@ -83,7 +84,7 @@ bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
}
}
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges);
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges, db_name);
}
#ifdef MG_ENTERPRISE
@ -108,7 +109,13 @@ std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGra
#endif
bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) {
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::string &db_name) { // NOLINT
#ifdef MG_ENTERPRISE
if (!db_name.empty() && !user.db_access().Contains(db_name)) {
return false;
}
#endif
const auto user_permissions = user.GetPermissions();
return std::all_of(privileges.begin(), privileges.end(), [&user_permissions](const auto privilege) {
return user_permissions.Has(memgraph::glue::PrivilegeToPermission(privilege)) ==

View File

@ -25,7 +25,8 @@ class AuthChecker : public query::AuthChecker {
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth);
bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges) const override;
const std::vector<query::AuthQuery::Privilege> &privileges,
const std::string &db_name) const override;
#ifdef MG_ENTERPRISE
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker(
@ -33,7 +34,8 @@ class AuthChecker : public query::AuthChecker {
#endif
[[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges);
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::string &db_name = "");
private:
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;

View File

@ -16,6 +16,7 @@
#include <fmt/format.h>
#include "auth/models.hpp"
#include "dbms/constants.hpp"
#include "glue/auth.hpp"
#include "license/license.hpp"
#include "query/constants.hpp"
@ -122,6 +123,29 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowRolePrivileges(
}
#ifdef MG_ENTERPRISE
std::vector<std::vector<memgraph::query::TypedValue>> ShowDatabasePrivileges(
const std::optional<memgraph::auth::User> &user) {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !user) {
return {};
}
const auto &db = user->db_access();
const auto &allows = db.GetAllowAll();
const auto &grants = db.GetGrants();
const auto &denies = db.GetDenies();
std::vector<memgraph::query::TypedValue> res; // First element is a list of granted databases, second of revoked ones
if (allows) {
res.emplace_back("*");
} else {
std::vector<memgraph::query::TypedValue> grants_vec(grants.cbegin(), grants.cend());
res.emplace_back(std::move(grants_vec));
}
std::vector<memgraph::query::TypedValue> denies_vec(denies.cbegin(), denies.cend());
res.emplace_back(std::move(denies_vec));
return {res};
}
std::vector<FineGrainedPermissionForPrivilegeResult> GetFineGrainedPermissionForPrivilegeForUserOrRole(
const memgraph::auth::FineGrainedAccessPermissions &permissions, const std::string &permission_type,
const std::string &user_or_role) {
@ -268,6 +292,10 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option
}
#endif
);
#ifdef MG_ENTERPRISE
GrantDatabaseToUser(auth::kAllDatabases, username);
SetMainDatabase(username, dbms::kDefaultDB);
#endif
}
return user_added;
@ -319,6 +347,67 @@ bool AuthQueryHandler::CreateRole(const std::string &rolename) {
}
}
#ifdef MG_ENTERPRISE
bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db, const std::string &username) {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->RevokeDatabaseFromUser(db, username);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db, const std::string &username) {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->GrantDatabaseToUser(db, username);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
std::vector<std::vector<memgraph::query::TypedValue>> AuthQueryHandler::GetDatabasePrivileges(
const std::string &username) {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user or role name.");
}
try {
auto locked_auth = auth_->ReadLock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username);
}
return ShowDatabasePrivileges(user);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool AuthQueryHandler::SetMainDatabase(const std::string &db, const std::string &username) {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->SetMainDatabase(db, username);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
#endif
bool AuthQueryHandler::DropRole(const std::string &rolename) {
if (!std::regex_match(rolename, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid role name.");

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -38,6 +38,16 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
void SetPassword(const std::string &username, const std::optional<std::string> &password) override;
#ifdef MG_ENTERPRISE
bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) override;
bool GrantDatabaseToUser(const std::string &db, const std::string &username) override;
std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) override;
bool SetMainDatabase(const std::string &db, const std::string &username) override;
#endif
bool CreateRole(const std::string &rolename) override;
bool DropRole(const std::string &rolename) override;

View File

@ -47,10 +47,10 @@ struct MetricsResponse {
std::vector<std::tuple<std::string, std::string, uint64_t>> event_histograms{};
};
template <typename TSessionData>
template <typename TSessionContext>
class MetricsService {
public:
explicit MetricsService(TSessionData *data) : db_(data->interpreter_context->db.get()) {}
explicit MetricsService(TSessionContext *session_context) : db_(session_context->interpreter_context->db.get()) {}
nlohmann::json GetMetricsJSON() {
auto response = GetMetrics();
@ -141,10 +141,10 @@ class MetricsService {
}
};
template <typename TSessionData>
template <typename TSessionContext>
class MetricsRequestHandler final {
public:
explicit MetricsRequestHandler(TSessionData *data) : service_(data) {
explicit MetricsRequestHandler(TSessionContext *session_context) : service_(session_context) {
spdlog::info("Basic request handler started!");
}
@ -206,6 +206,6 @@ class MetricsRequestHandler final {
}
private:
MetricsService<TSessionData> service_;
MetricsService<TSessionContext> service_;
};
} // namespace memgraph::http

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -36,7 +36,13 @@ KVStore::KVStore(std::filesystem::path storage) : pimpl_(std::make_unique<impl>(
pimpl_->db.reset(db);
}
KVStore::~KVStore() {}
KVStore::~KVStore() {
spdlog::debug("Destroying KVStore at {}", pimpl_->storage.string());
const auto sync = pimpl_->db->SyncWAL();
if (!sync.ok()) spdlog::error("KVStore sync failed!");
const auto close = pimpl_->db->Close();
if (!close.ok()) spdlog::error("KVStore close failed!");
}
KVStore::KVStore(KVStore &&other) { pimpl_ = std::move(other.pimpl_); }

View File

@ -40,6 +40,9 @@
#include "communication/http/server.hpp"
#include "communication/websocket/auth.hpp"
#include "communication/websocket/server.hpp"
#include "dbms/constants.hpp"
#include "dbms/global.hpp"
#include "dbms/session_context.hpp"
#include "glue/auth_checker.hpp"
#include "glue/auth_handler.hpp"
#include "helpers.hpp"
@ -98,6 +101,7 @@
#include "communication/init.hpp"
#include "communication/v2/server.hpp"
#include "communication/v2/session.hpp"
#include "dbms/session_context_handler.hpp"
#include "glue/communication.hpp"
#include "auth/auth.hpp"
@ -157,6 +161,10 @@ DEFINE_string(init_data_file, "", "Path to cypherl file that is used for creatin
// `mg_import_csv`. If you change it, make sure to change it there as well.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_string(data_directory, "mg_data", "Path to directory in which to save all permanent data.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(data_recovery_on_startup, false, "Controls whether the database recovers persisted data on startup.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_uint64(memory_warning_threshold, 1024,
"Memory warning threshold, in MB. If Memgraph detects there is "
@ -174,8 +182,11 @@ DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30, "Storage garbage collector int
// `mg_import_csv`. If you change it, make sure to change it there as well.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(storage_properties_on_edges, false, "Controls whether edges have properties.");
// storage_recover_on_startup deprecated; use data_recovery_on_startup instead
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(storage_recover_on_startup, false, "Controls whether the storage recovers persisted data on startup.");
DEFINE_HIDDEN_bool(storage_recover_on_startup, false,
"Controls whether the storage recovers persisted data on startup.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_VALIDATED_uint64(storage_snapshot_interval_sec, 0,
"Storage snapshot creation interval (in seconds). Set "
@ -215,6 +226,12 @@ DEFINE_uint64(storage_recovery_thread_count,
memgraph::storage::Config::Durability().recovery_thread_count),
"The number of threads used to recover persisted data from disk.");
#ifdef MG_ENTERPRISE
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(storage_delete_on_drop, true,
"If set to true the query 'DROP DATABASE x' will delete the underlying storage as well.");
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(telemetry_enabled, false,
"Set to true to enable telemetry. We collect information about the "
@ -447,35 +464,6 @@ void AddLoggerSink(spdlog::sink_ptr new_sink) {
DEFINE_HIDDEN_string(license_key, "", "License key for Memgraph Enterprise.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_HIDDEN_string(organization_name, "", "Organization name.");
/// Encapsulates Dbms and Interpreter that are passed through the network server
/// and worker to the session.
struct SessionData {
// Explicit constructor here to ensure that pointers to all objects are
// supplied.
#if MG_ENTERPRISE
SessionData(memgraph::query::InterpreterContext *interpreter_context,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
memgraph::audit::Log *audit_log)
: interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {}
memgraph::query::InterpreterContext *interpreter_context;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
memgraph::audit::Log *audit_log;
#else
SessionData(memgraph::query::InterpreterContext *interpreter_context,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth)
: interpreter_context(interpreter_context), auth(auth) {}
memgraph::query::InterpreterContext *interpreter_context;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
#endif
// NOTE: run_id should be const but that complicates code a lot.
std::optional<std::string> run_id;
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_string(auth_user_or_role_name_regex, memgraph::glue::kDefaultUserRoleRegex.data(),
"Set to the regular expression that each user or role name must fulfill.");
@ -498,7 +486,7 @@ void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, std::string c
interpreter.Pull(&stream, {}, results.qid);
if (audit_log) {
audit_log->Record("", "", line, {});
audit_log->Record("", "", line, {}, memgraph::dbms::kDefaultDB);
}
}
}
@ -529,41 +517,146 @@ auto ToQueryExtras(memgraph::communication::bolt::Value const &extra) -> memgrap
return memgraph::query::QueryExtras{std::move(metadata_pv), tx_timeout};
}
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream> {
class SessionHL final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream> {
public:
BoltSession(SessionData *data, const memgraph::communication::v2::ServerEndpoint &endpoint,
memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream)
struct ContextWrapper {
explicit ContextWrapper(memgraph::dbms::SessionContext sc)
: session_context(sc),
interpreter(std::make_unique<memgraph::query::Interpreter>(session_context.interpreter_context.get())),
defunct_(false) {
session_context.interpreter_context->interpreters.WithLock(
[this](auto &interpreters) { interpreters.insert(interpreter.get()); });
}
~ContextWrapper() { Defunct(); }
void Defunct() {
if (!defunct_) {
session_context.interpreter_context->interpreters.WithLock(
[this](auto &interpreters) { interpreters.erase(interpreter.get()); });
defunct_ = true;
}
}
ContextWrapper(const ContextWrapper &) = delete;
ContextWrapper &operator=(const ContextWrapper &) = delete;
ContextWrapper(ContextWrapper &&in) noexcept
: session_context(std::move(in.session_context)),
interpreter(std::move(in.interpreter)),
defunct_(in.defunct_) {
in.defunct_ = true;
}
ContextWrapper &operator=(ContextWrapper &&in) noexcept {
if (this != &in) {
Defunct();
session_context = std::move(in.session_context);
interpreter = std::move(in.interpreter);
defunct_ = in.defunct_;
in.defunct_ = true;
}
return *this;
}
memgraph::query::InterpreterContext *interpreter_context() { return session_context.interpreter_context.get(); }
memgraph::query::Interpreter *interp() { return interpreter.get(); }
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth() const {
return session_context.auth;
}
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log() const { return session_context.audit_log; }
#endif
std::string run_id() const { return session_context.run_id; }
bool defunct() const { return defunct_; }
private:
memgraph::dbms::SessionContext session_context;
std::unique_ptr<memgraph::query::Interpreter> interpreter;
bool defunct_;
};
SessionHL(
#ifdef MG_ENTERPRISE
memgraph::dbms::SessionContextHandler &sc_handler,
#else
memgraph::dbms::SessionContext sc,
#endif
const memgraph::communication::v2::ServerEndpoint &endpoint,
memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::OutputStream *output_stream,
const std::string &default_db = memgraph::dbms::kDefaultDB) // NOLINT
: memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream>(input_stream, output_stream),
interpreter_context_(data->interpreter_context),
interpreter_(data->interpreter_context),
auth_(data->auth),
#if MG_ENTERPRISE
audit_log_(data->audit_log),
#ifdef MG_ENTERPRISE
sc_handler_(sc_handler),
current_(sc_handler_.Get(default_db)),
#else
current_(sc),
#endif
interpreter_context_(current_.interpreter_context()),
interpreter_(current_.interp()),
auth_(current_.auth()),
#ifdef MG_ENTERPRISE
audit_log_(current_.audit_log()),
#endif
endpoint_(endpoint),
run_id_(data->run_id) {
run_id_(current_.run_id()) {
memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions);
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); });
}
~BoltSession() override {
memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions);
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); });
~SessionHL() override { memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions); }
SessionHL(const SessionHL &) = delete;
SessionHL &operator=(const SessionHL &) = delete;
SessionHL(SessionHL &&) = delete;
SessionHL &operator=(SessionHL &&) = delete;
void Configure(const std::map<std::string, memgraph::communication::bolt::Value> &run_time_info) override {
#ifdef MG_ENTERPRISE
std::string db;
bool update = false;
// Check if user explicitly defined the database to use
if (run_time_info.contains("db")) {
const auto &db_info = run_time_info.at("db");
if (!db_info.IsString()) {
throw memgraph::communication::bolt::ClientError("Malformed database name.");
}
db = db_info.ValueString();
update = db != current_.interpreter_context()->db->id();
in_explicit_db_ = true;
// NOTE: Once in a transaction, the drivers stop explicitly sending the db and count on using it until commit
} else if (in_explicit_db_ && !interpreter_->in_explicit_transaction_) { // Just on a switch
db = GetDefaultDB();
update = db != current_.interpreter_context()->db->id();
in_explicit_db_ = false;
}
// Check if the underlying database needs to be updated
if (update) {
sc_handler_.SetInPlace(db, [this](auto new_sc) mutable {
const auto &db_name = new_sc.interpreter_context->db->id();
MultiDatabaseAuth(db_name);
try {
Update(ContextWrapper(new_sc));
return memgraph::dbms::SetForResult::SUCCESS;
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw memgraph::communication::bolt::ClientError("No database named \"{}\" found!", db_name);
}
});
}
#endif
}
using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream>::TEncoder;
using TEncoder = memgraph::communication::bolt::Encoder<
memgraph::communication::bolt::ChunkedEncoderBuffer<memgraph::communication::v2::OutputStream>>;
void BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &extra) override {
interpreter_.BeginTransaction(ToQueryExtras(extra));
interpreter_->BeginTransaction(ToQueryExtras(extra));
}
void CommitTransaction() override { interpreter_.CommitTransaction(); }
void CommitTransaction() override { interpreter_->CommitTransaction(); }
void RollbackTransaction() override { interpreter_.RollbackTransaction(); }
void RollbackTransaction() override { interpreter_->RollbackTransaction(); }
std::pair<std::vector<std::string>, std::optional<int>> Interpret(
const std::string &query, const std::map<std::string, memgraph::communication::bolt::Value> &params,
@ -580,16 +673,22 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query,
memgraph::storage::PropertyValue(params_pv));
memgraph::storage::PropertyValue(params_pv), interpreter_context_->db->id());
}
#endif
try {
auto result = interpreter_.Prepare(query, params_pv, username, ToQueryExtras(extra));
if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, result.privileges)) {
interpreter_.Abort();
auto result = interpreter_->Prepare(query, params_pv, username, ToQueryExtras(extra), UUID());
const std::string db_name = result.db ? *result.db : "";
if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, result.privileges, db_name)) {
interpreter_->Abort();
if (db_name.empty()) {
throw memgraph::communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact your database administrator.");
}
throw memgraph::communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact "
"your database administrator.");
"You are not authorized to execute this query on database \"{}\"! Please contact your database "
"administrator.",
db_name);
}
return {result.headers, result.qid};
@ -604,7 +703,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
std::map<std::string, memgraph::communication::bolt::Value> Pull(TEncoder *encoder, std::optional<int> n,
std::optional<int> qid) override {
TypedValueResultStream stream(encoder, interpreter_context_->db.get());
TypedValueResultStream stream(encoder, interpreter_context_);
return PullResults(stream, n, qid);
}
@ -614,14 +713,26 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
return PullResults(stream, n, qid);
}
void Abort() override { interpreter_.Abort(); }
void Abort() override { interpreter_->Abort(); }
// Called during Init
// During Init, the user cannot choose the landing DB (switch is done during query execution)
bool Authenticate(const std::string &username, const std::string &password) override {
auto locked_auth = auth_->Lock();
if (!locked_auth->HasUsers()) {
return true;
}
user_ = locked_auth->Authenticate(username, password);
#ifdef MG_ENTERPRISE
if (user_.has_value()) {
const auto &db = user_->db_access().GetDefault();
// Check if the underlying database needs to be updated
if (db != current_.interpreter_context()->db->id()) {
const auto &res = sc_handler_.SetFor(UUID(), db);
return res == memgraph::dbms::SetForResult::SUCCESS || res == memgraph::dbms::SetForResult::ALREADY_SET;
}
}
#endif
return user_.has_value();
}
@ -630,12 +741,31 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
return FLAGS_bolt_server_name_for_init;
}
#ifdef MG_ENTERPRISE
memgraph::dbms::SetForResult OnChange(const std::string &db_name) override {
MultiDatabaseAuth(db_name);
if (db_name != current_.interpreter_context()->db->id()) {
UpdateAndDefunct(db_name); // Done during Pull, so we cannot just replace the current db
return memgraph::dbms::SetForResult::SUCCESS;
}
return memgraph::dbms::SetForResult::ALREADY_SET;
}
bool OnDelete(const std::string &db_name) override {
MG_ASSERT(current_.interpreter_context()->db->id() != db_name && (!defunct_ || defunct_->defunct()),
"Trying to delete a database while still in use.");
return true;
}
#endif
std::string GetDatabaseName() const override { return interpreter_context_->db->id(); }
private:
template <typename TStream>
std::map<std::string, memgraph::communication::bolt::Value> PullResults(TStream &stream, std::optional<int> n,
std::optional<int> qid) {
try {
const auto &summary = interpreter_.Pull(&stream, n, qid);
const auto &summary = interpreter_->Pull(&stream, n, qid);
std::map<std::string, memgraph::communication::bolt::Value> decoded_summary;
for (const auto &kv : summary) {
auto maybe_value =
@ -660,6 +790,11 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
decoded_summary.emplace("run_id", *run_id);
}
// Clean up previous session (session gets defunct when switching between databases)
if (defunct_) {
defunct_.reset();
}
return decoded_summary;
} catch (const memgraph::query::QueryException &e) {
// Wrap QueryException into ClientError, because we want to allow the
@ -668,17 +803,71 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
}
}
#ifdef MG_ENTERPRISE
/**
* @brief Update setup to the new database.
*
* @param db_name name of the target database
* @throws UnknownDatabaseException if handler cannot get it
*/
void UpdateAndDefunct(const std::string &db_name) { UpdateAndDefunct(ContextWrapper(sc_handler_.Get(db_name))); }
void UpdateAndDefunct(ContextWrapper &&cntxt) {
defunct_.emplace(std::move(current_));
Update(std::forward<ContextWrapper>(cntxt));
defunct_->Defunct();
}
void Update(const std::string &db_name) {
ContextWrapper tmp(sc_handler_.Get(db_name));
Update(std::move(tmp));
}
void Update(ContextWrapper &&cntxt) {
current_ = std::move(cntxt);
interpreter_ = current_.interp();
interpreter_->in_explicit_db_ = in_explicit_db_;
interpreter_context_ = current_.interpreter_context();
}
/**
* @brief Authenticate user on passed database.
*
* @param db database to check against
* @throws bolt::ClientError when user is not authorized
*/
void MultiDatabaseAuth(const std::string &db) {
if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, {}, db)) {
throw memgraph::communication::bolt::ClientError(
"You are not authorized on the database \"{}\"! Please contact your database administrator.", db);
}
}
/**
* @brief Get the user's default database
*
* @return std::string
*/
std::string GetDefaultDB() {
if (user_.has_value()) {
return user_->db_access().GetDefault();
}
return memgraph::dbms::kDefaultDB;
}
#endif
/// Wrapper around TEncoder which converts TypedValue to Value
/// before forwarding the calls to original TEncoder.
class TypedValueResultStream {
public:
TypedValueResultStream(TEncoder *encoder, const memgraph::storage::Storage *db) : encoder_(encoder), db_(db) {}
TypedValueResultStream(TEncoder *encoder, memgraph::query::InterpreterContext *ic)
: encoder_(encoder), interpreter_context_(ic) {}
void Result(const std::vector<memgraph::query::TypedValue> &values) {
std::vector<memgraph::communication::bolt::Value> decoded_values;
decoded_values.reserve(values.size());
for (const auto &v : values) {
auto maybe_value = memgraph::glue::ToBoltValue(v, *db_, memgraph::storage::View::NEW);
auto maybe_value = memgraph::glue::ToBoltValue(v, *interpreter_context_->db, memgraph::storage::View::NEW);
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case memgraph::storage::Error::DELETED_OBJECT:
@ -699,25 +888,36 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
private:
TEncoder *encoder_;
// NOTE: Needed only for ToBoltValue conversions
const memgraph::storage::Storage *db_;
memgraph::query::InterpreterContext *interpreter_context_;
};
// NOTE: Needed only for ToBoltValue conversions
#ifdef MG_ENTERPRISE
memgraph::dbms::SessionContextHandler &sc_handler_;
#endif
ContextWrapper current_;
std::optional<ContextWrapper> defunct_;
memgraph::query::InterpreterContext *interpreter_context_;
memgraph::query::Interpreter interpreter_;
memgraph::query::Interpreter *interpreter_;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
std::optional<memgraph::auth::User> user_;
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log_;
bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata
#endif
memgraph::communication::v2::ServerEndpoint endpoint_;
// NOTE: run_id should be const but that complicates code a lot.
std::optional<std::string> run_id_;
};
using ServerT = memgraph::communication::v2::Server<BoltSession, SessionData>;
#ifdef MG_ENTERPRISE
using ServerT = memgraph::communication::v2::Server<SessionHL, memgraph::dbms::SessionContextHandler>;
#else
using ServerT = memgraph::communication::v2::Server<SessionHL, memgraph::dbms::SessionContext>;
#endif
using MonitoringServerT =
memgraph::communication::http::Server<memgraph::http::MetricsRequestHandler<SessionData>, SessionData>;
memgraph::communication::http::Server<memgraph::http::MetricsRequestHandler<memgraph::dbms::SessionContext>,
memgraph::dbms::SessionContext>;
using memgraph::communication::ServerContext;
// Needed to correctly handle memgraph destruction from a signal handler.
@ -880,10 +1080,6 @@ int main(int argc, char **argv) {
// Begin enterprise features initialization
// Auth
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth{data_directory /
"auth"};
#ifdef MG_ENTERPRISE
// Audit log
memgraph::audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size,
@ -907,7 +1103,7 @@ int main(int argc, char **argv) {
.interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)},
.items = {.properties_on_edges = FLAGS_storage_properties_on_edges},
.durability = {.storage_directory = FLAGS_data_directory,
.recover_on_startup = FLAGS_storage_recover_on_startup,
.recover_on_startup = FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup,
.snapshot_retention_count = FLAGS_storage_snapshot_retention_count,
.wal_file_size_kibibytes = FLAGS_storage_wal_file_size_kib,
.wal_file_flush_every_n_tx = FLAGS_storage_wal_file_flush_every_n_tx,
@ -944,31 +1140,62 @@ int main(int argc, char **argv) {
db_config.durability.snapshot_interval = std::chrono::seconds(FLAGS_storage_snapshot_interval_sec);
}
memgraph::query::InterpreterContext interpreter_context{
db_config,
{.query = {.allow_load_csv = FLAGS_allow_load_csv},
.execution_timeout_sec = FLAGS_query_execution_timeout_sec,
.replication_replica_check_frequency = std::chrono::seconds(FLAGS_replication_replica_check_frequency_sec),
.default_kafka_bootstrap_servers = FLAGS_kafka_bootstrap_servers,
.default_pulsar_service_url = FLAGS_pulsar_service_url,
.stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries,
.stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)},
FLAGS_data_directory};
// Default interpreter configuration
memgraph::query::InterpreterConfig interp_config{
.query = {.allow_load_csv = FLAGS_allow_load_csv},
.execution_timeout_sec = FLAGS_query_execution_timeout_sec,
.replication_replica_check_frequency = std::chrono::seconds(FLAGS_replication_replica_check_frequency_sec),
.default_kafka_bootstrap_servers = FLAGS_kafka_bootstrap_servers,
.default_pulsar_service_url = FLAGS_pulsar_service_url,
.stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries,
.stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)};
auto auth_glue =
[flag = FLAGS_auth_user_or_role_name_regex](
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
std::unique_ptr<memgraph::query::AuthQueryHandler> &ah, std::unique_ptr<memgraph::query::AuthChecker> &ac) {
// Glue high level auth implementations to the query side
ah = std::make_unique<memgraph::glue::AuthQueryHandler>(auth, flag);
ac = std::make_unique<memgraph::glue::AuthChecker>(auth);
// Handle users passed via arguments
auto *maybe_username = std::getenv(kMgUser);
auto *maybe_password = std::getenv(kMgPassword);
auto *maybe_pass_file = std::getenv(kMgPassfile);
if (maybe_username && maybe_password) {
ah->CreateUser(maybe_username, maybe_password);
} else if (maybe_pass_file) {
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
if (!username.empty() && !password.empty()) {
ah->CreateUser(username, password);
}
}
};
#ifdef MG_ENTERPRISE
SessionData session_data{&interpreter_context, &auth, &audit_log};
// SessionContext handler (multi-tenancy)
memgraph::dbms::SessionContextHandler sc_handler(audit_log, {db_config, interp_config, auth_glue},
FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup,
FLAGS_storage_delete_on_drop);
// Just for current support... TODO remove
auto session_context = sc_handler.Get(memgraph::dbms::kDefaultDB);
#else
SessionData session_data{&interpreter_context, &auth};
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth_{data_directory /
"auth"};
std::unique_ptr<memgraph::query::AuthQueryHandler> auth_handler;
std::unique_ptr<memgraph::query::AuthChecker> auth_checker;
auth_glue(&auth_, auth_handler, auth_checker);
auto session_context = memgraph::dbms::Init(db_config, interp_config, &auth_, auth_handler.get(), auth_checker.get());
#endif
auto *auth = session_context.auth;
auto &interpreter_context = *session_context.interpreter_context; // TODO remove
memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories, FLAGS_data_directory);
memgraph::query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories();
memgraph::query::procedure::gCallableAliasMapper.LoadMapping(FLAGS_query_callable_mappings_path);
memgraph::glue::AuthQueryHandler auth_handler(&auth, FLAGS_auth_user_or_role_name_regex);
memgraph::glue::AuthChecker auth_checker{&auth};
interpreter_context.auth = &auth_handler;
interpreter_context.auth_checker = &auth_checker;
if (!FLAGS_init_file.empty()) {
spdlog::info("Running init file...");
#ifdef MG_ENTERPRISE
@ -982,18 +1209,10 @@ int main(int argc, char **argv) {
#endif
}
auto *maybe_username = std::getenv(kMgUser);
auto *maybe_password = std::getenv(kMgPassword);
auto *maybe_pass_file = std::getenv(kMgPassfile);
if (maybe_username && maybe_password) {
auth_handler.CreateUser(maybe_username, maybe_password);
} else if (maybe_pass_file) {
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
if (!username.empty() && !password.empty()) {
auth_handler.CreateUser(username, password);
}
}
#ifdef MG_ENTERPRISE
sc_handler.RestoreTriggers();
sc_handler.RestoreStreams();
#else
{
// Triggers can execute query procedures, so we need to reload the modules first and then
// the triggers
@ -1005,6 +1224,7 @@ int main(int argc, char **argv) {
// As the Stream transformations are using modules, they have to be restored after the query modules are loaded.
interpreter_context.streams.RestoreStreams();
#endif
ServerContext context;
std::string service_name = "Bolt";
@ -1016,25 +1236,35 @@ int main(int argc, char **argv) {
spdlog::warn(
memgraph::utils::MessageWithLink("Using non-secure Bolt connection (without SSL).", "https://memgr.ph/ssl"));
}
auto server_endpoint = memgraph::communication::v2::ServerEndpoint{
boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast<uint16_t>(FLAGS_bolt_port)};
ServerT server(server_endpoint, &session_data, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
#ifdef MG_ENTERPRISE
ServerT server(server_endpoint, &sc_handler, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
FLAGS_bolt_num_workers);
#else
ServerT server(server_endpoint, &session_context, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
FLAGS_bolt_num_workers);
#endif
const auto run_id = memgraph::utils::GenerateUUID();
const auto machine_id = memgraph::utils::GetMachineId();
session_data.run_id = run_id;
const auto run_id = session_context.run_id; // For current compatibility
// Setup telemetry
static constexpr auto telemetry_server{"https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/"};
std::optional<memgraph::telemetry::Telemetry> telemetry;
if (FLAGS_telemetry_enabled) {
telemetry.emplace(telemetry_server, data_directory / "telemetry", run_id, machine_id, std::chrono::minutes(10));
telemetry->AddCollector("storage", [db_ = interpreter_context.db.get()]() -> nlohmann::json {
auto info = db_->GetInfo();
#ifdef MG_ENTERPRISE
telemetry->AddCollector("storage", [&sc_handler]() -> nlohmann::json {
const auto &info = sc_handler.Info();
return {{"vertices", info.num_vertex}, {"edges", info.num_edges}, {"databases", info.num_databases}};
});
#else
telemetry->AddCollector("storage", [&interpreter_context]() -> nlohmann::json {
auto info = interpreter_context.db->GetInfo();
return {{"vertices", info.vertex_count}, {"edges", info.edge_count}};
});
#endif
telemetry->AddCollector("event_counters", []() -> nlohmann::json {
nlohmann::json ret;
for (size_t i = 0; i < memgraph::metrics::CounterEnd(); ++i) {
@ -1050,25 +1280,25 @@ int main(int argc, char **argv) {
memgraph::license::LicenseInfoSender license_info_sender(telemetry_server, run_id, machine_id, memory_limit,
memgraph::license::global_license_checker.GetLicenseInfo());
memgraph::communication::websocket::SafeAuth websocket_auth{&auth};
memgraph::communication::websocket::SafeAuth websocket_auth{auth};
memgraph::communication::websocket::Server websocket_server{
{FLAGS_monitoring_address, static_cast<uint16_t>(FLAGS_monitoring_port)}, &context, websocket_auth};
AddLoggerSink(websocket_server.GetLoggingSink());
MonitoringServerT metrics_server{
{FLAGS_metrics_address, static_cast<uint16_t>(FLAGS_metrics_port)}, &session_data, &context};
{FLAGS_metrics_address, static_cast<uint16_t>(FLAGS_metrics_port)}, &session_context, &context};
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
// Handler for regular termination signals
auto shutdown = [&metrics_server, &websocket_server, &server, &interpreter_context] {
auto shutdown = [&metrics_server, &websocket_server, &server, &sc_handler] {
// Server needs to be shutdown first and then the database. This prevents
// a race condition when a transaction is accepted during server shutdown.
server.Shutdown();
// After the server is notified to stop accepting and processing
// connections we tell the execution engine to stop processing all pending
// queries.
memgraph::query::Shutdown(&interpreter_context);
sc_handler.Shutdown();
websocket_server.Shutdown();
metrics_server.Shutdown();

View File

@ -24,7 +24,8 @@ class AuthChecker {
virtual ~AuthChecker() = default;
[[nodiscard]] virtual bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges) const = 0;
const std::vector<query::AuthQuery::Privilege> &privileges,
const std::string &db_name) const = 0;
#ifdef MG_ENTERPRISE
[[nodiscard]] virtual std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker(
@ -92,7 +93,8 @@ class AllowEverythingFineGrainedAuthChecker final : public query::FineGrainedAut
class AllowEverythingAuthChecker final : public query::AuthChecker {
public:
bool IsUserAuthorized(const std::optional<std::string> & /*username*/,
const std::vector<query::AuthQuery::Privilege> & /*privileges*/) const override {
const std::vector<query::AuthQuery::Privilege> & /*privileges*/,
const std::string & /*db*/) const override {
return true;
}

View File

@ -495,6 +495,8 @@ class DbAccessor final {
storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); }
storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); }
const std::string &id() const { return accessor_->id(); }
};
class SubgraphDbAccessor final {

View File

@ -324,4 +324,10 @@ class ConstraintsPersistenceException : public QueryException {
ConstraintsPersistenceException() : QueryException("Persisting constraints on disk failed.") {}
};
class MultiDatabaseQueryInMulticommandTxException : public QueryException {
public:
MultiDatabaseQueryInMulticommandTxException()
: QueryException("Multi-database queries are not allowed in multicommand transactions.") {}
};
} // namespace memgraph::query

View File

@ -279,4 +279,10 @@ constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exist
constexpr utils::TypeInfo query::CallSubquery::kType{utils::TypeId::AST_CALL_SUBQUERY, "CallSubquery",
&query::Clause::kType};
constexpr utils::TypeInfo query::MultiDatabaseQuery::kType{utils::TypeId::AST_MULTI_DATABASE_QUERY,
"MultiDatabaseQuery", &query::Query::kType};
constexpr utils::TypeInfo query::ShowDatabasesQuery::kType{utils::TypeId::AST_SHOW_DATABASES, "ShowDatabasesQuery",
&query::Query::kType};
} // namespace memgraph

View File

@ -2778,7 +2778,11 @@ class AuthQuery : public memgraph::query::Query {
REVOKE_PRIVILEGE,
SHOW_PRIVILEGES,
SHOW_ROLE_FOR_USER,
SHOW_USERS_FOR_ROLE
SHOW_USERS_FOR_ROLE,
GRANT_DATABASE_TO_USER,
REVOKE_DATABASE_FROM_USER,
SHOW_DATABASE_PRIVILEGES,
SET_MAIN_DATABASE,
};
enum class Privilege {
@ -2804,7 +2808,9 @@ class AuthQuery : public memgraph::query::Query {
MODULE_WRITE,
WEBSOCKET,
STORAGE_MODE,
TRANSACTION_MANAGEMENT
TRANSACTION_MANAGEMENT,
MULTI_DATABASE_EDIT,
MULTI_DATABASE_USE,
};
enum class FineGrainedPrivilege { NOTHING, READ, UPDATE, CREATE_DELETE };
@ -2818,6 +2824,7 @@ class AuthQuery : public memgraph::query::Query {
std::string role_;
std::string user_or_role_;
memgraph::query::Expression *password_{nullptr};
std::string database_;
std::vector<memgraph::query::AuthQuery::Privilege> privileges_;
std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
label_privileges_;
@ -2831,6 +2838,7 @@ class AuthQuery : public memgraph::query::Query {
object->role_ = role_;
object->user_or_role_ = user_or_role_;
object->password_ = password_ ? password_->Clone(storage) : nullptr;
object->database_ = database_;
object->privileges_ = privileges_;
object->label_privileges_ = label_privileges_;
object->edge_type_privileges_ = edge_type_privileges_;
@ -2839,7 +2847,7 @@ class AuthQuery : public memgraph::query::Query {
protected:
AuthQuery(Action action, std::string user, std::string role, std::string user_or_role, Expression *password,
std::vector<Privilege> privileges,
std::string database, std::vector<Privilege> privileges,
std::vector<std::unordered_map<FineGrainedPrivilege, std::vector<std::string>>> label_privileges,
std::vector<std::unordered_map<FineGrainedPrivilege, std::vector<std::string>>> edge_type_privileges)
: action_(action),
@ -2847,6 +2855,7 @@ class AuthQuery : public memgraph::query::Query {
role_(role),
user_or_role_(user_or_role),
password_(password),
database_(database),
privileges_(privileges),
label_privileges_(label_privileges),
edge_type_privileges_(edge_type_privileges) {}
@ -2856,19 +2865,31 @@ class AuthQuery : public memgraph::query::Query {
};
/// Constant that holds all available privileges.
const std::vector<AuthQuery::Privilege> kPrivilegesAll = {
AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE,
AuthQuery::Privilege::MATCH, AuthQuery::Privilege::MERGE,
AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE,
AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS,
AuthQuery::Privilege::AUTH, AuthQuery::Privilege::CONSTRAINT,
AuthQuery::Privilege::DUMP, AuthQuery::Privilege::REPLICATION,
AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY,
AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER,
AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM,
AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE,
AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::TRANSACTION_MANAGEMENT,
AuthQuery::Privilege::STORAGE_MODE};
const std::vector<AuthQuery::Privilege> kPrivilegesAll = {AuthQuery::Privilege::CREATE,
AuthQuery::Privilege::DELETE,
AuthQuery::Privilege::MATCH,
AuthQuery::Privilege::MERGE,
AuthQuery::Privilege::SET,
AuthQuery::Privilege::REMOVE,
AuthQuery::Privilege::INDEX,
AuthQuery::Privilege::STATS,
AuthQuery::Privilege::AUTH,
AuthQuery::Privilege::CONSTRAINT,
AuthQuery::Privilege::DUMP,
AuthQuery::Privilege::REPLICATION,
AuthQuery::Privilege::READ_FILE,
AuthQuery::Privilege::DURABILITY,
AuthQuery::Privilege::FREE_MEMORY,
AuthQuery::Privilege::TRIGGER,
AuthQuery::Privilege::CONFIG,
AuthQuery::Privilege::STREAM,
AuthQuery::Privilege::MODULE_READ,
AuthQuery::Privilege::MODULE_WRITE,
AuthQuery::Privilege::WEBSOCKET,
AuthQuery::Privilege::TRANSACTION_MANAGEMENT,
AuthQuery::Privilege::STORAGE_MODE,
AuthQuery::Privilege::MULTI_DATABASE_EDIT,
AuthQuery::Privilege::MULTI_DATABASE_USE};
class InfoQuery : public memgraph::query::Query {
public:
@ -3446,5 +3467,38 @@ class CallSubquery : public memgraph::query::Clause {
friend class AstStorage;
};
class MultiDatabaseQuery : public memgraph::query::Query {
public:
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
DEFVISITABLE(QueryVisitor<void>);
enum class Action { CREATE, USE, DROP };
memgraph::query::MultiDatabaseQuery::Action action_;
std::string db_name_;
MultiDatabaseQuery *Clone(AstStorage *storage) const override {
auto *object = storage->Create<MultiDatabaseQuery>();
object->action_ = action_;
object->db_name_ = db_name_;
return object;
}
};
class ShowDatabasesQuery : public memgraph::query::Query {
public:
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
DEFVISITABLE(QueryVisitor<void>);
ShowDatabasesQuery *Clone(AstStorage *storage) const override {
auto *object = storage->Create<ShowDatabasesQuery>();
return object;
}
};
} // namespace query
} // namespace memgraph

View File

@ -11,6 +11,7 @@
#pragma once
#include "query/frontend/ast/ast.hpp"
#include "utils/visitor.hpp"
namespace memgraph::query {
@ -102,6 +103,8 @@ class CallSubquery;
class AnalyzeGraphQuery;
class TransactionQueueQuery;
class Exists;
class MultiDatabaseQuery;
class ShowDatabasesQuery;
using TreeCompositeVisitor = utils::CompositeVisitor<
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
@ -139,6 +142,7 @@ class QueryVisitor
: public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery,
ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, TriggerQuery,
IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery, VersionQuery,
ShowConfigQuery, TransactionQueueQuery, StorageModeQuery, AnalyzeGraphQuery> {};
ShowConfigQuery, TransactionQueueQuery, StorageModeQuery, AnalyzeGraphQuery,
MultiDatabaseQuery, ShowDatabasesQuery> {};
} // namespace memgraph::query

View File

@ -1272,7 +1272,7 @@ antlrcpp::Any CypherMainVisitor::visitUserOrRoleName(MemgraphCypher::UserOrRoleN
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitCreateRole(MemgraphCypher::CreateRoleContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::CREATE_ROLE;
auth->role_ = std::any_cast<std::string>(ctx->role->accept(this));
return auth;
@ -1282,7 +1282,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateRole(MemgraphCypher::CreateRoleConte
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitDropRole(MemgraphCypher::DropRoleContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::DROP_ROLE;
auth->role_ = std::any_cast<std::string>(ctx->role->accept(this));
return auth;
@ -1292,7 +1292,7 @@ antlrcpp::Any CypherMainVisitor::visitDropRole(MemgraphCypher::DropRoleContext *
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitShowRoles(MemgraphCypher::ShowRolesContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_ROLES;
return auth;
}
@ -1301,7 +1301,7 @@ antlrcpp::Any CypherMainVisitor::visitShowRoles(MemgraphCypher::ShowRolesContext
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitCreateUser(MemgraphCypher::CreateUserContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::CREATE_USER;
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
if (ctx->password) {
@ -1317,7 +1317,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateUser(MemgraphCypher::CreateUserConte
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitSetPassword(MemgraphCypher::SetPasswordContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SET_PASSWORD;
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
if (!ctx->password->StringLiteral() && !ctx->literal()->CYPHERNULL()) {
@ -1331,7 +1331,7 @@ antlrcpp::Any CypherMainVisitor::visitSetPassword(MemgraphCypher::SetPasswordCon
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitDropUser(MemgraphCypher::DropUserContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::DROP_USER;
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
return auth;
@ -1341,7 +1341,7 @@ antlrcpp::Any CypherMainVisitor::visitDropUser(MemgraphCypher::DropUserContext *
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitShowUsers(MemgraphCypher::ShowUsersContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_USERS;
return auth;
}
@ -1350,7 +1350,7 @@ antlrcpp::Any CypherMainVisitor::visitShowUsers(MemgraphCypher::ShowUsersContext
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitSetRole(MemgraphCypher::SetRoleContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SET_ROLE;
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
auth->role_ = std::any_cast<std::string>(ctx->role->accept(this));
@ -1361,7 +1361,7 @@ antlrcpp::Any CypherMainVisitor::visitSetRole(MemgraphCypher::SetRoleContext *ct
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitClearRole(MemgraphCypher::ClearRoleContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::CLEAR_ROLE;
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
return auth;
@ -1371,7 +1371,7 @@ antlrcpp::Any CypherMainVisitor::visitClearRole(MemgraphCypher::ClearRoleContext
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(MemgraphCypher::GrantPrivilegeContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::GRANT_PRIVILEGE;
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
if (ctx->grantPrivilegesList()) {
@ -1393,7 +1393,7 @@ antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(MemgraphCypher::GrantPrivil
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitDenyPrivilege(MemgraphCypher::DenyPrivilegeContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::DENY_PRIVILEGE;
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
if (ctx->privilegesList()) {
@ -1453,7 +1453,7 @@ antlrcpp::Any CypherMainVisitor::visitGrantPrivilegesList(MemgraphCypher::GrantP
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitRevokePrivilege(MemgraphCypher::RevokePrivilegeContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::REVOKE_PRIVILEGE;
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
if (ctx->revokePrivilegesList()) {
@ -1526,6 +1526,16 @@ antlrcpp::Any CypherMainVisitor::visitEntitiesList(MemgraphCypher::EntitiesListC
return entities;
}
/**
* @return std::string
*/
antlrcpp::Any CypherMainVisitor::visitWildcardName(MemgraphCypher::WildcardNameContext *ctx) {
if (ctx->symbolicName()) {
return ctx->symbolicName()->accept(this);
}
return std::string("*");
}
/**
* @return AuthQuery::Privilege
*/
@ -1553,6 +1563,8 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext
if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET;
if (ctx->TRANSACTION_MANAGEMENT()) return AuthQuery::Privilege::TRANSACTION_MANAGEMENT;
if (ctx->STORAGE_MODE()) return AuthQuery::Privilege::STORAGE_MODE;
if (ctx->MULTI_DATABASE_EDIT()) return AuthQuery::Privilege::MULTI_DATABASE_EDIT;
if (ctx->MULTI_DATABASE_USE()) return AuthQuery::Privilege::MULTI_DATABASE_USE;
LOG_FATAL("Should not get here - unknown privilege!");
}
@ -1580,7 +1592,7 @@ antlrcpp::Any CypherMainVisitor::visitEntityType(MemgraphCypher::EntityTypeConte
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_PRIVILEGES;
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
return auth;
@ -1590,7 +1602,7 @@ antlrcpp::Any CypherMainVisitor::visitShowPrivileges(MemgraphCypher::ShowPrivile
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(MemgraphCypher::ShowRoleForUserContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_ROLE_FOR_USER;
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
return auth;
@ -1600,12 +1612,55 @@ antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(MemgraphCypher::ShowRoleFo
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) {
AuthQuery *auth = storage_->Create<AuthQuery>();
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_USERS_FOR_ROLE;
auth->role_ = std::any_cast<std::string>(ctx->role->accept(this));
return auth;
}
/**
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) {
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::GRANT_DATABASE_TO_USER;
auth->database_ = std::any_cast<std::string>(ctx->wildcardName()->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
return auth;
}
/**
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) {
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::REVOKE_DATABASE_FROM_USER;
auth->database_ = std::any_cast<std::string>(ctx->wildcardName()->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
return auth;
}
/**
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) {
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_DATABASE_PRIVILEGES;
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
return auth;
}
/**
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitSetMainDatabase(MemgraphCypher::SetMainDatabaseContext *ctx) {
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SET_MAIN_DATABASE;
auth->database_ = std::any_cast<std::string>(ctx->db->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
return auth;
}
antlrcpp::Any CypherMainVisitor::visitCypherReturn(MemgraphCypher::CypherReturnContext *ctx) {
auto *return_clause = storage_->Create<Return>();
return_clause->body_ = std::any_cast<ReturnBody>(ctx->returnBody()->accept(this));
@ -2671,4 +2726,33 @@ PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return stor
EdgeTypeIx CypherMainVisitor::AddEdgeType(const std::string &name) { return storage_->GetEdgeTypeIx(name); }
antlrcpp::Any CypherMainVisitor::visitCreateDatabase(MemgraphCypher::CreateDatabaseContext *ctx) {
auto *mdb_query = storage_->Create<MultiDatabaseQuery>();
mdb_query->db_name_ = std::any_cast<std::string>(ctx->databaseName()->accept(this));
mdb_query->action_ = MultiDatabaseQuery::Action::CREATE;
query_ = mdb_query;
return mdb_query;
}
antlrcpp::Any CypherMainVisitor::visitUseDatabase(MemgraphCypher::UseDatabaseContext *ctx) {
auto *mdb_query = storage_->Create<MultiDatabaseQuery>();
mdb_query->db_name_ = std::any_cast<std::string>(ctx->databaseName()->accept(this));
mdb_query->action_ = MultiDatabaseQuery::Action::USE;
query_ = mdb_query;
return mdb_query;
}
antlrcpp::Any CypherMainVisitor::visitDropDatabase(MemgraphCypher::DropDatabaseContext *ctx) {
auto *mdb_query = storage_->Create<MultiDatabaseQuery>();
mdb_query->db_name_ = std::any_cast<std::string>(ctx->databaseName()->accept(this));
mdb_query->action_ = MultiDatabaseQuery::Action::DROP;
query_ = mdb_query;
return mdb_query;
}
antlrcpp::Any CypherMainVisitor::visitShowDatabases(MemgraphCypher::ShowDatabasesContext * /*ctx*/) {
query_ = storage_->Create<ShowDatabasesQuery>();
return query_;
}
} // namespace memgraph::query::frontend

View File

@ -524,6 +524,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitEntitiesList(MemgraphCypher::EntitiesListContext *ctx) override;
/**
* @return std::string
*/
antlrcpp::Any visitWildcardName(MemgraphCypher::WildcardNameContext *ctx) override;
/**
* @return AuthQuery::FineGrainedPrivilege
*/
@ -554,6 +559,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitSetMainDatabase(MemgraphCypher::SetMainDatabaseContext *ctx) override;
/**
* @return Return*
*/
@ -935,6 +960,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitCallSubquery(MemgraphCypher::CallSubqueryContext *ctx) override;
/**
* @return MultiDatabaseQuery*
*/
antlrcpp::Any visitCreateDatabase(MemgraphCypher::CreateDatabaseContext *ctx) override;
/**
* @return MultiDatabaseQuery*
*/
antlrcpp::Any visitUseDatabase(MemgraphCypher::UseDatabaseContext *ctx) override;
/**
* @return MultiDatabaseQuery*
*/
antlrcpp::Any visitDropDatabase(MemgraphCypher::DropDatabaseContext *ctx) override;
/**
* @return ShowDatabasesQuery*
*/
antlrcpp::Any visitShowDatabases(MemgraphCypher::ShowDatabasesContext *ctx) override;
public:
Query *query() { return query_; }
const static std::string kAnonPrefix;

View File

@ -107,6 +107,7 @@ memgraphCypherKeyword : cypherKeyword
| UNCOMMITTED
| UNLOCK
| UPDATE
| USE
| USER
| USERS
| VERSION
@ -140,6 +141,8 @@ query : cypherQuery
| versionQuery
| showConfigQuery
| transactionQueueQuery
| multiDatabaseQuery
| showDatabases
;
authQuery : createRole
@ -157,6 +160,10 @@ authQuery : createRole
| showPrivileges
| showRoleForUser
| showUsersForRole
| grantDatabaseToUser
| revokeDatabaseFromUser
| showDatabasePrivileges
| setMainDatabase
;
replicationQuery : setReplicationRole
@ -208,6 +215,10 @@ streamQuery : checkStream
| showStreams
;
databaseName : symbolicName ;
wildcardName : ASTERISK | symbolicName ;
settingQuery : setSetting
| showSetting
| showSettings
@ -265,6 +276,14 @@ denyPrivilege : DENY ( ALL PRIVILEGES | privileges=privilegesList ) TO userOrRol
revokePrivilege : REVOKE ( ALL PRIVILEGES | privileges=revokePrivilegesList ) FROM userOrRole=userOrRoleName ;
grantDatabaseToUser : GRANT DATABASE db=wildcardName TO user=symbolicName ;
revokeDatabaseFromUser : REVOKE DATABASE db=wildcardName FROM user=symbolicName ;
showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR user=symbolicName ;
setMainDatabase : SET MAIN DATABASE db=symbolicName FOR user=symbolicName ;
privilege : CREATE
| DELETE
| MATCH
@ -288,6 +307,8 @@ privilege : CREATE
| WEBSOCKET
| TRANSACTION_MANAGEMENT
| STORAGE_MODE
| MULTI_DATABASE_EDIT
| MULTI_DATABASE_USE
;
granularPrivilege : NOTHING | READ | UPDATE | CREATE_DELETE ;
@ -441,3 +462,16 @@ versionQuery : SHOW VERSION ;
transactionIdList : transactionId ( ',' transactionId )* ;
transactionId : literal ;
multiDatabaseQuery : createDatabase
| useDatabase
| dropDatabase
;
createDatabase : CREATE DATABASE databaseName ;
useDatabase : USE DATABASE databaseName ;
dropDatabase : DROP DATABASE databaseName ;
showDatabases: SHOW DATABASES ;

View File

@ -49,6 +49,7 @@ CSV : C S V ;
DATA : D A T A ;
DELIMITER : D E L I M I T E R ;
DATABASE : D A T A B A S E ;
DATABASES : D A T A B A S E S ;
DENY : D E N Y ;
DIRECTORY : D I R E C T O R Y ;
DROP : D R O P ;
@ -80,6 +81,8 @@ MAIN : M A I N ;
MODE : M O D E ;
MODULE_READ : M O D U L E UNDERSCORE R E A D ;
MODULE_WRITE : M O D U L E UNDERSCORE W R I T E ;
MULTI_DATABASE_EDIT : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE E D I T ;
MULTI_DATABASE_USE : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE U S E ;
NEXT : N E X T ;
NO : N O ;
NOTHING : N O T H I N G ;
@ -127,6 +130,7 @@ TRIGGERS : T R I G G E R S ;
UNCOMMITTED : U N C O M M I T T E D ;
UNLOCK : U N L O C K ;
UPDATE : U P D A T E ;
USE : U S E ;
USER : U S E R ;
USERS : U S E R S ;
VERSION : V E R S I O N ;

View File

@ -89,6 +89,22 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis
void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); }
void Visit(MultiDatabaseQuery &query) override {
switch (query.action_) {
case MultiDatabaseQuery::Action::CREATE:
case MultiDatabaseQuery::Action::DROP:
AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_EDIT);
break;
case MultiDatabaseQuery::Action::USE:
AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_USE);
break;
}
}
void Visit(ShowDatabasesQuery & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_USE); /* OR EDIT */
}
bool PreVisit(Create & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::CREATE);
return false;

View File

@ -139,6 +139,7 @@ const trie::Trie kKeywords = {"union",
"false",
"reduce",
"coalesce",
"use",
"user",
"password",
"alter",
@ -159,6 +160,7 @@ const trie::Trie kKeywords = {"union",
"key",
"dump",
"database",
"databases",
"call",
"yield",
"memory",

View File

@ -24,13 +24,17 @@
#include <limits>
#include <memory>
#include <optional>
#include <stdexcept>
#include <thread>
#include <unordered_map>
#include <utility>
#include <variant>
#include "auth/auth.hpp"
#include "auth/models.hpp"
#include "csv/parsing.hpp"
#include "dbms/global.hpp"
#include "dbms/session_context_handler.hpp"
#include "glue/communication.hpp"
#include "license/license.hpp"
#include "memory/memory_control.hpp"
@ -105,6 +109,26 @@ template <typename>
constexpr auto kAlwaysFalse = false;
namespace {
template <typename T, typename K>
void Sort(std::vector<T, K> &vec) {
std::sort(vec.begin(), vec.end());
}
template <typename K>
void Sort(std::vector<TypedValue, K> &vec) {
std::sort(vec.begin(), vec.end(),
[](const TypedValue &lv, const TypedValue &rv) { return lv.ValueString() < rv.ValueString(); });
}
// NOLINTNEXTLINE (misc-unused-parameters)
bool Same(const TypedValue &lv, const TypedValue &rv) {
return TypedValue(lv).ValueString() == TypedValue(rv).ValueString();
}
bool Same(const TypedValue &lv, const std::string &rv) { return std::string(TypedValue(lv).ValueString()) == rv; }
// NOLINTNEXTLINE (misc-unused-parameters)
bool Same(const std::string &lv, const TypedValue &rv) { return lv == std::string(TypedValue(rv).ValueString()); }
bool Same(const std::string &lv, const std::string &rv) { return lv == rv; }
void UpdateTypeCount(const plan::ReadWriteTypeChecker::RWType type) {
switch (type) {
case plan::ReadWriteTypeChecker::RWType::R:
@ -333,8 +357,12 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
/// returns false if the replication role can't be set
/// @throw QueryRuntimeException if an error ocurred.
Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Parameters &parameters,
Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters &parameters,
DbAccessor *db_accessor) {
AuthQueryHandler *auth = interpreter_context->auth;
#ifdef MG_ENTERPRISE
auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context);
#endif
// Empty frame for evaluation of password expression. This is OK since
// password should be either null or string literal and it's evaluation
// should not depend on frame.
@ -351,6 +379,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
std::string username = auth_query->user_;
std::string rolename = auth_query->role_;
std::string user_or_role = auth_query->user_or_role_;
std::string database = auth_query->database_;
std::vector<AuthQuery::Privilege> privileges = auth_query->privileges_;
#ifdef MG_ENTERPRISE
std::vector<std::unordered_map<AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> label_privileges =
@ -364,11 +393,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
const auto license_check_result = license::global_license_checker.IsEnterpriseValid(utils::global_settings);
static const std::unordered_set enterprise_only_methods{
AuthQuery::Action::CREATE_ROLE, AuthQuery::Action::DROP_ROLE, AuthQuery::Action::SET_ROLE,
AuthQuery::Action::CLEAR_ROLE, AuthQuery::Action::GRANT_PRIVILEGE, AuthQuery::Action::DENY_PRIVILEGE,
AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE,
AuthQuery::Action::SHOW_ROLE_FOR_USER};
static const std::unordered_set enterprise_only_methods{AuthQuery::Action::CREATE_ROLE,
AuthQuery::Action::DROP_ROLE,
AuthQuery::Action::SET_ROLE,
AuthQuery::Action::CLEAR_ROLE,
AuthQuery::Action::GRANT_PRIVILEGE,
AuthQuery::Action::DENY_PRIVILEGE,
AuthQuery::Action::REVOKE_PRIVILEGE,
AuthQuery::Action::SHOW_PRIVILEGES,
AuthQuery::Action::SHOW_USERS_FOR_ROLE,
AuthQuery::Action::SHOW_ROLE_FOR_USER,
AuthQuery::Action::GRANT_DATABASE_TO_USER,
AuthQuery::Action::REVOKE_DATABASE_FROM_USER,
AuthQuery::Action::SHOW_DATABASE_PRIVILEGES,
AuthQuery::Action::SET_MAIN_DATABASE};
if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) {
throw utils::BasicException(
@ -536,6 +574,73 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
return rows;
};
return callback;
case AuthQuery::Action::GRANT_DATABASE_TO_USER:
#ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, &sc_handler] { // NOLINT
try {
memgraph::dbms::SessionContext sc(nullptr, "", nullptr, nullptr);
if (database != memgraph::auth::kAllDatabases) {
sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull
}
if (!auth->GrantDatabaseToUser(database, username)) {
throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username);
}
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw QueryRuntimeException(e.what());
}
#else
callback.fn = [] {
#endif
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::REVOKE_DATABASE_FROM_USER:
#ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, &sc_handler] { // NOLINT
try {
memgraph::dbms::SessionContext sc(nullptr, "", nullptr, nullptr);
if (database != memgraph::auth::kAllDatabases) {
sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull
}
if (!auth->RevokeDatabaseFromUser(database, username)) {
throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username);
}
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw QueryRuntimeException(e.what());
}
#else
callback.fn = [] {
#endif
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::SHOW_DATABASE_PRIVILEGES:
callback.header = {"grants", "denies"};
callback.fn = [auth, username] { // NOLINT
#ifdef MG_ENTERPRISE
return auth->GetDatabasePrivileges(username);
#else
return std::vector<std::vector<TypedValue>>();
#endif
};
return callback;
case AuthQuery::Action::SET_MAIN_DATABASE:
#ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, &sc_handler] { // NOLINT
try {
const auto sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull
if (!auth->SetMainDatabase(database, username)) {
throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username);
}
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw QueryRuntimeException(e.what());
}
#else
callback.fn = [] {
#endif
return std::vector<std::vector<TypedValue>>();
};
return callback;
default:
break;
}
@ -1249,11 +1354,23 @@ bool IsWriteQueryOnMainMemoryReplica(storage::Storage *storage,
return false;
}
storage::replication::ReplicationRole GetReplicaRole(storage::Storage *storage) {
if (auto storage_mode = storage->GetStorageMode(); storage_mode == storage::StorageMode::IN_MEMORY_ANALYTICAL ||
storage_mode == storage::StorageMode::IN_MEMORY_TRANSACTIONAL) {
auto *mem_storage = static_cast<storage::InMemoryStorage *>(storage);
return mem_storage->GetReplicationRole();
}
return storage::replication::ReplicationRole::MAIN;
}
} // namespace
InterpreterContext::InterpreterContext(const storage::Config storage_config, const InterpreterConfig interpreter_config,
const std::filesystem::path &data_directory)
: trigger_store(data_directory / "triggers"),
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah,
query::AuthChecker *ac)
: auth(ah),
auth_checker(ac),
trigger_store(data_directory / "triggers"),
config(interpreter_config),
streams{this, data_directory / "streams"} {
if (utils::DirExists(storage_config.disk.main_storage_directory)) {
@ -1264,8 +1381,11 @@ InterpreterContext::InterpreterContext(const storage::Config storage_config, con
}
InterpreterContext::InterpreterContext(std::unique_ptr<storage::Storage> db, InterpreterConfig interpreter_config,
const std::filesystem::path &data_directory)
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah,
query::AuthChecker *ac)
: db(std::move(db)),
auth(ah),
auth_checker(ac),
trigger_store(data_directory / "triggers"),
config(interpreter_config),
streams{this, data_directory / "streams"} {}
@ -2027,7 +2147,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
auto *auth_query = utils::Downcast<AuthQuery>(parsed_query.query);
auto callback = HandleAuthQuery(auth_query, interpreter_context->auth, parsed_query.parameters, dba);
auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, dba);
SymbolTable symbol_table;
std::vector<Symbol> output_symbols;
@ -2668,7 +2788,7 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query,
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized(
username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT});
username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, "");
Callback callback;
switch (transaction_query->action_) {
@ -2770,6 +2890,7 @@ PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transa
handler = [db, interpreter_isolation_level, next_transaction_isolation_level] {
auto info = db->GetInfo();
std::vector<std::vector<TypedValue>> results{
{TypedValue("name"), TypedValue(db->id())},
{TypedValue("vertex_count"), TypedValue(static_cast<int64_t>(info.vertex_count))},
{TypedValue("edge_count"), TypedValue(static_cast<int64_t>(info.edge_count))},
{TypedValue("average_degree"), TypedValue(info.average_degree)},
@ -3118,6 +3239,250 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_
RWType::NONE};
}
PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explicit_transaction, bool in_explicit_db,
InterpreterContext *interpreter_context, const std::string &session_uuid) {
#ifdef MG_ENTERPRISE
if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license.");
}
// TODO: Remove once replicas support multi-tenant replication
if (GetReplicaRole(interpreter_context->db.get()) == storage::replication::ReplicationRole::REPLICA) {
throw QueryException("Query forbidden on the replica!");
}
if (in_explicit_transaction) {
throw MultiDatabaseQueryInMulticommandTxException();
}
auto *query = utils::Downcast<MultiDatabaseQuery>(parsed_query.query);
auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context);
switch (query->action_) {
case MultiDatabaseQuery::Action::CREATE:
return PreparedQuery{
{"STATUS"},
std::move(parsed_query.required_privileges),
[db_name = query->db_name_, session_uuid, &sc_handler](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
std::vector<std::vector<TypedValue>> status;
std::string res;
const auto success = sc_handler.New(db_name);
if (success.HasError()) {
switch (success.GetError()) {
case dbms::NewError::EXISTS:
res = db_name + " already exists.";
break;
case dbms::NewError::DEFUNCT:
throw QueryRuntimeException(
"{} is defunct and in an unknown state. Try to delete it again or clean up storage and restart "
"Memgraph.",
db_name);
case dbms::NewError::GENERIC:
throw QueryRuntimeException("Failed while creating {}", db_name);
case dbms::NewError::NO_CONFIGS:
throw QueryRuntimeException("No configuration found while trying to create {}", db_name);
}
} else {
res = "Successfully created database " + db_name;
}
status.emplace_back(std::vector<TypedValue>{TypedValue(res)});
auto pull_plan = std::make_shared<PullPlanVector>(std::move(status));
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::W,
"" // No target DB possible
};
case MultiDatabaseQuery::Action::USE:
if (in_explicit_db) {
throw QueryException("Database switching is prohibited if session explicitly defines the used database");
}
return PreparedQuery{{"STATUS"},
std::move(parsed_query.required_privileges),
[db_name = query->db_name_, session_uuid, &sc_handler](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
std::vector<std::vector<TypedValue>> status;
std::string res;
memgraph::dbms::SetForResult set = memgraph::dbms::SetForResult::SUCCESS;
try {
set = sc_handler.SetFor(session_uuid, db_name);
} catch (const utils::BasicException &e) {
throw QueryRuntimeException(e.what());
}
switch (set) {
case dbms::SetForResult::SUCCESS:
res = "Using " + db_name;
break;
case dbms::SetForResult::ALREADY_SET:
res = "Already using " + db_name;
break;
case dbms::SetForResult::FAIL:
throw QueryRuntimeException("Failed to start using {}", db_name);
}
status.emplace_back(std::vector<TypedValue>{TypedValue(res)});
auto pull_plan = std::make_shared<PullPlanVector>(std::move(status));
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::NONE,
query->db_name_};
case MultiDatabaseQuery::Action::DROP:
return PreparedQuery{
{"STATUS"},
std::move(parsed_query.required_privileges),
[db_name = query->db_name_, session_uuid, &sc_handler](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
std::vector<std::vector<TypedValue>> status;
memgraph::dbms::DeleteResult success{};
try {
success = sc_handler.Delete(db_name);
} catch (const utils::BasicException &e) {
throw QueryRuntimeException(e.what());
}
if (success.HasError()) {
switch (success.GetError()) {
case dbms::DeleteError::DEFAULT_DB:
throw QueryRuntimeException("Cannot delete the default database.");
case dbms::DeleteError::NON_EXISTENT:
throw QueryRuntimeException("{} does not exist.", db_name);
case dbms::DeleteError::USING:
throw QueryRuntimeException("Cannot delete {}, it is currently being used.", db_name);
case dbms::DeleteError::FAIL:
throw QueryRuntimeException("Failed while deleting {}", db_name);
case dbms::DeleteError::DISK_FAIL:
throw QueryRuntimeException("Failed to clean storage of {}", db_name);
}
}
status.emplace_back(std::vector<TypedValue>{TypedValue("Successfully deleted " + db_name)});
auto pull_plan = std::make_shared<PullPlanVector>(std::move(status));
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::W,
query->db_name_};
}
#else
throw QueryException("Query not supported.");
#endif
}
PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context,
const std::string &session_uuid, std::map<std::string, TypedValue> *summary,
DbAccessor *dba, utils::MemoryResource *execution_memory,
const std::optional<std::string> &username,
std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<utils::AsyncTimer> tx_timer) {
#ifdef MG_ENTERPRISE
if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license.");
}
// TODO: Remove once replicas support multi-tenant replication
if (GetReplicaRole(interpreter_context->db.get()) == storage::replication::ReplicationRole::REPLICA) {
throw QueryException("SHOW DATABASES forbidden on the replica!");
}
auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context);
AuthQueryHandler *auth = interpreter_context->auth;
Callback callback;
callback.header = {"Name", "Current"};
callback.fn = [auth, session_uuid, &sc_handler, username]() mutable -> std::vector<std::vector<TypedValue>> {
std::vector<std::vector<TypedValue>> status;
const auto in_use = sc_handler.Current(session_uuid);
bool found_current = false;
auto gen_status = [&]<typename T, typename K>(T all, K denied) {
Sort(all);
Sort(denied);
status.reserve(all.size());
for (const auto &name : all) {
TypedValue use("");
if (!found_current && Same(name, in_use)) {
use = TypedValue("*");
found_current = true;
}
status.push_back({TypedValue(name), std::move(use)});
}
// No denied databases (no need to filter them out)
if (denied.empty()) return;
auto denied_itr = denied.begin();
auto iter = std::remove_if(status.begin(), status.end(), [&denied_itr, &denied](auto &in) -> bool {
while (denied_itr != denied.end() && denied_itr->ValueString() < in[0].ValueString()) ++denied_itr;
return (denied_itr != denied.end() && denied_itr->ValueString() == in[0].ValueString());
});
status.erase(iter, status.end());
};
if (!username) {
// No user, return all
gen_status(sc_handler.All(), std::vector<TypedValue>{});
} else {
// User has a subset of accessible dbs; this is synched with the SessionContextHandler
const auto &db_priv = auth->GetDatabasePrivileges(*username);
const auto &allowed = db_priv[0][0];
const auto &denied = db_priv[0][1].ValueList();
if (allowed.IsString() && allowed.ValueString() == auth::kAllDatabases) {
// All databases are allowed
gen_status(sc_handler.All(), denied);
} else {
gen_status(allowed.ValueList(), denied);
}
}
if (!found_current) throw QueryRuntimeException("Missing current database!");
return status;
};
SymbolTable symbol_table;
std::vector<Symbol> output_symbols;
for (const auto &column : callback.header) {
output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false"));
}
auto plan = std::make_shared<CachedPlan>(std::make_unique<SingleNodeLogicalPlan>(
std::make_unique<plan::OutputTable>(output_symbols,
[fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }),
0.0, AstStorage{}, symbol_table));
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, username, transaction_status, std::move(tx_timer));
return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
summary](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (pull_plan->Pull(stream, n, output_symbols, summary)) {
return callback.should_abort_query ? QueryHandlerResult::ABORT : QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::NONE,
"" // No target DB
};
#else
throw QueryException("Query not supported.");
#endif
}
std::optional<uint64_t> Interpreter::GetTransactionId() const {
if (db_accessor_) {
return db_accessor_->GetTransactionId();
@ -3146,7 +3511,8 @@ void Interpreter::RollbackTransaction() {
Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
const std::map<std::string, storage::PropertyValue> &params,
const std::string *username, QueryExtras const &extras) {
const std::string *username, QueryExtras const &extras,
const std::string &session_uuid) {
std::shared_ptr<utils::AsyncTimer> current_timer;
if (!in_explicit_transaction_) {
query_executions_.clear();
@ -3177,7 +3543,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
in_explicit_transaction_ ? static_cast<int>(query_executions_.size() - 1) : std::optional<int>{};
query_execution->prepared_query.emplace(PrepareTransactionQuery(trimmed_query, extras));
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid};
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid, {}};
}
// Don't save BEGIN, COMMIT or ROLLBACK
@ -3329,6 +3695,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<TransactionQueueQuery>(parsed_query.query)) {
prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, in_explicit_transaction_,
interpreter_context_, &*execution_db_accessor_);
} else if (utils::Downcast<MultiDatabaseQuery>(parsed_query.query)) {
prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), in_explicit_transaction_, in_explicit_db_,
interpreter_context_, session_uuid);
} else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) {
prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, session_uuid,
&query_execution->summary, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception, username_,
&transaction_status_, std::move(current_timer));
} else {
LOG_FATAL("Should not get here -- unknown query type!");
}
@ -3346,7 +3720,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
throw QueryException("Write query forbidden on the replica!");
}
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid};
// Set the target db to the current db (some queries have different target from the current db)
if (!query_execution->prepared_query->db) {
query_execution->prepared_query->db = interpreter_context_->db->id();
}
query_execution->summary["db"] = *query_execution->prepared_query->db;
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid,
query_execution->prepared_query->db};
} catch (const utils::BasicException &) {
memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery);
AbortCommand(query_execution_ptr);

View File

@ -77,6 +77,24 @@ class AuthQueryHandler {
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0;
#ifdef MG_ENTERPRISE
/// Return true if access revoked successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0;
/// Return true if access granted successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0;
/// Returns database access rights for the user
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) = 0;
/// Return true if main database set successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool SetMainDatabase(const std::string &db, const std::string &username) = 0;
#endif
/// Return false if the role already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateRole(const std::string &rolename) = 0;
@ -202,6 +220,7 @@ struct PreparedQuery {
std::vector<AuthQuery::Privilege> privileges;
std::function<std::optional<QueryHandlerResult>(AnyStream *stream, std::optional<int> n)> query_handler;
plan::ReadWriteTypeChecker::RWType rw_type;
std::optional<std::string> db{};
};
/**
@ -223,10 +242,12 @@ class Interpreter;
/// TODO: andi decouple in a separate file why here?
struct InterpreterContext {
explicit InterpreterContext(storage::Config storage_config, InterpreterConfig interpreter_config,
const std::filesystem::path &data_directory);
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr,
query::AuthChecker *ac = nullptr);
InterpreterContext(std::unique_ptr<storage::Storage> db, InterpreterConfig interpreter_config,
const std::filesystem::path &data_directory);
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr,
query::AuthChecker *ac = nullptr);
std::unique_ptr<storage::Storage> db;
@ -240,8 +261,8 @@ struct InterpreterContext {
std::optional<double> tsc_frequency{utils::GetTSCFrequency()};
std::atomic<bool> is_shutting_down{false};
AuthQueryHandler *auth{nullptr};
AuthChecker *auth_checker{nullptr};
AuthQueryHandler *auth;
AuthChecker *auth_checker;
utils::SkipList<QueryCacheEntry> ast_cache;
utils::SkipList<PlanCacheEntry> plan_cache;
@ -272,10 +293,12 @@ class Interpreter final {
std::vector<std::string> headers;
std::vector<query::AuthQuery::Privilege> privileges;
std::optional<int> qid;
std::optional<std::string> db;
};
std::optional<std::string> username_;
bool in_explicit_transaction_{false};
bool in_explicit_db_{false};
bool expect_rollback_{false};
std::shared_ptr<utils::AsyncTimer> explicit_transaction_timer_{};
std::optional<std::map<std::string, storage::PropertyValue>> metadata_{}; //!< User defined transaction metadata
@ -289,7 +312,8 @@ class Interpreter final {
* @throw query::QueryException
*/
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> &params,
const std::string *username, QueryExtras const &extras = {});
const std::string *username, QueryExtras const &extras = {},
const std::string &session_uuid = {});
/**
* Execute the last prepared query and stream *all* of the results into the

View File

@ -520,7 +520,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name);
auto prepare_result =
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), nullptr);
if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges)) {
if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges, "")) {
throw StreamsException{
"Couldn't execute query '{}' for stream '{}' because the owner is not authorized to execute the "
"query!",

View File

@ -187,7 +187,7 @@ std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers));
}
if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges)) {
if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges, "")) {
throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_);
}
return trigger_plan_;

View File

@ -16,9 +16,15 @@
#include <filesystem>
#include "storage/v2/isolation_level.hpp"
#include "storage/v2/transaction.hpp"
#include "utils/exceptions.hpp"
namespace memgraph::storage {
/// Exception used to signal configuration errors.
class StorageConfigException : public utils::BasicException {
using utils::BasicException::BasicException;
};
/// Pass this class to the \ref Storage constructor to change the behavior of
/// the storage. This class also defines the default behavior.
struct Config {
@ -62,15 +68,49 @@ struct Config {
} transaction;
struct DiskConfig {
std::filesystem::path main_storage_directory{"rocksdb_main_storage"};
std::filesystem::path label_index_directory{"rocksdb_label_index"};
std::filesystem::path label_property_index_directory{"rocksdb_label_property_index"};
std::filesystem::path unique_constraints_directory{"rocksdb_unique_constraints"};
std::filesystem::path name_id_mapper_directory{"rocksdb_name_id_mapper"};
std::filesystem::path id_name_mapper_directory{"rocksdb_id_name_mapper"};
std::filesystem::path durability_directory{"rocksdb_durability"};
std::filesystem::path wal_directory{"rocksdb_wal"};
std::filesystem::path main_storage_directory{"storage/rocksdb_main_storage"};
std::filesystem::path label_index_directory{"storage/rocksdb_label_index"};
std::filesystem::path label_property_index_directory{"storage/rocksdb_label_property_index"};
std::filesystem::path unique_constraints_directory{"storage/rocksdb_unique_constraints"};
std::filesystem::path name_id_mapper_directory{"storage/rocksdb_name_id_mapper"};
std::filesystem::path id_name_mapper_directory{"storage/rocksdb_id_name_mapper"};
std::filesystem::path durability_directory{"storage/rocksdb_durability"};
std::filesystem::path wal_directory{"storage/rocksdb_wal"};
} disk;
std::string name;
};
static inline void UpdatePaths(Config &config, const std::filesystem::path &storage_dir) {
auto contained = [](const auto &path, const auto &base) -> std::optional<std::filesystem::path> {
auto rel = std::filesystem::relative(path, base);
if (!rel.empty() && rel.native()[0] != '.') { // Contained
return rel;
}
return {};
};
const auto old_base =
std::filesystem::weakly_canonical(std::filesystem::absolute(config.durability.storage_directory));
config.durability.storage_directory = std::filesystem::weakly_canonical(std::filesystem::absolute(storage_dir));
auto UPDATE_PATH = [&](auto to_update) {
const auto old_path = std::filesystem::weakly_canonical(std::filesystem::absolute(to_update(config.disk)));
const auto contained_path = contained(old_path, old_base);
if (!contained_path) {
throw StorageConfigException("On-disk directories not contained in root.");
}
to_update(config.disk) = config.durability.storage_directory / *contained_path;
};
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::main_storage_directory));
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::label_index_directory));
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::label_property_index_directory));
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::unique_constraints_directory));
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::name_id_mapper_directory));
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::id_name_mapper_directory));
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::durability_directory));
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::wal_directory));
}
} // namespace memgraph::storage

View File

@ -50,7 +50,8 @@ Storage::Storage(Config config, StorageMode storage_mode)
isolation_level_(config.transaction.isolation_level),
storage_mode_(storage_mode),
indices_(&constraints_, config, storage_mode),
constraints_(config, storage_mode) {}
constraints_(config, storage_mode),
id_(config.name) {}
Storage::Accessor::Accessor(Storage *storage, IsolationLevel isolation_level, StorageMode storage_mode)
: storage_(storage),

View File

@ -74,6 +74,8 @@ class Storage {
virtual ~Storage() {}
const std::string &id() const { return id_; }
class Accessor {
public:
Accessor(Storage *storage, IsolationLevel isolation_level, StorageMode storage_mode);
@ -179,6 +181,8 @@ class Storage {
StorageMode GetCreationStorageMode() const;
const std::string &id() const { return storage_->id(); }
protected:
Storage *storage_;
std::shared_lock<utils::RWLock> storage_guard_;
@ -301,6 +305,7 @@ class Storage {
std::atomic<uint64_t> vertex_id_{0};
std::atomic<uint64_t> edge_id_{0};
const std::string id_; //!< High-level assigned ID
};
} // namespace memgraph::storage

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -27,6 +27,7 @@ inline uint64_t GetDirDiskUsage(const std::filesystem::path &path) {
uint64_t size = 0;
for (auto &p : std::filesystem::directory_iterator(path)) {
if (std::filesystem::is_symlink(p)) continue;
if (std::filesystem::is_directory(p)) {
size += GetDirDiskUsage(p);
} else if (std::filesystem::is_regular_file(p)) {

189
src/utils/sync_ptr.hpp Normal file
View File

@ -0,0 +1,189 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <condition_variable>
#include <functional>
#include <iostream>
#include <memory>
#include "utils/exceptions.hpp"
namespace memgraph::utils {
/**
* @brief
*
* @tparam TContext
* @tparam TConfig
*/
template <typename TContext, typename TConfig = void>
struct SyncPtr {
/**
* @brief Construct a new synched pointer.
*
* @tparam TArgs variable templates used by the TContext constructor
* @param config Additional metadata associated with context
* @param args Arguments to pass to TContext constructor
*/
template <typename... TArgs>
explicit SyncPtr(TConfig config, TArgs &&...args)
: timeout_{1000}, config_{config}, ptr_{new TContext(std::forward<TArgs>(args)...), [this](TContext *ptr) {
this->OnDelete(ptr);
}} {}
~SyncPtr() = default;
SyncPtr(const SyncPtr &) = delete;
SyncPtr &operator=(const SyncPtr &) = delete;
SyncPtr(SyncPtr &&) noexcept = delete;
SyncPtr &operator=(SyncPtr &&) noexcept = delete;
/**
* @brief Destroy the synched pointer and wait for all copies to get destroyed.
*
*/
void DestroyAndSync() {
ptr_.reset();
SyncOnDelete();
}
/**
* @brief Get (copy) the underlying shared pointer.
*
* @return std::shared_ptr<TContext>
*/
std::shared_ptr<TContext> get() { return ptr_; }
std::shared_ptr<const TContext> get() const { return ptr_; }
/**
* @brief Return saved configuration (metadata)
*
* @return TConfig
*/
TConfig config() { return config_; }
const TConfig &config() const { return config_; }
void timeout(const std::chrono::milliseconds to) { timeout_ = to; }
std::chrono::milliseconds timeout() const { return timeout_; }
private:
/**
* @brief Block until OnDelete gets called.
*
*/
void SyncOnDelete() {
std::unique_lock<std::mutex> lock(in_use_mtx_);
if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) {
throw utils::BasicException("Syncronization timeout!");
}
}
/**
* @brief Custom destructor used to sync the shared_ptr release.
*
* @param p Pointer to the undelying object.
*/
void OnDelete(TContext *p) {
delete p;
{
std::lock_guard<std::mutex> lock(in_use_mtx_);
in_use_ = false;
}
in_use_cv_.notify_all();
}
bool in_use_{true}; //!< Flag used to signal sync
mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync
mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync
std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms
TConfig config_; //!< Additional metadata associated with the context
std::shared_ptr<TContext> ptr_; //!< Pointer being synced
};
template <typename TContext>
class SyncPtr<TContext, void> {
public:
/**
* @brief Construct a new synched pointer.
*
* @tparam TArgs variable templates used by the TContext constructor
* @param config Additional metadata associated with context
* @param args Arguments to pass to TContext constructor
*/
template <typename... TArgs>
explicit SyncPtr(TArgs &&...args)
: timeout_{1000}, ptr_{new TContext(std::forward<TArgs>(args)...), [this](TContext *ptr) {
this->OnDelete(ptr);
}} {}
~SyncPtr() = default;
SyncPtr(const SyncPtr &) = delete;
SyncPtr &operator=(const SyncPtr &) = delete;
SyncPtr(SyncPtr &&) noexcept = delete;
SyncPtr &operator=(SyncPtr &&) noexcept = delete;
/**
* @brief Destroy the synched pointer and wait for all copies to get destroyed.
*
*/
void DestroyAndSync() {
ptr_.reset();
SyncOnDelete();
}
/**
* @brief Get (copy) the underlying shared pointer.
*
* @return std::shared_ptr<TContext>
*/
std::shared_ptr<TContext> get() { return ptr_; }
std::shared_ptr<const TContext> get() const { return ptr_; }
void timeout(const std::chrono::milliseconds to) { timeout_ = to; }
std::chrono::milliseconds timeout() const { return timeout_; }
private:
/**
* @brief Block until OnDelete gets called.
*
*/
void SyncOnDelete() {
std::unique_lock<std::mutex> lock(in_use_mtx_);
if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) {
throw utils::BasicException("Syncronization timeout!");
}
}
/**
* @brief Custom destructor used to sync the shared_ptr release.
*
* @param p Pointer to the undelying object.
*/
void OnDelete(TContext *p) {
delete p;
{
std::lock_guard<std::mutex> lock(in_use_mtx_);
in_use_ = false;
}
in_use_cv_.notify_all();
}
bool in_use_{true}; //!< Flag used to signal sync
mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync
mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync
std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms
std::shared_ptr<TContext> ptr_; //!< Pointer being synced
};
} // namespace memgraph::utils

View File

@ -184,6 +184,8 @@ enum class TypeId : uint64_t {
AST_TRANSACTION_QUEUE_QUERY,
AST_EXISTS,
AST_CALL_SUBQUERY,
AST_MULTI_DATABASE_QUERY,
AST_SHOW_DATABASES,
// Symbol
SYMBOL,
};

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -30,10 +30,10 @@ TEST(Network, Server) {
std::cout << endpoint << std::endl;
// initialize server
TestData session_data;
TestData session_context;
int N = (std::thread::hardware_concurrency() + 1) / 2;
ContextT context;
ServerT server(endpoint, &session_data, &context, -1, "Test", N);
ServerT server(endpoint, &session_context, &context, -1, "Test", N);
ASSERT_TRUE(server.Start());
const auto &ep = server.endpoint();

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -32,9 +32,9 @@ TEST(Network, SessionLeak) {
Endpoint endpoint(interface, 0);
// initialize server
TestData session_data;
TestData session_context;
ContextT context;
ServerT server(endpoint, &session_data, &context, -1, "Test", 2);
ServerT server(endpoint, &session_context, &context, -1, "Test", 2);
ASSERT_TRUE(server.Start());
// start clients

View File

@ -25,9 +25,14 @@ def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}
def connect(**kwargs) -> mgclient.Connection:
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
connection.autocommit = True
yield connection
cursor = connection.cursor()
execute_and_fetch_all(cursor, "USE DATABASE memgraph")
try:
execute_and_fetch_all(cursor, "DROP DATABASE clean")
except:
pass
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
yield connection
@pytest.fixture

View File

@ -21,6 +21,18 @@ QUERY_PLAN = "QUERY PLAN"
# ------------------------------------
@pytest.fixture(scope="function")
def multi_db(request, connect):
cursor = connect.cursor()
if request.param:
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
execute_and_fetch_all(cursor, "USE DATABASE clean")
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
pass
yield connect
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize(
"delete_query",
[
@ -30,9 +42,9 @@ QUERY_PLAN = "QUERY PLAN"
"ANALYZE GRAPH ON LABELS :Label, :NONEXISTING DELETE STATISTICS",
],
)
def test_analyze_graph_delete_statistics(delete_query, connect):
def test_analyze_graph_delete_statistics(delete_query, multi_db):
"""Tests that all variants of delete queries work as expected."""
cursor = connect.cursor()
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));")
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i % 5}));")
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
@ -62,6 +74,7 @@ def test_analyze_graph_delete_statistics(delete_query, connect):
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize(
"analyze_query",
[
@ -71,11 +84,11 @@ def test_analyze_graph_delete_statistics(delete_query, connect):
"ANALYZE GRAPH ON LABELS :Label, :NONEXISTING",
],
)
def test_analyze_full_graph(analyze_query, connect):
def test_analyze_full_graph(analyze_query, multi_db):
"""Tests analyzing full graph and choosing better index based on the smaller average group size.
It also tests querying based on labels and that nothing bad will happen by providing non-existing label.
"""
cursor = connect.cursor()
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));")
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i % 5}));")
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
@ -121,9 +134,10 @@ def test_analyze_full_graph(analyze_query, connect):
# -----------------------------
def test_cardinality_different_avg_group_size_uniform_dist(connect):
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_cardinality_different_avg_group_size_uniform_dist(multi_db):
"""Tests index optimization with indices both having uniform distribution but one has smaller avg. group size."""
cursor = connect.cursor()
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));")
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id2: i % 20}));")
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
@ -151,9 +165,10 @@ def test_cardinality_different_avg_group_size_uniform_dist(connect):
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(connect):
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(multi_db):
"""Tests index choosing where both indices have uniform key distribution with same avg. group size but one has less vertices."""
cursor = connect.cursor()
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));")
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i}));")
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
@ -181,9 +196,10 @@ def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(connect)
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
def test_large_diff_in_num_vertices_v1(connect):
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_large_diff_in_num_vertices_v1(multi_db):
"""Tests that when one index has > 10x vertices than the other one, it should be chosen no matter avg group size and uniform distribution."""
cursor = connect.cursor()
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 1000) | CREATE (n:Label {id1: i}));")
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 99) | CREATE (n:Label {id2: 1}));")
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
@ -211,9 +227,10 @@ def test_large_diff_in_num_vertices_v1(connect):
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
def test_large_diff_in_num_vertices_v2(connect):
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_large_diff_in_num_vertices_v2(multi_db):
"""Tests that when one index has > 10x vertices than the other one, it should be chosen no matter avg group size and uniform distribution."""
cursor = connect.cursor()
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 99) | CREATE (n:Label {id1: 1}));")
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 1000) | CREATE (n:Label {id2: i}));")
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
@ -241,9 +258,10 @@ def test_large_diff_in_num_vertices_v2(connect):
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
def test_same_avg_group_size_diff_distribution(connect):
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_same_avg_group_size_diff_distribution(multi_db):
"""Tests index choice decision based on key distribution."""
cursor = connect.cursor()
cursor = multi_db.cursor()
# Setup first key distribution
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 10) | CREATE (n:Label {id1: 1}));")
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 30) | CREATE (n:Label {id1: 2}));")

View File

@ -66,6 +66,11 @@ startup_config_dict = {
"Time in seconds after which inactive Bolt sessions will be closed.",
),
"data_directory": ("mg_data", "mg_data", "Path to directory in which to save all permanent data."),
"data_recovery_on_startup": (
"false",
"false",
"Controls whether the database recovers persisted data on startup.",
),
"isolation_level": (
"SNAPSHOT_ISOLATION",
"SNAPSHOT_ISOLATION",
@ -133,11 +138,6 @@ startup_config_dict = {
"The number of edges and vertices stored in a batch in a snapshot file.",
),
"storage_properties_on_edges": ("false", "true", "Controls whether edges have properties."),
"storage_recover_on_startup": (
"false",
"false",
"Controls whether the storage recovers persisted data on startup.",
),
"storage_recovery_thread_count": ("12", "12", "The number of threads used to recover persisted data from disk."),
"storage_snapshot_interval_sec": (
"0",
@ -157,6 +157,11 @@ startup_config_dict = {
"Issue a 'fsync' call after this amount of transactions are written to the WAL file. Set to 1 for fully synchronous operation.",
),
"storage_wal_file_size_kib": ("20480", "20480", "Minimum file size of each WAL file."),
"storage_delete_on_drop": (
"true",
"true",
"If set to true the query 'DROP DATABASE x' will delete the underlying storage as well.",
),
"stream_transaction_conflict_retries": (
"30",
"30",

View File

@ -16,6 +16,7 @@ import mgclient
import pytest
default_storage_info_dict = {
"name": "memgraph",
"vertex_count": 0,
"edge_count": 0,
"average_degree": 0,
@ -55,7 +56,7 @@ def test_does_default_config_match():
machine_dependent_configurations = ["memory_usage", "disk_usage", "memory_allocated", "allocation_limit"]
# Number of different data-points returned by SHOW STORAGE INFO
assert len(config) == 11
assert len(config) == 12
for conf in config:
conf_name = conf[0]

View File

@ -6,3 +6,4 @@ copy_fine_grained_access_e2e_python_files(common.py)
copy_fine_grained_access_e2e_python_files(create_delete_filtering_tests.py)
copy_fine_grained_access_e2e_python_files(edge_type_filtering_tests.py)
copy_fine_grained_access_e2e_python_files(path_filtering_tests.py)
copy_fine_grained_access_e2e_python_files(show_db.py)

View File

@ -1,4 +1,4 @@
# Copyright 2021 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,6 +12,23 @@
import mgclient
def switch_db(cursor):
execute_and_fetch_all(cursor, "USE DATABASE clean;")
def create_multi_db(cursor, switch):
execute_and_fetch_all(cursor, "USE DATABASE memgraph;")
try:
execute_and_fetch_all(cursor, "DROP DATABASE clean;")
except:
pass
execute_and_fetch_all(cursor, "CREATE DATABASE clean;")
if switch:
switch_db(cursor)
reset_and_prepare(cursor)
execute_and_fetch_all(cursor, "USE DATABASE memgraph;")
def reset_and_prepare(admin_cursor):
execute_and_fetch_all(admin_cursor, "REVOKE LABELS * FROM user;")
execute_and_fetch_all(admin_cursor, "REVOKE EDGE_TYPES * FROM user;")

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -16,168 +16,224 @@ import pytest
from mgclient import DatabaseError
def test_create_node_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_create_node_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
assert len(results) == 1
def test_create_node_all_labels_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_create_node_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;")
common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
def test_create_node_specific_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_create_node_specific_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
assert len(results) == 1
def test_create_node_specific_label_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_create_node_specific_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;")
common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
def test_delete_node_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;")
if switch:
common.switch_db(user_connection.cursor())
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) RETURN n;")
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) RETURN n;")
assert len(results) == 0
def test_delete_node_all_labels_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n")
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
def test_delete_node_specific_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_specific_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
if switch:
common.switch_db(admin_connection.cursor())
results = common.execute_and_fetch_all(admin_connection.cursor(), "MATCH (n:test_delete) RETURN n;")
assert len(results) == 0
def test_delete_node_specific_label_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_specific_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;")
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
def test_create_edge_all_labels_all_edge_types_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
assert len(results) == 1
def test_create_edge_all_labels_all_edge_types_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_create_edge_all_labels_denied_all_edge_types_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_denied_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_create_edge_all_labels_granted_all_edge_types_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_granted_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_create_edge_all_labels_granted_specific_edge_types_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(
admin_connection.cursor(),
"GRANT UPDATE ON EDGE_TYPES :edge_type TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_create_edge_first_node_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_first_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label2 TO user;")
common.execute_and_fetch_all(
@ -185,17 +241,21 @@ def test_create_edge_first_node_label_granted():
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_create_edge_second_node_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_second_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
common.execute_and_fetch_all(
@ -203,62 +263,78 @@ def test_create_edge_second_node_label_granted():
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_delete_edge_all_labels_denied_all_edge_types_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_all_labels_denied_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
)
def test_delete_edge_all_labels_granted_all_edge_types_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_all_labels_granted_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
)
def test_delete_edge_all_labels_granted_specific_edge_types_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(
admin_connection.cursor(),
"GRANT UPDATE ON EDGE_TYPES :edge_type_delete TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
)
def test_delete_edge_first_node_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_first_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete_2 TO user;")
common.execute_and_fetch_all(
@ -266,17 +342,21 @@ def test_delete_edge_first_node_label_granted():
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
)
def test_delete_edge_second_node_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_second_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete_1 TO user;")
common.execute_and_fetch_all(
@ -284,159 +364,209 @@ def test_delete_edge_second_node_label_granted():
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
)
def test_delete_node_with_edge_label_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_with_edge_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(
admin_connection.cursor(),
"GRANT UPDATE ON LABELS :test_delete_1 TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n) DETACH DELETE n;")
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n) DETACH DELETE n;")
def test_delete_node_with_edge_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_with_edge_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(
admin_connection.cursor(),
"GRANT CREATE_DELETE ON LABELS :test_delete_1 TO user;",
)
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n) DETACH DELETE n;")
if switch:
common.switch_db(user_connection.cursor())
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n) DETACH DELETE n;")
if switch:
common.switch_db(admin_connection.cursor())
results = common.execute_and_fetch_all(admin_connection.cursor(), "MATCH (n:test_delete_1) RETURN n;")
assert len(results) == 0
def test_merge_node_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_node_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
assert len(results) == 1
def test_merge_node_all_labels_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_node_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;")
common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
def test_merge_node_specific_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_node_specific_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
assert len(results) == 1
def test_merge_node_specific_label_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_node_specific_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;")
common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
def test_merge_edge_all_labels_all_edge_types_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
assert len(results) == 1
def test_merge_edge_all_labels_all_edge_types_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_merge_edge_all_labels_denied_all_edge_types_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_denied_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_merge_edge_all_labels_granted_all_edge_types_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_granted_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_merge_edge_all_labels_granted_specific_edge_types_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(
admin_connection.cursor(),
"GRANT UPDATE ON EDGE_TYPES :edge_type TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_merge_edge_first_node_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_first_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label2 TO user;")
common.execute_and_fetch_all(
@ -444,17 +574,21 @@ def test_merge_edge_first_node_label_granted():
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_merge_edge_second_node_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_second_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
common.execute_and_fetch_all(
@ -462,64 +596,86 @@ def test_merge_edge_second_node_label_granted():
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
)
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
)
def test_set_label_when_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_set_label_when_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :update_label_2 TO user;")
if switch:
common.switch_db(user_connection.cursor())
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) SET p:update_label_2;")
def test_set_label_when_label_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_set_label_when_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) SET p:update_label_2;")
def test_remove_label_when_label_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_remove_label_when_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;")
if switch:
common.switch_db(user_connection.cursor())
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) REMOVE p:test_delete;")
def test_remove_label_when_label_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_remove_label_when_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;")
if switch:
common.switch_db(user_connection.cursor())
with pytest.raises(DatabaseError):
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) REMOVE p:test_delete;")
def test_merge_nodes_pass_when_having_create_delete():
@pytest.mark.parametrize("switch", [False, True])
def test_merge_nodes_pass_when_having_create_delete(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(
user_connection.cursor(),
"UNWIND [{id: '1', lat: 10, lng: 10}, {id: '2', lat: 10, lng: 10}, {id: '3', lat: 10, lng: 10}] AS row MERGE (o:Location {id: row.id}) RETURN o;",

View File

@ -1,91 +1,116 @@
import common
import sys
import common
import pytest
def test_all_edge_types_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
assert len(results) == 3
def test_deny_all_edge_types_and_all_labels():
@pytest.mark.parametrize("switch", [False, True])
def test_deny_all_edge_types_and_all_labels(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
assert len(results) == 0
def test_revoke_all_edge_types_and_all_labels():
@pytest.mark.parametrize("switch", [False, True])
def test_revoke_all_edge_types_and_all_labels(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
assert len(results) == 0
def test_deny_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_deny_edge_type(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1, :label2, :label3 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edgeType1 TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
assert len(results) == 2
def test_denied_node_label():
@pytest.mark.parametrize("switch", [False, True])
def test_denied_node_label(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label3 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label2 TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
assert len(results) == 2
def test_denied_one_of_node_label():
@pytest.mark.parametrize("switch", [False, True])
def test_denied_one_of_node_label(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label3 TO user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
assert len(results) == 1
def test_revoke_all_labels():
@pytest.mark.parametrize("switch", [False, True])
def test_revoke_all_labels(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
assert len(results) == 0
def test_revoke_all_edge_types():
@pytest.mark.parametrize("switch", [False, True])
def test_revoke_all_edge_types(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
assert len(results) == 0

View File

@ -1,22 +1,26 @@
import common
import sys
import common
import pytest
def test_weighted_shortest_path_all_edge_types_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
total_paths_results = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
)
path_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length,nodes(p);",
)
@ -47,24 +51,28 @@ def test_weighted_shortest_path_all_edge_types_all_labels_granted():
assert all(node.id in expected_path for node in path_result[0][1])
def test_weighted_shortest_path_all_edge_types_all_labels_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN p;"
user_connection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN p;"
)
assert len(results) == 0
def test_weighted_shortest_path_denied_start():
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -73,17 +81,20 @@ def test_weighted_shortest_path_denied_start():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
if switch:
common.switch_db(user_connection.cursor())
path_length_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;",
)
assert len(path_length_result) == 0
def test_weighted_shortest_path_denied_destination():
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -92,17 +103,20 @@ def test_weighted_shortest_path_denied_destination():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
if switch:
common.switch_db(user_connection.cursor())
path_length_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;",
)
assert len(path_length_result) == 0
def test_weighted_shortest_path_denied_label_1():
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -111,13 +125,15 @@ def test_weighted_shortest_path_denied_label_1():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
total_paths_results = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
)
path_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);",
)
@ -143,9 +159,10 @@ def test_weighted_shortest_path_denied_label_1():
assert all(node.id in expected_path for node in path_result[0][1])
def test_weighted_shortest_path_denied_edge_type_3():
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
@ -154,13 +171,15 @@ def test_weighted_shortest_path_denied_edge_type_3():
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
if switch:
common.switch_db(user_connection.cursor())
path_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);",
)
total_paths_results = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
)
@ -191,16 +210,19 @@ def test_weighted_shortest_path_denied_edge_type_3():
assert all(node.id in expected_path for node in path_result[0][1])
def test_dfs_all_edge_types_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_paths = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH path=(n:label0)-[* 1..3]->(m:label4) RETURN extract( node in nodes(path) | node.id);",
)
@ -210,22 +232,26 @@ def test_dfs_all_edge_types_all_labels_granted():
assert all(path[0] in expected_paths for path in source_destination_paths)
def test_dfs_all_edge_types_all_labels_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
total_paths_results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH p=(n)-[*]->(m) RETURN p;")
if switch:
common.switch_db(user_connection.cursor())
total_paths_results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH p=(n)-[*]->(m) RETURN p;")
assert len(total_paths_results) == 0
def test_dfs_denied_start():
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -234,16 +260,19 @@ def test_dfs_denied_start():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;"
user_connection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;"
)
assert len(source_destination_path) == 0
def test_dfs_denied_destination():
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -252,16 +281,19 @@ def test_dfs_denied_destination():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;"
user_connection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;"
)
assert len(source_destination_path) == 0
def test_dfs_denied_label_1():
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -269,8 +301,11 @@ def test_dfs_denied_label_1():
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_paths = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[* 1..3]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
)
@ -280,9 +315,10 @@ def test_dfs_denied_label_1():
assert all(path[0] in expected_paths for path in source_destination_paths)
def test_dfs_denied_edge_type_3():
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
@ -292,8 +328,10 @@ def test_dfs_denied_edge_type_3():
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r * 1..3]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
)
@ -303,16 +341,19 @@ def test_dfs_denied_edge_type_3():
assert source_destination_path[0][0] == expected_path
def test_bfs_sts_all_edge_types_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
)
@ -322,24 +363,28 @@ def test_bfs_sts_all_edge_types_all_labels_granted():
assert source_destination_path[0][0] == expected_path
def test_bfs_sts_all_edge_types_all_labels_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
total_paths_results = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n)-[r *BFS]->(m) RETURN p;"
user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n)-[r *BFS]->(m) RETURN p;"
)
assert len(total_paths_results) == 0
def test_bfs_sts_denied_start():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -348,16 +393,19 @@ def test_bfs_sts_denied_start():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
)
assert len(source_destination_path) == 0
def test_bfs_sts_denied_destination():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -366,16 +414,19 @@ def test_bfs_sts_denied_destination():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
)
assert len(source_destination_path) == 0
def test_bfs_sts_denied_label_1():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -383,8 +434,11 @@ def test_bfs_sts_denied_label_1():
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
)
expected_path = [0, 2, 4, 5]
@ -393,9 +447,10 @@ def test_bfs_sts_denied_label_1():
assert source_destination_path[0][0] == expected_path
def test_bfs_sts_denied_edge_type_3():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
@ -404,8 +459,10 @@ def test_bfs_sts_denied_edge_type_3():
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
)
expected_path = [0, 2, 4, 5]
@ -414,16 +471,19 @@ def test_bfs_sts_denied_edge_type_3():
assert source_destination_path[0][0] == expected_path
def test_bfs_single_source_all_edge_types_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
)
@ -433,22 +493,26 @@ def test_bfs_single_source_all_edge_types_all_labels_granted():
assert source_destination_path[0][0] == expected_path
def test_bfs_single_source_all_edge_types_all_labels_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
total_paths_results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH p=(n)-[r *BFS]->(m) RETURN p;")
if switch:
common.switch_db(user_connection.cursor())
total_paths_results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH p=(n)-[r *BFS]->(m) RETURN p;")
assert len(total_paths_results) == 0
def test_bfs_single_source_denied_start():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -457,16 +521,19 @@ def test_bfs_single_source_denied_start():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
)
assert len(source_destination_path) == 0
def test_bfs_single_source_denied_destination():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -475,16 +542,19 @@ def test_bfs_single_source_denied_destination():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
)
assert len(source_destination_path) == 0
def test_bfs_single_source_denied_label_1():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -492,8 +562,11 @@ def test_bfs_single_source_denied_label_1():
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
)
@ -503,9 +576,10 @@ def test_bfs_single_source_denied_label_1():
assert source_destination_path[0][0] == expected_path
def test_bfs_single_source_denied_edge_type_3():
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
@ -514,8 +588,10 @@ def test_bfs_single_source_denied_edge_type_3():
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
if switch:
common.switch_db(user_connection.cursor())
source_destination_path = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
)
@ -525,20 +601,23 @@ def test_bfs_single_source_denied_edge_type_3():
assert source_destination_path[0][0] == expected_path
def test_all_shortest_paths_when_all_edge_types_all_labels_granted():
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
total_paths_results = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
)
path_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length,nodes(p);",
)
@ -569,24 +648,28 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_granted():
assert all(node.id in expected_path for node in path_result[0][1])
def test_all_shortest_paths_when_all_edge_types_all_labels_denied():
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(
user_connnection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN p;"
user_connection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN p;"
)
assert len(results) == 0
def test_all_shortest_paths_when_denied_start():
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -595,17 +678,20 @@ def test_all_shortest_paths_when_denied_start():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
if switch:
common.switch_db(user_connection.cursor())
path_length_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;",
)
assert len(path_length_result) == 0
def test_all_shortest_paths_when_denied_destination():
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -614,17 +700,20 @@ def test_all_shortest_paths_when_denied_destination():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
if switch:
common.switch_db(user_connection.cursor())
path_length_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;",
)
assert len(path_length_result) == 0
def test_all_shortest_paths_when_denied_label_1():
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -633,13 +722,15 @@ def test_all_shortest_paths_when_denied_label_1():
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
if switch:
common.switch_db(user_connection.cursor())
total_paths_results = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
)
path_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);",
)
@ -665,9 +756,10 @@ def test_all_shortest_paths_when_denied_label_1():
assert all(node.id in expected_path for node in path_result[0][1])
def test_all_shortest_paths_when_denied_edge_type_3():
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connnection = common.connect(username="user", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
@ -676,13 +768,15 @@ def test_all_shortest_paths_when_denied_edge_type_3():
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
if switch:
common.switch_db(user_connection.cursor())
path_result = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);",
)
total_paths_results = common.execute_and_fetch_all(
user_connnection.cursor(),
user_connection.cursor(),
"MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
)

View File

@ -0,0 +1,36 @@
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import sys
import common
import pytest
from mgclient import DatabaseError
def test_show_databases_w_user():
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
user2_connection = common.connect(username="user2", password="test")
user3_connection = common.connect(username="user3", password="test")
assert common.execute_and_fetch_all(admin_connection.cursor(), "SHOW DATABASES") == [
("db1", ""),
("db2", ""),
("memgraph", "*"),
]
assert common.execute_and_fetch_all(user_connection.cursor(), "SHOW DATABASES") == [("db1", ""), ("memgraph", "*")]
assert common.execute_and_fetch_all(user2_connection.cursor(), "SHOW DATABASES") == [("db2", "*")]
assert common.execute_and_fetch_all(user3_connection.cursor(), "SHOW DATABASES") == [("db1", "*"), ("db2", "")]
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -9,7 +9,9 @@ create_delete_filtering_cluster: &create_delete_filtering_cluster
"CREATE USER admin IDENTIFIED BY 'test';",
"CREATE USER user IDENTIFIED BY 'test';",
"GRANT ALL PRIVILEGES TO admin;",
"GRANT DATABASE * TO admin;",
"GRANT ALL PRIVILEGES TO user;",
"GRANT DATABASE * TO user;",
]
edge_type_filtering_cluster: &edge_type_filtering_cluster
@ -22,7 +24,9 @@ edge_type_filtering_cluster: &edge_type_filtering_cluster
"CREATE USER admin IDENTIFIED BY 'test';",
"CREATE USER user IDENTIFIED BY 'test';",
"GRANT ALL PRIVILEGES TO admin;",
"GRANT DATABASE * TO admin;",
"GRANT ALL PRIVILEGES TO user;",
"GRANT DATABASE * TO user;",
"GRANT CREATE_DELETE ON LABELS * TO admin;",
"GRANT CREATE_DELETE ON EDGE_TYPES * TO admin;",
"MERGE (l1:label1 {name: 'test1'});",
@ -32,6 +36,17 @@ edge_type_filtering_cluster: &edge_type_filtering_cluster
"MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3);",
"MERGE (mix:label3:label1 {name: 'test4'});",
"MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix);",
"CREATE DATABASE clean;",
"USE DATABASE clean",
"MATCH (n) DETACH DELETE n;",
"MERGE (l1:label1 {name: 'test1'});",
"MERGE (l2:label2 {name: 'test2'});",
"MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2);",
"MERGE (l3:label3 {name: 'test3'});",
"MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3);",
"MERGE (mix:label3:label1 {name: 'test4'});",
"MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix);",
"USE DATABASE memgraph",
]
validation_queries: []
@ -45,7 +60,9 @@ path_filtering_cluster: &path_filtering_cluster
"CREATE USER admin IDENTIFIED BY 'test';",
"CREATE USER user IDENTIFIED BY 'test';",
"GRANT ALL PRIVILEGES TO admin;",
"GRANT DATABASE * TO admin;",
"GRANT ALL PRIVILEGES TO user;",
"GRANT DATABASE * TO user;",
"MERGE (a:label0 {id: 0}) MERGE (b:label1 {id: 1}) CREATE (a)-[:edge_type_1 {weight: 6}]->(b);",
"MERGE (a:label0 {id: 0}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_1 {weight: 14}]->(b);",
"MERGE (a:label1 {id: 1}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_2 {weight: 1}]->(b);",
@ -56,6 +73,47 @@ path_filtering_cluster: &path_filtering_cluster
"MERGE (a:label3 {id: 4}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);",
"MERGE (a:label3 {id: 3}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 14}]->(b);",
"MERGE (a:label3 {id: 4}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 8}]->(b);",
"CREATE DATABASE clean;",
"USE DATABASE clean",
"MATCH (n) DETACH DELETE n;",
"MERGE (a:label0 {id: 0}) MERGE (b:label1 {id: 1}) CREATE (a)-[:edge_type_1 {weight: 6}]->(b);",
"MERGE (a:label0 {id: 0}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_1 {weight: 14}]->(b);",
"MERGE (a:label1 {id: 1}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_2 {weight: 1}]->(b);",
"MERGE (a:label2 {id: 2}) MERGE (b:label3 {id: 4}) CREATE (a)-[:edge_type_2 {weight: 10}]->(b);",
"MERGE (a:label1 {id: 1}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_3 {weight: 5}]->(b);",
"MERGE (a:label2 {id: 2}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_3 {weight: 7}]->(b);",
"MERGE (a:label3 {id: 3}) MERGE (b:label3 {id: 4}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);",
"MERGE (a:label3 {id: 4}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);",
"MERGE (a:label3 {id: 3}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 14}]->(b);",
"MERGE (a:label3 {id: 4}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 8}]->(b);",
"USE DATABASE memgraph",
]
show_databases_w_user: &show_databases_w_user
cluster:
main:
args: ["--bolt-port", "7687", "--log-level=TRACE"]
log_file: "fine_grained_access.log"
setup_queries:
[
"CREATE USER admin IDENTIFIED BY 'test';",
"CREATE USER user IDENTIFIED BY 'test';",
"CREATE USER user2 IDENTIFIED BY 'test';",
"CREATE USER user3 IDENTIFIED BY 'test';",
"CREATE DATABASE db1;",
"CREATE DATABASE db2;",
"GRANT ALL PRIVILEGES TO admin;",
"GRANT DATABASE * TO admin;",
"GRANT ALL PRIVILEGES TO user;",
"GRANT DATABASE db1 TO user;",
"GRANT ALL PRIVILEGES TO user2;",
"GRANT DATABASE db2 TO user2;",
"REVOKE DATABASE memgraph FROM user2;",
"SET MAIN DATABASE db2 FOR user2",
"GRANT ALL PRIVILEGES TO user3;",
"GRANT DATABASE * TO user3;",
"REVOKE DATABASE memgraph FROM user3;",
"SET MAIN DATABASE db1 FOR user3",
]
workloads:
@ -63,7 +121,6 @@ workloads:
binary: "tests/e2e/pytest_runner.sh"
args: ["fine_grained_access/create_delete_filtering_tests.py"]
<<: *create_delete_filtering_cluster
- name: "EdgeType filtering"
binary: "tests/e2e/pytest_runner.sh"
args: ["fine_grained_access/edge_type_filtering_tests.py"]
@ -72,3 +129,7 @@ workloads:
binary: "tests/e2e/pytest_runner.sh"
args: ["fine_grained_access/path_filtering_tests.py"]
<<: *path_filtering_cluster
- name: "Show databases with users"
binary: "tests/e2e/pytest_runner.sh"
args: ["fine_grained_access/show_db.py"]
<<: *show_databases_w_user

View File

@ -21,7 +21,14 @@ const driver = neo4j.driver(
neo4j.auth.basic("", "")
);
const neoSchema = new Neo4jGraphQL({ typeDefs, driver });
const neoSchema = new Neo4jGraphQL({
typeDefs, driver,
config: {
driverConfig: {
database: "memgraph",
},
}
});
neoSchema.getSchema().then((schema) => {
const server = new ApolloServer({

View File

@ -9,12 +9,16 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <fmt/core.h>
#include <gflags/gflags.h>
#include <mgclient.hpp>
#include "query/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/timer.hpp"
#include <iostream>
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
@ -53,16 +57,55 @@ bool IsDiskStorageMode(std::unique_ptr<mg::Client> &client) {
return false;
}
void CleanDatabase() {
auto client = GetClient();
void CleanDatabase(std::unique_ptr<mg::Client> &client) {
MG_ASSERT(client->Execute("MATCH (n) DETACH DELETE n;"));
client->DiscardAll();
}
void SetupCleanDB() {
auto client = GetClient();
MG_ASSERT(client->Execute("USE DATABASE memgraph;"));
client->DiscardAll();
try {
client->Execute("DROP DATABASE clean;");
client->DiscardAll();
} catch (const mg::ClientException &) {
// In case clean doesn't exist
}
MG_ASSERT(client->Execute("CREATE DATABASE clean;"));
client->DiscardAll();
MG_ASSERT(client->Execute("USE DATABASE clean;"));
client->DiscardAll();
CleanDatabase(client);
}
void SwitchToDB(const std::string &name, std::unique_ptr<mg::Client> &client) {
MG_ASSERT(client->Execute(fmt::format("USE DATABASE {};", name)));
client->DiscardAll();
}
void SwitchToCleanDB(std::unique_ptr<mg::Client> &client) { SwitchToDB("clean", client); }
void SwitchToSameDB(std::unique_ptr<mg::Client> &main, std::unique_ptr<mg::Client> &client) {
MG_ASSERT(main->Execute("SHOW DATABASES;"));
auto dbs = main->FetchAll();
MG_ASSERT(dbs, "Failed to show databases");
for (const auto &elem : *dbs) {
MG_ASSERT(elem.size(), "Show databases wrong output");
const auto &active = elem[1].ValueString();
if (active == "*") {
const auto &name = elem[0].ValueString();
SwitchToDB(std::string(name), client);
break;
}
}
}
void TestSnapshotIsolation(std::unique_ptr<mg::Client> &client) {
spdlog::info("Verifying SNAPSHOT ISOLATION");
auto creator = GetClient();
SwitchToSameDB(client, creator);
MG_ASSERT(client->BeginTransaction());
MG_ASSERT(creator->BeginTransaction());
@ -89,13 +132,14 @@ void TestSnapshotIsolation(std::unique_ptr<mg::Client> &client) {
"at a later point.",
current_vertex_count, 0);
MG_ASSERT(client->CommitTransaction());
CleanDatabase();
CleanDatabase(creator);
}
void TestReadCommitted(std::unique_ptr<mg::Client> &client) {
spdlog::info("Verifying READ COMMITTED");
auto creator = GetClient();
SwitchToSameDB(client, creator);
MG_ASSERT(client->BeginTransaction());
MG_ASSERT(creator->BeginTransaction());
@ -121,13 +165,14 @@ void TestReadCommitted(std::unique_ptr<mg::Client> &client) {
"from a committed transaction",
current_vertex_count, vertex_count);
MG_ASSERT(client->CommitTransaction());
CleanDatabase();
CleanDatabase(creator);
}
void TestReadUncommitted(std::unique_ptr<mg::Client> &client) {
spdlog::info("Verifying READ UNCOMMITTED");
auto creator = GetClient();
SwitchToSameDB(client, creator);
MG_ASSERT(client->BeginTransaction());
MG_ASSERT(creator->BeginTransaction());
@ -152,18 +197,23 @@ void TestReadUncommitted(std::unique_ptr<mg::Client> &client) {
"from a different transaction",
current_vertex_count, vertex_count);
MG_ASSERT(client->CommitTransaction());
CleanDatabase();
CleanDatabase(creator);
}
inline constexpr std::array isolation_levels{std::pair{"SNAPSHOT ISOLATION", &TestSnapshotIsolation},
std::pair{"READ COMMITTED", &TestReadCommitted},
std::pair{"READ UNCOMMITTED", &TestReadUncommitted}};
void TestGlobalIsolationLevel(bool isDiskStorage) {
void TestGlobalIsolationLevel(bool isDiskStorage, bool mdb = false) {
spdlog::info("\n\n----Test global isolation levels----\n");
auto first_client = GetClient();
auto second_client = GetClient();
if (mdb) {
SwitchToCleanDB(first_client);
SwitchToCleanDB(second_client);
}
for (const auto &[isolation_level, verification_function] : isolation_levels) {
spdlog::info("--------------------------");
@ -183,11 +233,17 @@ void TestGlobalIsolationLevel(bool isDiskStorage) {
}
}
void TestSessionIsolationLevel(bool isDiskStorage) {
void TestSessionIsolationLevel(bool isDiskStorage, bool mdb = false) {
spdlog::info("\n\n----Test session isolation levels----\n");
auto global_client = GetClient();
auto session_client = GetClient();
if (mdb) {
SwitchToCleanDB(global_client);
SwitchToCleanDB(session_client);
}
for (const auto &[global_isolation_level, global_verification_function] : isolation_levels) {
if (isDiskStorage && strcmp(global_isolation_level, "SNAPSHOT ISOLATION") != 0) {
spdlog::info("Skipping for disk storage unsupported global isolation level {}", global_isolation_level);
@ -218,11 +274,17 @@ void TestSessionIsolationLevel(bool isDiskStorage) {
}
// Priority of applying the isolation level from highest priority NEXT -> SESSION -> GLOBAL
void TestNextIsolationLevel(bool isDiskStorage) {
void TestNextIsolationLevel(bool isDiskStorage, bool mdb = false) {
spdlog::info("\n\n----Test next isolation levels----\n");
auto global_client = GetClient();
auto session_client = GetClient();
if (mdb) {
SwitchToCleanDB(global_client);
SwitchToCleanDB(session_client);
}
for (const auto &[global_isolation_level, global_verification_function] : isolation_levels) {
if (isDiskStorage && strcmp(global_isolation_level, "SNAPSHOT ISOLATION") != 0) {
spdlog::info("Skipping for disk storage unsupported global isolation level {}", global_isolation_level);
@ -289,10 +351,21 @@ int main(int argc, char **argv) {
auto client = GetClient();
bool isDiskStorage = IsDiskStorageMode(client);
client->DiscardAll();
bool multiDB = false;
TestGlobalIsolationLevel(isDiskStorage);
TestSessionIsolationLevel(isDiskStorage);
TestNextIsolationLevel(isDiskStorage);
// MultiDB tests
multiDB = true;
spdlog::info("--------------------------");
spdlog::info("---- RUNNING MULTI DB ----");
spdlog::info("--------------------------");
SetupCleanDB();
TestGlobalIsolationLevel(isDiskStorage, multiDB);
TestSessionIsolationLevel(isDiskStorage, multiDB);
TestNextIsolationLevel(isDiskStorage, multiDB);
return 0;
}

View File

@ -1,4 +1,4 @@
# Copyright 2021 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,9 +9,10 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgclient
import typing
import mgclient
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
cursor.execute(query, params)
@ -24,6 +25,19 @@ def connect(**kwargs) -> mgclient.Connection:
return connection
def switch_db(cursor):
execute_and_fetch_all(cursor, "USE DATABASE clean;")
def create_multi_db(cursor):
execute_and_fetch_all(cursor, "USE DATABASE memgraph;")
try:
execute_and_fetch_all(cursor, "DROP DATABASE clean;")
except:
pass
execute_and_fetch_all(cursor, "CREATE DATABASE clean;")
def reset_permissions(admin_cursor: mgclient.Cursor, create_index: bool = False):
execute_and_fetch_all(admin_cursor, "REVOKE LABELS * FROM user;")
execute_and_fetch_all(admin_cursor, "REVOKE EDGE_TYPES * FROM user;")

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,15 +9,10 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import pytest
import sys
from common import (
connect,
execute_and_fetch_all,
mgclient,
reset_create_delete_permissions,
)
import pytest
from common import *
AUTHORIZATION_ERROR_IDENTIFIER = "AuthorizationError"
@ -28,55 +23,83 @@ create_edge_query = "MATCH (n:create_delete_label_1), (m:create_delete_label_2)
delete_edge_query = "CALL create_delete.delete_edge() YIELD * RETURN *;"
def test_can_not_create_vertex_when_given_nothing():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_create_vertex_when_given_nothing(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, create_vertex_query)
def test_can_create_vertex_when_given_global_create_delete():
@pytest.mark.parametrize("switch", [False, True])
def test_can_create_vertex_when_given_global_create_delete(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, create_vertex_query)
len(result[0][0]) == 1
def test_can_not_create_vertex_when_given_global_read():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_create_vertex_when_given_global_read(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, create_vertex_query)
def test_can_not_create_vertex_when_given_global_update():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_create_vertex_when_given_global_update(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :create_delete_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, create_vertex_query)
def test_can_add_vertex_label_when_given_create_delete():
@pytest.mark.parametrize("switch", [False, True])
def test_can_add_vertex_label_when_given_create_delete(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(
@ -85,14 +108,20 @@ def test_can_add_vertex_label_when_given_create_delete():
)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_label_vertex_query)
assert "create_delete_label" in result[0][0]
assert "new_create_delete_label" in result[0][0]
def test_can_not_add_vertex_label_when_given_update():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_add_vertex_label_when_given_update(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(
@ -100,12 +129,18 @@ def test_can_not_add_vertex_label_when_given_update():
)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, set_label_vertex_query)
def test_can_not_add_vertex_label_when_given_read():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_add_vertex_label_when_given_read(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(
@ -113,118 +148,178 @@ def test_can_not_add_vertex_label_when_given_read():
)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, set_label_vertex_query)
def test_can_remove_vertex_label_when_given_create_delete():
@pytest.mark.parametrize("switch", [False, True])
def test_can_remove_vertex_label_when_given_create_delete(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :create_delete_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, remove_label_vertex_query)
assert result[0][0] != ":create_delete_label"
def test_can_remove_vertex_label_when_given_global_create_delete():
@pytest.mark.parametrize("switch", [False, True])
def test_can_remove_vertex_label_when_given_global_create_delete(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, remove_label_vertex_query)
assert result[0][0] != ":create_delete_label"
def test_can_not_remove_vertex_label_when_given_update():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_remove_vertex_label_when_given_update(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :create_delete_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, remove_label_vertex_query)
def test_can_not_remove_vertex_label_when_given_global_update():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_remove_vertex_label_when_given_global_update(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, remove_label_vertex_query)
def test_can_not_remove_vertex_label_when_given_read():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_remove_vertex_label_when_given_read(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :create_delete_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, remove_label_vertex_query)
def test_can_not_remove_vertex_label_when_given_global_read():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_remove_vertex_label_when_given_global_read(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, remove_label_vertex_query)
def test_can_not_create_edge_when_given_nothing():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_create_edge_when_given_nothing(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, create_edge_query)
def test_can_not_create_edge_when_given_read():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_create_edge_when_given_read(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :new_create_delete_edge_type TO user")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, create_edge_query)
def test_can_not_create_edge_when_given_update():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_create_edge_when_given_update(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :new_create_delete_edge_type TO user")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, create_edge_query)
def test_can_create_edge_when_given_create_delete():
@pytest.mark.parametrize("switch", [False, True])
def test_can_create_edge_when_given_create_delete(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(
@ -233,24 +328,36 @@ def test_can_create_edge_when_given_create_delete():
)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
no_of_edges = execute_and_fetch_all(test_cursor, create_edge_query)
assert no_of_edges[0][0] == 2
def test_can_not_delete_edge_when_given_nothing():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_delete_edge_when_given_nothing(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, delete_edge_query)
def test_can_not_delete_edge_when_given_read():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_delete_edge_when_given_read(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(
@ -259,13 +366,19 @@ def test_can_not_delete_edge_when_given_read():
)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, delete_edge_query)
def test_can_not_delete_edge_when_given_update():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_delete_edge_when_given_update(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(
@ -274,13 +387,19 @@ def test_can_not_delete_edge_when_given_update():
)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
execute_and_fetch_all(test_cursor, delete_edge_query)
def test_can_delete_edge_when_given_create_delete():
@pytest.mark.parametrize("switch", [False, True])
def test_can_delete_edge_when_given_create_delete(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_create_delete_permissions(admin_cursor)
execute_and_fetch_all(
@ -289,6 +408,8 @@ def test_can_delete_edge_when_given_create_delete():
)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
no_of_edges = execute_and_fetch_all(test_cursor, delete_edge_query)

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,12 +9,11 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import pytest
import sys
from typing import List
from common import connect, execute_and_fetch_all, reset_permissions
import pytest
from common import *
match_query = "MATCH (n) RETURN n;"
match_by_id_query = "MATCH (n) WHERE ID(n) >= 0 RETURN n;"
@ -105,11 +104,16 @@ def get_user_cursor():
def execute_read_node_assertion(
operation_case: List[str], queries: List[str], create_index: bool, expected_size: int
operation_case: List[str], queries: List[str], create_index: bool, expected_size: int, switch: bool
) -> None:
admin_cursor = get_admin_cursor()
user_cursor = get_user_cursor()
if switch:
create_multi_db(admin_cursor)
switch_db(admin_cursor)
switch_db(user_cursor)
reset_permissions(admin_cursor, create_index)
for operation in operation_case:
@ -120,7 +124,8 @@ def execute_read_node_assertion(
assert len(results) == expected_size
def test_can_read_node_when_authorized():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_node_when_authorized(switch):
match_queries_without_index = [match_query, match_by_id_query]
match_queries_with_index = [
match_by_label_query,
@ -132,14 +137,15 @@ def test_can_read_node_when_authorized():
for expected_size, operation_case in zip(
read_node_without_index_operation_cases_expected_size, read_node_without_index_operation_cases
):
execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size)
execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size, switch)
for expected_size, operation_case in zip(
read_node_with_index_operation_cases_expected_sizes, read_node_with_index_operation_cases
):
execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size)
execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size, switch)
def test_can_not_read_node_when_authorized():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_read_node_when_authorized(switch):
match_queries_without_index = [match_query, match_by_id_query]
match_queries_with_index = [
match_by_label_query,
@ -151,11 +157,11 @@ def test_can_not_read_node_when_authorized():
for expected_size, operation_case in zip(
not_read_node_without_index_operation_cases_expected_sizes, not_read_node_without_index_operation_cases
):
execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size)
execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size, switch)
for expected_size, operation_case in zip(
not_read_node_with_index_operation_cases_expexted_sizes, not_read_node_with_index_operation_cases
):
execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size)
execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size, switch)
if __name__ == "__main__":

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -10,177 +10,260 @@
# licenses/APL.txt.
import sys
import pytest
from common import connect, execute_and_fetch_all, reset_permissions
from common import *
get_number_of_vertices_query = "CALL read.number_of_visible_nodes() YIELD nr_of_nodes RETURN nr_of_nodes;"
get_number_of_edges_query = "CALL read.number_of_visible_edges() YIELD nr_of_edges RETURN nr_of_edges;"
def test_can_read_vertex_through_c_api_when_given_grant_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_vertex_through_c_api_when_given_grant_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
assert result[0][0] == 1
def test_can_read_vertex_through_c_api_when_given_update_grant_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_vertex_through_c_api_when_given_update_grant_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :read_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
assert result[0][0] == 1
def test_can_read_vertex_through_c_api_when_given_create_delete_grant_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_vertex_through_c_api_when_given_create_delete_grant_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :read_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
assert result[0][0] == 1
def test_can_not_read_vertex_through_c_api_when_given_nothing():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_read_vertex_through_c_api_when_given_nothing(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
assert result[0][0] == 0
def test_can_not_read_vertex_through_c_api_when_given_deny_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_read_vertex_through_c_api_when_given_deny_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
assert result[0][0] == 0
def test_can_read_partial_vertices_through_c_api_when_given_global_read_but_deny_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_partial_vertices_through_c_api_when_given_global_read_but_deny_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;")
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
assert result[0][0] == 2
def test_can_read_partial_vertices_through_c_api_when_given_global_update_but_deny_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_partial_vertices_through_c_api_when_given_global_update_but_deny_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;")
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
assert result[0][0] == 2
def test_can_read_partial_vertices_through_c_api_when_given_global_create_delete_but_deny_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_partial_vertices_through_c_api_when_given_global_create_delete_but_deny_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;")
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
assert result[0][0] == 2
def test_can_read_edge_through_c_api_when_given_grant_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_edge_through_c_api_when_given_grant_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :read_edge_type TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
assert result[0][0] == 1
def test_can_not_read_edge_through_c_api_when_given_deny_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_read_edge_through_c_api_when_given_deny_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON EDGE_TYPES :read_edge_type TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
assert result[0][0] == 0
def test_can_read_edge_through_c_api_when_given_grant_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_edge_through_c_api_when_given_grant_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :read_edge_type TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
assert result[0][0] == 1
def test_can_read_edge_through_c_api_when_given_update_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_edge_through_c_api_when_given_update_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :read_edge_type TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
assert result[0][0] == 1
def test_can_read_edge_through_c_api_when_given_create_delete_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_read_edge_through_c_api_when_given_create_delete_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES :read_edge_type TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
assert result[0][0] == 1
def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
@ -188,13 +271,19 @@ def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
assert result[0][0] == 0
def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
@ -202,13 +291,19 @@ def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_ed
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
assert result[0][0] == 0
def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_deny_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_deny_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
@ -216,6 +311,8 @@ def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_den
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
assert result[0][0] == 0

View File

@ -38,6 +38,8 @@ BASIC_PRIVILEGES = [
"MODULE_WRITE",
"TRANSACTION_MANAGEMENT",
"STORAGE_MODE",
"MULTI_DATABASE_EDIT",
"MULTI_DATABASE_USE",
]
@ -61,7 +63,7 @@ def test_lba_procedures_show_privileges_first_user():
cursor = connect(username="Josip", password="").cursor()
result = execute_and_fetch_all(cursor, "SHOW PRIVILEGES FOR Josip;")
assert len(result) == 32
assert len(result) == 34
fine_privilege_results = [res for res in result if res[0] not in BASIC_PRIVILEGES]

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,107 +9,149 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import pytest
import sys
from common import (
connect,
execute_and_fetch_all,
reset_update_permissions,
)
import pytest
from common import *
set_vertex_property_query = "MATCH (n:update_label) CALL update.set_property(n) YIELD * RETURN n.prop;"
set_edge_property_query = "MATCH (n:update_label_1)-[r:update_edge_type]->(m:update_label_2) CALL update.set_property(r) YIELD * RETURN r.prop;"
def test_can_not_update_vertex_when_given_read():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_update_vertex_when_given_read(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
assert result[0][0] == 1
def test_can_update_vertex_when_given_update_grant_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_update_vertex_when_given_update_grant_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :update_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
assert result[0][0] == 2
def test_can_update_vertex_when_given_create_delete_grant_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_update_vertex_when_given_create_delete_grant_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :update_label TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
assert result[0][0] == 2
def test_can_update_vertex_when_given_update_global_grant_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_update_vertex_when_given_update_global_grant_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
assert result[0][0] == 2
def test_can_update_vertex_when_given_create_delete_global_grant_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_update_vertex_when_given_create_delete_global_grant_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
assert result[0][0] == 2
def test_can_not_update_vertex_when_denied_update_and_granted_global_update_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_update_vertex_when_denied_update_and_granted_global_update_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;")
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
assert result[0][0] == 1
def test_can_not_update_vertex_when_denied_update_and_granted_global_create_delete_on_label():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_update_vertex_when_denied_update_and_granted_global_create_delete_on_label(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;")
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
assert result[0][0] == 1
def test_can_update_edge_when_given_update_grant_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_update_edge_when_given_update_grant_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
@ -117,13 +159,19 @@ def test_can_update_edge_when_given_update_grant_on_edge_type():
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :update_edge_type TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
assert result[0][0] == 2
def test_can_not_update_edge_when_given_read_grant_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_update_edge_when_given_read_grant_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
@ -131,13 +179,19 @@ def test_can_not_update_edge_when_given_read_grant_on_edge_type():
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :update_edge_type TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
assert result[0][0] == 1
def test_can_update_edge_when_given_create_delete_grant_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_update_edge_when_given_create_delete_grant_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
@ -145,13 +199,19 @@ def test_can_update_edge_when_given_create_delete_grant_on_edge_type():
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES :update_edge_type TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
assert result[0][0] == 2
def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_update_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_update_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
@ -160,13 +220,19 @@ def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_upd
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
assert result[0][0] == 1
def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_create_delete_on_edge_type():
@pytest.mark.parametrize("switch", [False, True])
def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_create_delete_on_edge_type(switch):
admin_cursor = connect(username="admin", password="test").cursor()
create_multi_db(admin_cursor)
if switch:
switch_db(admin_cursor)
reset_update_permissions(admin_cursor)
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
@ -175,6 +241,8 @@ def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_cre
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES * TO user;")
test_cursor = connect(username="user", password="test").cursor()
if switch:
switch_db(test_cursor)
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
assert result[0][0] == 1

View File

@ -6,8 +6,10 @@ read_query_modules_cluster: &read_query_modules_cluster
setup_queries:
- "CREATE USER admin IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO admin"
- "GRANT DATABASE * TO admin"
- "CREATE USER user IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO user"
- "GRANT DATABASE * TO user"
validation_queries: []
update_query_modules_cluster: &update_query_modules_cluster
@ -18,8 +20,10 @@ update_query_modules_cluster: &update_query_modules_cluster
setup_queries:
- "CREATE USER admin IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO admin"
- "GRANT DATABASE * TO admin"
- "CREATE USER user IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO user"
- "GRANT DATABASE * TO user"
validation_queries: []
show_privileges_cluster: &show_privileges_cluster
@ -67,8 +71,10 @@ read_permission_queries: &read_permission_queries
setup_queries:
- "CREATE USER admin IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO admin"
- "GRANT DATABASE * TO admin"
- "CREATE USER user IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO user"
- "GRANT DATABASE * TO user"
validation_queries: []
create_delete_query_modules_cluster: &create_delete_query_modules_cluster
@ -79,8 +85,10 @@ create_delete_query_modules_cluster: &create_delete_query_modules_cluster
setup_queries:
- "CREATE USER admin IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO admin;"
- "GRANT DATABASE * TO admin"
- "CREATE USER user IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO user;"
- "GRANT DATABASE * TO user"
validation_queries: []
update_permission_queries_cluster: &update_permission_queries_cluster
@ -91,8 +99,10 @@ update_permission_queries_cluster: &update_permission_queries_cluster
setup_queries:
- "CREATE USER admin IDENTIFIED BY 'test';"
- "GRANT ALL PRIVILEGES TO admin;"
- "GRANT DATABASE * TO admin"
- "CREATE USER user IDENTIFIED BY 'test'"
- "GRANT ALL PRIVILEGES TO user;"
- "GRANT DATABASE * TO user"
validation_queries: []
workloads:

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,16 +9,29 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import typing
import mgclient
import sys
import typing
import mgclient
import pytest
from common import execute_and_fetch_all, has_n_result_row
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_argument(connection, function_type):
@pytest.fixture(scope="function")
def multi_db(request, connection):
cursor = connection.cursor()
if request.param:
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
execute_and_fetch_all(cursor, "USE DATABASE clean")
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
pass
yield connection
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_argument(multi_db, function_type):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
result = execute_and_fetch_all(
@ -31,9 +44,10 @@ def test_return_argument(connection, function_type):
assert vertex.properties == {"id": 1}
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_optional_argument(connection, function_type):
cursor = connection.cursor()
def test_return_optional_argument(multi_db, function_type):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
@ -44,9 +58,10 @@ def test_return_optional_argument(connection, function_type):
assert result == 42
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_optional_argument_no_arg(connection, function_type):
cursor = connection.cursor()
def test_return_optional_argument_no_arg(multi_db, function_type):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
@ -57,9 +72,10 @@ def test_return_optional_argument_no_arg(connection, function_type):
assert result == 42
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_add_two_numbers(connection, function_type):
cursor = connection.cursor()
def test_add_two_numbers(multi_db, function_type):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
@ -70,9 +86,10 @@ def test_add_two_numbers(connection, function_type):
assert result_sum == 6
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_null(connection, function_type):
cursor = connection.cursor()
def test_return_null(multi_db, function_type):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
@ -82,9 +99,10 @@ def test_return_null(connection, function_type):
assert result_null is None
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_too_many_arguments(connection, function_type):
cursor = connection.cursor()
def test_too_many_arguments(multi_db, function_type):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
# Should raise too many arguments
with pytest.raises(mgclient.DatabaseError):
@ -94,9 +112,10 @@ def test_too_many_arguments(connection, function_type):
)
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_try_to_write(connection, function_type):
cursor = connection.cursor()
def test_try_to_write(multi_db, function_type):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
# Should raise non mutable
@ -106,9 +125,11 @@ def test_try_to_write(connection, function_type):
f"MATCH (n) RETURN {function_type}_write.try_to_write(n, 'property', 1);",
)
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_case_sensitivity(connection, function_type):
cursor = connection.cursor()
def test_case_sensitivity(multi_db, function_type):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
# Should raise function does not exist
with pytest.raises(mgclient.DatabaseError):

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -17,6 +17,7 @@
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Memory Control");
@ -34,6 +35,15 @@ int main(int argc, char **argv) {
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
const auto *create_query = "UNWIND range(1, 50) as u CREATE (n {string: \"Some longer string\"}) RETURN n;";
memgraph::utils::Timer timer;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -17,6 +17,7 @@
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Memory Limit For Global Allocators");
@ -31,6 +32,15 @@ int main(int argc, char **argv) {
LOG_FATAL("Failed to connect!");
}
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
bool result = client->Execute("CALL libglobal_memory_limit.procedure() YIELD *");
MG_ASSERT(result == false);
return 0;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,6 +18,7 @@
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Memory Limit For Global Allocators");
@ -31,6 +32,16 @@ int main(int argc, char **argv) {
if (!client) {
LOG_FATAL("Failed to connect!");
}
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
MG_ASSERT(client->Execute("CALL libglobal_memory_limit_proc.error() YIELD *"));
MG_ASSERT(std::invoke([&] {
try {

View File

@ -21,14 +21,31 @@ workloads:
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
<<: *template_cluster
- name: "Memory control multi database"
binary: "tests/e2e/memory/memgraph__e2e__memory__control"
args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"]
<<: *template_cluster
- name: "Memory limit for modules upon loading"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc"
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
proc: "tests/e2e/memory/procedures/"
<<: *template_cluster
- name: "Memory limit for modules upon loading multi database"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc"
args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"]
proc: "tests/e2e/memory/procedures/"
<<: *template_cluster
- name: "Memory limit for modules inside a procedure"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc_proc"
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
proc: "tests/e2e/memory/procedures/"
<<: *template_cluster
- name: "Memory limit for modules inside a procedure multi database"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc_proc"
args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"]
proc: "tests/e2e/memory/procedures/"
<<: *template_cluster

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -21,6 +21,7 @@
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
namespace {
auto GetClient() {
@ -181,6 +182,15 @@ int main(int argc, char **argv) {
mg::Client::Init();
auto client = GetClient();
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
AssertQueryFails<mg::ClientException>(client, CreateModuleFileQuery("some.cpp", "some content"),
"mg.create_module_file: The specified file isn't in the supported format.");

View File

@ -12,3 +12,8 @@ workloads:
binary: "tests/e2e/module_file_manager/memgraph__e2e__module_file_manager"
args: ["--bolt-port", *bolt_port]
<<: *template_cluster
- name: "Module File Manager multi database"
binary: "tests/e2e/module_file_manager/memgraph__e2e__module_file_manager"
args: ["--bolt-port", *bolt_port, "--multi-db", "true"]
<<: *template_cluster

View File

@ -14,6 +14,19 @@ import typing
import mgclient
def switch_db(cursor):
execute_and_fetch_all(cursor, "USE DATABASE clean;")
def create_multi_db(cursor):
execute_and_fetch_all(cursor, "USE DATABASE memgraph;")
try:
execute_and_fetch_all(cursor, "DROP DATABASE clean;")
except:
pass
execute_and_fetch_all(cursor, "CREATE DATABASE clean;")
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
cursor.execute(query, params)
return cursor.fetchall()

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -14,7 +14,7 @@ import os # To be removed
import sys
import pytest
from common import connect, execute_and_fetch_all
from common import connect, create_multi_db, execute_and_fetch_all, switch_db
COMMON_PATH_PREFIX_TEST1 = "procedures/mage/test_module"
COMMON_PATH_PREFIX_TEST2 = "procedures/new_test_module_utils"
@ -76,9 +76,13 @@ def postprocess_functions(path1: str, path2: str):
)
def test_mg_load_reload_submodule_root_utils():
@pytest.mark.parametrize("switch", [False, True])
def test_mg_load_reload_submodule_root_utils(switch):
"""Tests whether mg.load reloads content of some submodule code."""
cursor = connect().cursor()
if switch:
create_multi_db(cursor)
switch_db(cursor)
# First do a simple experiment
test_module_res = execute_and_fetch_all(cursor, "CALL new_test_module.test(10, 2) YIELD * RETURN *;")
try:
@ -101,9 +105,13 @@ def test_mg_load_reload_submodule_root_utils():
execute_and_fetch_all(cursor, "CALL mg.load('new_test_module');")
def test_mg_load_all_reload_submodule_root_utils():
@pytest.mark.parametrize("switch", [False, True])
def test_mg_load_all_reload_submodule_root_utils(switch):
"""Tests whether mg.load_all reloads content of some submodule code"""
cursor = connect().cursor()
if switch:
create_multi_db(cursor)
switch_db(cursor)
# First do a simple experiment
test_module_res = execute_and_fetch_all(cursor, "CALL new_test_module.test(10, 2) YIELD * RETURN *;")
try:
@ -126,9 +134,13 @@ def test_mg_load_all_reload_submodule_root_utils():
execute_and_fetch_all(cursor, "CALL mg.load_all();")
def test_mg_load_reload_submodule():
@pytest.mark.parametrize("switch", [False, True])
def test_mg_load_reload_submodule(switch):
"""Tests whether mg.load reloads content of some submodule code."""
cursor = connect().cursor()
if switch:
create_multi_db(cursor)
switch_db(cursor)
# First do a simple experiment
test_module_res = execute_and_fetch_all(cursor, "CALL test_module.test(10, 2) YIELD * RETURN *;")
try:
@ -151,9 +163,13 @@ def test_mg_load_reload_submodule():
execute_and_fetch_all(cursor, "CALL mg.load('test_module');")
def test_mg_load_all_reload_submodule():
@pytest.mark.parametrize("switch", [False, True])
def test_mg_load_all_reload_submodule(switch):
"""Tests whether mg.load_all reloads content of some submodule code"""
cursor = connect().cursor()
if switch:
create_multi_db(cursor)
switch_db(cursor)
# First do a simple experiment
test_module_res = execute_and_fetch_all(cursor, "CALL test_module.test(10, 2) YIELD * RETURN *;")
try:

View File

@ -12,7 +12,6 @@
import typing
import mgclient
import pytest
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:

View File

@ -12,7 +12,6 @@
import multiprocessing
import sys
import threading
import time
from typing import List
@ -56,12 +55,28 @@ def test_self_transaction():
assert len(results) == 1
def test_multitenant_transactions():
"""Tests that show transactions work on another database"""
test_cursor = connect().cursor()
execute_and_fetch_all(test_cursor, "CREATE DATABASE testing")
tx_connection = connect()
tx_cursor = tx_connection.cursor()
tx_process = multiprocessing.Process(
target=process_function, args=(tx_cursor, ["USE DATABASE testing", "MATCH (n) RETURN n"])
)
tx_process.start()
time.sleep(0.5)
show_transactions_test(test_cursor, 1)
# TODO Add SHOW TRANSACTIONS ON * that should return all transactions
def test_admin_has_one_transaction():
"""Creates admin and tests that he sees only one transaction."""
# a_cursor is used for creating admin user, simulates main thread
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
admin_cursor = connect(username="admin", password="").cursor()
process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1))
process.start()
@ -74,6 +89,7 @@ def test_user_can_see_its_transaction():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user")
user_cursor = connect(username="user", password="").cursor()
@ -89,6 +105,7 @@ def test_explicit_transaction_output():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
admin_connection = connect(username="admin", password="")
admin_cursor = admin_connection.cursor()
# Admin starts running explicit transaction
@ -114,8 +131,10 @@ def test_superadmin_cannot_see_admin_can_see_admin():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin1")
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2")
# Admin starts running infinite query
admin_connection_1 = connect(username="admin1", password="")
admin_cursor_1 = admin_connection_1.cursor()
@ -153,6 +172,7 @@ def test_admin_sees_superadmin():
superadmin_cursor = superadmin_connection.cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
# Admin starts running infinite query
process = multiprocessing.Process(
target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
@ -183,6 +203,7 @@ def test_admin_can_see_user_transaction():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
# Admin starts running infinite query
admin_connection = connect(username="admin", password="")
@ -220,8 +241,10 @@ def test_user_cannot_see_admin_transaction():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin1")
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
admin_connection_1 = connect(username="admin1", password="")
admin_cursor_1 = admin_connection_1.cursor()
@ -282,6 +305,7 @@ def test_admin_killing_multiple_non_existing_transactions():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
# Connect with admin
admin_cursor = connect(username="admin", password="").cursor()
transactions_id = ["'1'", "'2'", "'3'"]
@ -298,6 +322,7 @@ def test_user_killing_some_transactions():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user1")
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user1")

View File

@ -24,11 +24,14 @@ def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}
def connect(**kwargs) -> mgclient.Connection:
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
connection.autocommit = True
execute_and_fetch_all(connection.cursor(), "USE DATABASE memgraph")
try:
execute_and_fetch_all(connection.cursor(), "DROP DATABASE clean")
except:
pass
execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n")
triggers_list = execute_and_fetch_all(connection.cursor(), "SHOW TRIGGERS;")
for trigger in triggers_list:
execute_and_fetch_all(connection.cursor(), f"DROP TRIGGER {trigger[0]}")
execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n")
yield connection
for trigger in triggers_list:
execute_and_fetch_all(connection.cursor(), f"DROP TRIGGER {trigger[0]}")
execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n")

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -16,13 +16,25 @@ import pytest
from common import connect, execute_and_fetch_all
@pytest.fixture(scope="function")
def multi_db(request, connect):
cursor = connect.cursor()
if request.param:
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
execute_and_fetch_all(cursor, "USE DATABASE clean")
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
pass
yield connect
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("ba_commit", ["BEFORE COMMIT", "AFTER COMMIT"])
def test_create_on_create(ba_commit, connect):
def test_create_on_create(ba_commit, multi_db):
"""
Args:
ba_commit (str): BEFORE OR AFTER commit
"""
cursor = connect.cursor()
cursor = multi_db.cursor()
QUERY_TRIGGER_CREATE = f"""
CREATE TRIGGER CreateTriggerEdgesCount
ON --> CREATE
@ -30,6 +42,7 @@ def test_create_on_create(ba_commit, connect):
EXECUTE
CREATE (n:CreatedEdge {{count: size(createdEdges)}})
"""
execute_and_fetch_all(cursor, QUERY_TRIGGER_CREATE)
execute_and_fetch_all(cursor, "CREATE (n:Node {id: 1})")
execute_and_fetch_all(cursor, "CREATE (n:Node {id: 2})")
@ -50,14 +63,22 @@ def test_create_on_create(ba_commit, connect):
# execute_and_fetch_all(cursor, "DROP TRIGGER CreateTriggerEdgesCount")
# execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
# check that there is no cross contamination between databases
nodes = execute_and_fetch_all(cursor, "SHOW DATABASES")
if len(nodes) == 2: # multi db mode
execute_and_fetch_all(cursor, "USE DATABASE memgraph")
created_edges = execute_and_fetch_all(cursor, "MATCH (n:CreatedEdge) RETURN n")
assert len(created_edges) == 0
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("ba_commit", ["AFTER COMMIT", "BEFORE COMMIT"])
def test_create_on_delete(ba_commit, connect):
def test_create_on_delete(ba_commit, multi_db):
"""
Args:
ba_commit (str): BEFORE OR AFTER commit
"""
cursor = connect.cursor()
cursor = multi_db.cursor()
QUERY_TRIGGER_CREATE = f"""
CREATE TRIGGER DeleteTriggerEdgesCount
ON --> DELETE
@ -102,7 +123,15 @@ def test_create_on_delete(ba_commit, connect):
# execute_and_fetch_all(cursor, "DROP TRIGGER DeleteTriggerEdgesCount")
# execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")``
# check that there is no cross contamination between databases
nodes = execute_and_fetch_all(cursor, "SHOW DATABASES")
if len(nodes) == 2: # multi db mode
execute_and_fetch_all(cursor, "USE DATABASE memgraph")
created_edges = execute_and_fetch_all(cursor, "MATCH (n:CreatedEdge) RETURN n")
assert len(created_edges) == 0
# @pytest.mark.parametrize("multi_db", [False, True], indirect=True)
@pytest.mark.parametrize("ba_commit", ["BEFORE COMMIT", "AFTER COMMIT"])
def test_create_on_delete_explicit_transaction(ba_commit):
"""

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,13 +9,25 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import typing
import mgclient
import sys
import typing
import mgclient
import pytest
from common import execute_and_fetch_all, has_n_result_row
@pytest.fixture(scope="function")
def multi_db(request, connection):
cursor = connection.cursor()
if request.param:
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
execute_and_fetch_all(cursor, "USE DATABASE clean")
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
pass
yield connection
def create_subgraph(cursor):
execute_and_fetch_all(cursor, "CREATE (n:Person {id: 1});")
execute_and_fetch_all(cursor, "CREATE (n:Person {id: 2});")
@ -41,8 +53,9 @@ def create_smaller_subgraph(cursor):
execute_and_fetch_all(cursor, "MATCH (p:Person {id: 2}) MATCH (t:Team {id:6}) CREATE (p)-[:SUPPORTS]->(t);")
def test_is_callable(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_is_callable(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
@ -59,8 +72,9 @@ def test_is_callable(connection):
)
def test_incorrect_graph_argument_placement(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_incorrect_graph_argument_placement(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
@ -79,8 +93,9 @@ def test_incorrect_graph_argument_placement(connection):
)
def test_get_vertices(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_get_vertices(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
@ -97,8 +112,9 @@ def test_get_vertices(connection):
)
def test_get_out_edges(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_get_out_edges(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
@ -115,8 +131,9 @@ def test_get_out_edges(connection):
)
def test_get_in_edges(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_get_in_edges(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
@ -133,8 +150,9 @@ def test_get_in_edges(connection):
)
def test_get_2_hop_edges(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_get_2_hop_edges(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
@ -150,8 +168,9 @@ def test_get_2_hop_edges(connection):
)
def test_get_out_edges_vertex_id(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_get_out_edges_vertex_id(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor=cursor)
@ -168,8 +187,9 @@ def test_get_out_edges_vertex_id(connection):
)
def test_subgraph_get_path_vertices(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_get_path_vertices(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
@ -185,8 +205,9 @@ def test_subgraph_get_path_vertices(connection):
)
def test_subgraph_get_path_edges(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_get_path_edges(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
@ -202,8 +223,9 @@ def test_subgraph_get_path_edges(connection):
)
def test_subgraph_get_path_vertices_in_subgraph(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_get_path_vertices_in_subgraph(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
@ -218,8 +240,9 @@ def test_subgraph_get_path_vertices_in_subgraph(connection):
)
def test_subgraph_insert_vertex_get_vertices(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_insert_vertex_get_vertices(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
@ -234,8 +257,9 @@ def test_subgraph_insert_vertex_get_vertices(connection):
)
def test_subgraph_insert_edge_get_vertex_out_edges(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_insert_edge_get_vertex_out_edges(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
@ -250,8 +274,9 @@ def test_subgraph_insert_edge_get_vertex_out_edges(connection):
)
def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
@ -267,8 +292,9 @@ def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(connect
)
def test_subgraph_remove_edge_get_vertex_out_edges(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_remove_edge_get_vertex_out_edges(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
@ -283,8 +309,9 @@ def test_subgraph_remove_edge_get_vertex_out_edges(connection):
)
def test_subgraph_remove_edge_not_in_subgraph_error(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_remove_edge_not_in_subgraph_error(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_subgraph(cursor)
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
@ -299,8 +326,9 @@ def test_subgraph_remove_edge_not_in_subgraph_error(connection):
)
def test_subgraph_remove_vertex_and_out_edges_get_vertices(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_subgraph_remove_vertex_and_out_edges_get_vertices(multi_db):
cursor = multi_db.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
create_smaller_subgraph(cursor)
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 4)

View File

@ -1,4 +1,4 @@
# Copyright 2021 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,17 +9,30 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import typing
import mgclient
import sys
import typing
import mgclient
import pytest
from common import execute_and_fetch_all, has_one_result_row, has_n_result_row
from common import execute_and_fetch_all, has_n_result_row, has_one_result_row
def test_is_write(connection):
@pytest.fixture(scope="function")
def multi_db(request, connection):
cursor = connection.cursor()
if request.param:
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
execute_and_fetch_all(cursor, "USE DATABASE clean")
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
pass
yield connection
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_is_write(multi_db):
is_write = 2
result_order = "name, signature, is_write"
cursor = connection.cursor()
cursor = multi_db.cursor()
for proc in execute_and_fetch_all(
cursor,
"CALL mg.procedures() YIELD * WITH name, signature, "
@ -41,8 +54,9 @@ def test_is_write(connection):
assert cursor.description[2].name == "is_write"
def test_single_vertex(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_single_vertex(multi_db):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")
vertex = result[0][0]
@ -93,8 +107,9 @@ def test_single_vertex(connection):
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
def test_single_edge(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_single_edge(multi_db):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
@ -134,8 +149,9 @@ def test_single_edge(connection):
assert has_n_result_row(cursor, "MATCH ()-[e]->() RETURN e", 0)
def test_detach_delete_vertex(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_detach_delete_vertex(multi_db):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
@ -156,8 +172,9 @@ def test_detach_delete_vertex(connection):
assert has_one_result_row(cursor, f"MATCH (n) WHERE id(n) = {v2_id} RETURN n")
def test_graph_mutability(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_graph_mutability(multi_db):
cursor = multi_db.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
@ -193,8 +210,9 @@ def test_graph_mutability(connection):
test_mutability(False)
def test_log_message(connection):
cursor = connection.cursor()
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
def test_log_message(multi_db):
cursor = multi_db.cursor()
success = execute_and_fetch_all(cursor, f"CALL read.log_message('message') YIELD success RETURN success")[0][0]
assert (success) is True

View File

@ -21,6 +21,7 @@ import sys
import tempfile
import time
DEFAULT_DB = "memgraph"
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
@ -37,26 +38,24 @@ QUERIES = [
("CREATE (n {name: $name})", {"name": 5, "leftover": 42}),
("MATCH (n), (m) CREATE (n)-[:e {when: $when}]->(m)", {"when": 42}),
("MATCH (n) RETURN n", {}),
(
"MATCH (n), (m {type: $type}) RETURN count(n), count(m)",
{"type": "dadada"}
),
("MATCH (n), (m {type: $type}) RETURN count(n), count(m)", {"type": "dadada"}),
(
"MERGE (n) ON CREATE SET n.created = timestamp() "
"ON MATCH SET n.lastSeen = timestamp() "
"RETURN n.name, n.created, n.lastSeen",
{}
),
(
"MATCH (n {value: $value}) SET n.value = 0 RETURN n",
{"value": "nandare!"}
{},
),
("MATCH (n {value: $value}) SET n.value = 0 RETURN n", {"value": "nandare!"}),
("MATCH (n), (m) SET n.value = m.value", {}),
("MATCH (n {test: $test}) REMOVE n.value", {"test": 48}),
("MATCH (n), (m) REMOVE n.value, m.value", {}),
("CREATE INDEX ON :User (id)", {}),
]
CREATE_DB_QUERIES = [
("CREATE DATABASE clean", {}),
]
def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
@ -65,6 +64,13 @@ def wait_for_server(port, delay=0.1):
time.sleep(delay)
def gen_mt_queries(queries, db):
out = []
for query, params in queries:
out.append((db, query, params))
return out
def execute_test(memgraph_binary, tester_binary):
storage_directory = tempfile.TemporaryDirectory()
memgraph_args = [
@ -74,7 +80,8 @@ def execute_test(memgraph_binary, tester_binary):
storage_directory.name,
"--audit-enabled",
"--log-file=memgraph.log",
"--log-level=TRACE"]
"--log-level=TRACE",
]
# Start the memgraph binary
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
@ -90,17 +97,31 @@ def execute_test(memgraph_binary, tester_binary):
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
def execute_queries(queries):
for query, params in queries:
for db, query, params in queries:
print(query, params)
args = [tester_binary, "--query", query,
"--params-json", json.dumps(params)]
args = [tester_binary, "--query", query, "--use-db", db, "--params-json", json.dumps(params)]
subprocess.run(args).check_returncode()
# Test default db
mt_queries = gen_mt_queries(QUERIES, DEFAULT_DB)
# Execute all queries
print("\033[1;36m~~ Starting query execution ~~\033[0m")
execute_queries(QUERIES)
execute_queries(mt_queries)
print("\033[1;36m~~ Finished query execution ~~\033[0m\n")
# Test new db
print("\033[1;36m~~ Creating clean database ~~\033[0m")
mt_queries2 = gen_mt_queries(CREATE_DB_QUERIES, DEFAULT_DB)
execute_queries(mt_queries2)
print("\033[1;36m~~ Finished creating clean database ~~\033[0m\n")
# Execute all queries on clean database
mt_queries3 = gen_mt_queries(QUERIES, "clean")
print("\033[1;36m~~ Starting query execution on clean database ~~\033[0m")
execute_queries(mt_queries3)
print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n")
# Shutdown the memgraph binary
memgraph.terminate()
@ -109,26 +130,37 @@ def execute_test(memgraph_binary, tester_binary):
# Verify the written log
print("\033[1;36m~~ Starting log verification ~~\033[0m")
with open(os.path.join(storage_directory.name, "audit", "audit.log")) as f:
reader = csv.reader(f, delimiter=',', doublequote=False,
escapechar='\\', lineterminator='\n',
quotechar='"', quoting=csv.QUOTE_MINIMAL,
skipinitialspace=False, strict=True)
reader = csv.reader(
f,
delimiter=",",
doublequote=False,
escapechar="\\",
lineterminator="\n",
quotechar='"',
quoting=csv.QUOTE_MINIMAL,
skipinitialspace=False,
strict=True,
)
queries = []
for line in reader:
timestamp, address, username, query, params = line
timestamp, address, username, database, query, params = line
params = json.loads(params)
queries.append((query, params))
print(query, params)
if query.startswith("USE DATABASE"):
continue # Skip all databases switching queries
queries.append((database, query, params))
print(database, query, params)
assert queries == QUERIES, "Logged queries don't match " \
"executed queries!"
# Combine all queries executed
all_queries = mt_queries
all_queries += mt_queries2
all_queries += mt_queries3
assert queries == all_queries, "Logged queries don't match " "executed queries!"
print("\033[1;36m~~ Finished log verification ~~\033[0m\n")
if __name__ == "__main__":
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
"integration", "audit", "tester")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "audit", "tester")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,6 +9,7 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <fmt/core.h>
#include <gflags/gflags.h>
#include <json/json.hpp>
@ -25,6 +26,7 @@ DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_string(query, "", "Query to execute");
DEFINE_string(params_json, "{}", "Params for the query");
DEFINE_string(use_db, "memgraph", "Database to run the query against");
memgraph::communication::bolt::Value JsonToValue(const nlohmann::json &jv) {
memgraph::communication::bolt::Value ret;
@ -89,6 +91,7 @@ int main(int argc, char **argv) {
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
client.Execute(fmt::format("USE DATABASE {}", FLAGS_use_db), {});
client.Execute(FLAGS_query, JsonToValue(nlohmann::json::parse(FLAGS_params_json)).ValueMap());
return 0;

View File

@ -29,15 +29,8 @@ PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
QUERIES = [
# CREATE
(
"CREATE (n)",
("CREATE",)
),
(
"MATCH (n), (m) CREATE (n)-[:e]->(m)",
("CREATE", "MATCH")
),
("CREATE (n)", ("CREATE",)),
("MATCH (n), (m) CREATE (n)-[:e]->(m)", ("CREATE", "MATCH")),
# DELETE
(
"MATCH (n) DELETE n",
@ -47,116 +40,43 @@ QUERIES = [
"MATCH (n) DETACH DELETE n",
("DELETE", "MATCH"),
),
# MATCH
(
"MATCH (n) RETURN n",
("MATCH",)
),
(
"MATCH (n), (m) RETURN count(n), count(m)",
("MATCH",)
),
("MATCH (n) RETURN n", ("MATCH",)),
("MATCH (n), (m) RETURN count(n), count(m)", ("MATCH",)),
# MERGE
(
"MERGE (n) ON CREATE SET n.created = timestamp() "
"ON MATCH SET n.lastSeen = timestamp() "
"RETURN n.name, n.created, n.lastSeen",
("MERGE",)
("MERGE",),
),
# SET
(
"MATCH (n) SET n.value = 0 RETURN n",
("SET", "MATCH")
),
(
"MATCH (n), (m) SET n.value = m.value",
("SET", "MATCH")
),
("MATCH (n) SET n.value = 0 RETURN n", ("SET", "MATCH")),
("MATCH (n), (m) SET n.value = m.value", ("SET", "MATCH")),
# REMOVE
(
"MATCH (n) REMOVE n.value",
("REMOVE", "MATCH")
),
(
"MATCH (n), (m) REMOVE n.value, m.value",
("REMOVE", "MATCH")
),
("MATCH (n) REMOVE n.value", ("REMOVE", "MATCH")),
("MATCH (n), (m) REMOVE n.value, m.value", ("REMOVE", "MATCH")),
# INDEX
(
"CREATE INDEX ON :User (id)",
("INDEX",)
),
("CREATE INDEX ON :User (id)", ("INDEX",)),
# AUTH
(
"CREATE ROLE test_role",
("AUTH",)
),
(
"DROP ROLE test_role",
("AUTH",)
),
(
"SHOW ROLES",
("AUTH",)
),
(
"CREATE USER test_user",
("AUTH",)
),
(
"SET PASSWORD FOR test_user TO '1234'",
("AUTH",)
),
(
"DROP USER test_user",
("AUTH",)
),
(
"SHOW USERS",
("AUTH",)
),
(
"SET ROLE FOR test_user TO test_role",
("AUTH",)
),
(
"CLEAR ROLE FOR test_user",
("AUTH",)
),
(
"GRANT ALL PRIVILEGES TO test_user",
("AUTH",)
),
(
"DENY ALL PRIVILEGES TO test_user",
("AUTH",)
),
(
"REVOKE ALL PRIVILEGES FROM test_user",
("AUTH",)
),
(
"SHOW PRIVILEGES FOR test_user",
("AUTH",)
),
(
"SHOW ROLE FOR test_user",
("AUTH",)
),
(
"SHOW USERS FOR test_role",
("AUTH",)
),
("CREATE ROLE test_role", ("AUTH",)),
("DROP ROLE test_role", ("AUTH",)),
("SHOW ROLES", ("AUTH",)),
("CREATE USER test_user", ("AUTH",)),
("SET PASSWORD FOR test_user TO '1234'", ("AUTH",)),
("DROP USER test_user", ("AUTH",)),
("SHOW USERS", ("AUTH",)),
("SET ROLE FOR test_user TO test_role", ("AUTH",)),
("CLEAR ROLE FOR test_user", ("AUTH",)),
("GRANT ALL PRIVILEGES TO test_user", ("AUTH",)),
("DENY ALL PRIVILEGES TO test_user", ("AUTH",)),
("REVOKE ALL PRIVILEGES FROM test_user", ("AUTH",)),
("SHOW PRIVILEGES FOR test_user", ("AUTH",)),
("SHOW ROLE FOR test_user", ("AUTH",)),
("SHOW USERS FOR test_role", ("AUTH",)),
]
UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " \
"contact your database administrator."
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
def wait_for_server(port, delay=0.1):
@ -166,8 +86,16 @@ def wait_for_server(port, delay=0.1):
time.sleep(delay)
def execute_tester(binary, queries, should_fail=False, failure_message="",
username="", password="", check_failure=True):
def execute_tester(
binary,
queries,
should_fail=False,
failure_message="",
username="",
password="",
check_failure=True,
connection_should_fail=False,
):
args = [binary, "--username", username, "--password", password]
if should_fail:
args.append("--should-fail")
@ -175,6 +103,8 @@ def execute_tester(binary, queries, should_fail=False, failure_message="",
args.extend(["--failure-message", failure_message])
if check_failure:
args.append("--check-failure")
if connection_should_fail:
args.append("--connection-should-fail")
args.extend(queries)
subprocess.run(args).check_returncode()
@ -200,18 +130,31 @@ def check_permissions(query_perms, user_perms):
def execute_test(memgraph_binary, tester_binary, checker_binary):
storage_directory = tempfile.TemporaryDirectory()
memgraph_args = [memgraph_binary,
"--data-directory", storage_directory.name]
memgraph_args = [memgraph_binary, "--data-directory", storage_directory.name]
def execute_admin_queries(queries):
return execute_tester(tester_binary, queries, should_fail=False,
check_failure=True, username="admin",
password="admin")
return execute_tester(
tester_binary, queries, should_fail=False, check_failure=True, username="admin", password="admin"
)
def execute_user_queries(queries, should_fail=False, failure_message="",
check_failure=True):
return execute_tester(tester_binary, queries, should_fail,
failure_message, "user", "user", check_failure)
def execute_user_queries(
queries,
should_fail=False,
failure_message="",
check_failure=True,
username="user",
connection_should_fail=False,
):
return execute_tester(
tester_binary,
queries,
should_fail,
failure_message,
username,
"user",
check_failure,
connection_should_fail,
)
# Start the memgraph binary
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
@ -226,12 +169,33 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
# Prepare the multi database environment
execute_admin_queries(
[
"CREATE DATABASE db1",
"CREATE DATABASE db2",
]
)
# Prepare all users
execute_admin_queries([
"CREATE USER ADmin IDENTIFIED BY 'admin'",
"GRANT ALL PRIVILEGES TO admIN",
"CREATE USER usEr IDENTIFIED BY 'user'",
])
execute_admin_queries(
[
"CREATE USER ADmin IDENTIFIED BY 'admin'",
"GRANT ALL PRIVILEGES TO admIN",
"GRANT DATABASE * TO admin",
"CREATE USER usEr IDENTIFIED BY 'user'",
"GRANT DATABASE db1 TO user",
"GRANT DATABASE db2 TO user",
"CREATE USER useR2 IDENTIFIED BY 'user'",
"GRANT DATABASE db2 TO user2",
"REVOKE DATABASE memgraph FROM user2",
"SET MAIN DATABASE db2 FOR user2",
"CREATE USER user3 IDENTIFIED BY 'user'",
"GRANT ALL PRIVILEGES TO user3",
"GRANT DATABASE * TO user3",
"REVOKE DATABASE memgraph FROM user3",
]
)
# Find all existing permissions
permissions = set()
@ -241,14 +205,99 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
# Run the test with all combinations of permissions
print("\033[1;36m~~ Starting query test ~~\033[0m")
for db in ["memgraph", "db1"]:
print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db))
execute_user_queries(["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR)
execute_admin_queries(["GRANT MULTI_DATABASE_USE TO User"])
execute_user_queries(["USE DATABASE {}".format(db)], check_failure=False, failure_message=UNAUTHORIZED_ERROR)
for mask in range(0, 2 ** len(permissions)):
user_perms = get_permissions(permissions, mask)
print("\033[1;34m~~ Checking queries with privileges: ", ", ".join(user_perms), " ~~\033[0m")
admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer"]
if len(user_perms) > 0:
admin_queries.append("GRANT {} TO User".format(", ".join(user_perms)))
execute_admin_queries(admin_queries)
authorized, unauthorized = [], []
for query, query_perms in QUERIES:
if check_permissions(query_perms, user_perms):
authorized.append(query)
else:
unauthorized.append(query)
execute_user_queries(authorized, check_failure=False, failure_message=UNAUTHORIZED_ERROR)
execute_user_queries(unauthorized, should_fail=True, failure_message=UNAUTHORIZED_ERROR)
print("\033[1;36m~~ Finished query test ~~\033[0m\n")
# Run the user/role permissions test
print("\033[1;36m~~ Starting permissions test ~~\033[0m")
execute_admin_queries(
[
"CREATE ROLE roLe",
"REVOKE ALL PRIVILEGES FROM uSeR",
]
)
execute_checker(checker_binary, [])
for db in ["memgraph", "db1"]:
print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db))
execute_user_queries(["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR)
execute_admin_queries(["GRANT MULTI_DATABASE_USE TO User"])
execute_user_queries(["USE DATABASE {}".format(db)], check_failure=False, failure_message=UNAUTHORIZED_ERROR)
execute_admin_queries(["REVOKE MULTI_DATABASE_USE FROM User"])
for user_perm in ["GRANT", "DENY", "REVOKE"]:
for role_perm in ["GRANT", "DENY", "REVOKE"]:
for mapped in [True, False]:
print(
"\033[1;34m~~ Checking permissions with user ",
user_perm,
", role ",
role_perm,
"user mapped to role:",
mapped,
" ~~\033[0m",
)
if mapped:
execute_admin_queries(["SET ROLE FOR USER TO roLE"])
else:
execute_admin_queries(["CLEAR ROLE FOR user"])
user_prep = "FROM" if user_perm == "REVOKE" else "TO"
role_prep = "FROM" if role_perm == "REVOKE" else "TO"
execute_admin_queries(
[
"{} MATCH {} user".format(user_perm, user_prep),
"{} MATCH {} rOLe".format(role_perm, role_prep),
]
)
expected = []
perms = [user_perm, role_perm] if mapped else [user_perm]
if "DENY" in perms:
expected = ["MATCH", "DENY"]
elif "GRANT" in perms:
expected = ["MATCH", "GRANT"]
if len(expected) > 0:
details = []
if user_perm == "GRANT":
details.append("GRANTED TO USER")
elif user_perm == "DENY":
details.append("DENIED TO USER")
if mapped:
if role_perm == "GRANT":
details.append("GRANTED TO ROLE")
elif role_perm == "DENY":
details.append("DENIED TO ROLE")
expected.append(", ".join(details))
execute_checker(checker_binary, expected)
print("\033[1;36m~~ Finished permissions test ~~\033[0m\n")
# Check database access
# user has access to every db (with global privileges) <- tested above
# user2 has access only to db2 (and it set to default)
# user3 has access only to db2, but the default db is set to default (shouldn't even connect)
print("\033[1;36m~~ Checking privileges with custom default db ~~\033[0m\n")
for mask in range(0, 2 ** len(permissions)):
user_perms = get_permissions(permissions, mask)
print("\033[1;34m~~ Checking queries with privileges: ",
", ".join(user_perms), " ~~\033[0m")
admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer"]
print("\033[1;34m~~ Checking queries with privileges: ", ", ".join(user_perms), " ~~\033[0m")
admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer2"]
if len(user_perms) > 0:
admin_queries.append(
"GRANT {} TO User".format(", ".join(user_perms)))
admin_queries.append("GRANT {} TO User2".format(", ".join(user_perms)))
execute_admin_queries(admin_queries)
authorized, unauthorized = [], []
for query, query_perms in QUERIES:
@ -256,55 +305,26 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
authorized.append(query)
else:
unauthorized.append(query)
execute_user_queries(authorized, check_failure=False,
failure_message=UNAUTHORIZED_ERROR)
execute_user_queries(unauthorized, should_fail=True,
failure_message=UNAUTHORIZED_ERROR)
print("\033[1;36m~~ Finished query test ~~\033[0m\n")
execute_user_queries(authorized, check_failure=False, failure_message=UNAUTHORIZED_ERROR, username="user2")
execute_user_queries(unauthorized, should_fail=True, failure_message=UNAUTHORIZED_ERROR, username="user2")
print("\033[1;36m~~ Finished custom default db checks ~~\033[0m\n")
# Run the user/role permissions test
print("\033[1;36m~~ Starting permissions test ~~\033[0m")
execute_admin_queries([
"CREATE ROLE roLe",
"REVOKE ALL PRIVILEGES FROM uSeR",
])
execute_checker(checker_binary, [])
for user_perm in ["GRANT", "DENY", "REVOKE"]:
for role_perm in ["GRANT", "DENY", "REVOKE"]:
for mapped in [True, False]:
print("\033[1;34m~~ Checking permissions with user ",
user_perm, ", role ", role_perm,
"user mapped to role:", mapped, " ~~\033[0m")
if mapped:
execute_admin_queries(["SET ROLE FOR USER TO roLE"])
else:
execute_admin_queries(["CLEAR ROLE FOR user"])
user_prep = "FROM" if user_perm == "REVOKE" else "TO"
role_prep = "FROM" if role_perm == "REVOKE" else "TO"
execute_admin_queries([
"{} MATCH {} user".format(user_perm, user_prep),
"{} MATCH {} rOLe".format(role_perm, role_prep)
])
expected = []
perms = [user_perm, role_perm] if mapped else [user_perm]
if "DENY" in perms:
expected = ["MATCH", "DENY"]
elif "GRANT" in perms:
expected = ["MATCH", "GRANT"]
if len(expected) > 0:
details = []
if user_perm == "GRANT":
details.append("GRANTED TO USER")
elif user_perm == "DENY":
details.append("DENIED TO USER")
if mapped:
if role_perm == "GRANT":
details.append("GRANTED TO ROLE")
elif role_perm == "DENY":
details.append("DENIED TO ROLE")
expected.append(", ".join(details))
execute_checker(checker_binary, expected)
print("\033[1;36m~~ Finished permissions test ~~\033[0m\n")
print("\033[1;36m~~ Checking connections and database switching ~~\033[0m\n")
for db in ["memgraph", "db1"]:
print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db))
execute_admin_queries(["GRANT {} TO User2".format("MULTI_DATABASE_USE")])
execute_user_queries(
["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR, username="user2"
)
print("\033[1;36m~~ Running with user3 (shouldn't even connect) ~~\033[0m")
execute_admin_queries(["GRANT {} TO User3".format("MULTI_DATABASE_USE")])
execute_user_queries(
["USE DATABASE db2"],
connection_should_fail=True,
failure_message="Couldn't communicate with the server!",
username="user3",
)
print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n")
# Shutdown the memgraph binary
memgraph.terminate()
@ -313,10 +333,8 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
if __name__ == "__main__":
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
"integration", "auth", "tester")
checker_binary = os.path.join(PROJECT_DIR, "build", "tests",
"integration", "auth", "checker")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "auth", "tester")
checker_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "auth", "checker")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,6 +9,8 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <regex>
#include <gflags/gflags.h>
#include "communication/bolt/client.hpp"
@ -23,6 +25,7 @@ DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_bool(check_failure, false, "Set to true to enable failure checking.");
DEFINE_bool(should_fail, false, "Set to true to expect a failure.");
DEFINE_bool(connection_should_fail, false, "Set to true to expect a connection failure.");
DEFINE_string(failure_message, "", "Set to the expected failure message.");
/**
@ -40,7 +43,26 @@ int main(int argc, char **argv) {
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
std::regex re(FLAGS_failure_message);
try {
client.Connect(endpoint, FLAGS_username, FLAGS_password);
} catch (const memgraph::communication::bolt::ClientFatalException &e) {
if (FLAGS_connection_should_fail) {
if (!FLAGS_failure_message.empty() && !std::regex_match(e.what(), re)) {
LOG_FATAL(
"The connection should have failed with an error message of '{}'' but "
"instead it failed with '{}'",
FLAGS_failure_message, e.what());
}
return 0;
} else {
LOG_FATAL(
"The connection shoudn't have failed but it failed with an "
"error message '{}'",
e.what());
}
}
for (int i = 1; i < argc; ++i) {
std::string query(argv[i]);
@ -48,7 +70,7 @@ int main(int argc, char **argv) {
client.Execute(query, {});
} catch (const memgraph::communication::bolt::ClientQueryException &e) {
if (!FLAGS_check_failure) {
if (!FLAGS_failure_message.empty() && e.what() == FLAGS_failure_message) {
if (!FLAGS_failure_message.empty() && std::regex_match(e.what(), re)) {
LOG_FATAL(
"The query should have succeeded or failed with an error "
"message that isn't equal to '{}' but it failed with that error "
@ -58,7 +80,7 @@ int main(int argc, char **argv) {
continue;
}
if (FLAGS_should_fail) {
if (!FLAGS_failure_message.empty() && e.what() != FLAGS_failure_message) {
if (!FLAGS_failure_message.empty() && !std::regex_match(e.what(), re)) {
LOG_FATAL(
"The query should have failed with an error message of '{}'' but "
"instead it failed with '{}'",

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -22,6 +22,7 @@ DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "admin", "Username for the database");
DEFINE_string(password, "admin", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_string(use_db, "memgraph", "Database to run the query against");
/**
* Verifies that user 'user' has privileges that are given as positional
@ -38,6 +39,7 @@ int main(int argc, char **argv) {
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
client.Execute(fmt::format("USE DATABASE {}", FLAGS_use_db), {});
try {
std::string query(argv[1]);

View File

@ -23,7 +23,7 @@ from typing import List
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " "contact your database administrator."
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
def wait_for_server(port, delay=0.1):
@ -47,8 +47,10 @@ def execute_tester(
subprocess.run(args).check_returncode()
def execute_filtering(binary: str, queries: List[str], expected: int, username: str = "", password: str = "") -> None:
args = [binary, "--username", username, "--password", password]
def execute_filtering(
binary: str, queries: List[str], expected: int, username: str = "", password: str = "", db: str = "memgraph"
) -> None:
args = [binary, "--username", username, "--password", password, "--use-db", db]
args.extend(queries)
args.append(str(expected))
@ -82,35 +84,48 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
# Prepare all users
execute_admin_queries(
[
"CREATE USER admin IDENTIFIED BY 'admin'",
"GRANT ALL PRIVILEGES TO admin",
"CREATE USER user IDENTIFIED BY 'user'",
"GRANT ALL PRIVILEGES TO user",
"GRANT LABELS :label1, :label2, :label3 TO user",
"GRANT EDGE_TYPES :edgeType1, :edgeType2 TO user",
"MERGE (l1:label1 {name: 'test1'})",
"MERGE (l2:label2 {name: 'test2'})",
"MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2)",
"MERGE (l3:label3 {name: 'test3'})",
"MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3)",
"MERGE (mix:label3:label1 {name: 'test4'})",
"MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix)",
]
)
def setup_user():
execute_admin_queries(
[
"CREATE USER admin IDENTIFIED BY 'admin'",
"GRANT ALL PRIVILEGES TO admin",
"CREATE USER user IDENTIFIED BY 'user'",
"GRANT ALL PRIVILEGES TO user",
"GRANT LABELS :label1, :label2, :label3 TO user",
"GRANT EDGE_TYPES :edgeType1, :edgeType2 TO user",
]
)
def db_setup():
execute_admin_queries(
[
"MERGE (l1:label1 {name: 'test1'})",
"MERGE (l2:label2 {name: 'test2'})",
"MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2)",
"MERGE (l3:label3 {name: 'test3'})",
"MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3)",
"MERGE (mix:label3:label1 {name: 'test4'})",
"MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix)",
]
)
db_setup() # default db setup
execute_admin_queries(["CREATE DATABASE db1", "USE DATABASE db1"])
db_setup() # db1 setup
# Run the test with all combinations of permissions
print("\033[1;36m~~ Starting edge filtering test ~~\033[0m")
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 3, "user", "user")
execute_admin_queries(["DENY EDGE_TYPES :edgeType1 TO user"])
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 2, "user", "user")
execute_admin_queries(["GRANT EDGE_TYPES :edgeType1 TO user", "DENY LABELS :label3 TO user"])
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 1, "user", "user")
execute_admin_queries(["DENY LABELS :label1 TO user"])
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user")
execute_admin_queries(["REVOKE LABELS * FROM user", "REVOKE EDGE_TYPES * FROM user"])
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user")
for db in ["memgraph", "db1"]:
setup_user()
# Run the test with all combinations of permissions
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 3, "user", "user", db)
execute_admin_queries(["DENY EDGE_TYPES :edgeType1 TO user"])
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 2, "user", "user", db)
execute_admin_queries(["GRANT EDGE_TYPES :edgeType1 TO user", "DENY LABELS :label3 TO user"])
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 1, "user", "user", db)
execute_admin_queries(["DENY LABELS :label1 TO user"])
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user", db)
execute_admin_queries(["REVOKE LABELS * FROM user", "REVOKE EDGE_TYPES * FROM user"])
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user", db)
print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n")

View File

@ -14,10 +14,14 @@ while ! nc -z -w 1 127.0.0.1 7687; do
sleep 0.5
done
# Start the test.
# Start the test on default db.
$binary_dir/tests/integration/transactions/tester
code=$?
# Start the test on another db.
$binary_dir/tests/integration/transactions/tester --use-db db1
code2=$?
# Shutdown the memgraph process.
kill $pid
wait $pid
@ -30,4 +34,12 @@ if [ $code_mg -ne 0 ]; then
fi
# Exit with the exitcode of the test.
exit $code
if [ $code -ne 0 ]; then
echo "Default database tests failed!"
exit $code
fi
if [ $code2 -ne 0 ]; then
echo "Non default database tests failed!"
exit $code2
fi

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,6 +11,7 @@
#include <iostream>
#include <fmt/core.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
@ -24,12 +25,16 @@ DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_string(use_db, "memgraph", "Database to run the query against");
using namespace memgraph::communication::bolt;
class BoltClient : public ::testing::Test {
protected:
virtual void SetUp() { client_.Connect(endpoint_, FLAGS_username, FLAGS_password); }
virtual void SetUp() {
client_.Connect(endpoint_, FLAGS_username, FLAGS_password);
Execute("CREATE DATABASE db1");
}
virtual void TearDown() {}
@ -90,6 +95,15 @@ const std::string kCommitInvalid =
"Transaction can't be committed because there was a previous error. Please "
"invoke a rollback instead.";
TEST_F(BoltClient, SelectDB) { Execute(fmt::format("USE DATABASE {}", FLAGS_use_db)); }
TEST_F(BoltClient, SelectDBUnderTx) {
EXPECT_TRUE(Execute("begin"));
EXPECT_THROW(Execute("USE DATABASE memgraph", "Multi-database queries are not allowed in multicommand transactions."),
ClientQueryException);
EXPECT_FALSE(TransactionActive());
}
TEST_F(BoltClient, CommitWithoutTransaction) {
EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException);
EXPECT_FALSE(TransactionActive());

View File

@ -36,7 +36,7 @@ int main(int argc, char *argv[]) {
memgraph::query::Interpreter interpreter{&interpreter_context};
ResultStreamFaker stream(interpreter_context.db.get());
auto [header, _, qid] = interpreter.Prepare(argv[1], {}, nullptr);
auto [header, _1, qid, _2] = interpreter.Prepare(argv[1], {}, nullptr);
stream.Header(header);
auto summary = interpreter.PullAll(&stream);
stream.Summary(summary);

Some files were not shown because too many files have changed in this diff Show More