Add multi-tenancy v1 (#952)
* Decouple BoltSession and communication::bolt::Session * Add CREATE/USE/DROP DATABASE * Add SHOW DATABASES * Cover WebSocket session * Simple session safety implemented via RWLock * Storage symlinks for backward. compatibility * Extend the audit log with the DB info * Add auth part * Add tenant recovery
This commit is contained in:
parent
fd819cd099
commit
e8850549d2
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
|
||||
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
|
||||
@ -116,12 +116,12 @@ Log::~Log() {
|
||||
}
|
||||
|
||||
void Log::Record(const std::string &address, const std::string &username, const std::string &query,
|
||||
const storage::PropertyValue ¶ms) {
|
||||
const storage::PropertyValue ¶ms, const std::string &db) {
|
||||
if (!started_.load(std::memory_order_relaxed)) return;
|
||||
auto timestamp =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
buffer_->emplace(Item{timestamp, address, username, query, params});
|
||||
buffer_->emplace(Item{timestamp, address, username, query, params, db});
|
||||
}
|
||||
|
||||
void Log::ReopenLog() {
|
||||
@ -136,8 +136,8 @@ void Log::Flush() {
|
||||
for (uint64_t i = 0; i < buffer_size_; ++i) {
|
||||
auto item = buffer_->pop();
|
||||
if (!item) break;
|
||||
log_.Write(fmt::format("{}.{:06d},{},{},{},{}\n", item->timestamp / 1000000, item->timestamp % 1000000,
|
||||
item->address, item->username, utils::Escape(item->query),
|
||||
log_.Write(fmt::format("{}.{:06d},{},{},{},{},{}\n", item->timestamp / 1000000, item->timestamp % 1000000,
|
||||
item->address, item->username, item->db, utils::Escape(item->query),
|
||||
utils::Escape(PropertyValueToJson(item->params).dump())));
|
||||
}
|
||||
log_.Sync();
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
|
||||
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
|
||||
@ -32,6 +32,7 @@ class Log {
|
||||
std::string username;
|
||||
std::string query;
|
||||
storage::PropertyValue params;
|
||||
std::string db;
|
||||
};
|
||||
|
||||
public:
|
||||
@ -51,7 +52,7 @@ class Log {
|
||||
|
||||
/// Adds an entry to the audit log. Thread-safe.
|
||||
void Record(const std::string &address, const std::string &username, const std::string &query,
|
||||
const storage::PropertyValue ¶ms);
|
||||
const storage::PropertyValue ¶ms, const std::string &db);
|
||||
|
||||
/// Reopens the log file. Used for log file rotation. Thread-safe.
|
||||
void ReopenLog();
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
|
||||
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
|
||||
@ -314,4 +314,57 @@ std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig)
|
||||
return ret;
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name) {
|
||||
auto user = GetUser(name);
|
||||
if (user) {
|
||||
if (db == kAllDatabases) {
|
||||
user->db_access().GrantAll();
|
||||
} else {
|
||||
user->db_access().Add(db);
|
||||
}
|
||||
SaveUser(*user);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name) {
|
||||
auto user = GetUser(name);
|
||||
if (user) {
|
||||
if (db == kAllDatabases) {
|
||||
user->db_access().DenyAll();
|
||||
} else {
|
||||
user->db_access().Remove(db);
|
||||
}
|
||||
SaveUser(*user);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void Auth::DeleteDatabase(const std::string &db) {
|
||||
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
|
||||
auto username = it->first.substr(kUserPrefix.size());
|
||||
auto user = GetUser(username);
|
||||
if (user) {
|
||||
user->db_access().Delete(db);
|
||||
SaveUser(*user);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Auth::SetMainDatabase(const std::string &db, const std::string &name) {
|
||||
auto user = GetUser(name);
|
||||
if (user) {
|
||||
if (!user->db_access().SetDefault(db)) {
|
||||
throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name);
|
||||
}
|
||||
SaveUser(*user);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace memgraph::auth
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
|
||||
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
|
||||
@ -19,6 +19,9 @@
|
||||
#include "utils/settings.hpp"
|
||||
|
||||
namespace memgraph::auth {
|
||||
|
||||
static const constexpr char *const kAllDatabases = "*";
|
||||
|
||||
/**
|
||||
* This class serves as the main Authentication/Authorization storage.
|
||||
* It provides functions for managing Users, Roles, Permissions and FineGrainedAccessPermissions.
|
||||
@ -155,6 +158,46 @@ class Auth final {
|
||||
*/
|
||||
std::vector<User> AllUsersForRole(const std::string &rolename) const;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
/**
|
||||
* @brief Revoke access to individual database for a user.
|
||||
*
|
||||
* @param db name of the database to revoke
|
||||
* @param name user's username
|
||||
* @return true on success
|
||||
* @throw AuthException if unable to find or update the user
|
||||
*/
|
||||
bool RevokeDatabaseFromUser(const std::string &db, const std::string &name);
|
||||
|
||||
/**
|
||||
* @brief Grant access to individual database for a user.
|
||||
*
|
||||
* @param db name of the database to revoke
|
||||
* @param name user's username
|
||||
* @return true on success
|
||||
* @throw AuthException if unable to find or update the user
|
||||
*/
|
||||
bool GrantDatabaseToUser(const std::string &db, const std::string &name);
|
||||
|
||||
/**
|
||||
* @brief Delete a database from all users.
|
||||
*
|
||||
* @param db name of the database to delete
|
||||
* @throw AuthException if unable to read data
|
||||
*/
|
||||
void DeleteDatabase(const std::string &db);
|
||||
|
||||
/**
|
||||
* @brief Set main database for an individual user.
|
||||
*
|
||||
* @param db name of the database to revoke
|
||||
* @param name user's username
|
||||
* @return true on success
|
||||
* @throw AuthException if unable to find or update the user
|
||||
*/
|
||||
bool SetMainDatabase(const std::string &db, const std::string &name);
|
||||
#endif
|
||||
|
||||
private:
|
||||
// Even though the `kvstore::KVStore` class is guaranteed to be thread-safe,
|
||||
// Auth is not thread-safe because modifying users and roles might require
|
||||
|
@ -15,8 +15,10 @@
|
||||
|
||||
#include "auth/crypto.hpp"
|
||||
#include "auth/exceptions.hpp"
|
||||
#include "dbms/constants.hpp"
|
||||
#include "license/license.hpp"
|
||||
#include "query/constants.hpp"
|
||||
#include "spdlog/spdlog.h"
|
||||
#include "utils/cast.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/settings.hpp"
|
||||
@ -35,18 +37,31 @@ namespace memgraph::auth {
|
||||
namespace {
|
||||
|
||||
// Constant list of all available permissions.
|
||||
const std::vector<Permission> kPermissionsAll = {Permission::MATCH, Permission::CREATE,
|
||||
Permission::MERGE, Permission::DELETE,
|
||||
Permission::SET, Permission::REMOVE,
|
||||
Permission::INDEX, Permission::STATS,
|
||||
Permission::CONSTRAINT, Permission::DUMP,
|
||||
Permission::AUTH, Permission::REPLICATION,
|
||||
Permission::DURABILITY, Permission::READ_FILE,
|
||||
Permission::FREE_MEMORY, Permission::TRIGGER,
|
||||
Permission::CONFIG, Permission::STREAM,
|
||||
Permission::MODULE_READ, Permission::MODULE_WRITE,
|
||||
Permission::WEBSOCKET, Permission::TRANSACTION_MANAGEMENT,
|
||||
Permission::STORAGE_MODE};
|
||||
const std::vector<Permission> kPermissionsAll = {Permission::MATCH,
|
||||
Permission::CREATE,
|
||||
Permission::MERGE,
|
||||
Permission::DELETE,
|
||||
Permission::SET,
|
||||
Permission::REMOVE,
|
||||
Permission::INDEX,
|
||||
Permission::STATS,
|
||||
Permission::CONSTRAINT,
|
||||
Permission::DUMP,
|
||||
Permission::AUTH,
|
||||
Permission::REPLICATION,
|
||||
Permission::DURABILITY,
|
||||
Permission::READ_FILE,
|
||||
Permission::FREE_MEMORY,
|
||||
Permission::TRIGGER,
|
||||
Permission::CONFIG,
|
||||
Permission::STREAM,
|
||||
Permission::MODULE_READ,
|
||||
Permission::MODULE_WRITE,
|
||||
Permission::WEBSOCKET,
|
||||
Permission::TRANSACTION_MANAGEMENT,
|
||||
Permission::STORAGE_MODE,
|
||||
Permission::MULTI_DATABASE_EDIT,
|
||||
Permission::MULTI_DATABASE_USE};
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -98,6 +113,10 @@ std::string PermissionToString(Permission permission) {
|
||||
return "TRANSACTION_MANAGEMENT";
|
||||
case Permission::STORAGE_MODE:
|
||||
return "STORAGE_MODE";
|
||||
case Permission::MULTI_DATABASE_EDIT:
|
||||
return "MULTI_DATABASE_EDIT";
|
||||
case Permission::MULTI_DATABASE_USE:
|
||||
return "MULTI_DATABASE_USE";
|
||||
}
|
||||
}
|
||||
|
||||
@ -464,6 +483,82 @@ bool operator==(const Role &first, const Role &second) {
|
||||
return first.rolename_ == second.rolename_ && first.permissions_ == second.permissions_;
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
void Databases::Add(const std::string &db) {
|
||||
if (allow_all_) {
|
||||
grants_dbs_.clear();
|
||||
allow_all_ = false;
|
||||
}
|
||||
grants_dbs_.emplace(db);
|
||||
denies_dbs_.erase(db);
|
||||
}
|
||||
|
||||
void Databases::Remove(const std::string &db) {
|
||||
denies_dbs_.emplace(db);
|
||||
grants_dbs_.erase(db);
|
||||
}
|
||||
|
||||
void Databases::Delete(const std::string &db) {
|
||||
denies_dbs_.erase(db);
|
||||
if (!allow_all_) {
|
||||
grants_dbs_.erase(db);
|
||||
}
|
||||
// Reset if default deleted
|
||||
if (default_db_ == db) {
|
||||
default_db_ = "";
|
||||
}
|
||||
}
|
||||
|
||||
void Databases::GrantAll() {
|
||||
allow_all_ = true;
|
||||
grants_dbs_.clear();
|
||||
denies_dbs_.clear();
|
||||
}
|
||||
|
||||
void Databases::DenyAll() {
|
||||
allow_all_ = false;
|
||||
grants_dbs_.clear();
|
||||
denies_dbs_.clear();
|
||||
}
|
||||
|
||||
bool Databases::SetDefault(const std::string &db) {
|
||||
if (!Contains(db)) return false;
|
||||
default_db_ = db;
|
||||
return true;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool Databases::Contains(const std::string &db) const {
|
||||
return !denies_dbs_.contains(db) && (allow_all_ || grants_dbs_.contains(db));
|
||||
}
|
||||
|
||||
const std::string &Databases::GetDefault() const {
|
||||
if (!Contains(default_db_)) {
|
||||
throw AuthException("No access to the set default database \"{}\".", default_db_);
|
||||
}
|
||||
return default_db_;
|
||||
}
|
||||
|
||||
nlohmann::json Databases::Serialize() const {
|
||||
nlohmann::json data = nlohmann::json::object();
|
||||
data["grants"] = grants_dbs_;
|
||||
data["denies"] = denies_dbs_;
|
||||
data["allow_all"] = allow_all_;
|
||||
data["default"] = default_db_;
|
||||
return data;
|
||||
}
|
||||
|
||||
Databases Databases::Deserialize(const nlohmann::json &data) {
|
||||
if (!data.is_object()) {
|
||||
throw AuthException("Couldn't load database data!");
|
||||
}
|
||||
if (!data["grants"].is_structured() || !data["denies"].is_structured() || !data["allow_all"].is_boolean() ||
|
||||
!data["default"].is_string()) {
|
||||
throw AuthException("Couldn't load database data!");
|
||||
}
|
||||
return {data["allow_all"], data["grants"], data["denies"], data["default"]};
|
||||
}
|
||||
#endif
|
||||
|
||||
User::User() {}
|
||||
|
||||
User::User(const std::string &username) : username_(utils::ToLowerCase(username)) {}
|
||||
@ -472,11 +567,12 @@ User::User(const std::string &username, const std::string &password_hash, const
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions,
|
||||
FineGrainedAccessHandler fine_grained_access_handler)
|
||||
FineGrainedAccessHandler fine_grained_access_handler, Databases db_access)
|
||||
: username_(utils::ToLowerCase(username)),
|
||||
password_hash_(password_hash),
|
||||
permissions_(permissions),
|
||||
fine_grained_access_handler_(std::move(fine_grained_access_handler)) {}
|
||||
fine_grained_access_handler_(std::move(fine_grained_access_handler)),
|
||||
database_access_(db_access) {}
|
||||
#endif
|
||||
|
||||
bool User::CheckPassword(const std::string &password) {
|
||||
@ -576,8 +672,10 @@ nlohmann::json User::Serialize() const {
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
data["fine_grained_access_handler"] = fine_grained_access_handler_.Serialize();
|
||||
data["databases"] = database_access_.Serialize();
|
||||
} else {
|
||||
data["fine_grained_access_handler"] = {};
|
||||
data["databases"] = {};
|
||||
}
|
||||
#endif
|
||||
// The role shouldn't be serialized here, it is stored as a foreign key.
|
||||
@ -594,11 +692,20 @@ User User::Deserialize(const nlohmann::json &data) {
|
||||
auto permissions = Permissions::Deserialize(data["permissions"]);
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
Databases db_access;
|
||||
if (data["databases"].is_structured()) {
|
||||
db_access = Databases::Deserialize(data["databases"]);
|
||||
} else {
|
||||
// Back-compatibility
|
||||
spdlog::warn("User without specified database access. Given access to the default database.");
|
||||
db_access.Add(dbms::kDefaultDB);
|
||||
db_access.SetDefault(dbms::kDefaultDB);
|
||||
}
|
||||
if (!data["fine_grained_access_handler"].is_object()) {
|
||||
throw AuthException("Couldn't load user data!");
|
||||
}
|
||||
auto fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data["fine_grained_access_handler"]);
|
||||
return {data["username"], data["password_hash"], permissions, fine_grained_access_handler};
|
||||
return {data["username"], data["password_hash"], permissions, fine_grained_access_handler, db_access};
|
||||
}
|
||||
#endif
|
||||
return {data["username"], data["password_hash"], permissions};
|
||||
|
@ -9,10 +9,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <json/json.hpp>
|
||||
#include "dbms/constants.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
namespace memgraph::auth {
|
||||
// These permissions must have values that are applicable for usage in a
|
||||
@ -41,7 +44,9 @@ enum class Permission : uint64_t {
|
||||
MODULE_WRITE = 1U << 19U,
|
||||
WEBSOCKET = 1U << 20U,
|
||||
TRANSACTION_MANAGEMENT = 1U << 21U,
|
||||
STORAGE_MODE = 1U << 22U
|
||||
STORAGE_MODE = 1U << 22U,
|
||||
MULTI_DATABASE_EDIT = 1U << 23U,
|
||||
MULTI_DATABASE_USE = 1U << 24U,
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
@ -237,6 +242,85 @@ class Role final {
|
||||
|
||||
bool operator==(const Role &first, const Role &second);
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
class Databases final {
|
||||
public:
|
||||
Databases() : grants_dbs_({dbms::kDefaultDB}), allow_all_(false), default_db_(dbms::kDefaultDB) {}
|
||||
|
||||
Databases(const Databases &) = default;
|
||||
Databases &operator=(const Databases &) = default;
|
||||
Databases(Databases &&) noexcept = default;
|
||||
Databases &operator=(Databases &&) noexcept = default;
|
||||
~Databases() = default;
|
||||
|
||||
/**
|
||||
* @brief Add database to the list of granted access. @note allow_all_ will be false after execution
|
||||
*
|
||||
* @param db name of the database to grant access to
|
||||
*/
|
||||
void Add(const std::string &db);
|
||||
|
||||
/**
|
||||
* @brief Remove database to the list of granted access.
|
||||
* @note if allow_all_ is set, the flag will remain set and the
|
||||
* database will be added to the set of denied databases.
|
||||
*
|
||||
* @param db name of the database to grant access to
|
||||
*/
|
||||
void Remove(const std::string &db);
|
||||
|
||||
/**
|
||||
* @brief Called when database is dropped. Removes it from granted (if allow_all is false) and denied set.
|
||||
* @note allow_all_ is not changed
|
||||
*
|
||||
* @param db name of the database to grant access to
|
||||
*/
|
||||
void Delete(const std::string &db);
|
||||
|
||||
/**
|
||||
* @brief Set allow_all_ to true and clears grants and denied sets.
|
||||
*/
|
||||
void GrantAll();
|
||||
|
||||
/**
|
||||
* @brief Set allow_all_ to false and clears grants and denied sets.
|
||||
*/
|
||||
void DenyAll();
|
||||
|
||||
/**
|
||||
* @brief Set the default database.
|
||||
*/
|
||||
bool SetDefault(const std::string &db);
|
||||
|
||||
/**
|
||||
* @brief Checks if access is grated to the database.
|
||||
*
|
||||
* @param db name of the database
|
||||
* @return true if allow_all and not denied or granted
|
||||
*/
|
||||
bool Contains(const std::string &db) const;
|
||||
|
||||
bool GetAllowAll() const { return allow_all_; }
|
||||
const std::set<std::string> &GetGrants() const { return grants_dbs_; }
|
||||
const std::set<std::string> &GetDenies() const { return denies_dbs_; }
|
||||
const std::string &GetDefault() const;
|
||||
|
||||
nlohmann::json Serialize() const;
|
||||
/// @throw AuthException if unable to deserialize.
|
||||
static Databases Deserialize(const nlohmann::json &data);
|
||||
|
||||
private:
|
||||
Databases(bool allow_all, std::set<std::string> grant, std::set<std::string> deny,
|
||||
const std::string &default_db = dbms::kDefaultDB)
|
||||
: grants_dbs_(grant), denies_dbs_(deny), allow_all_(allow_all), default_db_(default_db) {}
|
||||
|
||||
std::set<std::string> grants_dbs_; //!< set of databases with granted access
|
||||
std::set<std::string> denies_dbs_; //!< set of databases with denied access
|
||||
bool allow_all_; //!< flag to allow access to everything (denied overrides this)
|
||||
std::string default_db_; //!< user's default database
|
||||
};
|
||||
#endif
|
||||
|
||||
// TODO (mferencevic): Implement password expiry.
|
||||
class User final {
|
||||
public:
|
||||
@ -246,7 +330,7 @@ class User final {
|
||||
User(const std::string &username, const std::string &password_hash, const Permissions &permissions);
|
||||
#ifdef MG_ENTERPRISE
|
||||
User(const std::string &username, const std::string &password_hash, const Permissions &permissions,
|
||||
FineGrainedAccessHandler fine_grained_access_handler);
|
||||
FineGrainedAccessHandler fine_grained_access_handler, Databases db_access = {});
|
||||
#endif
|
||||
User(const User &) = default;
|
||||
User &operator=(const User &) = default;
|
||||
@ -279,6 +363,11 @@ class User final {
|
||||
|
||||
const Role *role() const;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
Databases &db_access() { return database_access_; }
|
||||
const Databases &db_access() const { return database_access_; }
|
||||
#endif
|
||||
|
||||
nlohmann::json Serialize() const;
|
||||
|
||||
/// @throw AuthException if unable to deserialize.
|
||||
@ -292,6 +381,7 @@ class User final {
|
||||
Permissions permissions_;
|
||||
#ifdef MG_ENTERPRISE
|
||||
FineGrainedAccessHandler fine_grained_access_handler_;
|
||||
Databases database_access_;
|
||||
#endif
|
||||
std::optional<Role> role_;
|
||||
};
|
||||
|
@ -11,6 +11,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <thread>
|
||||
|
||||
@ -24,8 +26,12 @@
|
||||
#include "communication/bolt/v1/states/executing.hpp"
|
||||
#include "communication/bolt/v1/states/handshake.hpp"
|
||||
#include "communication/bolt/v1/states/init.hpp"
|
||||
#include "communication/bolt/v1/value.hpp"
|
||||
#include "dbms/constants.hpp"
|
||||
#include "dbms/global.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/uuid.hpp"
|
||||
|
||||
namespace memgraph::communication::bolt {
|
||||
|
||||
@ -48,14 +54,26 @@ class SessionException : public utils::BasicException {
|
||||
* @tparam TOutputStream type of output stream that will be used
|
||||
*/
|
||||
template <typename TInputStream, typename TOutputStream>
|
||||
class Session {
|
||||
class Session : public dbms::SessionInterface {
|
||||
public:
|
||||
using TEncoder = Encoder<ChunkedEncoderBuffer<TOutputStream>>;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Session object
|
||||
*
|
||||
* @param input_stream stream to read from
|
||||
* @param output_stream stream to write to
|
||||
* @param impl a default high-level implementation to use (has to be defined)
|
||||
*/
|
||||
Session(TInputStream *input_stream, TOutputStream *output_stream)
|
||||
: input_stream_(*input_stream), output_stream_(*output_stream) {}
|
||||
: input_stream_(*input_stream), output_stream_(*output_stream), session_uuid_(utils::GenerateUUID()) {}
|
||||
|
||||
virtual ~Session() {}
|
||||
virtual ~Session() = default;
|
||||
|
||||
Session(const Session &) = delete;
|
||||
Session &operator=(const Session &) = delete;
|
||||
Session(Session &&) noexcept = delete;
|
||||
Session &operator=(Session &&) noexcept = delete;
|
||||
|
||||
/**
|
||||
* Process the given `query` with `params`.
|
||||
@ -66,6 +84,8 @@ class Session {
|
||||
const std::string &query, const std::map<std::string, Value> ¶ms,
|
||||
const std::map<std::string, memgraph::communication::bolt::Value> &extra) = 0;
|
||||
|
||||
virtual void Configure(const std::map<std::string, memgraph::communication::bolt::Value> &run_time_info) = 0;
|
||||
|
||||
/**
|
||||
* Put results of the processed query in the `encoder`.
|
||||
*
|
||||
@ -86,7 +106,7 @@ class Session {
|
||||
*/
|
||||
virtual std::map<std::string, Value> Discard(std::optional<int> n, std::optional<int> qid) = 0;
|
||||
|
||||
virtual void BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &) = 0;
|
||||
virtual void BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> ¶ms) = 0;
|
||||
virtual void CommitTransaction() = 0;
|
||||
virtual void RollbackTransaction() = 0;
|
||||
|
||||
@ -99,7 +119,6 @@ class Session {
|
||||
/** Return the name of the server that should be used for the Bolt INIT
|
||||
* message. */
|
||||
virtual std::optional<std::string> GetServerNameForInit() = 0;
|
||||
|
||||
/**
|
||||
* Executes the session after data has been read into the buffer.
|
||||
* Goes through the bolt states in order to execute commands from the client.
|
||||
@ -161,8 +180,7 @@ class Session {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Rethink if there is a way to hide some members. At the momement all
|
||||
// of them are public.
|
||||
// TODO: Rethink if there is a way to hide some members. At the momement all of them are public.
|
||||
TInputStream &input_stream_;
|
||||
TOutputStream &output_stream_;
|
||||
|
||||
@ -182,6 +200,9 @@ class Session {
|
||||
|
||||
Version version_;
|
||||
|
||||
std::string GetDatabaseName() const override = 0;
|
||||
std::string UUID() const final { return session_uuid_; }
|
||||
|
||||
private:
|
||||
void ClientFailureInvalidData() {
|
||||
// Set the state to Close.
|
||||
@ -197,6 +218,8 @@ class Session {
|
||||
// of the session to trigger session cleanup and socket close.
|
||||
throw SessionException("Something went wrong during session execution!");
|
||||
}
|
||||
|
||||
const std::string session_uuid_; //!< unique identifier of the session (auto generated)
|
||||
};
|
||||
|
||||
} // namespace memgraph::communication::bolt
|
||||
|
@ -11,6 +11,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
@ -207,7 +208,7 @@ State HandleRunV1(TSession &session, const State state, const Marker marker) {
|
||||
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
|
||||
|
||||
spdlog::debug("[Run] '{}'", query.ValueString());
|
||||
spdlog::debug("[Run - {}] '{}'", session.GetDatabaseName(), query.ValueString());
|
||||
|
||||
try {
|
||||
// Interpret can throw.
|
||||
@ -265,7 +266,13 @@ State HandleRunV4(TSession &session, const State state, const Marker marker) {
|
||||
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
|
||||
|
||||
spdlog::debug("[Run] '{}'", query.ValueString());
|
||||
try {
|
||||
session.Configure(extra.ValueMap());
|
||||
} catch (const std::exception &e) {
|
||||
return HandleFailure(session, e);
|
||||
}
|
||||
|
||||
spdlog::debug("[Run - {}] '{}'", session.GetDatabaseName(), query.ValueString());
|
||||
|
||||
try {
|
||||
// Interpret can throw.
|
||||
@ -381,6 +388,7 @@ State HandleBegin(TSession &session, const State state, const Marker marker) {
|
||||
}
|
||||
|
||||
try {
|
||||
session.Configure(extra.ValueMap());
|
||||
session.BeginTransaction(extra.ValueMap());
|
||||
} catch (const std::exception &e) {
|
||||
return HandleFailure(session, e);
|
||||
@ -489,7 +497,7 @@ State HandleRoute(TSession &session, const Marker marker) {
|
||||
|
||||
template <typename TSession>
|
||||
State HandleLogOff() {
|
||||
// Not arguments sent, the user just needs to reauthenticate
|
||||
// No arguments sent, the user just needs to reauthenticate
|
||||
return State::Init;
|
||||
}
|
||||
} // namespace memgraph::communication::bolt
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "communication/bolt/v1/state.hpp"
|
||||
#include "communication/bolt/v1/value.hpp"
|
||||
#include "communication/exceptions.hpp"
|
||||
#include "spdlog/spdlog.h"
|
||||
#include "utils/likely.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
@ -248,8 +249,9 @@ State StateInitRunV5(TSession &session, Marker marker, Signature signature) {
|
||||
}
|
||||
// Stay in Init
|
||||
return State::Init;
|
||||
}
|
||||
|
||||
} else if (signature == Signature::LogOn) {
|
||||
if (signature == Signature::LogOn) {
|
||||
if (marker != Marker::TinyStruct1) [[unlikely]] {
|
||||
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
|
||||
spdlog::trace(
|
||||
@ -273,11 +275,10 @@ State StateInitRunV5(TSession &session, Marker marker, Signature signature) {
|
||||
return State::Close;
|
||||
}
|
||||
return State::Idle;
|
||||
|
||||
} else [[unlikely]] {
|
||||
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
|
||||
return State::Close;
|
||||
}
|
||||
} // namespace details
|
||||
|
||||
|
@ -27,11 +27,11 @@
|
||||
|
||||
namespace memgraph::communication::http {
|
||||
|
||||
template <class TRequestHandler, typename TSessionData>
|
||||
class Listener final : public std::enable_shared_from_this<Listener<TRequestHandler, TSessionData>> {
|
||||
template <class TRequestHandler, typename TSessionContext>
|
||||
class Listener final : public std::enable_shared_from_this<Listener<TRequestHandler, TSessionContext>> {
|
||||
using tcp = boost::asio::ip::tcp;
|
||||
using SessionHandler = Session<TRequestHandler, TSessionData>;
|
||||
using std::enable_shared_from_this<Listener<TRequestHandler, TSessionData>>::shared_from_this;
|
||||
using SessionHandler = Session<TRequestHandler, TSessionContext>;
|
||||
using std::enable_shared_from_this<Listener<TRequestHandler, TSessionContext>>::shared_from_this;
|
||||
|
||||
public:
|
||||
Listener(const Listener &) = delete;
|
||||
@ -50,8 +50,9 @@ class Listener final : public std::enable_shared_from_this<Listener<TRequestHand
|
||||
tcp::endpoint GetEndpoint() const { return acceptor_.local_endpoint(); }
|
||||
|
||||
private:
|
||||
Listener(boost::asio::io_context &ioc, TSessionData *data, ServerContext *context, tcp::endpoint endpoint)
|
||||
: ioc_(ioc), data_(data), context_(context), acceptor_(ioc) {
|
||||
Listener(boost::asio::io_context &ioc, TSessionContext *session_context, ServerContext *context,
|
||||
tcp::endpoint endpoint)
|
||||
: ioc_(ioc), session_context_(session_context), context_(context), acceptor_(ioc) {
|
||||
boost::beast::error_code ec;
|
||||
|
||||
// Open the acceptor
|
||||
@ -95,13 +96,13 @@ class Listener final : public std::enable_shared_from_this<Listener<TRequestHand
|
||||
return LogError(ec, "accept");
|
||||
}
|
||||
|
||||
SessionHandler::Create(std::move(socket), data_, *context_)->Run();
|
||||
SessionHandler::Create(std::move(socket), session_context_, *context_)->Run();
|
||||
|
||||
DoAccept();
|
||||
}
|
||||
|
||||
boost::asio::io_context &ioc_;
|
||||
TSessionData *data_;
|
||||
TSessionContext *session_context_;
|
||||
ServerContext *context_;
|
||||
tcp::acceptor acceptor_;
|
||||
};
|
||||
|
@ -21,14 +21,15 @@
|
||||
|
||||
namespace memgraph::communication::http {
|
||||
|
||||
template <class TRequestHandler, typename TSessionData>
|
||||
template <class TRequestHandler, typename TSessionContext>
|
||||
class Server final {
|
||||
using tcp = boost::asio::ip::tcp;
|
||||
|
||||
public:
|
||||
explicit Server(io::network::Endpoint endpoint, TSessionData *data, ServerContext *context)
|
||||
: listener_{Listener<TRequestHandler, TSessionData>::Create(
|
||||
ioc_, data, context, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port})} {}
|
||||
explicit Server(io::network::Endpoint endpoint, TSessionContext *session_context, ServerContext *context)
|
||||
: listener_{Listener<TRequestHandler, TSessionContext>::Create(
|
||||
ioc_, session_context, context,
|
||||
tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port})} {}
|
||||
|
||||
Server(const Server &) = delete;
|
||||
Server(Server &&) = delete;
|
||||
@ -59,7 +60,7 @@ class Server final {
|
||||
private:
|
||||
boost::asio::io_context ioc_;
|
||||
|
||||
std::shared_ptr<Listener<TRequestHandler, TSessionData>> listener_;
|
||||
std::shared_ptr<Listener<TRequestHandler, TSessionContext>> listener_;
|
||||
std::optional<std::thread> background_thread_;
|
||||
};
|
||||
} // namespace memgraph::communication::http
|
||||
|
@ -42,10 +42,10 @@ inline void LogError(boost::beast::error_code ec, const std::string_view what) {
|
||||
spdlog::warn("HTTP session failed on {}: {}", what, ec.message());
|
||||
}
|
||||
|
||||
template <class TRequestHandler, typename TSessionData>
|
||||
class Session : public std::enable_shared_from_this<Session<TRequestHandler, TSessionData>> {
|
||||
template <class TRequestHandler, typename TSessionContext>
|
||||
class Session : public std::enable_shared_from_this<Session<TRequestHandler, TSessionContext>> {
|
||||
using tcp = boost::asio::ip::tcp;
|
||||
using std::enable_shared_from_this<Session<TRequestHandler, TSessionData>>::shared_from_this;
|
||||
using std::enable_shared_from_this<Session<TRequestHandler, TSessionContext>>::shared_from_this;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
@ -72,7 +72,7 @@ class Session : public std::enable_shared_from_this<Session<TRequestHandler, TSe
|
||||
using PlainSocket = boost::beast::tcp_stream;
|
||||
using SSLSocket = boost::beast::ssl_stream<boost::beast::tcp_stream>;
|
||||
|
||||
explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &context)
|
||||
explicit Session(tcp::socket &&socket, TSessionContext *data, ServerContext &context)
|
||||
: stream_(CreateSocket(std::move(socket), context)),
|
||||
handler_(data),
|
||||
strand_{boost::asio::make_strand(GetExecutor())} {}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -39,7 +39,7 @@ namespace memgraph::communication {
|
||||
* second, checks all sessions for expiration and shuts them down if they have
|
||||
* expired.
|
||||
*/
|
||||
template <class TSession, class TSessionData>
|
||||
template <class TSession, class TSessionContext>
|
||||
class Listener final {
|
||||
private:
|
||||
// The maximum number of events handled per execution thread is 1. This is
|
||||
@ -48,10 +48,10 @@ class Listener final {
|
||||
// can take a long time.
|
||||
static const int kMaxEvents = 1;
|
||||
|
||||
using SessionHandler = Session<TSession, TSessionData>;
|
||||
using SessionHandler = Session<TSession, TSessionContext>;
|
||||
|
||||
public:
|
||||
Listener(TSessionData *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name,
|
||||
Listener(TSessionContext *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name,
|
||||
size_t workers_count)
|
||||
: data_(data),
|
||||
alive_(false),
|
||||
@ -259,7 +259,7 @@ class Listener final {
|
||||
|
||||
io::network::Epoll epoll_;
|
||||
|
||||
TSessionData *data_;
|
||||
TSessionContext *data_;
|
||||
|
||||
utils::SpinLock lock_;
|
||||
std::vector<std::unique_ptr<SessionHandler>> sessions_;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -46,10 +46,10 @@ namespace memgraph::communication {
|
||||
* @tparam TSession the server can handle different Sessions, each session
|
||||
* represents a different protocol so the same network infrastructure
|
||||
* can be used for handling different protocols
|
||||
* @tparam TSessionData the class with objects that will be forwarded to the
|
||||
* @tparam TSessionContext the class with objects that will be forwarded to the
|
||||
* session
|
||||
*/
|
||||
template <typename TSession, typename TSessionData>
|
||||
template <typename TSession, typename TSessionContext>
|
||||
class Server final {
|
||||
public:
|
||||
using Socket = io::network::Socket;
|
||||
@ -58,12 +58,12 @@ class Server final {
|
||||
* Constructs and binds server to endpoint, operates on session data and
|
||||
* invokes workers_count workers
|
||||
*/
|
||||
Server(const io::network::Endpoint &endpoint, TSessionData *session_data, ServerContext *context,
|
||||
Server(const io::network::Endpoint &endpoint, TSessionContext *session_context, ServerContext *context,
|
||||
int inactivity_timeout_sec, const std::string &service_name,
|
||||
size_t workers_count = std::thread::hardware_concurrency())
|
||||
: alive_(false),
|
||||
endpoint_(endpoint),
|
||||
listener_(session_data, context, inactivity_timeout_sec, service_name, workers_count),
|
||||
listener_(session_context, context, inactivity_timeout_sec, service_name, workers_count),
|
||||
service_name_(service_name) {}
|
||||
|
||||
~Server() {
|
||||
@ -156,7 +156,7 @@ class Server final {
|
||||
|
||||
Socket socket_;
|
||||
io::network::Endpoint endpoint_;
|
||||
Listener<TSession, TSessionData> listener_;
|
||||
Listener<TSession, TSessionContext> listener_;
|
||||
|
||||
const std::string service_name_;
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -69,10 +69,10 @@ class OutputStream final {
|
||||
* sessions. It handles socket ownership, inactivity timeout and protocol
|
||||
* wrapping.
|
||||
*/
|
||||
template <class TSession, class TSessionData>
|
||||
template <class TSession, class TSessionContext>
|
||||
class Session final {
|
||||
public:
|
||||
Session(io::network::Socket &&socket, TSessionData *data, ServerContext *context, int inactivity_timeout_sec)
|
||||
Session(io::network::Socket &&socket, TSessionContext *data, ServerContext *context, int inactivity_timeout_sec)
|
||||
: socket_(std::move(socket)),
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
|
||||
session_(data, socket_.endpoint(), input_buffer_.read_end(), &output_stream_),
|
||||
|
@ -36,11 +36,11 @@
|
||||
|
||||
namespace memgraph::communication::v2 {
|
||||
|
||||
template <class TSession, class TSessionData>
|
||||
class Listener final : public std::enable_shared_from_this<Listener<TSession, TSessionData>> {
|
||||
template <class TSession, class TSessionContext>
|
||||
class Listener final : public std::enable_shared_from_this<Listener<TSession, TSessionContext>> {
|
||||
using tcp = boost::asio::ip::tcp;
|
||||
using SessionHandler = Session<TSession, TSessionData>;
|
||||
using std::enable_shared_from_this<Listener<TSession, TSessionData>>::shared_from_this;
|
||||
using SessionHandler = Session<TSession, TSessionContext>;
|
||||
using std::enable_shared_from_this<Listener<TSession, TSessionContext>>::shared_from_this;
|
||||
|
||||
public:
|
||||
Listener(const Listener &) = delete;
|
||||
@ -59,10 +59,10 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
|
||||
bool IsRunning() const noexcept { return alive_.load(std::memory_order_relaxed); }
|
||||
|
||||
private:
|
||||
Listener(boost::asio::io_context &io_context, TSessionData *data, ServerContext *server_context,
|
||||
Listener(boost::asio::io_context &io_context, TSessionContext *session_context, ServerContext *server_context,
|
||||
tcp::endpoint &endpoint, const std::string_view service_name, const uint64_t inactivity_timeout_sec)
|
||||
: io_context_(io_context),
|
||||
data_(data),
|
||||
session_context_(session_context),
|
||||
server_context_(server_context),
|
||||
acceptor_(io_context_),
|
||||
endpoint_{endpoint},
|
||||
@ -111,8 +111,8 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
|
||||
return OnError(ec, "accept");
|
||||
}
|
||||
|
||||
auto session = SessionHandler::Create(std::move(socket), data_, *server_context_, endpoint_, inactivity_timeout_,
|
||||
service_name_);
|
||||
auto session = SessionHandler::Create(std::move(socket), session_context_, *server_context_, endpoint_,
|
||||
inactivity_timeout_, service_name_);
|
||||
session->Start();
|
||||
DoAccept();
|
||||
}
|
||||
@ -123,7 +123,7 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
|
||||
}
|
||||
|
||||
boost::asio::io_context &io_context_;
|
||||
TSessionData *data_;
|
||||
TSessionContext *session_context_;
|
||||
ServerContext *server_context_;
|
||||
tcp::acceptor acceptor_;
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -60,27 +60,27 @@ using ServerEndpoint = boost::asio::ip::tcp::endpoint;
|
||||
* @tparam TSession the server can handle different Sessions, each session
|
||||
* represents a different protocol so the same network infrastructure
|
||||
* can be used for handling different protocols
|
||||
* @tparam TSessionData the class with objects that will be forwarded to the
|
||||
* @tparam TSessionContext the class with objects that will be forwarded to the
|
||||
* session
|
||||
*/
|
||||
template <typename TSession, typename TSessionData>
|
||||
template <typename TSession, typename TSessionContext>
|
||||
class Server final {
|
||||
using ServerHandler = Server<TSession, TSessionData>;
|
||||
using ServerHandler = Server<TSession, TSessionContext>;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Constructs and binds server to endpoint, operates on session data and
|
||||
* invokes workers_count workers
|
||||
*/
|
||||
Server(ServerEndpoint &endpoint, TSessionData *session_data, ServerContext *server_context,
|
||||
Server(ServerEndpoint &endpoint, TSessionContext *session_context, ServerContext *server_context,
|
||||
const int inactivity_timeout_sec, const std::string_view service_name,
|
||||
size_t workers_count = std::thread::hardware_concurrency())
|
||||
: endpoint_{endpoint},
|
||||
service_name_{service_name},
|
||||
context_thread_pool_{workers_count},
|
||||
listener_{Listener<TSession, TSessionData>::Create(context_thread_pool_.GetIOContext(), session_data,
|
||||
server_context, endpoint_, service_name_,
|
||||
inactivity_timeout_sec)} {}
|
||||
listener_{Listener<TSession, TSessionContext>::Create(context_thread_pool_.GetIOContext(), session_context,
|
||||
server_context, endpoint_, service_name_,
|
||||
inactivity_timeout_sec)} {}
|
||||
|
||||
~Server() { MG_ASSERT(!IsRunning(), "Server wasn't shutdown properly"); }
|
||||
|
||||
@ -122,7 +122,7 @@ class Server final {
|
||||
std::string service_name_;
|
||||
|
||||
IOContextThreadPool context_thread_pool_;
|
||||
std::shared_ptr<Listener<TSession, TSessionData>> listener_;
|
||||
std::shared_ptr<Listener<TSession, TSessionContext>> listener_;
|
||||
};
|
||||
|
||||
} // namespace memgraph::communication::v2
|
||||
|
@ -16,10 +16,12 @@
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <deque>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
@ -41,9 +43,11 @@
|
||||
#include <boost/beast/websocket/rfc6455.hpp>
|
||||
#include <boost/system/detail/error_code.hpp>
|
||||
|
||||
#include "communication/bolt/v1/session.hpp"
|
||||
#include "communication/buffer.hpp"
|
||||
#include "communication/context.hpp"
|
||||
#include "communication/exceptions.hpp"
|
||||
#include "dbms/global.hpp"
|
||||
#include "utils/event_counter.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/on_scope_exit.hpp"
|
||||
@ -95,10 +99,10 @@ class OutputStream final {
|
||||
* Websocket Sessions. It handles socket ownership, inactivity timeout and protocol
|
||||
* wrapping.
|
||||
*/
|
||||
template <typename TSession, typename TSessionData>
|
||||
class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TSession, TSessionData>> {
|
||||
template <typename TSession, typename TSessionContext>
|
||||
class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TSession, TSessionContext>> {
|
||||
using WebSocket = boost::beast::websocket::stream<boost::beast::tcp_stream>;
|
||||
using std::enable_shared_from_this<WebsocketSession<TSession, TSessionData>>::shared_from_this;
|
||||
using std::enable_shared_from_this<WebsocketSession<TSession, TSessionContext>>::shared_from_this;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
@ -106,6 +110,17 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
|
||||
return std::shared_ptr<WebsocketSession>(new WebsocketSession(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
~WebsocketSession() { session_context_->Delete(session_); }
|
||||
#else
|
||||
~WebsocketSession() = default;
|
||||
#endif
|
||||
|
||||
WebsocketSession(const WebsocketSession &) = delete;
|
||||
WebsocketSession &operator=(const WebsocketSession &) = delete;
|
||||
WebsocketSession(WebsocketSession &&) noexcept = delete;
|
||||
WebsocketSession &operator=(WebsocketSession &&) noexcept = delete;
|
||||
|
||||
// Start the asynchronous accept operation
|
||||
template <class Body, class Allocator>
|
||||
void DoAccept(boost::beast::http::request<Body, boost::beast::http::basic_fields<Allocator>> req) {
|
||||
@ -151,15 +166,20 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
|
||||
|
||||
private:
|
||||
// Take ownership of the socket
|
||||
explicit WebsocketSession(tcp::socket &&socket, TSessionData *data, tcp::endpoint endpoint,
|
||||
explicit WebsocketSession(tcp::socket &&socket, TSessionContext *session_context, tcp::endpoint endpoint,
|
||||
std::string_view service_name)
|
||||
: ws_(std::move(socket)),
|
||||
strand_{boost::asio::make_strand(ws_.get_executor())},
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool /*have_more*/) { return Write(data, len); }),
|
||||
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
|
||||
session_{*session_context, endpoint, input_buffer_.read_end(), &output_stream_},
|
||||
session_context_{session_context},
|
||||
endpoint_{endpoint},
|
||||
remote_endpoint_{ws_.next_layer().socket().remote_endpoint()},
|
||||
service_name_{service_name} {}
|
||||
service_name_{service_name} {
|
||||
#ifdef MG_ENTERPRISE
|
||||
session_context_->Register(session_);
|
||||
#endif
|
||||
}
|
||||
|
||||
void OnAccept(boost::beast::error_code ec) {
|
||||
if (ec) {
|
||||
@ -242,6 +262,7 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
|
||||
communication::Buffer input_buffer_;
|
||||
OutputStream output_stream_;
|
||||
TSession session_;
|
||||
TSessionContext *session_context_;
|
||||
tcp::endpoint endpoint_;
|
||||
tcp::endpoint remote_endpoint_;
|
||||
std::string_view service_name_;
|
||||
@ -253,11 +274,11 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
|
||||
* Sessions. It handles socket ownership, inactivity timeout and protocol
|
||||
* wrapping.
|
||||
*/
|
||||
template <typename TSession, typename TSessionData>
|
||||
class Session final : public std::enable_shared_from_this<Session<TSession, TSessionData>> {
|
||||
template <typename TSession, typename TSessionContext>
|
||||
class Session final : public std::enable_shared_from_this<Session<TSession, TSessionContext>> {
|
||||
using TCPSocket = tcp::socket;
|
||||
using SSLSocket = boost::asio::ssl::stream<TCPSocket>;
|
||||
using std::enable_shared_from_this<Session<TSession, TSessionData>>::shared_from_this;
|
||||
using std::enable_shared_from_this<Session<TSession, TSessionContext>>::shared_from_this;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
@ -265,11 +286,16 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
|
||||
return std::shared_ptr<Session>(new Session(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
~Session() { session_context_->Delete(session_); }
|
||||
#else
|
||||
~Session() = default;
|
||||
#endif
|
||||
|
||||
Session(const Session &) = delete;
|
||||
Session(Session &&) = delete;
|
||||
Session &operator=(const Session &) = delete;
|
||||
Session &operator=(Session &&) = delete;
|
||||
~Session() = default;
|
||||
|
||||
bool Start() {
|
||||
if (execution_active_) {
|
||||
@ -334,18 +360,23 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
|
||||
}
|
||||
|
||||
private:
|
||||
explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &server_context, tcp::endpoint endpoint,
|
||||
const std::chrono::seconds inactivity_timeout_sec, std::string_view service_name)
|
||||
explicit Session(tcp::socket &&socket, TSessionContext *session_context, ServerContext &server_context,
|
||||
tcp::endpoint endpoint, const std::chrono::seconds inactivity_timeout_sec,
|
||||
std::string_view service_name)
|
||||
: socket_(CreateSocket(std::move(socket), server_context)),
|
||||
strand_{boost::asio::make_strand(GetExecutor())},
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
|
||||
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
|
||||
data_{data},
|
||||
session_{*session_context, endpoint, input_buffer_.read_end(), &output_stream_},
|
||||
session_context_{session_context},
|
||||
endpoint_{endpoint},
|
||||
remote_endpoint_{GetRemoteEndpoint()},
|
||||
service_name_{service_name},
|
||||
timeout_seconds_(inactivity_timeout_sec),
|
||||
timeout_timer_(GetExecutor()) {
|
||||
#ifdef MG_ENTERPRISE
|
||||
// TODO Try to remove Register (see comment at SessionInterface declaration)
|
||||
session_context_->Register(session_);
|
||||
#endif
|
||||
ExecuteForSocket([](auto &&socket) {
|
||||
socket.lowest_layer().set_option(tcp::no_delay(true)); // enable PSH
|
||||
socket.lowest_layer().set_option(boost::asio::socket_base::keep_alive(true)); // enable SO_KEEPALIVE
|
||||
@ -396,7 +427,8 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
|
||||
spdlog::info("Switching {} to websocket connection", remote_endpoint_);
|
||||
if (std::holds_alternative<TCPSocket>(socket_)) {
|
||||
auto sock = std::get<TCPSocket>(std::move(socket_));
|
||||
WebsocketSession<TSession, TSessionData>::Create(std::move(sock), data_, endpoint_, service_name_)
|
||||
WebsocketSession<TSession, TSessionContext>::Create(std::move(sock), session_context_, endpoint_,
|
||||
service_name_)
|
||||
->DoAccept(parser.release());
|
||||
execution_active_ = false;
|
||||
return;
|
||||
@ -535,7 +567,7 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
|
||||
communication::Buffer input_buffer_;
|
||||
OutputStream output_stream_;
|
||||
TSession session_;
|
||||
TSessionData *data_;
|
||||
TSessionContext *session_context_;
|
||||
tcp::endpoint endpoint_;
|
||||
tcp::endpoint remote_endpoint_;
|
||||
std::string_view service_name_;
|
||||
|
18
src/dbms/constants.hpp
Normal file
18
src/dbms/constants.hpp
Normal file
@ -0,0 +1,18 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace memgraph::dbms {
|
||||
|
||||
constexpr static const char *kDefaultDB = "memgraph"; //!< Name of the default database
|
||||
|
||||
} // namespace memgraph::dbms
|
110
src/dbms/global.hpp
Normal file
110
src/dbms/global.hpp
Normal file
@ -0,0 +1,110 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "utils/exceptions.hpp"
|
||||
|
||||
namespace memgraph::dbms {
|
||||
|
||||
enum class DeleteError : uint8_t {
|
||||
DEFAULT_DB,
|
||||
USING,
|
||||
NON_EXISTENT,
|
||||
FAIL,
|
||||
DISK_FAIL,
|
||||
};
|
||||
|
||||
enum class NewError : uint8_t {
|
||||
NO_CONFIGS,
|
||||
EXISTS,
|
||||
DEFUNCT,
|
||||
GENERIC,
|
||||
};
|
||||
|
||||
enum class SetForResult : uint8_t {
|
||||
SUCCESS,
|
||||
ALREADY_SET,
|
||||
FAIL,
|
||||
};
|
||||
|
||||
/**
|
||||
* UnknownSession Exception
|
||||
*
|
||||
* Used to indicate that an unknown session was used.
|
||||
*/
|
||||
class UnknownSessionException : public utils::BasicException {
|
||||
public:
|
||||
using utils::BasicException::BasicException;
|
||||
};
|
||||
|
||||
/**
|
||||
* UnknownDatabase Exception
|
||||
*
|
||||
* Used to indicate that an unknown database was used.
|
||||
*/
|
||||
class UnknownDatabaseException : public utils::BasicException {
|
||||
public:
|
||||
using utils::BasicException::BasicException;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Session interface used by the DBMS to handle the the active sessions.
|
||||
* @todo Try to remove this dependency from SessionContextHandler. OnDelete could be removed, as it only does an assert.
|
||||
* OnChange could be removed if SetFor returned the pointer and the called then handled the OnChange execution.
|
||||
* However, the interface is very useful to decouple the interpreter's query execution and the sessions themselves.
|
||||
*/
|
||||
class SessionInterface {
|
||||
public:
|
||||
SessionInterface() = default;
|
||||
virtual ~SessionInterface() = default;
|
||||
|
||||
SessionInterface(const SessionInterface &) = default;
|
||||
SessionInterface &operator=(const SessionInterface &) = default;
|
||||
SessionInterface(SessionInterface &&) noexcept = default;
|
||||
SessionInterface &operator=(SessionInterface &&) noexcept = default;
|
||||
|
||||
/**
|
||||
* @brief Return the unique string identifying the session.
|
||||
*
|
||||
* @return std::string
|
||||
*/
|
||||
virtual std::string UUID() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Return the currently active database.
|
||||
*
|
||||
* @return std::string
|
||||
*/
|
||||
virtual std::string GetDatabaseName() const = 0;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
/**
|
||||
* @brief Gets called on database change.
|
||||
*
|
||||
* @return SetForResult enum (SUCCESS, ALREADY_SET or FAIL)
|
||||
*/
|
||||
virtual dbms::SetForResult OnChange(const std::string &) = 0;
|
||||
|
||||
/**
|
||||
* @brief Callback that gets called on database delete (drop).
|
||||
*
|
||||
* @return true on success
|
||||
*/
|
||||
virtual bool OnDelete(const std::string &) = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace memgraph::dbms
|
142
src/dbms/handler.hpp
Normal file
142
src/dbms/handler.hpp
Normal file
@ -0,0 +1,142 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "global.hpp"
|
||||
#include "utils/result.hpp"
|
||||
#include "utils/sync_ptr.hpp"
|
||||
|
||||
namespace memgraph::dbms {
|
||||
|
||||
/**
|
||||
* @brief Generic multi-database content handler.
|
||||
*
|
||||
* @tparam TContext
|
||||
* @tparam TConfig
|
||||
*/
|
||||
template <typename TContext, typename TConfig>
|
||||
class Handler {
|
||||
public:
|
||||
using NewResult = utils::BasicResult<NewError, std::shared_ptr<TContext>>;
|
||||
|
||||
/**
|
||||
* @brief Empty Handler constructor.
|
||||
*
|
||||
*/
|
||||
Handler() {}
|
||||
|
||||
/**
|
||||
* @brief Generate a new context and corresponding configuration.
|
||||
*
|
||||
* @tparam T1 Variadic template of context constructor arguments
|
||||
* @tparam T2 Variadic template of config constructor arguments
|
||||
* @param name Name associated with the new context/config pair
|
||||
* @param args1 Arguments passed (as a tuple) to the context constructor
|
||||
* @param args2 Arguments passed (as a tuple) to the config constructor
|
||||
* @return NewResult
|
||||
*/
|
||||
template <typename... T1, typename... T2>
|
||||
NewResult New(std::string name, std::tuple<T1...> args1, std::tuple<T2...> args2) {
|
||||
return New_(name, args1, args2, std::make_index_sequence<sizeof...(T1)>{},
|
||||
std::make_index_sequence<sizeof...(T2)>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get pointer to context.
|
||||
*
|
||||
* @param name Name associated with the wanted context
|
||||
* @return std::optional<std::shared_ptr<TContext>>
|
||||
*/
|
||||
std::optional<std::shared_ptr<TContext>> Get(const std::string &name) {
|
||||
if (auto search = items_.find(name); search != items_.end()) {
|
||||
return search->second.get();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the config.
|
||||
*
|
||||
* @param name Name associated with the wanted config
|
||||
* @return std::optional<TConfig>
|
||||
*/
|
||||
std::optional<TConfig> GetConfig(const std::string &name) const {
|
||||
if (auto search = items_.find(name); search != items_.end()) {
|
||||
return search->second.config();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Delete the context/config pair associated with the name.
|
||||
*
|
||||
* @param name Name associated with the context/config pair to delete
|
||||
* @return true on success
|
||||
*/
|
||||
bool Delete(const std::string &name) {
|
||||
if (auto itr = items_.find(name); itr != items_.end()) {
|
||||
itr->second.DestroyAndSync();
|
||||
items_.erase(itr);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check if a name is already used.
|
||||
*
|
||||
* @param name Name to check
|
||||
* @return true if a context/config pair is already associated with the name
|
||||
*/
|
||||
bool Has(const std::string &name) const { return items_.find(name) != items_.end(); }
|
||||
|
||||
auto begin() { return items_.begin(); }
|
||||
auto end() { return items_.end(); }
|
||||
auto begin() const { return items_.begin(); }
|
||||
auto end() const { return items_.end(); }
|
||||
auto cbegin() const { return items_.cbegin(); }
|
||||
auto cend() const { return items_.cend(); }
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Lower level handler that hides some ugly code.
|
||||
*
|
||||
* @tparam T1 Variadic template of context constructor arguments
|
||||
* @tparam T2 Variadic template of config constructor arguments
|
||||
* @tparam I1 List of indexes associated with the first tuple
|
||||
* @tparam I2 List of indexes associated with the second tuple
|
||||
*/
|
||||
template <typename... T1, typename... T2, std::size_t... I1, std::size_t... I2>
|
||||
NewResult New_(std::string name, std::tuple<T1...> &args1, std::tuple<T2...> &args2,
|
||||
std::integer_sequence<std::size_t, I1...> /*not-used*/,
|
||||
std::integer_sequence<std::size_t, I2...> /*not-used*/) {
|
||||
// Make sure the emplace will succeed, since we don't want to create temporary objects that could break something
|
||||
if (!Has(name)) {
|
||||
auto [itr, _] = items_.emplace(std::piecewise_construct, std::forward_as_tuple(name),
|
||||
std::forward_as_tuple(TConfig{std::forward<T1>(std::get<I1>(args1))...},
|
||||
std::forward<T2>(std::get<I2>(args2))...));
|
||||
return itr->second.get();
|
||||
}
|
||||
spdlog::info("Item with name \"{}\" already exists.", name);
|
||||
return NewError::EXISTS;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, utils::SyncPtr<TContext, TConfig>> items_; //!< map to all active items
|
||||
};
|
||||
|
||||
} // namespace memgraph::dbms
|
106
src/dbms/interp_handler.hpp
Normal file
106
src/dbms/interp_handler.hpp
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
|
||||
#include "global.hpp"
|
||||
#include "query/auth_checker.hpp"
|
||||
#include "query/config.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
#include "storage/v2/storage.hpp"
|
||||
|
||||
#include "handler.hpp"
|
||||
|
||||
namespace memgraph::dbms {
|
||||
|
||||
/**
|
||||
* @brief Simple class that adds useful information to the query's InterpreterContext
|
||||
*
|
||||
* @tparam T Multi-database handler type
|
||||
*/
|
||||
template <typename T>
|
||||
class ExpandedInterpContext : public query::InterpreterContext {
|
||||
public:
|
||||
template <typename... TArgs>
|
||||
explicit ExpandedInterpContext(T &ref, TArgs &&...args)
|
||||
: query::InterpreterContext(std::forward<TArgs>(args)...), sc_handler_(ref) {}
|
||||
|
||||
T &sc_handler_; //!< Multi-database/SessionContext handler (used in some queries)
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Simple structure that expands on the query's InterpreterConfig
|
||||
*
|
||||
*/
|
||||
struct ExpandedInterpConfig {
|
||||
storage::Config storage_config; //!< Storage configuration
|
||||
query::InterpreterConfig interp_config; //!< Interpreter configuration
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Multi-database interpreter context handler
|
||||
*
|
||||
* @tparam TSCHandler High-level multi-database/SessionContext handler type
|
||||
*/
|
||||
template <typename TSCHandler>
|
||||
class InterpContextHandler : public Handler<ExpandedInterpContext<TSCHandler>, ExpandedInterpConfig> {
|
||||
public:
|
||||
using InterpContextT = ExpandedInterpContext<TSCHandler>;
|
||||
using HandlerT = Handler<InterpContextT, ExpandedInterpConfig>;
|
||||
|
||||
/**
|
||||
* @brief Generate a new interpreter context associated with the passed name.
|
||||
*
|
||||
* @param name Name associating the new interpreter context
|
||||
* @param sc_handler Multi-database/SessionContext handler used (some queries might use it)
|
||||
* @param db Storage associated with the interpreter context
|
||||
* @param config Interpreter's configuration
|
||||
* @param dir Directory used by the interpreter
|
||||
* @param auth_handler AuthQueryHandler used
|
||||
* @param auth_checker AuthChecker used
|
||||
* @return HandlerT::NewResult
|
||||
*/
|
||||
typename HandlerT::NewResult New(const std::string &name, TSCHandler &sc_handler, storage::Config storage_config,
|
||||
const query::InterpreterConfig &interpreter_config,
|
||||
query::AuthQueryHandler &auth_handler, query::AuthChecker &auth_checker) {
|
||||
// Check if compatible with the existing interpreters
|
||||
if (std::any_of(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) {
|
||||
const auto &config = elem.second.config().storage_config;
|
||||
return config.durability.storage_directory == storage_config.durability.storage_directory;
|
||||
})) {
|
||||
spdlog::info("Tried to generate a new context using claimed directory and/or storage.");
|
||||
return NewError::EXISTS;
|
||||
}
|
||||
const auto dir = storage_config.durability.storage_directory;
|
||||
storage_config.name = name; // Set storage id via config
|
||||
return HandlerT::New(
|
||||
name, std::forward_as_tuple(storage_config, interpreter_config),
|
||||
std::forward_as_tuple(sc_handler, storage_config, interpreter_config, dir, &auth_handler, &auth_checker));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief All currently active storage.
|
||||
*
|
||||
* @return std::vector<std::string>
|
||||
*/
|
||||
std::vector<std::string> All() const {
|
||||
std::vector<std::string> res;
|
||||
res.reserve(std::distance(HandlerT::cbegin(), HandlerT::cend()));
|
||||
std::for_each(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) { res.push_back(elem.first); });
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace memgraph::dbms
|
||||
|
||||
#endif
|
61
src/dbms/session_context.hpp
Normal file
61
src/dbms/session_context.hpp
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "auth/auth.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
#include "storage/v2/storage.hpp"
|
||||
#include "utils/synchronized.hpp"
|
||||
|
||||
#if MG_ENTERPRISE
|
||||
#include "audit/log.hpp"
|
||||
#endif
|
||||
namespace memgraph::dbms {
|
||||
|
||||
/**
|
||||
* @brief Structure encapsulating storage and interpreter context.
|
||||
*
|
||||
* @note Each session contains a copy.
|
||||
*/
|
||||
struct SessionContext {
|
||||
// Explicit constructor here to ensure that pointers to all objects are
|
||||
// supplied.
|
||||
|
||||
SessionContext(std::shared_ptr<memgraph::query::InterpreterContext> interpreter_context, std::string run,
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
|
||||
#ifdef MG_ENTERPRISE
|
||||
,
|
||||
memgraph::audit::Log *audit_log
|
||||
#endif
|
||||
)
|
||||
: interpreter_context(interpreter_context),
|
||||
run_id(run),
|
||||
auth(auth)
|
||||
#ifdef MG_ENTERPRISE
|
||||
,
|
||||
audit_log(audit_log)
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<memgraph::query::InterpreterContext> interpreter_context;
|
||||
std::string run_id;
|
||||
|
||||
// std::shared_ptr<AuthContext> auth_context;
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
memgraph::audit::Log *audit_log;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace memgraph::dbms
|
603
src/dbms/session_context_handler.hpp
Normal file
603
src/dbms/session_context_handler.hpp
Normal file
@ -0,0 +1,603 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <concepts>
|
||||
#include <cstdint>
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <system_error>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "constants.hpp"
|
||||
#include "global.hpp"
|
||||
#include "interp_handler.hpp"
|
||||
#include "query/auth_checker.hpp"
|
||||
#include "query/config.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
#include "session_context.hpp"
|
||||
#include "spdlog/spdlog.h"
|
||||
#include "storage/v2/durability/durability.hpp"
|
||||
#include "storage/v2/durability/paths.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/file.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/result.hpp"
|
||||
#include "utils/rw_lock.hpp"
|
||||
#include "utils/synchronized.hpp"
|
||||
#include "utils/uuid.hpp"
|
||||
|
||||
#include "handler.hpp"
|
||||
|
||||
namespace memgraph::dbms {
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
|
||||
using DeleteResult = utils::BasicResult<DeleteError>;
|
||||
|
||||
/**
|
||||
* @brief Multi-database session contexts handler.
|
||||
*/
|
||||
class SessionContextHandler {
|
||||
public:
|
||||
using StorageT = storage::Storage;
|
||||
using StorageConfigT = storage::Config;
|
||||
using LockT = utils::RWLock;
|
||||
using NewResultT = utils::BasicResult<NewError, SessionContext>;
|
||||
|
||||
struct Config {
|
||||
StorageConfigT storage_config; //!< Storage configuration
|
||||
query::InterpreterConfig interp_config; //!< Interpreter context configuration
|
||||
std::function<void(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *,
|
||||
std::unique_ptr<query::AuthQueryHandler> &, std::unique_ptr<query::AuthChecker> &)>
|
||||
glue_auth;
|
||||
};
|
||||
|
||||
struct Statistics {
|
||||
uint64_t num_vertex; //!< Sum of vertexes in every database
|
||||
uint64_t num_edges; //!< Sum of edges in every database
|
||||
uint64_t num_databases; //! number of isolated databases
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler.
|
||||
*
|
||||
* @param audit_log pointer to the audit logger (ENTERPRISE only)
|
||||
* @param configs storage and interpreter configurations
|
||||
* @param recovery_on_startup restore databases (and its content) and authentication data
|
||||
*/
|
||||
SessionContextHandler(memgraph::audit::Log &audit_log, Config configs, bool recovery_on_startup, bool delete_on_drop)
|
||||
: lock_{utils::RWLock::Priority::READ},
|
||||
default_configs_(configs),
|
||||
run_id_{utils::GenerateUUID()},
|
||||
audit_log_(&audit_log),
|
||||
delete_on_drop_(delete_on_drop) {
|
||||
const auto &root = configs.storage_config.durability.storage_directory;
|
||||
utils::EnsureDirOrDie(root);
|
||||
// Verify that the user that started the process is the same user that is
|
||||
// the owner of the storage directory.
|
||||
storage::durability::VerifyStorageDirectoryOwnerAndProcessUserOrDie(root);
|
||||
|
||||
// Create the lock file and open a handle to it. This will crash the
|
||||
// database if it can't open the file for writing or if any other process is
|
||||
// holding the file opened.
|
||||
lock_file_path_ = root / ".lock";
|
||||
lock_file_handle_.Open(lock_file_path_, utils::OutputFile::Mode::OVERWRITE_EXISTING);
|
||||
MG_ASSERT(lock_file_handle_.AcquireLock(),
|
||||
"Couldn't acquire lock on the storage directory {}"
|
||||
"!\nAnother Memgraph process is currently running with the same "
|
||||
"storage directory, please stop it first before starting this "
|
||||
"process!",
|
||||
root);
|
||||
|
||||
// TODO: Figure out if this is needed/wanted
|
||||
// Clear auth database since we are not recovering
|
||||
// if (!recovery_on_startup) {
|
||||
// const auto &auth_dir = root / "auth";
|
||||
// // Backup if auth present
|
||||
// if (utils::DirExists(auth_dir)) {
|
||||
// auto backup_dir = root / storage::durability::kBackupDirectory;
|
||||
// std::error_code error_code;
|
||||
// utils::EnsureDirOrDie(backup_dir);
|
||||
// std::error_code ec;
|
||||
// const auto now = std::chrono::system_clock::now();
|
||||
// std::ostringstream os;
|
||||
// os << now.time_since_epoch().count();
|
||||
// std::filesystem::rename(auth_dir, backup_dir / ("auth-" + os.str()), ec);
|
||||
// MG_ASSERT(!ec, "Couldn't backup auth directory because of: {}", ec.message());
|
||||
// spdlog::warn(
|
||||
// "Since Memgraph was not supposed to recover on startup the authentication files will be "
|
||||
// "overwritten. To prevent important data loss, Memgraph has stored those files into .backup directory "
|
||||
// "inside the storage directory.");
|
||||
// }
|
||||
|
||||
// // Clear
|
||||
// if (std::filesystem::exists(auth_dir)) {
|
||||
// std::filesystem::remove_all(auth_dir);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Lazy initialization of auth_
|
||||
auth_ = std::make_unique<utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock>>(root / "auth");
|
||||
configs.glue_auth(auth_.get(), auth_handler_, auth_checker_);
|
||||
|
||||
// TODO: Decouple storage config from dbms config
|
||||
// TODO: Save individual db configs inside the kvstore and restore from there
|
||||
storage::UpdatePaths(default_configs_->storage_config,
|
||||
default_configs_->storage_config.durability.storage_directory / "databases");
|
||||
const auto &db_dir = default_configs_->storage_config.durability.storage_directory;
|
||||
const auto durability_dir = db_dir / ".durability";
|
||||
utils::EnsureDirOrDie(db_dir);
|
||||
utils::EnsureDirOrDie(durability_dir);
|
||||
durability_ = std::make_unique<kvstore::KVStore>(durability_dir);
|
||||
|
||||
// Generate the default database
|
||||
MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB.");
|
||||
|
||||
// Recover previous databases
|
||||
if (recovery_on_startup) {
|
||||
for (const auto &[name, _] : *durability_) {
|
||||
if (name == kDefaultDB) continue; // Already set
|
||||
spdlog::info("Restoring database {}.", name);
|
||||
MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name);
|
||||
spdlog::info("Database {} restored.", name);
|
||||
}
|
||||
} else { // Clear databases from the durability list and auth
|
||||
auto locked_auth = auth_->Lock();
|
||||
for (const auto &[name, _] : *durability_) {
|
||||
if (name == kDefaultDB) continue;
|
||||
locked_auth->DeleteDatabase(name);
|
||||
durability_->Delete(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Shutdown() {
|
||||
for (auto &ic : interp_handler_) memgraph::query::Shutdown(ic.second.get().get());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a new SessionContext associated with the "name" database
|
||||
*
|
||||
* @param name name of the database
|
||||
* @return NewResultT context on success, error on failure
|
||||
*/
|
||||
NewResultT New(const std::string &name) {
|
||||
std::lock_guard<LockT> wr(lock_);
|
||||
return New_(name, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the context associated with the "name" database
|
||||
*
|
||||
* @param name
|
||||
* @return SessionContext
|
||||
* @throw UnknownDatabaseException if getting unknown database
|
||||
*/
|
||||
SessionContext Get(const std::string &name) {
|
||||
std::shared_lock<LockT> rd(lock_);
|
||||
return Get_(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the undelying database for a particular session.
|
||||
*
|
||||
* @param uuid unique session identifier
|
||||
* @param db_name unique database name
|
||||
* @return SetForResult enum
|
||||
* @throws UnknownDatabaseException, UnknownSessionException or anything OnChange throws
|
||||
*/
|
||||
SetForResult SetFor(const std::string &uuid, const std::string &db_name) {
|
||||
std::shared_lock<LockT> rd(lock_);
|
||||
(void)Get_(
|
||||
db_name); // throws if db doesn't exist (TODO: Better to pass it via OnChange - but injecting dependency)
|
||||
try {
|
||||
auto &s = sessions_.at(uuid);
|
||||
return s.OnChange(db_name);
|
||||
} catch (std::out_of_range &) {
|
||||
throw UnknownSessionException("Unknown session \"{}\"", uuid);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the undelying database from a session itself. SessionContext handler.
|
||||
*
|
||||
* @param db_name unique database name
|
||||
* @param handler function that gets called in place with the appropriate SessionContext
|
||||
* @return SetForResult enum
|
||||
*/
|
||||
template <typename THandler>
|
||||
requires std::invocable<THandler, SessionContext> SetForResult SetInPlace(const std::string &db_name,
|
||||
THandler handler) {
|
||||
std::shared_lock<LockT> rd(lock_);
|
||||
return handler(Get_(db_name));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Call void handler under a shared lock.
|
||||
*
|
||||
* @param handler function that gets called in place
|
||||
*/
|
||||
template <typename THandler>
|
||||
requires std::invocable<THandler>
|
||||
void CallInPlace(THandler handler) {
|
||||
std::shared_lock<LockT> rd(lock_);
|
||||
handler();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Register an active session (used to handle callbacks).
|
||||
*
|
||||
* @param session
|
||||
* @return true on success
|
||||
*/
|
||||
bool Register(SessionInterface &session) {
|
||||
std::lock_guard<LockT> wr(lock_);
|
||||
auto [_, success] = sessions_.emplace(session.UUID(), session);
|
||||
return success;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Delete a session.
|
||||
*
|
||||
* @param session
|
||||
*/
|
||||
bool Delete(const SessionInterface &session) {
|
||||
std::lock_guard<LockT> wr(lock_);
|
||||
return sessions_.erase(session.UUID()) > 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Delete database.
|
||||
*
|
||||
* @param db_name database name
|
||||
* @return DeleteResult error on failure
|
||||
*/
|
||||
DeleteResult Delete(const std::string &db_name) {
|
||||
std::lock_guard<LockT> wr(lock_);
|
||||
if (db_name == kDefaultDB) {
|
||||
// MSG cannot delete the default db
|
||||
return DeleteError::DEFAULT_DB;
|
||||
}
|
||||
// Check if db exists
|
||||
try {
|
||||
auto sc = Get_(db_name);
|
||||
// Check if a session is using the db
|
||||
if (!sc.interpreter_context->interpreters->empty()) {
|
||||
return DeleteError::USING;
|
||||
}
|
||||
} catch (UnknownDatabaseException &) {
|
||||
return DeleteError::NON_EXISTENT;
|
||||
}
|
||||
|
||||
// High level handlers
|
||||
for (auto &[_, s] : sessions_) {
|
||||
if (!s.OnDelete(db_name)) {
|
||||
spdlog::error("Partial failure while deleting database \"{}\".", db_name);
|
||||
defunct_dbs_.emplace(db_name);
|
||||
return DeleteError::FAIL;
|
||||
}
|
||||
}
|
||||
|
||||
// Low level handlers
|
||||
const auto storage_path = StorageDir_(db_name);
|
||||
MG_ASSERT(storage_path, "Missing storage for {}", db_name);
|
||||
if (!interp_handler_.Delete(db_name)) {
|
||||
spdlog::error("Partial failure while deleting database \"{}\".", db_name);
|
||||
defunct_dbs_.emplace(db_name);
|
||||
return DeleteError::FAIL;
|
||||
}
|
||||
|
||||
// Remove from auth
|
||||
auth_->Lock()->DeleteDatabase(db_name);
|
||||
// Remove from durability list
|
||||
if (durability_) durability_->Delete(db_name);
|
||||
|
||||
// Delete disk storage
|
||||
if (delete_on_drop_) {
|
||||
std::error_code ec;
|
||||
(void)std::filesystem::remove_all(*storage_path, ec);
|
||||
if (ec) {
|
||||
spdlog::error("Failed to clean disk while deleting database \"{}\".", db_name);
|
||||
defunct_dbs_.emplace(db_name);
|
||||
return DeleteError::DISK_FAIL;
|
||||
}
|
||||
}
|
||||
|
||||
// Delete from defunct_dbs_ (in case a second delete call was successful)
|
||||
defunct_dbs_.erase(db_name);
|
||||
|
||||
return {}; // Success
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the default configurations.
|
||||
*
|
||||
* @param configs storage, interpreter and authorization configurations
|
||||
*/
|
||||
void SetDefaultConfigs(Config configs) {
|
||||
std::lock_guard<LockT> wr(lock_);
|
||||
default_configs_ = configs;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the default configurations.
|
||||
*
|
||||
* @return std::optional<Config>
|
||||
*/
|
||||
std::optional<Config> GetDefaultConfigs() const {
|
||||
std::shared_lock<LockT> rd(lock_);
|
||||
return default_configs_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return all active databases.
|
||||
*
|
||||
* @return std::vector<std::string>
|
||||
*/
|
||||
std::vector<std::string> All() const {
|
||||
std::shared_lock<LockT> rd(lock_);
|
||||
return interp_handler_.All();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the number of vertex across all databases.
|
||||
*
|
||||
* @return uint64_t
|
||||
*/
|
||||
Statistics Info() const {
|
||||
// TODO: Handle overflow
|
||||
uint64_t nv = 0;
|
||||
uint64_t ne = 0;
|
||||
std::shared_lock<LockT> rd(lock_);
|
||||
const uint64_t ndb = std::distance(interp_handler_.cbegin(), interp_handler_.cend());
|
||||
for (const auto &ic : interp_handler_) {
|
||||
const auto &info = ic.second.get()->db->GetInfo();
|
||||
nv += info.vertex_count;
|
||||
ne += info.edge_count;
|
||||
}
|
||||
return {nv, ne, ndb};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the currently active database for a particular session.
|
||||
*
|
||||
* @param uuid session's unique identifier
|
||||
* @return std::string name of the database
|
||||
* @throw
|
||||
*/
|
||||
std::string Current(const std::string &uuid) const {
|
||||
std::shared_lock<LockT> rd(lock_);
|
||||
return sessions_.at(uuid).GetDatabaseName();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Restore triggers for all currently defined databases.
|
||||
* @note: Triggers can execute query procedures, so we need to reload the modules first and then the triggers
|
||||
*/
|
||||
void RestoreTriggers() {
|
||||
std::lock_guard<LockT> wr(lock_);
|
||||
for (auto &ic_itr : interp_handler_) {
|
||||
auto ic = ic_itr.second.get();
|
||||
spdlog::debug("Restoring trigger for database \"{}\"", ic->db->id());
|
||||
auto storage_accessor = ic->db->Access();
|
||||
auto dba = memgraph::query::DbAccessor{storage_accessor.get()};
|
||||
ic->trigger_store.RestoreTriggers(&ic->ast_cache, &dba, ic->config.query, ic->auth_checker);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Restore streams of all currently defined databases.
|
||||
* @note: Stream transformations are using modules, they have to be restored after the query modules are loaded.
|
||||
*/
|
||||
void RestoreStreams() {
|
||||
std::lock_guard<LockT> wr(lock_);
|
||||
for (auto &ic_itr : interp_handler_) {
|
||||
auto ic = ic_itr.second.get();
|
||||
spdlog::debug("Restoring streams for database \"{}\"", ic->db->id());
|
||||
ic->streams.RestoreStreams();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<std::filesystem::path> StorageDir_(const std::string &name) const {
|
||||
const auto conf = interp_handler_.GetConfig(name);
|
||||
if (conf) {
|
||||
return conf->storage_config.durability.storage_directory;
|
||||
}
|
||||
spdlog::debug("Failed to find storage dir for database \"{}\"", name);
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a new SessionContext associated with the "name" database
|
||||
*
|
||||
* @param name name of the database
|
||||
* @return NewResultT context on success, error on failure
|
||||
*/
|
||||
NewResultT New_(const std::string &name) { return New_(name, name); }
|
||||
|
||||
/**
|
||||
* @brief Create a new SessionContext associated with the "name" database
|
||||
*
|
||||
* @param name name of the database
|
||||
* @param storage_subdir undelying RocksDB directory
|
||||
* @return NewResultT context on success, error on failure
|
||||
*/
|
||||
NewResultT New_(const std::string &name, std::filesystem::path storage_subdir) {
|
||||
if (default_configs_) {
|
||||
auto storage = default_configs_->storage_config;
|
||||
storage::UpdatePaths(storage, storage.durability.storage_directory / storage_subdir);
|
||||
return New_(name, storage, default_configs_->interp_config);
|
||||
}
|
||||
spdlog::info("Trying to generate session context without any configurations.");
|
||||
return NewError::NO_CONFIGS;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a new SessionContext associated with the "name" database
|
||||
*
|
||||
* @param name name of the database
|
||||
* @param storage_config storage configuration
|
||||
* @param inter_config interpreter configuration
|
||||
* @return NewResultT context on success, error on failure
|
||||
*/
|
||||
NewResultT New_(const std::string &name, StorageConfigT &storage_config, query::InterpreterConfig &inter_config/*,
|
||||
const std::string &ah_flags*/) {
|
||||
MG_ASSERT(auth_handler_, "No high level AuthQueryHandler has been supplied.");
|
||||
MG_ASSERT(auth_checker_, "No high level AuthChecker has been supplied.");
|
||||
|
||||
if (defunct_dbs_.contains(name)) {
|
||||
spdlog::warn("Failed to generate database due to the unknown state of the previously defunct database \"{}\".",
|
||||
name);
|
||||
return NewError::DEFUNCT;
|
||||
}
|
||||
|
||||
auto new_interp = interp_handler_.New(name, *this, storage_config, inter_config, *auth_handler_, *auth_checker_);
|
||||
|
||||
if (new_interp.HasValue()) {
|
||||
// Success
|
||||
if (durability_) durability_->Put(name, "ok");
|
||||
return SessionContext{new_interp.GetValue(), run_id_, auth_.get(), audit_log_};
|
||||
}
|
||||
return new_interp.GetError();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a new SessionContext associated with the default database
|
||||
*
|
||||
* @return NewResultT context on success, error on failure
|
||||
*/
|
||||
NewResultT NewDefault_() {
|
||||
// Create the default DB in the root (this is how it was done pre multi-tenancy)
|
||||
auto res = New_(kDefaultDB, "..");
|
||||
if (res.HasValue()) {
|
||||
// For back-compatibility...
|
||||
// Recreate the dbms layout for the default db and symlink to the root
|
||||
const auto dir = StorageDir_(kDefaultDB);
|
||||
MG_ASSERT(dir, "Failed to find storage path.");
|
||||
const auto main_dir = *dir / "databases" / kDefaultDB;
|
||||
|
||||
if (!std::filesystem::exists(main_dir)) {
|
||||
std::filesystem::create_directory(main_dir);
|
||||
}
|
||||
|
||||
// Force link on-disk directories
|
||||
const auto conf = interp_handler_.GetConfig(kDefaultDB);
|
||||
MG_ASSERT(conf, "No configuration for the default database.");
|
||||
const auto &tmp_conf = conf->storage_config.disk;
|
||||
std::vector<std::filesystem::path> to_link{
|
||||
tmp_conf.main_storage_directory, tmp_conf.label_index_directory,
|
||||
tmp_conf.label_property_index_directory, tmp_conf.unique_constraints_directory,
|
||||
tmp_conf.name_id_mapper_directory, tmp_conf.id_name_mapper_directory,
|
||||
tmp_conf.durability_directory, tmp_conf.wal_directory,
|
||||
};
|
||||
|
||||
// Add in-memory paths
|
||||
// Some directories are redundant (skip those)
|
||||
const std::vector<std::string> skip{".lock", "audit_log", "auth", "databases", "internal_modules", "settings"};
|
||||
for (auto const &item : std::filesystem::directory_iterator{*dir}) {
|
||||
const auto dir_name = std::filesystem::relative(item.path(), item.path().parent_path());
|
||||
if (std::find(skip.begin(), skip.end(), dir_name) != skip.end()) continue;
|
||||
to_link.push_back(item.path());
|
||||
}
|
||||
|
||||
// Symlink to root dir
|
||||
for (auto const &item : to_link) {
|
||||
const auto dir_name = std::filesystem::relative(item, item.parent_path());
|
||||
const auto link = main_dir / dir_name;
|
||||
const auto to = std::filesystem::relative(item, main_dir);
|
||||
if (!std::filesystem::is_symlink(link) && !std::filesystem::exists(link)) {
|
||||
std::filesystem::create_directory_symlink(to, link);
|
||||
} else { // Check existing link
|
||||
std::error_code ec;
|
||||
const auto test_link = std::filesystem::read_symlink(link, ec);
|
||||
if (ec || test_link != to) {
|
||||
MG_ASSERT(false,
|
||||
"Memgraph storage directory incompatible with new version.\n"
|
||||
"Please use a clean directory or remove \"{}\" and try again.",
|
||||
link.string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the context associated with the "name" database
|
||||
*
|
||||
* @param name
|
||||
* @return SessionContext
|
||||
* @throw UnknownDatabaseException if trying to get unknown database
|
||||
*/
|
||||
SessionContext Get_(const std::string &name) {
|
||||
auto interp = interp_handler_.Get(name);
|
||||
if (interp) {
|
||||
return SessionContext{*interp, run_id_, auth_.get(), audit_log_};
|
||||
}
|
||||
throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name);
|
||||
}
|
||||
|
||||
// Should storage objects ever be deleted?
|
||||
mutable LockT lock_; //!< protective lock
|
||||
std::filesystem::path lock_file_path_; //!< Lock file protecting the main storage
|
||||
utils::OutputFile lock_file_handle_; //!< Handler the lock (crash if already open)
|
||||
InterpContextHandler<SessionContextHandler> interp_handler_; //!< multi-tenancy interpreter handler
|
||||
// AuthContextHandler auth_handler_; //!< multi-tenancy authorization handler (currently we use a single global
|
||||
// auth)
|
||||
std::unique_ptr<utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock>> auth_;
|
||||
std::unique_ptr<query::AuthQueryHandler> auth_handler_;
|
||||
std::unique_ptr<query::AuthChecker> auth_checker_;
|
||||
std::optional<Config> default_configs_; //!< default storage and interpreter configurations
|
||||
const std::string run_id_; //!< run's unique identifier (auto generated)
|
||||
memgraph::audit::Log *audit_log_; //!< pointer to the audit logger
|
||||
std::unordered_map<std::string, SessionInterface &> sessions_; //!< map of active/registered sessions
|
||||
std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation)
|
||||
|
||||
std::set<std::string> defunct_dbs_; //!< Databases that are in an unknown state due to various failures
|
||||
bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory
|
||||
public:
|
||||
static SessionContextHandler &ExtractSCH(query::InterpreterContext *interpreter_context) {
|
||||
return static_cast<typename decltype(interp_handler_)::InterpContextT *>(interpreter_context)->sc_handler_;
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
/**
|
||||
* @brief Initialize the handler.
|
||||
*
|
||||
* @param auth pointer to the authenticator
|
||||
* @param configs storage and interpreter configurations
|
||||
*/
|
||||
static inline SessionContext Init(storage::Config &storage_config, query::InterpreterConfig &interp_config,
|
||||
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth,
|
||||
query::AuthQueryHandler *auth_handler, query::AuthChecker *auth_checker) {
|
||||
MG_ASSERT(auth, "Passed a nullptr auth");
|
||||
MG_ASSERT(auth_handler, "Passed a nullptr auth_handler");
|
||||
MG_ASSERT(auth_checker, "Passed a nullptr auth_checker");
|
||||
|
||||
storage_config.name = kDefaultDB;
|
||||
auto interp_context = std::make_shared<query::InterpreterContext>(
|
||||
storage_config, interp_config, storage_config.durability.storage_directory, auth_handler, auth_checker);
|
||||
MG_ASSERT(interp_context, "Failed to construct main interpret context.");
|
||||
|
||||
return SessionContext{interp_context, utils::GenerateUUID(), auth};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace memgraph::dbms
|
@ -62,6 +62,10 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) {
|
||||
return auth::Permission::STORAGE_MODE;
|
||||
case query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT:
|
||||
return auth::Permission::TRANSACTION_MANAGEMENT;
|
||||
case query::AuthQuery::Privilege::MULTI_DATABASE_EDIT:
|
||||
return auth::Permission::MULTI_DATABASE_EDIT;
|
||||
case query::AuthQuery::Privilege::MULTI_DATABASE_USE:
|
||||
return auth::Permission::MULTI_DATABASE_USE;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,8 @@ AuthChecker::AuthChecker(
|
||||
: auth_(auth) {}
|
||||
|
||||
bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
|
||||
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) const {
|
||||
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
|
||||
const std::string &db_name) const {
|
||||
std::optional<memgraph::auth::User> maybe_user;
|
||||
{
|
||||
auto locked_auth = auth_->ReadLock();
|
||||
@ -83,7 +84,7 @@ bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
|
||||
}
|
||||
}
|
||||
|
||||
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges);
|
||||
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges, db_name);
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
@ -108,7 +109,13 @@ std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGra
|
||||
#endif
|
||||
|
||||
bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user,
|
||||
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) {
|
||||
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
|
||||
const std::string &db_name) { // NOLINT
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (!db_name.empty() && !user.db_access().Contains(db_name)) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
const auto user_permissions = user.GetPermissions();
|
||||
return std::all_of(privileges.begin(), privileges.end(), [&user_permissions](const auto privilege) {
|
||||
return user_permissions.Has(memgraph::glue::PrivilegeToPermission(privilege)) ==
|
||||
|
@ -25,7 +25,8 @@ class AuthChecker : public query::AuthChecker {
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth);
|
||||
|
||||
bool IsUserAuthorized(const std::optional<std::string> &username,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) const override;
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges,
|
||||
const std::string &db_name) const override;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker(
|
||||
@ -33,7 +34,8 @@ class AuthChecker : public query::AuthChecker {
|
||||
|
||||
#endif
|
||||
[[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user,
|
||||
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges);
|
||||
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
|
||||
const std::string &db_name = "");
|
||||
|
||||
private:
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "auth/models.hpp"
|
||||
#include "dbms/constants.hpp"
|
||||
#include "glue/auth.hpp"
|
||||
#include "license/license.hpp"
|
||||
#include "query/constants.hpp"
|
||||
@ -122,6 +123,29 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowRolePrivileges(
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
std::vector<std::vector<memgraph::query::TypedValue>> ShowDatabasePrivileges(
|
||||
const std::optional<memgraph::auth::User> &user) {
|
||||
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !user) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const auto &db = user->db_access();
|
||||
const auto &allows = db.GetAllowAll();
|
||||
const auto &grants = db.GetGrants();
|
||||
const auto &denies = db.GetDenies();
|
||||
|
||||
std::vector<memgraph::query::TypedValue> res; // First element is a list of granted databases, second of revoked ones
|
||||
if (allows) {
|
||||
res.emplace_back("*");
|
||||
} else {
|
||||
std::vector<memgraph::query::TypedValue> grants_vec(grants.cbegin(), grants.cend());
|
||||
res.emplace_back(std::move(grants_vec));
|
||||
}
|
||||
std::vector<memgraph::query::TypedValue> denies_vec(denies.cbegin(), denies.cend());
|
||||
res.emplace_back(std::move(denies_vec));
|
||||
return {res};
|
||||
}
|
||||
|
||||
std::vector<FineGrainedPermissionForPrivilegeResult> GetFineGrainedPermissionForPrivilegeForUserOrRole(
|
||||
const memgraph::auth::FineGrainedAccessPermissions &permissions, const std::string &permission_type,
|
||||
const std::string &user_or_role) {
|
||||
@ -268,6 +292,10 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option
|
||||
}
|
||||
#endif
|
||||
);
|
||||
#ifdef MG_ENTERPRISE
|
||||
GrantDatabaseToUser(auth::kAllDatabases, username);
|
||||
SetMainDatabase(username, dbms::kDefaultDB);
|
||||
#endif
|
||||
}
|
||||
|
||||
return user_added;
|
||||
@ -319,6 +347,67 @@ bool AuthQueryHandler::CreateRole(const std::string &rolename) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db, const std::string &username) {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw memgraph::query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) return false;
|
||||
return locked_auth->RevokeDatabaseFromUser(db, username);
|
||||
} catch (const memgraph::auth::AuthException &e) {
|
||||
throw memgraph::query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db, const std::string &username) {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw memgraph::query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) return false;
|
||||
return locked_auth->GrantDatabaseToUser(db, username);
|
||||
} catch (const memgraph::auth::AuthException &e) {
|
||||
throw memgraph::query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<memgraph::query::TypedValue>> AuthQueryHandler::GetDatabasePrivileges(
|
||||
const std::string &username) {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw memgraph::query::QueryRuntimeException("Invalid user or role name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->ReadLock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) {
|
||||
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username);
|
||||
}
|
||||
return ShowDatabasePrivileges(user);
|
||||
} catch (const memgraph::auth::AuthException &e) {
|
||||
throw memgraph::query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool AuthQueryHandler::SetMainDatabase(const std::string &db, const std::string &username) {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw memgraph::query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) return false;
|
||||
return locked_auth->SetMainDatabase(db, username);
|
||||
} catch (const memgraph::auth::AuthException &e) {
|
||||
throw memgraph::query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
bool AuthQueryHandler::DropRole(const std::string &rolename) {
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw memgraph::query::QueryRuntimeException("Invalid role name.");
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -38,6 +38,16 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
|
||||
|
||||
void SetPassword(const std::string &username, const std::optional<std::string> &password) override;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) override;
|
||||
|
||||
bool GrantDatabaseToUser(const std::string &db, const std::string &username) override;
|
||||
|
||||
std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) override;
|
||||
|
||||
bool SetMainDatabase(const std::string &db, const std::string &username) override;
|
||||
#endif
|
||||
|
||||
bool CreateRole(const std::string &rolename) override;
|
||||
|
||||
bool DropRole(const std::string &rolename) override;
|
||||
|
@ -47,10 +47,10 @@ struct MetricsResponse {
|
||||
std::vector<std::tuple<std::string, std::string, uint64_t>> event_histograms{};
|
||||
};
|
||||
|
||||
template <typename TSessionData>
|
||||
template <typename TSessionContext>
|
||||
class MetricsService {
|
||||
public:
|
||||
explicit MetricsService(TSessionData *data) : db_(data->interpreter_context->db.get()) {}
|
||||
explicit MetricsService(TSessionContext *session_context) : db_(session_context->interpreter_context->db.get()) {}
|
||||
|
||||
nlohmann::json GetMetricsJSON() {
|
||||
auto response = GetMetrics();
|
||||
@ -141,10 +141,10 @@ class MetricsService {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TSessionData>
|
||||
template <typename TSessionContext>
|
||||
class MetricsRequestHandler final {
|
||||
public:
|
||||
explicit MetricsRequestHandler(TSessionData *data) : service_(data) {
|
||||
explicit MetricsRequestHandler(TSessionContext *session_context) : service_(session_context) {
|
||||
spdlog::info("Basic request handler started!");
|
||||
}
|
||||
|
||||
@ -206,6 +206,6 @@ class MetricsRequestHandler final {
|
||||
}
|
||||
|
||||
private:
|
||||
MetricsService<TSessionData> service_;
|
||||
MetricsService<TSessionContext> service_;
|
||||
};
|
||||
} // namespace memgraph::http
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -36,7 +36,13 @@ KVStore::KVStore(std::filesystem::path storage) : pimpl_(std::make_unique<impl>(
|
||||
pimpl_->db.reset(db);
|
||||
}
|
||||
|
||||
KVStore::~KVStore() {}
|
||||
KVStore::~KVStore() {
|
||||
spdlog::debug("Destroying KVStore at {}", pimpl_->storage.string());
|
||||
const auto sync = pimpl_->db->SyncWAL();
|
||||
if (!sync.ok()) spdlog::error("KVStore sync failed!");
|
||||
const auto close = pimpl_->db->Close();
|
||||
if (!close.ok()) spdlog::error("KVStore close failed!");
|
||||
}
|
||||
|
||||
KVStore::KVStore(KVStore &&other) { pimpl_ = std::move(other.pimpl_); }
|
||||
|
||||
|
452
src/memgraph.cpp
452
src/memgraph.cpp
@ -40,6 +40,9 @@
|
||||
#include "communication/http/server.hpp"
|
||||
#include "communication/websocket/auth.hpp"
|
||||
#include "communication/websocket/server.hpp"
|
||||
#include "dbms/constants.hpp"
|
||||
#include "dbms/global.hpp"
|
||||
#include "dbms/session_context.hpp"
|
||||
#include "glue/auth_checker.hpp"
|
||||
#include "glue/auth_handler.hpp"
|
||||
#include "helpers.hpp"
|
||||
@ -98,6 +101,7 @@
|
||||
#include "communication/init.hpp"
|
||||
#include "communication/v2/server.hpp"
|
||||
#include "communication/v2/session.hpp"
|
||||
#include "dbms/session_context_handler.hpp"
|
||||
#include "glue/communication.hpp"
|
||||
|
||||
#include "auth/auth.hpp"
|
||||
@ -157,6 +161,10 @@ DEFINE_string(init_data_file, "", "Path to cypherl file that is used for creatin
|
||||
// `mg_import_csv`. If you change it, make sure to change it there as well.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_string(data_directory, "mg_data", "Path to directory in which to save all permanent data.");
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_bool(data_recovery_on_startup, false, "Controls whether the database recovers persisted data on startup.");
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_uint64(memory_warning_threshold, 1024,
|
||||
"Memory warning threshold, in MB. If Memgraph detects there is "
|
||||
@ -174,8 +182,11 @@ DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30, "Storage garbage collector int
|
||||
// `mg_import_csv`. If you change it, make sure to change it there as well.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_bool(storage_properties_on_edges, false, "Controls whether edges have properties.");
|
||||
|
||||
// storage_recover_on_startup deprecated; use data_recovery_on_startup instead
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_bool(storage_recover_on_startup, false, "Controls whether the storage recovers persisted data on startup.");
|
||||
DEFINE_HIDDEN_bool(storage_recover_on_startup, false,
|
||||
"Controls whether the storage recovers persisted data on startup.");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_VALIDATED_uint64(storage_snapshot_interval_sec, 0,
|
||||
"Storage snapshot creation interval (in seconds). Set "
|
||||
@ -215,6 +226,12 @@ DEFINE_uint64(storage_recovery_thread_count,
|
||||
memgraph::storage::Config::Durability().recovery_thread_count),
|
||||
"The number of threads used to recover persisted data from disk.");
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_bool(storage_delete_on_drop, true,
|
||||
"If set to true the query 'DROP DATABASE x' will delete the underlying storage as well.");
|
||||
#endif
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_bool(telemetry_enabled, false,
|
||||
"Set to true to enable telemetry. We collect information about the "
|
||||
@ -447,35 +464,6 @@ void AddLoggerSink(spdlog::sink_ptr new_sink) {
|
||||
DEFINE_HIDDEN_string(license_key, "", "License key for Memgraph Enterprise.");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_HIDDEN_string(organization_name, "", "Organization name.");
|
||||
|
||||
/// Encapsulates Dbms and Interpreter that are passed through the network server
|
||||
/// and worker to the session.
|
||||
struct SessionData {
|
||||
// Explicit constructor here to ensure that pointers to all objects are
|
||||
// supplied.
|
||||
#if MG_ENTERPRISE
|
||||
|
||||
SessionData(memgraph::query::InterpreterContext *interpreter_context,
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
|
||||
memgraph::audit::Log *audit_log)
|
||||
: interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {}
|
||||
memgraph::query::InterpreterContext *interpreter_context;
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
|
||||
memgraph::audit::Log *audit_log;
|
||||
|
||||
#else
|
||||
|
||||
SessionData(memgraph::query::InterpreterContext *interpreter_context,
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth)
|
||||
: interpreter_context(interpreter_context), auth(auth) {}
|
||||
memgraph::query::InterpreterContext *interpreter_context;
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
|
||||
|
||||
#endif
|
||||
// NOTE: run_id should be const but that complicates code a lot.
|
||||
std::optional<std::string> run_id;
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_string(auth_user_or_role_name_regex, memgraph::glue::kDefaultUserRoleRegex.data(),
|
||||
"Set to the regular expression that each user or role name must fulfill.");
|
||||
@ -498,7 +486,7 @@ void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, std::string c
|
||||
interpreter.Pull(&stream, {}, results.qid);
|
||||
|
||||
if (audit_log) {
|
||||
audit_log->Record("", "", line, {});
|
||||
audit_log->Record("", "", line, {}, memgraph::dbms::kDefaultDB);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -529,41 +517,146 @@ auto ToQueryExtras(memgraph::communication::bolt::Value const &extra) -> memgrap
|
||||
return memgraph::query::QueryExtras{std::move(metadata_pv), tx_timeout};
|
||||
}
|
||||
|
||||
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
|
||||
memgraph::communication::v2::OutputStream> {
|
||||
class SessionHL final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
|
||||
memgraph::communication::v2::OutputStream> {
|
||||
public:
|
||||
BoltSession(SessionData *data, const memgraph::communication::v2::ServerEndpoint &endpoint,
|
||||
memgraph::communication::v2::InputStream *input_stream,
|
||||
memgraph::communication::v2::OutputStream *output_stream)
|
||||
struct ContextWrapper {
|
||||
explicit ContextWrapper(memgraph::dbms::SessionContext sc)
|
||||
: session_context(sc),
|
||||
interpreter(std::make_unique<memgraph::query::Interpreter>(session_context.interpreter_context.get())),
|
||||
defunct_(false) {
|
||||
session_context.interpreter_context->interpreters.WithLock(
|
||||
[this](auto &interpreters) { interpreters.insert(interpreter.get()); });
|
||||
}
|
||||
~ContextWrapper() { Defunct(); }
|
||||
|
||||
void Defunct() {
|
||||
if (!defunct_) {
|
||||
session_context.interpreter_context->interpreters.WithLock(
|
||||
[this](auto &interpreters) { interpreters.erase(interpreter.get()); });
|
||||
defunct_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
ContextWrapper(const ContextWrapper &) = delete;
|
||||
ContextWrapper &operator=(const ContextWrapper &) = delete;
|
||||
|
||||
ContextWrapper(ContextWrapper &&in) noexcept
|
||||
: session_context(std::move(in.session_context)),
|
||||
interpreter(std::move(in.interpreter)),
|
||||
defunct_(in.defunct_) {
|
||||
in.defunct_ = true;
|
||||
}
|
||||
|
||||
ContextWrapper &operator=(ContextWrapper &&in) noexcept {
|
||||
if (this != &in) {
|
||||
Defunct();
|
||||
session_context = std::move(in.session_context);
|
||||
interpreter = std::move(in.interpreter);
|
||||
defunct_ = in.defunct_;
|
||||
in.defunct_ = true;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
memgraph::query::InterpreterContext *interpreter_context() { return session_context.interpreter_context.get(); }
|
||||
memgraph::query::Interpreter *interp() { return interpreter.get(); }
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth() const {
|
||||
return session_context.auth;
|
||||
}
|
||||
#ifdef MG_ENTERPRISE
|
||||
memgraph::audit::Log *audit_log() const { return session_context.audit_log; }
|
||||
#endif
|
||||
std::string run_id() const { return session_context.run_id; }
|
||||
bool defunct() const { return defunct_; }
|
||||
|
||||
private:
|
||||
memgraph::dbms::SessionContext session_context;
|
||||
std::unique_ptr<memgraph::query::Interpreter> interpreter;
|
||||
bool defunct_;
|
||||
};
|
||||
|
||||
SessionHL(
|
||||
#ifdef MG_ENTERPRISE
|
||||
memgraph::dbms::SessionContextHandler &sc_handler,
|
||||
#else
|
||||
memgraph::dbms::SessionContext sc,
|
||||
#endif
|
||||
const memgraph::communication::v2::ServerEndpoint &endpoint,
|
||||
memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::OutputStream *output_stream,
|
||||
const std::string &default_db = memgraph::dbms::kDefaultDB) // NOLINT
|
||||
: memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
|
||||
memgraph::communication::v2::OutputStream>(input_stream, output_stream),
|
||||
interpreter_context_(data->interpreter_context),
|
||||
interpreter_(data->interpreter_context),
|
||||
auth_(data->auth),
|
||||
#if MG_ENTERPRISE
|
||||
audit_log_(data->audit_log),
|
||||
#ifdef MG_ENTERPRISE
|
||||
sc_handler_(sc_handler),
|
||||
current_(sc_handler_.Get(default_db)),
|
||||
#else
|
||||
current_(sc),
|
||||
#endif
|
||||
interpreter_context_(current_.interpreter_context()),
|
||||
interpreter_(current_.interp()),
|
||||
auth_(current_.auth()),
|
||||
#ifdef MG_ENTERPRISE
|
||||
audit_log_(current_.audit_log()),
|
||||
#endif
|
||||
endpoint_(endpoint),
|
||||
run_id_(data->run_id) {
|
||||
run_id_(current_.run_id()) {
|
||||
memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions);
|
||||
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); });
|
||||
}
|
||||
|
||||
~BoltSession() override {
|
||||
memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions);
|
||||
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); });
|
||||
~SessionHL() override { memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions); }
|
||||
|
||||
SessionHL(const SessionHL &) = delete;
|
||||
SessionHL &operator=(const SessionHL &) = delete;
|
||||
SessionHL(SessionHL &&) = delete;
|
||||
SessionHL &operator=(SessionHL &&) = delete;
|
||||
|
||||
void Configure(const std::map<std::string, memgraph::communication::bolt::Value> &run_time_info) override {
|
||||
#ifdef MG_ENTERPRISE
|
||||
std::string db;
|
||||
bool update = false;
|
||||
// Check if user explicitly defined the database to use
|
||||
if (run_time_info.contains("db")) {
|
||||
const auto &db_info = run_time_info.at("db");
|
||||
if (!db_info.IsString()) {
|
||||
throw memgraph::communication::bolt::ClientError("Malformed database name.");
|
||||
}
|
||||
db = db_info.ValueString();
|
||||
update = db != current_.interpreter_context()->db->id();
|
||||
in_explicit_db_ = true;
|
||||
// NOTE: Once in a transaction, the drivers stop explicitly sending the db and count on using it until commit
|
||||
} else if (in_explicit_db_ && !interpreter_->in_explicit_transaction_) { // Just on a switch
|
||||
db = GetDefaultDB();
|
||||
update = db != current_.interpreter_context()->db->id();
|
||||
in_explicit_db_ = false;
|
||||
}
|
||||
|
||||
// Check if the underlying database needs to be updated
|
||||
if (update) {
|
||||
sc_handler_.SetInPlace(db, [this](auto new_sc) mutable {
|
||||
const auto &db_name = new_sc.interpreter_context->db->id();
|
||||
MultiDatabaseAuth(db_name);
|
||||
try {
|
||||
Update(ContextWrapper(new_sc));
|
||||
return memgraph::dbms::SetForResult::SUCCESS;
|
||||
} catch (memgraph::dbms::UnknownDatabaseException &e) {
|
||||
throw memgraph::communication::bolt::ClientError("No database named \"{}\" found!", db_name);
|
||||
}
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
|
||||
memgraph::communication::v2::OutputStream>::TEncoder;
|
||||
using TEncoder = memgraph::communication::bolt::Encoder<
|
||||
memgraph::communication::bolt::ChunkedEncoderBuffer<memgraph::communication::v2::OutputStream>>;
|
||||
|
||||
void BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &extra) override {
|
||||
interpreter_.BeginTransaction(ToQueryExtras(extra));
|
||||
interpreter_->BeginTransaction(ToQueryExtras(extra));
|
||||
}
|
||||
|
||||
void CommitTransaction() override { interpreter_.CommitTransaction(); }
|
||||
void CommitTransaction() override { interpreter_->CommitTransaction(); }
|
||||
|
||||
void RollbackTransaction() override { interpreter_.RollbackTransaction(); }
|
||||
void RollbackTransaction() override { interpreter_->RollbackTransaction(); }
|
||||
|
||||
std::pair<std::vector<std::string>, std::optional<int>> Interpret(
|
||||
const std::string &query, const std::map<std::string, memgraph::communication::bolt::Value> ¶ms,
|
||||
@ -580,16 +673,22 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query,
|
||||
memgraph::storage::PropertyValue(params_pv));
|
||||
memgraph::storage::PropertyValue(params_pv), interpreter_context_->db->id());
|
||||
}
|
||||
#endif
|
||||
try {
|
||||
auto result = interpreter_.Prepare(query, params_pv, username, ToQueryExtras(extra));
|
||||
if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, result.privileges)) {
|
||||
interpreter_.Abort();
|
||||
auto result = interpreter_->Prepare(query, params_pv, username, ToQueryExtras(extra), UUID());
|
||||
const std::string db_name = result.db ? *result.db : "";
|
||||
if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, result.privileges, db_name)) {
|
||||
interpreter_->Abort();
|
||||
if (db_name.empty()) {
|
||||
throw memgraph::communication::bolt::ClientError(
|
||||
"You are not authorized to execute this query! Please contact your database administrator.");
|
||||
}
|
||||
throw memgraph::communication::bolt::ClientError(
|
||||
"You are not authorized to execute this query! Please contact "
|
||||
"your database administrator.");
|
||||
"You are not authorized to execute this query on database \"{}\"! Please contact your database "
|
||||
"administrator.",
|
||||
db_name);
|
||||
}
|
||||
return {result.headers, result.qid};
|
||||
|
||||
@ -604,7 +703,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
|
||||
std::map<std::string, memgraph::communication::bolt::Value> Pull(TEncoder *encoder, std::optional<int> n,
|
||||
std::optional<int> qid) override {
|
||||
TypedValueResultStream stream(encoder, interpreter_context_->db.get());
|
||||
TypedValueResultStream stream(encoder, interpreter_context_);
|
||||
return PullResults(stream, n, qid);
|
||||
}
|
||||
|
||||
@ -614,14 +713,26 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
return PullResults(stream, n, qid);
|
||||
}
|
||||
|
||||
void Abort() override { interpreter_.Abort(); }
|
||||
void Abort() override { interpreter_->Abort(); }
|
||||
|
||||
// Called during Init
|
||||
// During Init, the user cannot choose the landing DB (switch is done during query execution)
|
||||
bool Authenticate(const std::string &username, const std::string &password) override {
|
||||
auto locked_auth = auth_->Lock();
|
||||
if (!locked_auth->HasUsers()) {
|
||||
return true;
|
||||
}
|
||||
user_ = locked_auth->Authenticate(username, password);
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (user_.has_value()) {
|
||||
const auto &db = user_->db_access().GetDefault();
|
||||
// Check if the underlying database needs to be updated
|
||||
if (db != current_.interpreter_context()->db->id()) {
|
||||
const auto &res = sc_handler_.SetFor(UUID(), db);
|
||||
return res == memgraph::dbms::SetForResult::SUCCESS || res == memgraph::dbms::SetForResult::ALREADY_SET;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return user_.has_value();
|
||||
}
|
||||
|
||||
@ -630,12 +741,31 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
return FLAGS_bolt_server_name_for_init;
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
memgraph::dbms::SetForResult OnChange(const std::string &db_name) override {
|
||||
MultiDatabaseAuth(db_name);
|
||||
if (db_name != current_.interpreter_context()->db->id()) {
|
||||
UpdateAndDefunct(db_name); // Done during Pull, so we cannot just replace the current db
|
||||
return memgraph::dbms::SetForResult::SUCCESS;
|
||||
}
|
||||
return memgraph::dbms::SetForResult::ALREADY_SET;
|
||||
}
|
||||
|
||||
bool OnDelete(const std::string &db_name) override {
|
||||
MG_ASSERT(current_.interpreter_context()->db->id() != db_name && (!defunct_ || defunct_->defunct()),
|
||||
"Trying to delete a database while still in use.");
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::string GetDatabaseName() const override { return interpreter_context_->db->id(); }
|
||||
|
||||
private:
|
||||
template <typename TStream>
|
||||
std::map<std::string, memgraph::communication::bolt::Value> PullResults(TStream &stream, std::optional<int> n,
|
||||
std::optional<int> qid) {
|
||||
try {
|
||||
const auto &summary = interpreter_.Pull(&stream, n, qid);
|
||||
const auto &summary = interpreter_->Pull(&stream, n, qid);
|
||||
std::map<std::string, memgraph::communication::bolt::Value> decoded_summary;
|
||||
for (const auto &kv : summary) {
|
||||
auto maybe_value =
|
||||
@ -660,6 +790,11 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
decoded_summary.emplace("run_id", *run_id);
|
||||
}
|
||||
|
||||
// Clean up previous session (session gets defunct when switching between databases)
|
||||
if (defunct_) {
|
||||
defunct_.reset();
|
||||
}
|
||||
|
||||
return decoded_summary;
|
||||
} catch (const memgraph::query::QueryException &e) {
|
||||
// Wrap QueryException into ClientError, because we want to allow the
|
||||
@ -668,17 +803,71 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
/**
|
||||
* @brief Update setup to the new database.
|
||||
*
|
||||
* @param db_name name of the target database
|
||||
* @throws UnknownDatabaseException if handler cannot get it
|
||||
*/
|
||||
void UpdateAndDefunct(const std::string &db_name) { UpdateAndDefunct(ContextWrapper(sc_handler_.Get(db_name))); }
|
||||
|
||||
void UpdateAndDefunct(ContextWrapper &&cntxt) {
|
||||
defunct_.emplace(std::move(current_));
|
||||
Update(std::forward<ContextWrapper>(cntxt));
|
||||
defunct_->Defunct();
|
||||
}
|
||||
|
||||
void Update(const std::string &db_name) {
|
||||
ContextWrapper tmp(sc_handler_.Get(db_name));
|
||||
Update(std::move(tmp));
|
||||
}
|
||||
|
||||
void Update(ContextWrapper &&cntxt) {
|
||||
current_ = std::move(cntxt);
|
||||
interpreter_ = current_.interp();
|
||||
interpreter_->in_explicit_db_ = in_explicit_db_;
|
||||
interpreter_context_ = current_.interpreter_context();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Authenticate user on passed database.
|
||||
*
|
||||
* @param db database to check against
|
||||
* @throws bolt::ClientError when user is not authorized
|
||||
*/
|
||||
void MultiDatabaseAuth(const std::string &db) {
|
||||
if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, {}, db)) {
|
||||
throw memgraph::communication::bolt::ClientError(
|
||||
"You are not authorized on the database \"{}\"! Please contact your database administrator.", db);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the user's default database
|
||||
*
|
||||
* @return std::string
|
||||
*/
|
||||
std::string GetDefaultDB() {
|
||||
if (user_.has_value()) {
|
||||
return user_->db_access().GetDefault();
|
||||
}
|
||||
return memgraph::dbms::kDefaultDB;
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Wrapper around TEncoder which converts TypedValue to Value
|
||||
/// before forwarding the calls to original TEncoder.
|
||||
class TypedValueResultStream {
|
||||
public:
|
||||
TypedValueResultStream(TEncoder *encoder, const memgraph::storage::Storage *db) : encoder_(encoder), db_(db) {}
|
||||
TypedValueResultStream(TEncoder *encoder, memgraph::query::InterpreterContext *ic)
|
||||
: encoder_(encoder), interpreter_context_(ic) {}
|
||||
|
||||
void Result(const std::vector<memgraph::query::TypedValue> &values) {
|
||||
std::vector<memgraph::communication::bolt::Value> decoded_values;
|
||||
decoded_values.reserve(values.size());
|
||||
for (const auto &v : values) {
|
||||
auto maybe_value = memgraph::glue::ToBoltValue(v, *db_, memgraph::storage::View::NEW);
|
||||
auto maybe_value = memgraph::glue::ToBoltValue(v, *interpreter_context_->db, memgraph::storage::View::NEW);
|
||||
if (maybe_value.HasError()) {
|
||||
switch (maybe_value.GetError()) {
|
||||
case memgraph::storage::Error::DELETED_OBJECT:
|
||||
@ -699,25 +888,36 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
|
||||
private:
|
||||
TEncoder *encoder_;
|
||||
// NOTE: Needed only for ToBoltValue conversions
|
||||
const memgraph::storage::Storage *db_;
|
||||
memgraph::query::InterpreterContext *interpreter_context_;
|
||||
};
|
||||
|
||||
// NOTE: Needed only for ToBoltValue conversions
|
||||
#ifdef MG_ENTERPRISE
|
||||
memgraph::dbms::SessionContextHandler &sc_handler_;
|
||||
#endif
|
||||
ContextWrapper current_;
|
||||
std::optional<ContextWrapper> defunct_;
|
||||
|
||||
memgraph::query::InterpreterContext *interpreter_context_;
|
||||
memgraph::query::Interpreter interpreter_;
|
||||
memgraph::query::Interpreter *interpreter_;
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
|
||||
std::optional<memgraph::auth::User> user_;
|
||||
#ifdef MG_ENTERPRISE
|
||||
memgraph::audit::Log *audit_log_;
|
||||
bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata
|
||||
#endif
|
||||
memgraph::communication::v2::ServerEndpoint endpoint_;
|
||||
// NOTE: run_id should be const but that complicates code a lot.
|
||||
std::optional<std::string> run_id_;
|
||||
};
|
||||
|
||||
using ServerT = memgraph::communication::v2::Server<BoltSession, SessionData>;
|
||||
#ifdef MG_ENTERPRISE
|
||||
using ServerT = memgraph::communication::v2::Server<SessionHL, memgraph::dbms::SessionContextHandler>;
|
||||
#else
|
||||
using ServerT = memgraph::communication::v2::Server<SessionHL, memgraph::dbms::SessionContext>;
|
||||
#endif
|
||||
using MonitoringServerT =
|
||||
memgraph::communication::http::Server<memgraph::http::MetricsRequestHandler<SessionData>, SessionData>;
|
||||
memgraph::communication::http::Server<memgraph::http::MetricsRequestHandler<memgraph::dbms::SessionContext>,
|
||||
memgraph::dbms::SessionContext>;
|
||||
using memgraph::communication::ServerContext;
|
||||
|
||||
// Needed to correctly handle memgraph destruction from a signal handler.
|
||||
@ -880,10 +1080,6 @@ int main(int argc, char **argv) {
|
||||
|
||||
// Begin enterprise features initialization
|
||||
|
||||
// Auth
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth{data_directory /
|
||||
"auth"};
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
// Audit log
|
||||
memgraph::audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size,
|
||||
@ -907,7 +1103,7 @@ int main(int argc, char **argv) {
|
||||
.interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)},
|
||||
.items = {.properties_on_edges = FLAGS_storage_properties_on_edges},
|
||||
.durability = {.storage_directory = FLAGS_data_directory,
|
||||
.recover_on_startup = FLAGS_storage_recover_on_startup,
|
||||
.recover_on_startup = FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup,
|
||||
.snapshot_retention_count = FLAGS_storage_snapshot_retention_count,
|
||||
.wal_file_size_kibibytes = FLAGS_storage_wal_file_size_kib,
|
||||
.wal_file_flush_every_n_tx = FLAGS_storage_wal_file_flush_every_n_tx,
|
||||
@ -944,31 +1140,62 @@ int main(int argc, char **argv) {
|
||||
db_config.durability.snapshot_interval = std::chrono::seconds(FLAGS_storage_snapshot_interval_sec);
|
||||
}
|
||||
|
||||
memgraph::query::InterpreterContext interpreter_context{
|
||||
db_config,
|
||||
{.query = {.allow_load_csv = FLAGS_allow_load_csv},
|
||||
.execution_timeout_sec = FLAGS_query_execution_timeout_sec,
|
||||
.replication_replica_check_frequency = std::chrono::seconds(FLAGS_replication_replica_check_frequency_sec),
|
||||
.default_kafka_bootstrap_servers = FLAGS_kafka_bootstrap_servers,
|
||||
.default_pulsar_service_url = FLAGS_pulsar_service_url,
|
||||
.stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries,
|
||||
.stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)},
|
||||
FLAGS_data_directory};
|
||||
// Default interpreter configuration
|
||||
memgraph::query::InterpreterConfig interp_config{
|
||||
.query = {.allow_load_csv = FLAGS_allow_load_csv},
|
||||
.execution_timeout_sec = FLAGS_query_execution_timeout_sec,
|
||||
.replication_replica_check_frequency = std::chrono::seconds(FLAGS_replication_replica_check_frequency_sec),
|
||||
.default_kafka_bootstrap_servers = FLAGS_kafka_bootstrap_servers,
|
||||
.default_pulsar_service_url = FLAGS_pulsar_service_url,
|
||||
.stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries,
|
||||
.stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)};
|
||||
|
||||
auto auth_glue =
|
||||
[flag = FLAGS_auth_user_or_role_name_regex](
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
|
||||
std::unique_ptr<memgraph::query::AuthQueryHandler> &ah, std::unique_ptr<memgraph::query::AuthChecker> &ac) {
|
||||
// Glue high level auth implementations to the query side
|
||||
ah = std::make_unique<memgraph::glue::AuthQueryHandler>(auth, flag);
|
||||
ac = std::make_unique<memgraph::glue::AuthChecker>(auth);
|
||||
// Handle users passed via arguments
|
||||
auto *maybe_username = std::getenv(kMgUser);
|
||||
auto *maybe_password = std::getenv(kMgPassword);
|
||||
auto *maybe_pass_file = std::getenv(kMgPassfile);
|
||||
if (maybe_username && maybe_password) {
|
||||
ah->CreateUser(maybe_username, maybe_password);
|
||||
} else if (maybe_pass_file) {
|
||||
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
|
||||
if (!username.empty() && !password.empty()) {
|
||||
ah->CreateUser(username, password);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
SessionData session_data{&interpreter_context, &auth, &audit_log};
|
||||
// SessionContext handler (multi-tenancy)
|
||||
memgraph::dbms::SessionContextHandler sc_handler(audit_log, {db_config, interp_config, auth_glue},
|
||||
FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup,
|
||||
FLAGS_storage_delete_on_drop);
|
||||
// Just for current support... TODO remove
|
||||
auto session_context = sc_handler.Get(memgraph::dbms::kDefaultDB);
|
||||
#else
|
||||
SessionData session_data{&interpreter_context, &auth};
|
||||
|
||||
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth_{data_directory /
|
||||
"auth"};
|
||||
std::unique_ptr<memgraph::query::AuthQueryHandler> auth_handler;
|
||||
std::unique_ptr<memgraph::query::AuthChecker> auth_checker;
|
||||
auth_glue(&auth_, auth_handler, auth_checker);
|
||||
auto session_context = memgraph::dbms::Init(db_config, interp_config, &auth_, auth_handler.get(), auth_checker.get());
|
||||
|
||||
#endif
|
||||
|
||||
auto *auth = session_context.auth;
|
||||
auto &interpreter_context = *session_context.interpreter_context; // TODO remove
|
||||
|
||||
memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories, FLAGS_data_directory);
|
||||
memgraph::query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories();
|
||||
memgraph::query::procedure::gCallableAliasMapper.LoadMapping(FLAGS_query_callable_mappings_path);
|
||||
|
||||
memgraph::glue::AuthQueryHandler auth_handler(&auth, FLAGS_auth_user_or_role_name_regex);
|
||||
memgraph::glue::AuthChecker auth_checker{&auth};
|
||||
interpreter_context.auth = &auth_handler;
|
||||
interpreter_context.auth_checker = &auth_checker;
|
||||
|
||||
if (!FLAGS_init_file.empty()) {
|
||||
spdlog::info("Running init file...");
|
||||
#ifdef MG_ENTERPRISE
|
||||
@ -982,18 +1209,10 @@ int main(int argc, char **argv) {
|
||||
#endif
|
||||
}
|
||||
|
||||
auto *maybe_username = std::getenv(kMgUser);
|
||||
auto *maybe_password = std::getenv(kMgPassword);
|
||||
auto *maybe_pass_file = std::getenv(kMgPassfile);
|
||||
if (maybe_username && maybe_password) {
|
||||
auth_handler.CreateUser(maybe_username, maybe_password);
|
||||
} else if (maybe_pass_file) {
|
||||
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
|
||||
if (!username.empty() && !password.empty()) {
|
||||
auth_handler.CreateUser(username, password);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
sc_handler.RestoreTriggers();
|
||||
sc_handler.RestoreStreams();
|
||||
#else
|
||||
{
|
||||
// Triggers can execute query procedures, so we need to reload the modules first and then
|
||||
// the triggers
|
||||
@ -1005,6 +1224,7 @@ int main(int argc, char **argv) {
|
||||
|
||||
// As the Stream transformations are using modules, they have to be restored after the query modules are loaded.
|
||||
interpreter_context.streams.RestoreStreams();
|
||||
#endif
|
||||
|
||||
ServerContext context;
|
||||
std::string service_name = "Bolt";
|
||||
@ -1016,25 +1236,35 @@ int main(int argc, char **argv) {
|
||||
spdlog::warn(
|
||||
memgraph::utils::MessageWithLink("Using non-secure Bolt connection (without SSL).", "https://memgr.ph/ssl"));
|
||||
}
|
||||
|
||||
auto server_endpoint = memgraph::communication::v2::ServerEndpoint{
|
||||
boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast<uint16_t>(FLAGS_bolt_port)};
|
||||
ServerT server(server_endpoint, &session_data, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
|
||||
#ifdef MG_ENTERPRISE
|
||||
ServerT server(server_endpoint, &sc_handler, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
|
||||
FLAGS_bolt_num_workers);
|
||||
#else
|
||||
ServerT server(server_endpoint, &session_context, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
|
||||
FLAGS_bolt_num_workers);
|
||||
#endif
|
||||
|
||||
const auto run_id = memgraph::utils::GenerateUUID();
|
||||
const auto machine_id = memgraph::utils::GetMachineId();
|
||||
session_data.run_id = run_id;
|
||||
const auto run_id = session_context.run_id; // For current compatibility
|
||||
|
||||
// Setup telemetry
|
||||
static constexpr auto telemetry_server{"https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/"};
|
||||
std::optional<memgraph::telemetry::Telemetry> telemetry;
|
||||
if (FLAGS_telemetry_enabled) {
|
||||
telemetry.emplace(telemetry_server, data_directory / "telemetry", run_id, machine_id, std::chrono::minutes(10));
|
||||
telemetry->AddCollector("storage", [db_ = interpreter_context.db.get()]() -> nlohmann::json {
|
||||
auto info = db_->GetInfo();
|
||||
#ifdef MG_ENTERPRISE
|
||||
telemetry->AddCollector("storage", [&sc_handler]() -> nlohmann::json {
|
||||
const auto &info = sc_handler.Info();
|
||||
return {{"vertices", info.num_vertex}, {"edges", info.num_edges}, {"databases", info.num_databases}};
|
||||
});
|
||||
#else
|
||||
telemetry->AddCollector("storage", [&interpreter_context]() -> nlohmann::json {
|
||||
auto info = interpreter_context.db->GetInfo();
|
||||
return {{"vertices", info.vertex_count}, {"edges", info.edge_count}};
|
||||
});
|
||||
#endif
|
||||
telemetry->AddCollector("event_counters", []() -> nlohmann::json {
|
||||
nlohmann::json ret;
|
||||
for (size_t i = 0; i < memgraph::metrics::CounterEnd(); ++i) {
|
||||
@ -1050,25 +1280,25 @@ int main(int argc, char **argv) {
|
||||
memgraph::license::LicenseInfoSender license_info_sender(telemetry_server, run_id, machine_id, memory_limit,
|
||||
memgraph::license::global_license_checker.GetLicenseInfo());
|
||||
|
||||
memgraph::communication::websocket::SafeAuth websocket_auth{&auth};
|
||||
memgraph::communication::websocket::SafeAuth websocket_auth{auth};
|
||||
memgraph::communication::websocket::Server websocket_server{
|
||||
{FLAGS_monitoring_address, static_cast<uint16_t>(FLAGS_monitoring_port)}, &context, websocket_auth};
|
||||
AddLoggerSink(websocket_server.GetLoggingSink());
|
||||
|
||||
MonitoringServerT metrics_server{
|
||||
{FLAGS_metrics_address, static_cast<uint16_t>(FLAGS_metrics_port)}, &session_data, &context};
|
||||
{FLAGS_metrics_address, static_cast<uint16_t>(FLAGS_metrics_port)}, &session_context, &context};
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
// Handler for regular termination signals
|
||||
auto shutdown = [&metrics_server, &websocket_server, &server, &interpreter_context] {
|
||||
auto shutdown = [&metrics_server, &websocket_server, &server, &sc_handler] {
|
||||
// Server needs to be shutdown first and then the database. This prevents
|
||||
// a race condition when a transaction is accepted during server shutdown.
|
||||
server.Shutdown();
|
||||
// After the server is notified to stop accepting and processing
|
||||
// connections we tell the execution engine to stop processing all pending
|
||||
// queries.
|
||||
memgraph::query::Shutdown(&interpreter_context);
|
||||
sc_handler.Shutdown();
|
||||
|
||||
websocket_server.Shutdown();
|
||||
metrics_server.Shutdown();
|
||||
|
@ -24,7 +24,8 @@ class AuthChecker {
|
||||
virtual ~AuthChecker() = default;
|
||||
|
||||
[[nodiscard]] virtual bool IsUserAuthorized(const std::optional<std::string> &username,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) const = 0;
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges,
|
||||
const std::string &db_name) const = 0;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
[[nodiscard]] virtual std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker(
|
||||
@ -92,7 +93,8 @@ class AllowEverythingFineGrainedAuthChecker final : public query::FineGrainedAut
|
||||
class AllowEverythingAuthChecker final : public query::AuthChecker {
|
||||
public:
|
||||
bool IsUserAuthorized(const std::optional<std::string> & /*username*/,
|
||||
const std::vector<query::AuthQuery::Privilege> & /*privileges*/) const override {
|
||||
const std::vector<query::AuthQuery::Privilege> & /*privileges*/,
|
||||
const std::string & /*db*/) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -495,6 +495,8 @@ class DbAccessor final {
|
||||
storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); }
|
||||
|
||||
storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); }
|
||||
|
||||
const std::string &id() const { return accessor_->id(); }
|
||||
};
|
||||
|
||||
class SubgraphDbAccessor final {
|
||||
|
@ -324,4 +324,10 @@ class ConstraintsPersistenceException : public QueryException {
|
||||
ConstraintsPersistenceException() : QueryException("Persisting constraints on disk failed.") {}
|
||||
};
|
||||
|
||||
class MultiDatabaseQueryInMulticommandTxException : public QueryException {
|
||||
public:
|
||||
MultiDatabaseQueryInMulticommandTxException()
|
||||
: QueryException("Multi-database queries are not allowed in multicommand transactions.") {}
|
||||
};
|
||||
|
||||
} // namespace memgraph::query
|
||||
|
@ -279,4 +279,10 @@ constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exist
|
||||
|
||||
constexpr utils::TypeInfo query::CallSubquery::kType{utils::TypeId::AST_CALL_SUBQUERY, "CallSubquery",
|
||||
&query::Clause::kType};
|
||||
|
||||
constexpr utils::TypeInfo query::MultiDatabaseQuery::kType{utils::TypeId::AST_MULTI_DATABASE_QUERY,
|
||||
"MultiDatabaseQuery", &query::Query::kType};
|
||||
|
||||
constexpr utils::TypeInfo query::ShowDatabasesQuery::kType{utils::TypeId::AST_SHOW_DATABASES, "ShowDatabasesQuery",
|
||||
&query::Query::kType};
|
||||
} // namespace memgraph
|
||||
|
@ -2778,7 +2778,11 @@ class AuthQuery : public memgraph::query::Query {
|
||||
REVOKE_PRIVILEGE,
|
||||
SHOW_PRIVILEGES,
|
||||
SHOW_ROLE_FOR_USER,
|
||||
SHOW_USERS_FOR_ROLE
|
||||
SHOW_USERS_FOR_ROLE,
|
||||
GRANT_DATABASE_TO_USER,
|
||||
REVOKE_DATABASE_FROM_USER,
|
||||
SHOW_DATABASE_PRIVILEGES,
|
||||
SET_MAIN_DATABASE,
|
||||
};
|
||||
|
||||
enum class Privilege {
|
||||
@ -2804,7 +2808,9 @@ class AuthQuery : public memgraph::query::Query {
|
||||
MODULE_WRITE,
|
||||
WEBSOCKET,
|
||||
STORAGE_MODE,
|
||||
TRANSACTION_MANAGEMENT
|
||||
TRANSACTION_MANAGEMENT,
|
||||
MULTI_DATABASE_EDIT,
|
||||
MULTI_DATABASE_USE,
|
||||
};
|
||||
|
||||
enum class FineGrainedPrivilege { NOTHING, READ, UPDATE, CREATE_DELETE };
|
||||
@ -2818,6 +2824,7 @@ class AuthQuery : public memgraph::query::Query {
|
||||
std::string role_;
|
||||
std::string user_or_role_;
|
||||
memgraph::query::Expression *password_{nullptr};
|
||||
std::string database_;
|
||||
std::vector<memgraph::query::AuthQuery::Privilege> privileges_;
|
||||
std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
|
||||
label_privileges_;
|
||||
@ -2831,6 +2838,7 @@ class AuthQuery : public memgraph::query::Query {
|
||||
object->role_ = role_;
|
||||
object->user_or_role_ = user_or_role_;
|
||||
object->password_ = password_ ? password_->Clone(storage) : nullptr;
|
||||
object->database_ = database_;
|
||||
object->privileges_ = privileges_;
|
||||
object->label_privileges_ = label_privileges_;
|
||||
object->edge_type_privileges_ = edge_type_privileges_;
|
||||
@ -2839,7 +2847,7 @@ class AuthQuery : public memgraph::query::Query {
|
||||
|
||||
protected:
|
||||
AuthQuery(Action action, std::string user, std::string role, std::string user_or_role, Expression *password,
|
||||
std::vector<Privilege> privileges,
|
||||
std::string database, std::vector<Privilege> privileges,
|
||||
std::vector<std::unordered_map<FineGrainedPrivilege, std::vector<std::string>>> label_privileges,
|
||||
std::vector<std::unordered_map<FineGrainedPrivilege, std::vector<std::string>>> edge_type_privileges)
|
||||
: action_(action),
|
||||
@ -2847,6 +2855,7 @@ class AuthQuery : public memgraph::query::Query {
|
||||
role_(role),
|
||||
user_or_role_(user_or_role),
|
||||
password_(password),
|
||||
database_(database),
|
||||
privileges_(privileges),
|
||||
label_privileges_(label_privileges),
|
||||
edge_type_privileges_(edge_type_privileges) {}
|
||||
@ -2856,19 +2865,31 @@ class AuthQuery : public memgraph::query::Query {
|
||||
};
|
||||
|
||||
/// Constant that holds all available privileges.
|
||||
const std::vector<AuthQuery::Privilege> kPrivilegesAll = {
|
||||
AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE,
|
||||
AuthQuery::Privilege::MATCH, AuthQuery::Privilege::MERGE,
|
||||
AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE,
|
||||
AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS,
|
||||
AuthQuery::Privilege::AUTH, AuthQuery::Privilege::CONSTRAINT,
|
||||
AuthQuery::Privilege::DUMP, AuthQuery::Privilege::REPLICATION,
|
||||
AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY,
|
||||
AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER,
|
||||
AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM,
|
||||
AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE,
|
||||
AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::TRANSACTION_MANAGEMENT,
|
||||
AuthQuery::Privilege::STORAGE_MODE};
|
||||
const std::vector<AuthQuery::Privilege> kPrivilegesAll = {AuthQuery::Privilege::CREATE,
|
||||
AuthQuery::Privilege::DELETE,
|
||||
AuthQuery::Privilege::MATCH,
|
||||
AuthQuery::Privilege::MERGE,
|
||||
AuthQuery::Privilege::SET,
|
||||
AuthQuery::Privilege::REMOVE,
|
||||
AuthQuery::Privilege::INDEX,
|
||||
AuthQuery::Privilege::STATS,
|
||||
AuthQuery::Privilege::AUTH,
|
||||
AuthQuery::Privilege::CONSTRAINT,
|
||||
AuthQuery::Privilege::DUMP,
|
||||
AuthQuery::Privilege::REPLICATION,
|
||||
AuthQuery::Privilege::READ_FILE,
|
||||
AuthQuery::Privilege::DURABILITY,
|
||||
AuthQuery::Privilege::FREE_MEMORY,
|
||||
AuthQuery::Privilege::TRIGGER,
|
||||
AuthQuery::Privilege::CONFIG,
|
||||
AuthQuery::Privilege::STREAM,
|
||||
AuthQuery::Privilege::MODULE_READ,
|
||||
AuthQuery::Privilege::MODULE_WRITE,
|
||||
AuthQuery::Privilege::WEBSOCKET,
|
||||
AuthQuery::Privilege::TRANSACTION_MANAGEMENT,
|
||||
AuthQuery::Privilege::STORAGE_MODE,
|
||||
AuthQuery::Privilege::MULTI_DATABASE_EDIT,
|
||||
AuthQuery::Privilege::MULTI_DATABASE_USE};
|
||||
|
||||
class InfoQuery : public memgraph::query::Query {
|
||||
public:
|
||||
@ -3446,5 +3467,38 @@ class CallSubquery : public memgraph::query::Clause {
|
||||
friend class AstStorage;
|
||||
};
|
||||
|
||||
class MultiDatabaseQuery : public memgraph::query::Query {
|
||||
public:
|
||||
static const utils::TypeInfo kType;
|
||||
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
|
||||
|
||||
DEFVISITABLE(QueryVisitor<void>);
|
||||
|
||||
enum class Action { CREATE, USE, DROP };
|
||||
|
||||
memgraph::query::MultiDatabaseQuery::Action action_;
|
||||
std::string db_name_;
|
||||
|
||||
MultiDatabaseQuery *Clone(AstStorage *storage) const override {
|
||||
auto *object = storage->Create<MultiDatabaseQuery>();
|
||||
object->action_ = action_;
|
||||
object->db_name_ = db_name_;
|
||||
return object;
|
||||
}
|
||||
};
|
||||
|
||||
class ShowDatabasesQuery : public memgraph::query::Query {
|
||||
public:
|
||||
static const utils::TypeInfo kType;
|
||||
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
|
||||
|
||||
DEFVISITABLE(QueryVisitor<void>);
|
||||
|
||||
ShowDatabasesQuery *Clone(AstStorage *storage) const override {
|
||||
auto *object = storage->Create<ShowDatabasesQuery>();
|
||||
return object;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace query
|
||||
} // namespace memgraph
|
||||
|
@ -11,6 +11,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "utils/visitor.hpp"
|
||||
|
||||
namespace memgraph::query {
|
||||
@ -102,6 +103,8 @@ class CallSubquery;
|
||||
class AnalyzeGraphQuery;
|
||||
class TransactionQueueQuery;
|
||||
class Exists;
|
||||
class MultiDatabaseQuery;
|
||||
class ShowDatabasesQuery;
|
||||
|
||||
using TreeCompositeVisitor = utils::CompositeVisitor<
|
||||
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
|
||||
@ -139,6 +142,7 @@ class QueryVisitor
|
||||
: public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery,
|
||||
ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, TriggerQuery,
|
||||
IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery, VersionQuery,
|
||||
ShowConfigQuery, TransactionQueueQuery, StorageModeQuery, AnalyzeGraphQuery> {};
|
||||
ShowConfigQuery, TransactionQueueQuery, StorageModeQuery, AnalyzeGraphQuery,
|
||||
MultiDatabaseQuery, ShowDatabasesQuery> {};
|
||||
|
||||
} // namespace memgraph::query
|
||||
|
@ -1272,7 +1272,7 @@ antlrcpp::Any CypherMainVisitor::visitUserOrRoleName(MemgraphCypher::UserOrRoleN
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitCreateRole(MemgraphCypher::CreateRoleContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::CREATE_ROLE;
|
||||
auth->role_ = std::any_cast<std::string>(ctx->role->accept(this));
|
||||
return auth;
|
||||
@ -1282,7 +1282,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateRole(MemgraphCypher::CreateRoleConte
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitDropRole(MemgraphCypher::DropRoleContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::DROP_ROLE;
|
||||
auth->role_ = std::any_cast<std::string>(ctx->role->accept(this));
|
||||
return auth;
|
||||
@ -1292,7 +1292,7 @@ antlrcpp::Any CypherMainVisitor::visitDropRole(MemgraphCypher::DropRoleContext *
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitShowRoles(MemgraphCypher::ShowRolesContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SHOW_ROLES;
|
||||
return auth;
|
||||
}
|
||||
@ -1301,7 +1301,7 @@ antlrcpp::Any CypherMainVisitor::visitShowRoles(MemgraphCypher::ShowRolesContext
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitCreateUser(MemgraphCypher::CreateUserContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::CREATE_USER;
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
if (ctx->password) {
|
||||
@ -1317,7 +1317,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateUser(MemgraphCypher::CreateUserConte
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitSetPassword(MemgraphCypher::SetPasswordContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SET_PASSWORD;
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
if (!ctx->password->StringLiteral() && !ctx->literal()->CYPHERNULL()) {
|
||||
@ -1331,7 +1331,7 @@ antlrcpp::Any CypherMainVisitor::visitSetPassword(MemgraphCypher::SetPasswordCon
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitDropUser(MemgraphCypher::DropUserContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::DROP_USER;
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
return auth;
|
||||
@ -1341,7 +1341,7 @@ antlrcpp::Any CypherMainVisitor::visitDropUser(MemgraphCypher::DropUserContext *
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitShowUsers(MemgraphCypher::ShowUsersContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SHOW_USERS;
|
||||
return auth;
|
||||
}
|
||||
@ -1350,7 +1350,7 @@ antlrcpp::Any CypherMainVisitor::visitShowUsers(MemgraphCypher::ShowUsersContext
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitSetRole(MemgraphCypher::SetRoleContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SET_ROLE;
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
auth->role_ = std::any_cast<std::string>(ctx->role->accept(this));
|
||||
@ -1361,7 +1361,7 @@ antlrcpp::Any CypherMainVisitor::visitSetRole(MemgraphCypher::SetRoleContext *ct
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitClearRole(MemgraphCypher::ClearRoleContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::CLEAR_ROLE;
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
return auth;
|
||||
@ -1371,7 +1371,7 @@ antlrcpp::Any CypherMainVisitor::visitClearRole(MemgraphCypher::ClearRoleContext
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(MemgraphCypher::GrantPrivilegeContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::GRANT_PRIVILEGE;
|
||||
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
|
||||
if (ctx->grantPrivilegesList()) {
|
||||
@ -1393,7 +1393,7 @@ antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(MemgraphCypher::GrantPrivil
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitDenyPrivilege(MemgraphCypher::DenyPrivilegeContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::DENY_PRIVILEGE;
|
||||
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
|
||||
if (ctx->privilegesList()) {
|
||||
@ -1453,7 +1453,7 @@ antlrcpp::Any CypherMainVisitor::visitGrantPrivilegesList(MemgraphCypher::GrantP
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitRevokePrivilege(MemgraphCypher::RevokePrivilegeContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::REVOKE_PRIVILEGE;
|
||||
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
|
||||
if (ctx->revokePrivilegesList()) {
|
||||
@ -1526,6 +1526,16 @@ antlrcpp::Any CypherMainVisitor::visitEntitiesList(MemgraphCypher::EntitiesListC
|
||||
return entities;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return std::string
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitWildcardName(MemgraphCypher::WildcardNameContext *ctx) {
|
||||
if (ctx->symbolicName()) {
|
||||
return ctx->symbolicName()->accept(this);
|
||||
}
|
||||
return std::string("*");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return AuthQuery::Privilege
|
||||
*/
|
||||
@ -1553,6 +1563,8 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext
|
||||
if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET;
|
||||
if (ctx->TRANSACTION_MANAGEMENT()) return AuthQuery::Privilege::TRANSACTION_MANAGEMENT;
|
||||
if (ctx->STORAGE_MODE()) return AuthQuery::Privilege::STORAGE_MODE;
|
||||
if (ctx->MULTI_DATABASE_EDIT()) return AuthQuery::Privilege::MULTI_DATABASE_EDIT;
|
||||
if (ctx->MULTI_DATABASE_USE()) return AuthQuery::Privilege::MULTI_DATABASE_USE;
|
||||
LOG_FATAL("Should not get here - unknown privilege!");
|
||||
}
|
||||
|
||||
@ -1580,7 +1592,7 @@ antlrcpp::Any CypherMainVisitor::visitEntityType(MemgraphCypher::EntityTypeConte
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SHOW_PRIVILEGES;
|
||||
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
|
||||
return auth;
|
||||
@ -1590,7 +1602,7 @@ antlrcpp::Any CypherMainVisitor::visitShowPrivileges(MemgraphCypher::ShowPrivile
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(MemgraphCypher::ShowRoleForUserContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SHOW_ROLE_FOR_USER;
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
return auth;
|
||||
@ -1600,12 +1612,55 @@ antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(MemgraphCypher::ShowRoleFo
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) {
|
||||
AuthQuery *auth = storage_->Create<AuthQuery>();
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SHOW_USERS_FOR_ROLE;
|
||||
auth->role_ = std::any_cast<std::string>(ctx->role->accept(this));
|
||||
return auth;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) {
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::GRANT_DATABASE_TO_USER;
|
||||
auth->database_ = std::any_cast<std::string>(ctx->wildcardName()->accept(this));
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
return auth;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) {
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::REVOKE_DATABASE_FROM_USER;
|
||||
auth->database_ = std::any_cast<std::string>(ctx->wildcardName()->accept(this));
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
return auth;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) {
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SHOW_DATABASE_PRIVILEGES;
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
return auth;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any CypherMainVisitor::visitSetMainDatabase(MemgraphCypher::SetMainDatabaseContext *ctx) {
|
||||
auto *auth = storage_->Create<AuthQuery>();
|
||||
auth->action_ = AuthQuery::Action::SET_MAIN_DATABASE;
|
||||
auth->database_ = std::any_cast<std::string>(ctx->db->accept(this));
|
||||
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
|
||||
return auth;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitCypherReturn(MemgraphCypher::CypherReturnContext *ctx) {
|
||||
auto *return_clause = storage_->Create<Return>();
|
||||
return_clause->body_ = std::any_cast<ReturnBody>(ctx->returnBody()->accept(this));
|
||||
@ -2671,4 +2726,33 @@ PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return stor
|
||||
|
||||
EdgeTypeIx CypherMainVisitor::AddEdgeType(const std::string &name) { return storage_->GetEdgeTypeIx(name); }
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitCreateDatabase(MemgraphCypher::CreateDatabaseContext *ctx) {
|
||||
auto *mdb_query = storage_->Create<MultiDatabaseQuery>();
|
||||
mdb_query->db_name_ = std::any_cast<std::string>(ctx->databaseName()->accept(this));
|
||||
mdb_query->action_ = MultiDatabaseQuery::Action::CREATE;
|
||||
query_ = mdb_query;
|
||||
return mdb_query;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitUseDatabase(MemgraphCypher::UseDatabaseContext *ctx) {
|
||||
auto *mdb_query = storage_->Create<MultiDatabaseQuery>();
|
||||
mdb_query->db_name_ = std::any_cast<std::string>(ctx->databaseName()->accept(this));
|
||||
mdb_query->action_ = MultiDatabaseQuery::Action::USE;
|
||||
query_ = mdb_query;
|
||||
return mdb_query;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitDropDatabase(MemgraphCypher::DropDatabaseContext *ctx) {
|
||||
auto *mdb_query = storage_->Create<MultiDatabaseQuery>();
|
||||
mdb_query->db_name_ = std::any_cast<std::string>(ctx->databaseName()->accept(this));
|
||||
mdb_query->action_ = MultiDatabaseQuery::Action::DROP;
|
||||
query_ = mdb_query;
|
||||
return mdb_query;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitShowDatabases(MemgraphCypher::ShowDatabasesContext * /*ctx*/) {
|
||||
query_ = storage_->Create<ShowDatabasesQuery>();
|
||||
return query_;
|
||||
}
|
||||
|
||||
} // namespace memgraph::query::frontend
|
||||
|
@ -524,6 +524,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
*/
|
||||
antlrcpp::Any visitEntitiesList(MemgraphCypher::EntitiesListContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return std::string
|
||||
*/
|
||||
antlrcpp::Any visitWildcardName(MemgraphCypher::WildcardNameContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery::FineGrainedPrivilege
|
||||
*/
|
||||
@ -554,6 +559,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
*/
|
||||
antlrcpp::Any visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitSetMainDatabase(MemgraphCypher::SetMainDatabaseContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Return*
|
||||
*/
|
||||
@ -935,6 +960,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
*/
|
||||
antlrcpp::Any visitCallSubquery(MemgraphCypher::CallSubqueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return MultiDatabaseQuery*
|
||||
*/
|
||||
antlrcpp::Any visitCreateDatabase(MemgraphCypher::CreateDatabaseContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return MultiDatabaseQuery*
|
||||
*/
|
||||
antlrcpp::Any visitUseDatabase(MemgraphCypher::UseDatabaseContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return MultiDatabaseQuery*
|
||||
*/
|
||||
antlrcpp::Any visitDropDatabase(MemgraphCypher::DropDatabaseContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ShowDatabasesQuery*
|
||||
*/
|
||||
antlrcpp::Any visitShowDatabases(MemgraphCypher::ShowDatabasesContext *ctx) override;
|
||||
|
||||
public:
|
||||
Query *query() { return query_; }
|
||||
const static std::string kAnonPrefix;
|
||||
|
@ -107,6 +107,7 @@ memgraphCypherKeyword : cypherKeyword
|
||||
| UNCOMMITTED
|
||||
| UNLOCK
|
||||
| UPDATE
|
||||
| USE
|
||||
| USER
|
||||
| USERS
|
||||
| VERSION
|
||||
@ -140,6 +141,8 @@ query : cypherQuery
|
||||
| versionQuery
|
||||
| showConfigQuery
|
||||
| transactionQueueQuery
|
||||
| multiDatabaseQuery
|
||||
| showDatabases
|
||||
;
|
||||
|
||||
authQuery : createRole
|
||||
@ -157,6 +160,10 @@ authQuery : createRole
|
||||
| showPrivileges
|
||||
| showRoleForUser
|
||||
| showUsersForRole
|
||||
| grantDatabaseToUser
|
||||
| revokeDatabaseFromUser
|
||||
| showDatabasePrivileges
|
||||
| setMainDatabase
|
||||
;
|
||||
|
||||
replicationQuery : setReplicationRole
|
||||
@ -208,6 +215,10 @@ streamQuery : checkStream
|
||||
| showStreams
|
||||
;
|
||||
|
||||
databaseName : symbolicName ;
|
||||
|
||||
wildcardName : ASTERISK | symbolicName ;
|
||||
|
||||
settingQuery : setSetting
|
||||
| showSetting
|
||||
| showSettings
|
||||
@ -265,6 +276,14 @@ denyPrivilege : DENY ( ALL PRIVILEGES | privileges=privilegesList ) TO userOrRol
|
||||
|
||||
revokePrivilege : REVOKE ( ALL PRIVILEGES | privileges=revokePrivilegesList ) FROM userOrRole=userOrRoleName ;
|
||||
|
||||
grantDatabaseToUser : GRANT DATABASE db=wildcardName TO user=symbolicName ;
|
||||
|
||||
revokeDatabaseFromUser : REVOKE DATABASE db=wildcardName FROM user=symbolicName ;
|
||||
|
||||
showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR user=symbolicName ;
|
||||
|
||||
setMainDatabase : SET MAIN DATABASE db=symbolicName FOR user=symbolicName ;
|
||||
|
||||
privilege : CREATE
|
||||
| DELETE
|
||||
| MATCH
|
||||
@ -288,6 +307,8 @@ privilege : CREATE
|
||||
| WEBSOCKET
|
||||
| TRANSACTION_MANAGEMENT
|
||||
| STORAGE_MODE
|
||||
| MULTI_DATABASE_EDIT
|
||||
| MULTI_DATABASE_USE
|
||||
;
|
||||
|
||||
granularPrivilege : NOTHING | READ | UPDATE | CREATE_DELETE ;
|
||||
@ -441,3 +462,16 @@ versionQuery : SHOW VERSION ;
|
||||
transactionIdList : transactionId ( ',' transactionId )* ;
|
||||
|
||||
transactionId : literal ;
|
||||
|
||||
multiDatabaseQuery : createDatabase
|
||||
| useDatabase
|
||||
| dropDatabase
|
||||
;
|
||||
|
||||
createDatabase : CREATE DATABASE databaseName ;
|
||||
|
||||
useDatabase : USE DATABASE databaseName ;
|
||||
|
||||
dropDatabase : DROP DATABASE databaseName ;
|
||||
|
||||
showDatabases: SHOW DATABASES ;
|
||||
|
@ -49,6 +49,7 @@ CSV : C S V ;
|
||||
DATA : D A T A ;
|
||||
DELIMITER : D E L I M I T E R ;
|
||||
DATABASE : D A T A B A S E ;
|
||||
DATABASES : D A T A B A S E S ;
|
||||
DENY : D E N Y ;
|
||||
DIRECTORY : D I R E C T O R Y ;
|
||||
DROP : D R O P ;
|
||||
@ -80,6 +81,8 @@ MAIN : M A I N ;
|
||||
MODE : M O D E ;
|
||||
MODULE_READ : M O D U L E UNDERSCORE R E A D ;
|
||||
MODULE_WRITE : M O D U L E UNDERSCORE W R I T E ;
|
||||
MULTI_DATABASE_EDIT : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE E D I T ;
|
||||
MULTI_DATABASE_USE : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE U S E ;
|
||||
NEXT : N E X T ;
|
||||
NO : N O ;
|
||||
NOTHING : N O T H I N G ;
|
||||
@ -127,6 +130,7 @@ TRIGGERS : T R I G G E R S ;
|
||||
UNCOMMITTED : U N C O M M I T T E D ;
|
||||
UNLOCK : U N L O C K ;
|
||||
UPDATE : U P D A T E ;
|
||||
USE : U S E ;
|
||||
USER : U S E R ;
|
||||
USERS : U S E R S ;
|
||||
VERSION : V E R S I O N ;
|
||||
|
@ -89,6 +89,22 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis
|
||||
|
||||
void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); }
|
||||
|
||||
void Visit(MultiDatabaseQuery &query) override {
|
||||
switch (query.action_) {
|
||||
case MultiDatabaseQuery::Action::CREATE:
|
||||
case MultiDatabaseQuery::Action::DROP:
|
||||
AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_EDIT);
|
||||
break;
|
||||
case MultiDatabaseQuery::Action::USE:
|
||||
AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_USE);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShowDatabasesQuery & /*unused*/) override {
|
||||
AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_USE); /* OR EDIT */
|
||||
}
|
||||
|
||||
bool PreVisit(Create & /*unused*/) override {
|
||||
AddPrivilege(AuthQuery::Privilege::CREATE);
|
||||
return false;
|
||||
|
@ -139,6 +139,7 @@ const trie::Trie kKeywords = {"union",
|
||||
"false",
|
||||
"reduce",
|
||||
"coalesce",
|
||||
"use",
|
||||
"user",
|
||||
"password",
|
||||
"alter",
|
||||
@ -159,6 +160,7 @@ const trie::Trie kKeywords = {"union",
|
||||
"key",
|
||||
"dump",
|
||||
"database",
|
||||
"databases",
|
||||
"call",
|
||||
"yield",
|
||||
"memory",
|
||||
|
@ -24,13 +24,17 @@
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include "auth/auth.hpp"
|
||||
#include "auth/models.hpp"
|
||||
#include "csv/parsing.hpp"
|
||||
#include "dbms/global.hpp"
|
||||
#include "dbms/session_context_handler.hpp"
|
||||
#include "glue/communication.hpp"
|
||||
#include "license/license.hpp"
|
||||
#include "memory/memory_control.hpp"
|
||||
@ -105,6 +109,26 @@ template <typename>
|
||||
constexpr auto kAlwaysFalse = false;
|
||||
|
||||
namespace {
|
||||
template <typename T, typename K>
|
||||
void Sort(std::vector<T, K> &vec) {
|
||||
std::sort(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
template <typename K>
|
||||
void Sort(std::vector<TypedValue, K> &vec) {
|
||||
std::sort(vec.begin(), vec.end(),
|
||||
[](const TypedValue &lv, const TypedValue &rv) { return lv.ValueString() < rv.ValueString(); });
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE (misc-unused-parameters)
|
||||
bool Same(const TypedValue &lv, const TypedValue &rv) {
|
||||
return TypedValue(lv).ValueString() == TypedValue(rv).ValueString();
|
||||
}
|
||||
bool Same(const TypedValue &lv, const std::string &rv) { return std::string(TypedValue(lv).ValueString()) == rv; }
|
||||
// NOLINTNEXTLINE (misc-unused-parameters)
|
||||
bool Same(const std::string &lv, const TypedValue &rv) { return lv == std::string(TypedValue(rv).ValueString()); }
|
||||
bool Same(const std::string &lv, const std::string &rv) { return lv == rv; }
|
||||
|
||||
void UpdateTypeCount(const plan::ReadWriteTypeChecker::RWType type) {
|
||||
switch (type) {
|
||||
case plan::ReadWriteTypeChecker::RWType::R:
|
||||
@ -333,8 +357,12 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
|
||||
/// returns false if the replication role can't be set
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
|
||||
Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Parameters ¶meters,
|
||||
Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters ¶meters,
|
||||
DbAccessor *db_accessor) {
|
||||
AuthQueryHandler *auth = interpreter_context->auth;
|
||||
#ifdef MG_ENTERPRISE
|
||||
auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context);
|
||||
#endif
|
||||
// Empty frame for evaluation of password expression. This is OK since
|
||||
// password should be either null or string literal and it's evaluation
|
||||
// should not depend on frame.
|
||||
@ -351,6 +379,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
|
||||
std::string username = auth_query->user_;
|
||||
std::string rolename = auth_query->role_;
|
||||
std::string user_or_role = auth_query->user_or_role_;
|
||||
std::string database = auth_query->database_;
|
||||
std::vector<AuthQuery::Privilege> privileges = auth_query->privileges_;
|
||||
#ifdef MG_ENTERPRISE
|
||||
std::vector<std::unordered_map<AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> label_privileges =
|
||||
@ -364,11 +393,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
|
||||
|
||||
const auto license_check_result = license::global_license_checker.IsEnterpriseValid(utils::global_settings);
|
||||
|
||||
static const std::unordered_set enterprise_only_methods{
|
||||
AuthQuery::Action::CREATE_ROLE, AuthQuery::Action::DROP_ROLE, AuthQuery::Action::SET_ROLE,
|
||||
AuthQuery::Action::CLEAR_ROLE, AuthQuery::Action::GRANT_PRIVILEGE, AuthQuery::Action::DENY_PRIVILEGE,
|
||||
AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE,
|
||||
AuthQuery::Action::SHOW_ROLE_FOR_USER};
|
||||
static const std::unordered_set enterprise_only_methods{AuthQuery::Action::CREATE_ROLE,
|
||||
AuthQuery::Action::DROP_ROLE,
|
||||
AuthQuery::Action::SET_ROLE,
|
||||
AuthQuery::Action::CLEAR_ROLE,
|
||||
AuthQuery::Action::GRANT_PRIVILEGE,
|
||||
AuthQuery::Action::DENY_PRIVILEGE,
|
||||
AuthQuery::Action::REVOKE_PRIVILEGE,
|
||||
AuthQuery::Action::SHOW_PRIVILEGES,
|
||||
AuthQuery::Action::SHOW_USERS_FOR_ROLE,
|
||||
AuthQuery::Action::SHOW_ROLE_FOR_USER,
|
||||
AuthQuery::Action::GRANT_DATABASE_TO_USER,
|
||||
AuthQuery::Action::REVOKE_DATABASE_FROM_USER,
|
||||
AuthQuery::Action::SHOW_DATABASE_PRIVILEGES,
|
||||
AuthQuery::Action::SET_MAIN_DATABASE};
|
||||
|
||||
if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) {
|
||||
throw utils::BasicException(
|
||||
@ -536,6 +574,73 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
|
||||
return rows;
|
||||
};
|
||||
return callback;
|
||||
case AuthQuery::Action::GRANT_DATABASE_TO_USER:
|
||||
#ifdef MG_ENTERPRISE
|
||||
callback.fn = [auth, database, username, &sc_handler] { // NOLINT
|
||||
try {
|
||||
memgraph::dbms::SessionContext sc(nullptr, "", nullptr, nullptr);
|
||||
if (database != memgraph::auth::kAllDatabases) {
|
||||
sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull
|
||||
}
|
||||
if (!auth->GrantDatabaseToUser(database, username)) {
|
||||
throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username);
|
||||
}
|
||||
} catch (memgraph::dbms::UnknownDatabaseException &e) {
|
||||
throw QueryRuntimeException(e.what());
|
||||
}
|
||||
#else
|
||||
callback.fn = [] {
|
||||
#endif
|
||||
return std::vector<std::vector<TypedValue>>();
|
||||
};
|
||||
return callback;
|
||||
case AuthQuery::Action::REVOKE_DATABASE_FROM_USER:
|
||||
#ifdef MG_ENTERPRISE
|
||||
callback.fn = [auth, database, username, &sc_handler] { // NOLINT
|
||||
try {
|
||||
memgraph::dbms::SessionContext sc(nullptr, "", nullptr, nullptr);
|
||||
if (database != memgraph::auth::kAllDatabases) {
|
||||
sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull
|
||||
}
|
||||
if (!auth->RevokeDatabaseFromUser(database, username)) {
|
||||
throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username);
|
||||
}
|
||||
} catch (memgraph::dbms::UnknownDatabaseException &e) {
|
||||
throw QueryRuntimeException(e.what());
|
||||
}
|
||||
#else
|
||||
callback.fn = [] {
|
||||
#endif
|
||||
return std::vector<std::vector<TypedValue>>();
|
||||
};
|
||||
return callback;
|
||||
case AuthQuery::Action::SHOW_DATABASE_PRIVILEGES:
|
||||
callback.header = {"grants", "denies"};
|
||||
callback.fn = [auth, username] { // NOLINT
|
||||
#ifdef MG_ENTERPRISE
|
||||
return auth->GetDatabasePrivileges(username);
|
||||
#else
|
||||
return std::vector<std::vector<TypedValue>>();
|
||||
#endif
|
||||
};
|
||||
return callback;
|
||||
case AuthQuery::Action::SET_MAIN_DATABASE:
|
||||
#ifdef MG_ENTERPRISE
|
||||
callback.fn = [auth, database, username, &sc_handler] { // NOLINT
|
||||
try {
|
||||
const auto sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull
|
||||
if (!auth->SetMainDatabase(database, username)) {
|
||||
throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username);
|
||||
}
|
||||
} catch (memgraph::dbms::UnknownDatabaseException &e) {
|
||||
throw QueryRuntimeException(e.what());
|
||||
}
|
||||
#else
|
||||
callback.fn = [] {
|
||||
#endif
|
||||
return std::vector<std::vector<TypedValue>>();
|
||||
};
|
||||
return callback;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -1249,11 +1354,23 @@ bool IsWriteQueryOnMainMemoryReplica(storage::Storage *storage,
|
||||
return false;
|
||||
}
|
||||
|
||||
storage::replication::ReplicationRole GetReplicaRole(storage::Storage *storage) {
|
||||
if (auto storage_mode = storage->GetStorageMode(); storage_mode == storage::StorageMode::IN_MEMORY_ANALYTICAL ||
|
||||
storage_mode == storage::StorageMode::IN_MEMORY_TRANSACTIONAL) {
|
||||
auto *mem_storage = static_cast<storage::InMemoryStorage *>(storage);
|
||||
return mem_storage->GetReplicationRole();
|
||||
}
|
||||
return storage::replication::ReplicationRole::MAIN;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
InterpreterContext::InterpreterContext(const storage::Config storage_config, const InterpreterConfig interpreter_config,
|
||||
const std::filesystem::path &data_directory)
|
||||
: trigger_store(data_directory / "triggers"),
|
||||
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah,
|
||||
query::AuthChecker *ac)
|
||||
: auth(ah),
|
||||
auth_checker(ac),
|
||||
trigger_store(data_directory / "triggers"),
|
||||
config(interpreter_config),
|
||||
streams{this, data_directory / "streams"} {
|
||||
if (utils::DirExists(storage_config.disk.main_storage_directory)) {
|
||||
@ -1264,8 +1381,11 @@ InterpreterContext::InterpreterContext(const storage::Config storage_config, con
|
||||
}
|
||||
|
||||
InterpreterContext::InterpreterContext(std::unique_ptr<storage::Storage> db, InterpreterConfig interpreter_config,
|
||||
const std::filesystem::path &data_directory)
|
||||
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah,
|
||||
query::AuthChecker *ac)
|
||||
: db(std::move(db)),
|
||||
auth(ah),
|
||||
auth_checker(ac),
|
||||
trigger_store(data_directory / "triggers"),
|
||||
config(interpreter_config),
|
||||
streams{this, data_directory / "streams"} {}
|
||||
@ -2027,7 +2147,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
|
||||
|
||||
auto *auth_query = utils::Downcast<AuthQuery>(parsed_query.query);
|
||||
|
||||
auto callback = HandleAuthQuery(auth_query, interpreter_context->auth, parsed_query.parameters, dba);
|
||||
auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, dba);
|
||||
|
||||
SymbolTable symbol_table;
|
||||
std::vector<Symbol> output_symbols;
|
||||
@ -2668,7 +2788,7 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query,
|
||||
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
|
||||
|
||||
bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized(
|
||||
username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT});
|
||||
username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, "");
|
||||
|
||||
Callback callback;
|
||||
switch (transaction_query->action_) {
|
||||
@ -2770,6 +2890,7 @@ PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transa
|
||||
handler = [db, interpreter_isolation_level, next_transaction_isolation_level] {
|
||||
auto info = db->GetInfo();
|
||||
std::vector<std::vector<TypedValue>> results{
|
||||
{TypedValue("name"), TypedValue(db->id())},
|
||||
{TypedValue("vertex_count"), TypedValue(static_cast<int64_t>(info.vertex_count))},
|
||||
{TypedValue("edge_count"), TypedValue(static_cast<int64_t>(info.edge_count))},
|
||||
{TypedValue("average_degree"), TypedValue(info.average_degree)},
|
||||
@ -3118,6 +3239,250 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_
|
||||
RWType::NONE};
|
||||
}
|
||||
|
||||
PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explicit_transaction, bool in_explicit_db,
|
||||
InterpreterContext *interpreter_context, const std::string &session_uuid) {
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (!license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
throw QueryException("Trying to use enterprise feature without a valid license.");
|
||||
}
|
||||
// TODO: Remove once replicas support multi-tenant replication
|
||||
if (GetReplicaRole(interpreter_context->db.get()) == storage::replication::ReplicationRole::REPLICA) {
|
||||
throw QueryException("Query forbidden on the replica!");
|
||||
}
|
||||
if (in_explicit_transaction) {
|
||||
throw MultiDatabaseQueryInMulticommandTxException();
|
||||
}
|
||||
|
||||
auto *query = utils::Downcast<MultiDatabaseQuery>(parsed_query.query);
|
||||
auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context);
|
||||
|
||||
switch (query->action_) {
|
||||
case MultiDatabaseQuery::Action::CREATE:
|
||||
return PreparedQuery{
|
||||
{"STATUS"},
|
||||
std::move(parsed_query.required_privileges),
|
||||
[db_name = query->db_name_, session_uuid, &sc_handler](
|
||||
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
|
||||
std::vector<std::vector<TypedValue>> status;
|
||||
std::string res;
|
||||
|
||||
const auto success = sc_handler.New(db_name);
|
||||
if (success.HasError()) {
|
||||
switch (success.GetError()) {
|
||||
case dbms::NewError::EXISTS:
|
||||
res = db_name + " already exists.";
|
||||
break;
|
||||
case dbms::NewError::DEFUNCT:
|
||||
throw QueryRuntimeException(
|
||||
"{} is defunct and in an unknown state. Try to delete it again or clean up storage and restart "
|
||||
"Memgraph.",
|
||||
db_name);
|
||||
case dbms::NewError::GENERIC:
|
||||
throw QueryRuntimeException("Failed while creating {}", db_name);
|
||||
case dbms::NewError::NO_CONFIGS:
|
||||
throw QueryRuntimeException("No configuration found while trying to create {}", db_name);
|
||||
}
|
||||
} else {
|
||||
res = "Successfully created database " + db_name;
|
||||
}
|
||||
status.emplace_back(std::vector<TypedValue>{TypedValue(res)});
|
||||
auto pull_plan = std::make_shared<PullPlanVector>(std::move(status));
|
||||
if (pull_plan->Pull(stream, n)) {
|
||||
return QueryHandlerResult::COMMIT;
|
||||
}
|
||||
return std::nullopt;
|
||||
},
|
||||
RWType::W,
|
||||
"" // No target DB possible
|
||||
};
|
||||
|
||||
case MultiDatabaseQuery::Action::USE:
|
||||
if (in_explicit_db) {
|
||||
throw QueryException("Database switching is prohibited if session explicitly defines the used database");
|
||||
}
|
||||
return PreparedQuery{{"STATUS"},
|
||||
std::move(parsed_query.required_privileges),
|
||||
[db_name = query->db_name_, session_uuid, &sc_handler](
|
||||
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
|
||||
std::vector<std::vector<TypedValue>> status;
|
||||
std::string res;
|
||||
|
||||
memgraph::dbms::SetForResult set = memgraph::dbms::SetForResult::SUCCESS;
|
||||
|
||||
try {
|
||||
set = sc_handler.SetFor(session_uuid, db_name);
|
||||
} catch (const utils::BasicException &e) {
|
||||
throw QueryRuntimeException(e.what());
|
||||
}
|
||||
|
||||
switch (set) {
|
||||
case dbms::SetForResult::SUCCESS:
|
||||
res = "Using " + db_name;
|
||||
break;
|
||||
case dbms::SetForResult::ALREADY_SET:
|
||||
res = "Already using " + db_name;
|
||||
break;
|
||||
case dbms::SetForResult::FAIL:
|
||||
throw QueryRuntimeException("Failed to start using {}", db_name);
|
||||
}
|
||||
|
||||
status.emplace_back(std::vector<TypedValue>{TypedValue(res)});
|
||||
auto pull_plan = std::make_shared<PullPlanVector>(std::move(status));
|
||||
if (pull_plan->Pull(stream, n)) {
|
||||
return QueryHandlerResult::COMMIT;
|
||||
}
|
||||
return std::nullopt;
|
||||
},
|
||||
RWType::NONE,
|
||||
query->db_name_};
|
||||
|
||||
case MultiDatabaseQuery::Action::DROP:
|
||||
return PreparedQuery{
|
||||
{"STATUS"},
|
||||
std::move(parsed_query.required_privileges),
|
||||
[db_name = query->db_name_, session_uuid, &sc_handler](
|
||||
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
|
||||
std::vector<std::vector<TypedValue>> status;
|
||||
|
||||
memgraph::dbms::DeleteResult success{};
|
||||
|
||||
try {
|
||||
success = sc_handler.Delete(db_name);
|
||||
} catch (const utils::BasicException &e) {
|
||||
throw QueryRuntimeException(e.what());
|
||||
}
|
||||
|
||||
if (success.HasError()) {
|
||||
switch (success.GetError()) {
|
||||
case dbms::DeleteError::DEFAULT_DB:
|
||||
throw QueryRuntimeException("Cannot delete the default database.");
|
||||
case dbms::DeleteError::NON_EXISTENT:
|
||||
throw QueryRuntimeException("{} does not exist.", db_name);
|
||||
case dbms::DeleteError::USING:
|
||||
throw QueryRuntimeException("Cannot delete {}, it is currently being used.", db_name);
|
||||
case dbms::DeleteError::FAIL:
|
||||
throw QueryRuntimeException("Failed while deleting {}", db_name);
|
||||
case dbms::DeleteError::DISK_FAIL:
|
||||
throw QueryRuntimeException("Failed to clean storage of {}", db_name);
|
||||
}
|
||||
}
|
||||
|
||||
status.emplace_back(std::vector<TypedValue>{TypedValue("Successfully deleted " + db_name)});
|
||||
auto pull_plan = std::make_shared<PullPlanVector>(std::move(status));
|
||||
if (pull_plan->Pull(stream, n)) {
|
||||
return QueryHandlerResult::COMMIT;
|
||||
}
|
||||
return std::nullopt;
|
||||
},
|
||||
RWType::W,
|
||||
query->db_name_};
|
||||
}
|
||||
#else
|
||||
throw QueryException("Query not supported.");
|
||||
#endif
|
||||
}
|
||||
|
||||
PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context,
|
||||
const std::string &session_uuid, std::map<std::string, TypedValue> *summary,
|
||||
DbAccessor *dba, utils::MemoryResource *execution_memory,
|
||||
const std::optional<std::string> &username,
|
||||
std::atomic<TransactionStatus> *transaction_status,
|
||||
std::shared_ptr<utils::AsyncTimer> tx_timer) {
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (!license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
throw QueryException("Trying to use enterprise feature without a valid license.");
|
||||
}
|
||||
// TODO: Remove once replicas support multi-tenant replication
|
||||
if (GetReplicaRole(interpreter_context->db.get()) == storage::replication::ReplicationRole::REPLICA) {
|
||||
throw QueryException("SHOW DATABASES forbidden on the replica!");
|
||||
}
|
||||
|
||||
auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context);
|
||||
AuthQueryHandler *auth = interpreter_context->auth;
|
||||
|
||||
Callback callback;
|
||||
callback.header = {"Name", "Current"};
|
||||
callback.fn = [auth, session_uuid, &sc_handler, username]() mutable -> std::vector<std::vector<TypedValue>> {
|
||||
std::vector<std::vector<TypedValue>> status;
|
||||
const auto in_use = sc_handler.Current(session_uuid);
|
||||
bool found_current = false;
|
||||
|
||||
auto gen_status = [&]<typename T, typename K>(T all, K denied) {
|
||||
Sort(all);
|
||||
Sort(denied);
|
||||
|
||||
status.reserve(all.size());
|
||||
for (const auto &name : all) {
|
||||
TypedValue use("");
|
||||
if (!found_current && Same(name, in_use)) {
|
||||
use = TypedValue("*");
|
||||
found_current = true;
|
||||
}
|
||||
status.push_back({TypedValue(name), std::move(use)});
|
||||
}
|
||||
|
||||
// No denied databases (no need to filter them out)
|
||||
if (denied.empty()) return;
|
||||
|
||||
auto denied_itr = denied.begin();
|
||||
auto iter = std::remove_if(status.begin(), status.end(), [&denied_itr, &denied](auto &in) -> bool {
|
||||
while (denied_itr != denied.end() && denied_itr->ValueString() < in[0].ValueString()) ++denied_itr;
|
||||
return (denied_itr != denied.end() && denied_itr->ValueString() == in[0].ValueString());
|
||||
});
|
||||
status.erase(iter, status.end());
|
||||
};
|
||||
|
||||
if (!username) {
|
||||
// No user, return all
|
||||
gen_status(sc_handler.All(), std::vector<TypedValue>{});
|
||||
} else {
|
||||
// User has a subset of accessible dbs; this is synched with the SessionContextHandler
|
||||
const auto &db_priv = auth->GetDatabasePrivileges(*username);
|
||||
const auto &allowed = db_priv[0][0];
|
||||
const auto &denied = db_priv[0][1].ValueList();
|
||||
if (allowed.IsString() && allowed.ValueString() == auth::kAllDatabases) {
|
||||
// All databases are allowed
|
||||
gen_status(sc_handler.All(), denied);
|
||||
} else {
|
||||
gen_status(allowed.ValueList(), denied);
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_current) throw QueryRuntimeException("Missing current database!");
|
||||
return status;
|
||||
};
|
||||
|
||||
SymbolTable symbol_table;
|
||||
std::vector<Symbol> output_symbols;
|
||||
for (const auto &column : callback.header) {
|
||||
output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false"));
|
||||
}
|
||||
|
||||
auto plan = std::make_shared<CachedPlan>(std::make_unique<SingleNodeLogicalPlan>(
|
||||
std::make_unique<plan::OutputTable>(output_symbols,
|
||||
[fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }),
|
||||
0.0, AstStorage{}, symbol_table));
|
||||
|
||||
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
|
||||
execution_memory, username, transaction_status, std::move(tx_timer));
|
||||
|
||||
return PreparedQuery{
|
||||
callback.header, std::move(parsed_query.required_privileges),
|
||||
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
|
||||
summary](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
|
||||
if (pull_plan->Pull(stream, n, output_symbols, summary)) {
|
||||
return callback.should_abort_query ? QueryHandlerResult::ABORT : QueryHandlerResult::COMMIT;
|
||||
}
|
||||
return std::nullopt;
|
||||
},
|
||||
RWType::NONE,
|
||||
"" // No target DB
|
||||
};
|
||||
#else
|
||||
throw QueryException("Query not supported.");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::optional<uint64_t> Interpreter::GetTransactionId() const {
|
||||
if (db_accessor_) {
|
||||
return db_accessor_->GetTransactionId();
|
||||
@ -3146,7 +3511,8 @@ void Interpreter::RollbackTransaction() {
|
||||
|
||||
Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
|
||||
const std::map<std::string, storage::PropertyValue> ¶ms,
|
||||
const std::string *username, QueryExtras const &extras) {
|
||||
const std::string *username, QueryExtras const &extras,
|
||||
const std::string &session_uuid) {
|
||||
std::shared_ptr<utils::AsyncTimer> current_timer;
|
||||
if (!in_explicit_transaction_) {
|
||||
query_executions_.clear();
|
||||
@ -3177,7 +3543,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
|
||||
in_explicit_transaction_ ? static_cast<int>(query_executions_.size() - 1) : std::optional<int>{};
|
||||
|
||||
query_execution->prepared_query.emplace(PrepareTransactionQuery(trimmed_query, extras));
|
||||
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid};
|
||||
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid, {}};
|
||||
}
|
||||
|
||||
// Don't save BEGIN, COMMIT or ROLLBACK
|
||||
@ -3329,6 +3695,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
|
||||
} else if (utils::Downcast<TransactionQueueQuery>(parsed_query.query)) {
|
||||
prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, in_explicit_transaction_,
|
||||
interpreter_context_, &*execution_db_accessor_);
|
||||
} else if (utils::Downcast<MultiDatabaseQuery>(parsed_query.query)) {
|
||||
prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), in_explicit_transaction_, in_explicit_db_,
|
||||
interpreter_context_, session_uuid);
|
||||
} else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) {
|
||||
prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, session_uuid,
|
||||
&query_execution->summary, &*execution_db_accessor_,
|
||||
&query_execution->execution_memory_with_exception, username_,
|
||||
&transaction_status_, std::move(current_timer));
|
||||
} else {
|
||||
LOG_FATAL("Should not get here -- unknown query type!");
|
||||
}
|
||||
@ -3346,7 +3720,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
|
||||
throw QueryException("Write query forbidden on the replica!");
|
||||
}
|
||||
|
||||
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid};
|
||||
// Set the target db to the current db (some queries have different target from the current db)
|
||||
if (!query_execution->prepared_query->db) {
|
||||
query_execution->prepared_query->db = interpreter_context_->db->id();
|
||||
}
|
||||
query_execution->summary["db"] = *query_execution->prepared_query->db;
|
||||
|
||||
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid,
|
||||
query_execution->prepared_query->db};
|
||||
} catch (const utils::BasicException &) {
|
||||
memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery);
|
||||
AbortCommand(query_execution_ptr);
|
||||
|
@ -77,6 +77,24 @@ class AuthQueryHandler {
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
/// Return true if access revoked successfully
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0;
|
||||
|
||||
/// Return true if access granted successfully
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0;
|
||||
|
||||
/// Returns database access rights for the user
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) = 0;
|
||||
|
||||
/// Return true if main database set successfully
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual bool SetMainDatabase(const std::string &db, const std::string &username) = 0;
|
||||
#endif
|
||||
|
||||
/// Return false if the role already exists.
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual bool CreateRole(const std::string &rolename) = 0;
|
||||
@ -202,6 +220,7 @@ struct PreparedQuery {
|
||||
std::vector<AuthQuery::Privilege> privileges;
|
||||
std::function<std::optional<QueryHandlerResult>(AnyStream *stream, std::optional<int> n)> query_handler;
|
||||
plan::ReadWriteTypeChecker::RWType rw_type;
|
||||
std::optional<std::string> db{};
|
||||
};
|
||||
|
||||
/**
|
||||
@ -223,10 +242,12 @@ class Interpreter;
|
||||
/// TODO: andi decouple in a separate file why here?
|
||||
struct InterpreterContext {
|
||||
explicit InterpreterContext(storage::Config storage_config, InterpreterConfig interpreter_config,
|
||||
const std::filesystem::path &data_directory);
|
||||
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr,
|
||||
query::AuthChecker *ac = nullptr);
|
||||
|
||||
InterpreterContext(std::unique_ptr<storage::Storage> db, InterpreterConfig interpreter_config,
|
||||
const std::filesystem::path &data_directory);
|
||||
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr,
|
||||
query::AuthChecker *ac = nullptr);
|
||||
|
||||
std::unique_ptr<storage::Storage> db;
|
||||
|
||||
@ -240,8 +261,8 @@ struct InterpreterContext {
|
||||
std::optional<double> tsc_frequency{utils::GetTSCFrequency()};
|
||||
std::atomic<bool> is_shutting_down{false};
|
||||
|
||||
AuthQueryHandler *auth{nullptr};
|
||||
AuthChecker *auth_checker{nullptr};
|
||||
AuthQueryHandler *auth;
|
||||
AuthChecker *auth_checker;
|
||||
|
||||
utils::SkipList<QueryCacheEntry> ast_cache;
|
||||
utils::SkipList<PlanCacheEntry> plan_cache;
|
||||
@ -272,10 +293,12 @@ class Interpreter final {
|
||||
std::vector<std::string> headers;
|
||||
std::vector<query::AuthQuery::Privilege> privileges;
|
||||
std::optional<int> qid;
|
||||
std::optional<std::string> db;
|
||||
};
|
||||
|
||||
std::optional<std::string> username_;
|
||||
bool in_explicit_transaction_{false};
|
||||
bool in_explicit_db_{false};
|
||||
bool expect_rollback_{false};
|
||||
std::shared_ptr<utils::AsyncTimer> explicit_transaction_timer_{};
|
||||
std::optional<std::map<std::string, storage::PropertyValue>> metadata_{}; //!< User defined transaction metadata
|
||||
@ -289,7 +312,8 @@ class Interpreter final {
|
||||
* @throw query::QueryException
|
||||
*/
|
||||
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms,
|
||||
const std::string *username, QueryExtras const &extras = {});
|
||||
const std::string *username, QueryExtras const &extras = {},
|
||||
const std::string &session_uuid = {});
|
||||
|
||||
/**
|
||||
* Execute the last prepared query and stream *all* of the results into the
|
||||
|
@ -520,7 +520,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
|
||||
spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name);
|
||||
auto prepare_result =
|
||||
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), nullptr);
|
||||
if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges)) {
|
||||
if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges, "")) {
|
||||
throw StreamsException{
|
||||
"Couldn't execute query '{}' for stream '{}' because the owner is not authorized to execute the "
|
||||
"query!",
|
||||
|
@ -187,7 +187,7 @@ std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
|
||||
|
||||
trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers));
|
||||
}
|
||||
if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges)) {
|
||||
if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges, "")) {
|
||||
throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_);
|
||||
}
|
||||
return trigger_plan_;
|
||||
|
@ -16,9 +16,15 @@
|
||||
#include <filesystem>
|
||||
#include "storage/v2/isolation_level.hpp"
|
||||
#include "storage/v2/transaction.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
|
||||
namespace memgraph::storage {
|
||||
|
||||
/// Exception used to signal configuration errors.
|
||||
class StorageConfigException : public utils::BasicException {
|
||||
using utils::BasicException::BasicException;
|
||||
};
|
||||
|
||||
/// Pass this class to the \ref Storage constructor to change the behavior of
|
||||
/// the storage. This class also defines the default behavior.
|
||||
struct Config {
|
||||
@ -62,15 +68,49 @@ struct Config {
|
||||
} transaction;
|
||||
|
||||
struct DiskConfig {
|
||||
std::filesystem::path main_storage_directory{"rocksdb_main_storage"};
|
||||
std::filesystem::path label_index_directory{"rocksdb_label_index"};
|
||||
std::filesystem::path label_property_index_directory{"rocksdb_label_property_index"};
|
||||
std::filesystem::path unique_constraints_directory{"rocksdb_unique_constraints"};
|
||||
std::filesystem::path name_id_mapper_directory{"rocksdb_name_id_mapper"};
|
||||
std::filesystem::path id_name_mapper_directory{"rocksdb_id_name_mapper"};
|
||||
std::filesystem::path durability_directory{"rocksdb_durability"};
|
||||
std::filesystem::path wal_directory{"rocksdb_wal"};
|
||||
std::filesystem::path main_storage_directory{"storage/rocksdb_main_storage"};
|
||||
std::filesystem::path label_index_directory{"storage/rocksdb_label_index"};
|
||||
std::filesystem::path label_property_index_directory{"storage/rocksdb_label_property_index"};
|
||||
std::filesystem::path unique_constraints_directory{"storage/rocksdb_unique_constraints"};
|
||||
std::filesystem::path name_id_mapper_directory{"storage/rocksdb_name_id_mapper"};
|
||||
std::filesystem::path id_name_mapper_directory{"storage/rocksdb_id_name_mapper"};
|
||||
std::filesystem::path durability_directory{"storage/rocksdb_durability"};
|
||||
std::filesystem::path wal_directory{"storage/rocksdb_wal"};
|
||||
} disk;
|
||||
|
||||
std::string name;
|
||||
};
|
||||
|
||||
static inline void UpdatePaths(Config &config, const std::filesystem::path &storage_dir) {
|
||||
auto contained = [](const auto &path, const auto &base) -> std::optional<std::filesystem::path> {
|
||||
auto rel = std::filesystem::relative(path, base);
|
||||
if (!rel.empty() && rel.native()[0] != '.') { // Contained
|
||||
return rel;
|
||||
}
|
||||
return {};
|
||||
};
|
||||
|
||||
const auto old_base =
|
||||
std::filesystem::weakly_canonical(std::filesystem::absolute(config.durability.storage_directory));
|
||||
config.durability.storage_directory = std::filesystem::weakly_canonical(std::filesystem::absolute(storage_dir));
|
||||
|
||||
auto UPDATE_PATH = [&](auto to_update) {
|
||||
const auto old_path = std::filesystem::weakly_canonical(std::filesystem::absolute(to_update(config.disk)));
|
||||
const auto contained_path = contained(old_path, old_base);
|
||||
if (!contained_path) {
|
||||
throw StorageConfigException("On-disk directories not contained in root.");
|
||||
}
|
||||
to_update(config.disk) = config.durability.storage_directory / *contained_path;
|
||||
};
|
||||
|
||||
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::main_storage_directory));
|
||||
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::label_index_directory));
|
||||
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::label_property_index_directory));
|
||||
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::unique_constraints_directory));
|
||||
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::name_id_mapper_directory));
|
||||
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::id_name_mapper_directory));
|
||||
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::durability_directory));
|
||||
UPDATE_PATH(std::mem_fn(&Config::DiskConfig::wal_directory));
|
||||
}
|
||||
|
||||
} // namespace memgraph::storage
|
||||
|
@ -50,7 +50,8 @@ Storage::Storage(Config config, StorageMode storage_mode)
|
||||
isolation_level_(config.transaction.isolation_level),
|
||||
storage_mode_(storage_mode),
|
||||
indices_(&constraints_, config, storage_mode),
|
||||
constraints_(config, storage_mode) {}
|
||||
constraints_(config, storage_mode),
|
||||
id_(config.name) {}
|
||||
|
||||
Storage::Accessor::Accessor(Storage *storage, IsolationLevel isolation_level, StorageMode storage_mode)
|
||||
: storage_(storage),
|
||||
|
@ -74,6 +74,8 @@ class Storage {
|
||||
|
||||
virtual ~Storage() {}
|
||||
|
||||
const std::string &id() const { return id_; }
|
||||
|
||||
class Accessor {
|
||||
public:
|
||||
Accessor(Storage *storage, IsolationLevel isolation_level, StorageMode storage_mode);
|
||||
@ -179,6 +181,8 @@ class Storage {
|
||||
|
||||
StorageMode GetCreationStorageMode() const;
|
||||
|
||||
const std::string &id() const { return storage_->id(); }
|
||||
|
||||
protected:
|
||||
Storage *storage_;
|
||||
std::shared_lock<utils::RWLock> storage_guard_;
|
||||
@ -301,6 +305,7 @@ class Storage {
|
||||
|
||||
std::atomic<uint64_t> vertex_id_{0};
|
||||
std::atomic<uint64_t> edge_id_{0};
|
||||
const std::string id_; //!< High-level assigned ID
|
||||
};
|
||||
|
||||
} // namespace memgraph::storage
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -27,6 +27,7 @@ inline uint64_t GetDirDiskUsage(const std::filesystem::path &path) {
|
||||
|
||||
uint64_t size = 0;
|
||||
for (auto &p : std::filesystem::directory_iterator(path)) {
|
||||
if (std::filesystem::is_symlink(p)) continue;
|
||||
if (std::filesystem::is_directory(p)) {
|
||||
size += GetDirDiskUsage(p);
|
||||
} else if (std::filesystem::is_regular_file(p)) {
|
||||
|
189
src/utils/sync_ptr.hpp
Normal file
189
src/utils/sync_ptr.hpp
Normal file
@ -0,0 +1,189 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "utils/exceptions.hpp"
|
||||
|
||||
namespace memgraph::utils {
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @tparam TContext
|
||||
* @tparam TConfig
|
||||
*/
|
||||
template <typename TContext, typename TConfig = void>
|
||||
struct SyncPtr {
|
||||
/**
|
||||
* @brief Construct a new synched pointer.
|
||||
*
|
||||
* @tparam TArgs variable templates used by the TContext constructor
|
||||
* @param config Additional metadata associated with context
|
||||
* @param args Arguments to pass to TContext constructor
|
||||
*/
|
||||
template <typename... TArgs>
|
||||
explicit SyncPtr(TConfig config, TArgs &&...args)
|
||||
: timeout_{1000}, config_{config}, ptr_{new TContext(std::forward<TArgs>(args)...), [this](TContext *ptr) {
|
||||
this->OnDelete(ptr);
|
||||
}} {}
|
||||
|
||||
~SyncPtr() = default;
|
||||
|
||||
SyncPtr(const SyncPtr &) = delete;
|
||||
SyncPtr &operator=(const SyncPtr &) = delete;
|
||||
SyncPtr(SyncPtr &&) noexcept = delete;
|
||||
SyncPtr &operator=(SyncPtr &&) noexcept = delete;
|
||||
|
||||
/**
|
||||
* @brief Destroy the synched pointer and wait for all copies to get destroyed.
|
||||
*
|
||||
*/
|
||||
void DestroyAndSync() {
|
||||
ptr_.reset();
|
||||
SyncOnDelete();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get (copy) the underlying shared pointer.
|
||||
*
|
||||
* @return std::shared_ptr<TContext>
|
||||
*/
|
||||
std::shared_ptr<TContext> get() { return ptr_; }
|
||||
std::shared_ptr<const TContext> get() const { return ptr_; }
|
||||
|
||||
/**
|
||||
* @brief Return saved configuration (metadata)
|
||||
*
|
||||
* @return TConfig
|
||||
*/
|
||||
TConfig config() { return config_; }
|
||||
const TConfig &config() const { return config_; }
|
||||
|
||||
void timeout(const std::chrono::milliseconds to) { timeout_ = to; }
|
||||
std::chrono::milliseconds timeout() const { return timeout_; }
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Block until OnDelete gets called.
|
||||
*
|
||||
*/
|
||||
void SyncOnDelete() {
|
||||
std::unique_lock<std::mutex> lock(in_use_mtx_);
|
||||
if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) {
|
||||
throw utils::BasicException("Syncronization timeout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Custom destructor used to sync the shared_ptr release.
|
||||
*
|
||||
* @param p Pointer to the undelying object.
|
||||
*/
|
||||
void OnDelete(TContext *p) {
|
||||
delete p;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(in_use_mtx_);
|
||||
in_use_ = false;
|
||||
}
|
||||
in_use_cv_.notify_all();
|
||||
}
|
||||
|
||||
bool in_use_{true}; //!< Flag used to signal sync
|
||||
mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync
|
||||
mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync
|
||||
std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms
|
||||
TConfig config_; //!< Additional metadata associated with the context
|
||||
std::shared_ptr<TContext> ptr_; //!< Pointer being synced
|
||||
};
|
||||
|
||||
template <typename TContext>
|
||||
class SyncPtr<TContext, void> {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new synched pointer.
|
||||
*
|
||||
* @tparam TArgs variable templates used by the TContext constructor
|
||||
* @param config Additional metadata associated with context
|
||||
* @param args Arguments to pass to TContext constructor
|
||||
*/
|
||||
template <typename... TArgs>
|
||||
explicit SyncPtr(TArgs &&...args)
|
||||
: timeout_{1000}, ptr_{new TContext(std::forward<TArgs>(args)...), [this](TContext *ptr) {
|
||||
this->OnDelete(ptr);
|
||||
}} {}
|
||||
|
||||
~SyncPtr() = default;
|
||||
|
||||
SyncPtr(const SyncPtr &) = delete;
|
||||
SyncPtr &operator=(const SyncPtr &) = delete;
|
||||
SyncPtr(SyncPtr &&) noexcept = delete;
|
||||
SyncPtr &operator=(SyncPtr &&) noexcept = delete;
|
||||
|
||||
/**
|
||||
* @brief Destroy the synched pointer and wait for all copies to get destroyed.
|
||||
*
|
||||
*/
|
||||
void DestroyAndSync() {
|
||||
ptr_.reset();
|
||||
SyncOnDelete();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get (copy) the underlying shared pointer.
|
||||
*
|
||||
* @return std::shared_ptr<TContext>
|
||||
*/
|
||||
std::shared_ptr<TContext> get() { return ptr_; }
|
||||
std::shared_ptr<const TContext> get() const { return ptr_; }
|
||||
|
||||
void timeout(const std::chrono::milliseconds to) { timeout_ = to; }
|
||||
std::chrono::milliseconds timeout() const { return timeout_; }
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Block until OnDelete gets called.
|
||||
*
|
||||
*/
|
||||
void SyncOnDelete() {
|
||||
std::unique_lock<std::mutex> lock(in_use_mtx_);
|
||||
if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) {
|
||||
throw utils::BasicException("Syncronization timeout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Custom destructor used to sync the shared_ptr release.
|
||||
*
|
||||
* @param p Pointer to the undelying object.
|
||||
*/
|
||||
void OnDelete(TContext *p) {
|
||||
delete p;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(in_use_mtx_);
|
||||
in_use_ = false;
|
||||
}
|
||||
in_use_cv_.notify_all();
|
||||
}
|
||||
|
||||
bool in_use_{true}; //!< Flag used to signal sync
|
||||
mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync
|
||||
mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync
|
||||
std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms
|
||||
std::shared_ptr<TContext> ptr_; //!< Pointer being synced
|
||||
};
|
||||
|
||||
} // namespace memgraph::utils
|
@ -184,6 +184,8 @@ enum class TypeId : uint64_t {
|
||||
AST_TRANSACTION_QUEUE_QUERY,
|
||||
AST_EXISTS,
|
||||
AST_CALL_SUBQUERY,
|
||||
AST_MULTI_DATABASE_QUERY,
|
||||
AST_SHOW_DATABASES,
|
||||
// Symbol
|
||||
SYMBOL,
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -30,10 +30,10 @@ TEST(Network, Server) {
|
||||
std::cout << endpoint << std::endl;
|
||||
|
||||
// initialize server
|
||||
TestData session_data;
|
||||
TestData session_context;
|
||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||
ContextT context;
|
||||
ServerT server(endpoint, &session_data, &context, -1, "Test", N);
|
||||
ServerT server(endpoint, &session_context, &context, -1, "Test", N);
|
||||
ASSERT_TRUE(server.Start());
|
||||
|
||||
const auto &ep = server.endpoint();
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -32,9 +32,9 @@ TEST(Network, SessionLeak) {
|
||||
Endpoint endpoint(interface, 0);
|
||||
|
||||
// initialize server
|
||||
TestData session_data;
|
||||
TestData session_context;
|
||||
ContextT context;
|
||||
ServerT server(endpoint, &session_data, &context, -1, "Test", 2);
|
||||
ServerT server(endpoint, &session_context, &context, -1, "Test", 2);
|
||||
ASSERT_TRUE(server.Start());
|
||||
|
||||
// start clients
|
||||
|
@ -25,9 +25,14 @@ def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}
|
||||
def connect(**kwargs) -> mgclient.Connection:
|
||||
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
|
||||
connection.autocommit = True
|
||||
yield connection
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, "USE DATABASE memgraph")
|
||||
try:
|
||||
execute_and_fetch_all(cursor, "DROP DATABASE clean")
|
||||
except:
|
||||
pass
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
yield connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -21,6 +21,18 @@ QUERY_PLAN = "QUERY PLAN"
|
||||
# ------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def multi_db(request, connect):
|
||||
cursor = connect.cursor()
|
||||
if request.param:
|
||||
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "USE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
pass
|
||||
yield connect
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"delete_query",
|
||||
[
|
||||
@ -30,9 +42,9 @@ QUERY_PLAN = "QUERY PLAN"
|
||||
"ANALYZE GRAPH ON LABELS :Label, :NONEXISTING DELETE STATISTICS",
|
||||
],
|
||||
)
|
||||
def test_analyze_graph_delete_statistics(delete_query, connect):
|
||||
def test_analyze_graph_delete_statistics(delete_query, multi_db):
|
||||
"""Tests that all variants of delete queries work as expected."""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));")
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i % 5}));")
|
||||
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
|
||||
@ -62,6 +74,7 @@ def test_analyze_graph_delete_statistics(delete_query, connect):
|
||||
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"analyze_query",
|
||||
[
|
||||
@ -71,11 +84,11 @@ def test_analyze_graph_delete_statistics(delete_query, connect):
|
||||
"ANALYZE GRAPH ON LABELS :Label, :NONEXISTING",
|
||||
],
|
||||
)
|
||||
def test_analyze_full_graph(analyze_query, connect):
|
||||
def test_analyze_full_graph(analyze_query, multi_db):
|
||||
"""Tests analyzing full graph and choosing better index based on the smaller average group size.
|
||||
It also tests querying based on labels and that nothing bad will happen by providing non-existing label.
|
||||
"""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));")
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i % 5}));")
|
||||
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
|
||||
@ -121,9 +134,10 @@ def test_analyze_full_graph(analyze_query, connect):
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def test_cardinality_different_avg_group_size_uniform_dist(connect):
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_cardinality_different_avg_group_size_uniform_dist(multi_db):
|
||||
"""Tests index optimization with indices both having uniform distribution but one has smaller avg. group size."""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));")
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id2: i % 20}));")
|
||||
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
|
||||
@ -151,9 +165,10 @@ def test_cardinality_different_avg_group_size_uniform_dist(connect):
|
||||
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
|
||||
|
||||
|
||||
def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(connect):
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(multi_db):
|
||||
"""Tests index choosing where both indices have uniform key distribution with same avg. group size but one has less vertices."""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));")
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i}));")
|
||||
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
|
||||
@ -181,9 +196,10 @@ def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(connect)
|
||||
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
|
||||
|
||||
|
||||
def test_large_diff_in_num_vertices_v1(connect):
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_large_diff_in_num_vertices_v1(multi_db):
|
||||
"""Tests that when one index has > 10x vertices than the other one, it should be chosen no matter avg group size and uniform distribution."""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 1000) | CREATE (n:Label {id1: i}));")
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 99) | CREATE (n:Label {id2: 1}));")
|
||||
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
|
||||
@ -211,9 +227,10 @@ def test_large_diff_in_num_vertices_v1(connect):
|
||||
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
|
||||
|
||||
|
||||
def test_large_diff_in_num_vertices_v2(connect):
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_large_diff_in_num_vertices_v2(multi_db):
|
||||
"""Tests that when one index has > 10x vertices than the other one, it should be chosen no matter avg group size and uniform distribution."""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 99) | CREATE (n:Label {id1: 1}));")
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 1000) | CREATE (n:Label {id2: i}));")
|
||||
execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);")
|
||||
@ -241,9 +258,10 @@ def test_large_diff_in_num_vertices_v2(connect):
|
||||
execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);")
|
||||
|
||||
|
||||
def test_same_avg_group_size_diff_distribution(connect):
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_same_avg_group_size_diff_distribution(multi_db):
|
||||
"""Tests index choice decision based on key distribution."""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
# Setup first key distribution
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 10) | CREATE (n:Label {id1: 1}));")
|
||||
execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 30) | CREATE (n:Label {id1: 2}));")
|
||||
|
@ -66,6 +66,11 @@ startup_config_dict = {
|
||||
"Time in seconds after which inactive Bolt sessions will be closed.",
|
||||
),
|
||||
"data_directory": ("mg_data", "mg_data", "Path to directory in which to save all permanent data."),
|
||||
"data_recovery_on_startup": (
|
||||
"false",
|
||||
"false",
|
||||
"Controls whether the database recovers persisted data on startup.",
|
||||
),
|
||||
"isolation_level": (
|
||||
"SNAPSHOT_ISOLATION",
|
||||
"SNAPSHOT_ISOLATION",
|
||||
@ -133,11 +138,6 @@ startup_config_dict = {
|
||||
"The number of edges and vertices stored in a batch in a snapshot file.",
|
||||
),
|
||||
"storage_properties_on_edges": ("false", "true", "Controls whether edges have properties."),
|
||||
"storage_recover_on_startup": (
|
||||
"false",
|
||||
"false",
|
||||
"Controls whether the storage recovers persisted data on startup.",
|
||||
),
|
||||
"storage_recovery_thread_count": ("12", "12", "The number of threads used to recover persisted data from disk."),
|
||||
"storage_snapshot_interval_sec": (
|
||||
"0",
|
||||
@ -157,6 +157,11 @@ startup_config_dict = {
|
||||
"Issue a 'fsync' call after this amount of transactions are written to the WAL file. Set to 1 for fully synchronous operation.",
|
||||
),
|
||||
"storage_wal_file_size_kib": ("20480", "20480", "Minimum file size of each WAL file."),
|
||||
"storage_delete_on_drop": (
|
||||
"true",
|
||||
"true",
|
||||
"If set to true the query 'DROP DATABASE x' will delete the underlying storage as well.",
|
||||
),
|
||||
"stream_transaction_conflict_retries": (
|
||||
"30",
|
||||
"30",
|
||||
|
@ -16,6 +16,7 @@ import mgclient
|
||||
import pytest
|
||||
|
||||
default_storage_info_dict = {
|
||||
"name": "memgraph",
|
||||
"vertex_count": 0,
|
||||
"edge_count": 0,
|
||||
"average_degree": 0,
|
||||
@ -55,7 +56,7 @@ def test_does_default_config_match():
|
||||
machine_dependent_configurations = ["memory_usage", "disk_usage", "memory_allocated", "allocation_limit"]
|
||||
|
||||
# Number of different data-points returned by SHOW STORAGE INFO
|
||||
assert len(config) == 11
|
||||
assert len(config) == 12
|
||||
|
||||
for conf in config:
|
||||
conf_name = conf[0]
|
||||
|
@ -6,3 +6,4 @@ copy_fine_grained_access_e2e_python_files(common.py)
|
||||
copy_fine_grained_access_e2e_python_files(create_delete_filtering_tests.py)
|
||||
copy_fine_grained_access_e2e_python_files(edge_type_filtering_tests.py)
|
||||
copy_fine_grained_access_e2e_python_files(path_filtering_tests.py)
|
||||
copy_fine_grained_access_e2e_python_files(show_db.py)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2021 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -12,6 +12,23 @@
|
||||
import mgclient
|
||||
|
||||
|
||||
def switch_db(cursor):
|
||||
execute_and_fetch_all(cursor, "USE DATABASE clean;")
|
||||
|
||||
|
||||
def create_multi_db(cursor, switch):
|
||||
execute_and_fetch_all(cursor, "USE DATABASE memgraph;")
|
||||
try:
|
||||
execute_and_fetch_all(cursor, "DROP DATABASE clean;")
|
||||
except:
|
||||
pass
|
||||
execute_and_fetch_all(cursor, "CREATE DATABASE clean;")
|
||||
if switch:
|
||||
switch_db(cursor)
|
||||
reset_and_prepare(cursor)
|
||||
execute_and_fetch_all(cursor, "USE DATABASE memgraph;")
|
||||
|
||||
|
||||
def reset_and_prepare(admin_cursor):
|
||||
execute_and_fetch_all(admin_cursor, "REVOKE LABELS * FROM user;")
|
||||
execute_and_fetch_all(admin_cursor, "REVOKE EDGE_TYPES * FROM user;")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -16,168 +16,224 @@ import pytest
|
||||
from mgclient import DatabaseError
|
||||
|
||||
|
||||
def test_create_node_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_node_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_create_node_all_labels_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_node_all_labels_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;")
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
|
||||
|
||||
|
||||
def test_create_node_specific_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_node_specific_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_create_node_specific_label_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_node_specific_label_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;")
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
|
||||
|
||||
|
||||
def test_delete_node_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_node_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
|
||||
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) RETURN n;")
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) RETURN n;")
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_delete_node_all_labels_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_node_all_labels_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n")
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
|
||||
|
||||
|
||||
def test_delete_node_specific_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_node_specific_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;")
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(admin_connection.cursor())
|
||||
results = common.execute_and_fetch_all(admin_connection.cursor(), "MATCH (n:test_delete) RETURN n;")
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_delete_node_specific_label_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_node_specific_label_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;")
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
|
||||
|
||||
|
||||
def test_create_edge_all_labels_all_edge_types_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_edge_all_labels_all_edge_types_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_create_edge_all_labels_all_edge_types_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_edge_all_labels_all_edge_types_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_create_edge_all_labels_denied_all_edge_types_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_edge_all_labels_denied_all_edge_types_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_create_edge_all_labels_granted_all_edge_types_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_edge_all_labels_granted_all_edge_types_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_create_edge_all_labels_granted_specific_edge_types_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_edge_all_labels_granted_specific_edge_types_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
admin_connection.cursor(),
|
||||
"GRANT UPDATE ON EDGE_TYPES :edge_type TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_create_edge_first_node_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_edge_first_node_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label2 TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -185,17 +241,21 @@ def test_create_edge_first_node_label_granted():
|
||||
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_create_edge_second_node_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_create_edge_second_node_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -203,62 +263,78 @@ def test_create_edge_second_node_label_granted():
|
||||
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_edge_all_labels_denied_all_edge_types_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_edge_all_labels_denied_all_edge_types_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_edge_all_labels_granted_all_edge_types_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_edge_all_labels_granted_all_edge_types_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_edge_all_labels_granted_specific_edge_types_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
admin_connection.cursor(),
|
||||
"GRANT UPDATE ON EDGE_TYPES :edge_type_delete TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_edge_first_node_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_edge_first_node_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete_2 TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -266,17 +342,21 @@ def test_delete_edge_first_node_label_granted():
|
||||
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_edge_second_node_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_edge_second_node_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete_1 TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -284,159 +364,209 @@ def test_delete_edge_second_node_label_granted():
|
||||
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_node_with_edge_label_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_node_with_edge_label_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(
|
||||
admin_connection.cursor(),
|
||||
"GRANT UPDATE ON LABELS :test_delete_1 TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n) DETACH DELETE n;")
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n) DETACH DELETE n;")
|
||||
|
||||
|
||||
def test_delete_node_with_edge_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_delete_node_with_edge_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(
|
||||
admin_connection.cursor(),
|
||||
"GRANT CREATE_DELETE ON LABELS :test_delete_1 TO user;",
|
||||
)
|
||||
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n) DETACH DELETE n;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n) DETACH DELETE n;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(admin_connection.cursor())
|
||||
results = common.execute_and_fetch_all(admin_connection.cursor(), "MATCH (n:test_delete_1) RETURN n;")
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_merge_node_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_node_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_merge_node_all_labels_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_node_all_labels_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;")
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
|
||||
|
||||
|
||||
def test_merge_node_specific_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_node_specific_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_merge_node_specific_label_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_node_specific_label_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;")
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
|
||||
|
||||
|
||||
def test_merge_edge_all_labels_all_edge_types_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_edge_all_labels_all_edge_types_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_merge_edge_all_labels_all_edge_types_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_edge_all_labels_all_edge_types_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_merge_edge_all_labels_denied_all_edge_types_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_edge_all_labels_denied_all_edge_types_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_merge_edge_all_labels_granted_all_edge_types_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_edge_all_labels_granted_all_edge_types_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_merge_edge_all_labels_granted_specific_edge_types_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
admin_connection.cursor(),
|
||||
"GRANT UPDATE ON EDGE_TYPES :edge_type TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_merge_edge_first_node_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_edge_first_node_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label2 TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -444,17 +574,21 @@ def test_merge_edge_first_node_label_granted():
|
||||
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_merge_edge_second_node_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_edge_second_node_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -462,64 +596,86 @@ def test_merge_edge_second_node_label_granted():
|
||||
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
|
||||
)
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;",
|
||||
)
|
||||
|
||||
|
||||
def test_set_label_when_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_set_label_when_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :update_label_2 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) SET p:update_label_2;")
|
||||
|
||||
|
||||
def test_set_label_when_label_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_set_label_when_label_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) SET p:update_label_2;")
|
||||
|
||||
|
||||
def test_remove_label_when_label_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_remove_label_when_label_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) REMOVE p:test_delete;")
|
||||
|
||||
|
||||
def test_remove_label_when_label_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_remove_label_when_label_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
with pytest.raises(DatabaseError):
|
||||
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) REMOVE p:test_delete;")
|
||||
|
||||
|
||||
def test_merge_nodes_pass_when_having_create_delete():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_merge_nodes_pass_when_having_create_delete(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
|
||||
common.reset_and_prepare(admin_connection.cursor())
|
||||
common.create_multi_db(admin_connection.cursor(), switch)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(
|
||||
user_connection.cursor(),
|
||||
"UNWIND [{id: '1', lat: 10, lng: 10}, {id: '2', lat: 10, lng: 10}, {id: '3', lat: 10, lng: 10}] AS row MERGE (o:Location {id: row.id}) RETURN o;",
|
||||
|
@ -1,91 +1,116 @@
|
||||
import common
|
||||
import sys
|
||||
|
||||
import common
|
||||
import pytest
|
||||
|
||||
|
||||
def test_all_edge_types_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_all_edge_types_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
def test_deny_all_edge_types_and_all_labels():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_deny_all_edge_types_and_all_labels(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
|
||||
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_revoke_all_edge_types_and_all_labels():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_revoke_all_edge_types_and_all_labels(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_deny_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_deny_edge_type(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1, :label2, :label3 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edgeType1 TO user;")
|
||||
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
def test_denied_node_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_denied_node_label(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label3 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label2 TO user;")
|
||||
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
def test_denied_one_of_node_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_denied_one_of_node_label(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label3 TO user;")
|
||||
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_revoke_all_labels():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_revoke_all_labels(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_revoke_all_edge_types():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_revoke_all_edge_types(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
@ -1,22 +1,26 @@
|
||||
import common
|
||||
import sys
|
||||
|
||||
import common
|
||||
import pytest
|
||||
|
||||
|
||||
def test_weighted_shortest_path_all_edge_types_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
total_paths_results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
path_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length,nodes(p);",
|
||||
)
|
||||
|
||||
@ -47,24 +51,28 @@ def test_weighted_shortest_path_all_edge_types_all_labels_granted():
|
||||
assert all(node.id in expected_path for node in path_result[0][1])
|
||||
|
||||
|
||||
def test_weighted_shortest_path_all_edge_types_all_labels_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN p;"
|
||||
user_connection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_weighted_shortest_path_denied_start():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_weighted_shortest_path_denied_start(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -73,17 +81,20 @@ def test_weighted_shortest_path_denied_start():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
path_length_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;",
|
||||
)
|
||||
|
||||
assert len(path_length_result) == 0
|
||||
|
||||
|
||||
def test_weighted_shortest_path_denied_destination():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_weighted_shortest_path_denied_destination(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -92,17 +103,20 @@ def test_weighted_shortest_path_denied_destination():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
path_length_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;",
|
||||
)
|
||||
|
||||
assert len(path_length_result) == 0
|
||||
|
||||
|
||||
def test_weighted_shortest_path_denied_label_1():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_weighted_shortest_path_denied_label_1(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -111,13 +125,15 @@ def test_weighted_shortest_path_denied_label_1():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
total_paths_results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
path_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);",
|
||||
)
|
||||
|
||||
@ -143,9 +159,10 @@ def test_weighted_shortest_path_denied_label_1():
|
||||
assert all(node.id in expected_path for node in path_result[0][1])
|
||||
|
||||
|
||||
def test_weighted_shortest_path_denied_edge_type_3():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_weighted_shortest_path_denied_edge_type_3(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
@ -154,13 +171,15 @@ def test_weighted_shortest_path_denied_edge_type_3():
|
||||
)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
path_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);",
|
||||
)
|
||||
|
||||
total_paths_results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
@ -191,16 +210,19 @@ def test_weighted_shortest_path_denied_edge_type_3():
|
||||
assert all(node.id in expected_path for node in path_result[0][1])
|
||||
|
||||
|
||||
def test_dfs_all_edge_types_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_dfs_all_edge_types_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_paths = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH path=(n:label0)-[* 1..3]->(m:label4) RETURN extract( node in nodes(path) | node.id);",
|
||||
)
|
||||
|
||||
@ -210,22 +232,26 @@ def test_dfs_all_edge_types_all_labels_granted():
|
||||
assert all(path[0] in expected_paths for path in source_destination_paths)
|
||||
|
||||
|
||||
def test_dfs_all_edge_types_all_labels_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_dfs_all_edge_types_all_labels_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
|
||||
|
||||
total_paths_results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH p=(n)-[*]->(m) RETURN p;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
total_paths_results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH p=(n)-[*]->(m) RETURN p;")
|
||||
|
||||
assert len(total_paths_results) == 0
|
||||
|
||||
|
||||
def test_dfs_denied_start():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_dfs_denied_start(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -234,16 +260,19 @@ def test_dfs_denied_start():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;"
|
||||
user_connection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(source_destination_path) == 0
|
||||
|
||||
|
||||
def test_dfs_denied_destination():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_dfs_denied_destination(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -252,16 +281,19 @@ def test_dfs_denied_destination():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;"
|
||||
user_connection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(source_destination_path) == 0
|
||||
|
||||
|
||||
def test_dfs_denied_label_1():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_dfs_denied_label_1(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -269,8 +301,11 @@ def test_dfs_denied_label_1():
|
||||
)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_paths = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[* 1..3]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
@ -280,9 +315,10 @@ def test_dfs_denied_label_1():
|
||||
assert all(path[0] in expected_paths for path in source_destination_paths)
|
||||
|
||||
|
||||
def test_dfs_denied_edge_type_3():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_dfs_denied_edge_type_3(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
|
||||
@ -292,8 +328,10 @@ def test_dfs_denied_edge_type_3():
|
||||
)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r * 1..3]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
@ -303,16 +341,19 @@ def test_dfs_denied_edge_type_3():
|
||||
assert source_destination_path[0][0] == expected_path
|
||||
|
||||
|
||||
def test_bfs_sts_all_edge_types_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_sts_all_edge_types_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
@ -322,24 +363,28 @@ def test_bfs_sts_all_edge_types_all_labels_granted():
|
||||
assert source_destination_path[0][0] == expected_path
|
||||
|
||||
|
||||
def test_bfs_sts_all_edge_types_all_labels_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_sts_all_edge_types_all_labels_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
total_paths_results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n)-[r *BFS]->(m) RETURN p;"
|
||||
user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n)-[r *BFS]->(m) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(total_paths_results) == 0
|
||||
|
||||
|
||||
def test_bfs_sts_denied_start():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_sts_denied_start(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -348,16 +393,19 @@ def test_bfs_sts_denied_start():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
|
||||
user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(source_destination_path) == 0
|
||||
|
||||
|
||||
def test_bfs_sts_denied_destination():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_sts_denied_destination(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -366,16 +414,19 @@ def test_bfs_sts_denied_destination():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
|
||||
user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(source_destination_path) == 0
|
||||
|
||||
|
||||
def test_bfs_sts_denied_label_1():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_sts_denied_label_1(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -383,8 +434,11 @@ def test_bfs_sts_denied_label_1():
|
||||
)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
expected_path = [0, 2, 4, 5]
|
||||
@ -393,9 +447,10 @@ def test_bfs_sts_denied_label_1():
|
||||
assert source_destination_path[0][0] == expected_path
|
||||
|
||||
|
||||
def test_bfs_sts_denied_edge_type_3():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_sts_denied_edge_type_3(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
@ -404,8 +459,10 @@ def test_bfs_sts_denied_edge_type_3():
|
||||
)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
expected_path = [0, 2, 4, 5]
|
||||
@ -414,16 +471,19 @@ def test_bfs_sts_denied_edge_type_3():
|
||||
assert source_destination_path[0][0] == expected_path
|
||||
|
||||
|
||||
def test_bfs_single_source_all_edge_types_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_single_source_all_edge_types_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
@ -433,22 +493,26 @@ def test_bfs_single_source_all_edge_types_all_labels_granted():
|
||||
assert source_destination_path[0][0] == expected_path
|
||||
|
||||
|
||||
def test_bfs_single_source_all_edge_types_all_labels_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_single_source_all_edge_types_all_labels_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
|
||||
|
||||
total_paths_results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH p=(n)-[r *BFS]->(m) RETURN p;")
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
total_paths_results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH p=(n)-[r *BFS]->(m) RETURN p;")
|
||||
|
||||
assert len(total_paths_results) == 0
|
||||
|
||||
|
||||
def test_bfs_single_source_denied_start():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_single_source_denied_start(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -457,16 +521,19 @@ def test_bfs_single_source_denied_start():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
|
||||
user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(source_destination_path) == 0
|
||||
|
||||
|
||||
def test_bfs_single_source_denied_destination():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_single_source_denied_destination(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -475,16 +542,19 @@ def test_bfs_single_source_denied_destination():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
|
||||
user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(source_destination_path) == 0
|
||||
|
||||
|
||||
def test_bfs_single_source_denied_label_1():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_single_source_denied_label_1(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -492,8 +562,11 @@ def test_bfs_single_source_denied_label_1():
|
||||
)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
@ -503,9 +576,10 @@ def test_bfs_single_source_denied_label_1():
|
||||
assert source_destination_path[0][0] == expected_path
|
||||
|
||||
|
||||
def test_bfs_single_source_denied_edge_type_3():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_bfs_single_source_denied_edge_type_3(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
@ -514,8 +588,10 @@ def test_bfs_single_source_denied_edge_type_3():
|
||||
)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
source_destination_path = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
@ -525,20 +601,23 @@ def test_bfs_single_source_denied_edge_type_3():
|
||||
assert source_destination_path[0][0] == expected_path
|
||||
|
||||
|
||||
def test_all_shortest_paths_when_all_edge_types_all_labels_granted():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
total_paths_results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
path_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length,nodes(p);",
|
||||
)
|
||||
|
||||
@ -569,24 +648,28 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_granted():
|
||||
assert all(node.id in expected_path for node in path_result[0][1])
|
||||
|
||||
|
||||
def test_all_shortest_paths_when_all_edge_types_all_labels_denied():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN p;"
|
||||
user_connection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN p;"
|
||||
)
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_all_shortest_paths_when_denied_start():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_all_shortest_paths_when_denied_start(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -595,17 +678,20 @@ def test_all_shortest_paths_when_denied_start():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
path_length_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;",
|
||||
)
|
||||
|
||||
assert len(path_length_result) == 0
|
||||
|
||||
|
||||
def test_all_shortest_paths_when_denied_destination():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_all_shortest_paths_when_denied_destination(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -614,17 +700,20 @@ def test_all_shortest_paths_when_denied_destination():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
path_length_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;",
|
||||
)
|
||||
|
||||
assert len(path_length_result) == 0
|
||||
|
||||
|
||||
def test_all_shortest_paths_when_denied_label_1():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_all_shortest_paths_when_denied_label_1(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(
|
||||
@ -633,13 +722,15 @@ def test_all_shortest_paths_when_denied_label_1():
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
total_paths_results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
path_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);",
|
||||
)
|
||||
|
||||
@ -665,9 +756,10 @@ def test_all_shortest_paths_when_denied_label_1():
|
||||
assert all(node.id in expected_path for node in path_result[0][1])
|
||||
|
||||
|
||||
def test_all_shortest_paths_when_denied_edge_type_3():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_all_shortest_paths_when_denied_edge_type_3(switch):
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connnection = common.connect(username="user", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
|
||||
@ -676,13 +768,15 @@ def test_all_shortest_paths_when_denied_edge_type_3():
|
||||
)
|
||||
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
|
||||
|
||||
if switch:
|
||||
common.switch_db(user_connection.cursor())
|
||||
path_result = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);",
|
||||
)
|
||||
|
||||
total_paths_results = common.execute_and_fetch_all(
|
||||
user_connnection.cursor(),
|
||||
user_connection.cursor(),
|
||||
"MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);",
|
||||
)
|
||||
|
||||
|
36
tests/e2e/fine_grained_access/show_db.py
Normal file
36
tests/e2e/fine_grained_access/show_db.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import sys
|
||||
|
||||
import common
|
||||
import pytest
|
||||
from mgclient import DatabaseError
|
||||
|
||||
|
||||
def test_show_databases_w_user():
|
||||
admin_connection = common.connect(username="admin", password="test")
|
||||
user_connection = common.connect(username="user", password="test")
|
||||
user2_connection = common.connect(username="user2", password="test")
|
||||
user3_connection = common.connect(username="user3", password="test")
|
||||
|
||||
assert common.execute_and_fetch_all(admin_connection.cursor(), "SHOW DATABASES") == [
|
||||
("db1", ""),
|
||||
("db2", ""),
|
||||
("memgraph", "*"),
|
||||
]
|
||||
assert common.execute_and_fetch_all(user_connection.cursor(), "SHOW DATABASES") == [("db1", ""), ("memgraph", "*")]
|
||||
assert common.execute_and_fetch_all(user2_connection.cursor(), "SHOW DATABASES") == [("db2", "*")]
|
||||
assert common.execute_and_fetch_all(user3_connection.cursor(), "SHOW DATABASES") == [("db1", "*"), ("db2", "")]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-rA"]))
|
@ -9,7 +9,9 @@ create_delete_filtering_cluster: &create_delete_filtering_cluster
|
||||
"CREATE USER admin IDENTIFIED BY 'test';",
|
||||
"CREATE USER user IDENTIFIED BY 'test';",
|
||||
"GRANT ALL PRIVILEGES TO admin;",
|
||||
"GRANT DATABASE * TO admin;",
|
||||
"GRANT ALL PRIVILEGES TO user;",
|
||||
"GRANT DATABASE * TO user;",
|
||||
]
|
||||
|
||||
edge_type_filtering_cluster: &edge_type_filtering_cluster
|
||||
@ -22,7 +24,9 @@ edge_type_filtering_cluster: &edge_type_filtering_cluster
|
||||
"CREATE USER admin IDENTIFIED BY 'test';",
|
||||
"CREATE USER user IDENTIFIED BY 'test';",
|
||||
"GRANT ALL PRIVILEGES TO admin;",
|
||||
"GRANT DATABASE * TO admin;",
|
||||
"GRANT ALL PRIVILEGES TO user;",
|
||||
"GRANT DATABASE * TO user;",
|
||||
"GRANT CREATE_DELETE ON LABELS * TO admin;",
|
||||
"GRANT CREATE_DELETE ON EDGE_TYPES * TO admin;",
|
||||
"MERGE (l1:label1 {name: 'test1'});",
|
||||
@ -32,6 +36,17 @@ edge_type_filtering_cluster: &edge_type_filtering_cluster
|
||||
"MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3);",
|
||||
"MERGE (mix:label3:label1 {name: 'test4'});",
|
||||
"MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix);",
|
||||
"CREATE DATABASE clean;",
|
||||
"USE DATABASE clean",
|
||||
"MATCH (n) DETACH DELETE n;",
|
||||
"MERGE (l1:label1 {name: 'test1'});",
|
||||
"MERGE (l2:label2 {name: 'test2'});",
|
||||
"MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2);",
|
||||
"MERGE (l3:label3 {name: 'test3'});",
|
||||
"MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3);",
|
||||
"MERGE (mix:label3:label1 {name: 'test4'});",
|
||||
"MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix);",
|
||||
"USE DATABASE memgraph",
|
||||
]
|
||||
validation_queries: []
|
||||
|
||||
@ -45,7 +60,9 @@ path_filtering_cluster: &path_filtering_cluster
|
||||
"CREATE USER admin IDENTIFIED BY 'test';",
|
||||
"CREATE USER user IDENTIFIED BY 'test';",
|
||||
"GRANT ALL PRIVILEGES TO admin;",
|
||||
"GRANT DATABASE * TO admin;",
|
||||
"GRANT ALL PRIVILEGES TO user;",
|
||||
"GRANT DATABASE * TO user;",
|
||||
"MERGE (a:label0 {id: 0}) MERGE (b:label1 {id: 1}) CREATE (a)-[:edge_type_1 {weight: 6}]->(b);",
|
||||
"MERGE (a:label0 {id: 0}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_1 {weight: 14}]->(b);",
|
||||
"MERGE (a:label1 {id: 1}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_2 {weight: 1}]->(b);",
|
||||
@ -56,6 +73,47 @@ path_filtering_cluster: &path_filtering_cluster
|
||||
"MERGE (a:label3 {id: 4}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);",
|
||||
"MERGE (a:label3 {id: 3}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 14}]->(b);",
|
||||
"MERGE (a:label3 {id: 4}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 8}]->(b);",
|
||||
"CREATE DATABASE clean;",
|
||||
"USE DATABASE clean",
|
||||
"MATCH (n) DETACH DELETE n;",
|
||||
"MERGE (a:label0 {id: 0}) MERGE (b:label1 {id: 1}) CREATE (a)-[:edge_type_1 {weight: 6}]->(b);",
|
||||
"MERGE (a:label0 {id: 0}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_1 {weight: 14}]->(b);",
|
||||
"MERGE (a:label1 {id: 1}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_2 {weight: 1}]->(b);",
|
||||
"MERGE (a:label2 {id: 2}) MERGE (b:label3 {id: 4}) CREATE (a)-[:edge_type_2 {weight: 10}]->(b);",
|
||||
"MERGE (a:label1 {id: 1}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_3 {weight: 5}]->(b);",
|
||||
"MERGE (a:label2 {id: 2}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_3 {weight: 7}]->(b);",
|
||||
"MERGE (a:label3 {id: 3}) MERGE (b:label3 {id: 4}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);",
|
||||
"MERGE (a:label3 {id: 4}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);",
|
||||
"MERGE (a:label3 {id: 3}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 14}]->(b);",
|
||||
"MERGE (a:label3 {id: 4}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 8}]->(b);",
|
||||
"USE DATABASE memgraph",
|
||||
]
|
||||
|
||||
show_databases_w_user: &show_databases_w_user
|
||||
cluster:
|
||||
main:
|
||||
args: ["--bolt-port", "7687", "--log-level=TRACE"]
|
||||
log_file: "fine_grained_access.log"
|
||||
setup_queries:
|
||||
[
|
||||
"CREATE USER admin IDENTIFIED BY 'test';",
|
||||
"CREATE USER user IDENTIFIED BY 'test';",
|
||||
"CREATE USER user2 IDENTIFIED BY 'test';",
|
||||
"CREATE USER user3 IDENTIFIED BY 'test';",
|
||||
"CREATE DATABASE db1;",
|
||||
"CREATE DATABASE db2;",
|
||||
"GRANT ALL PRIVILEGES TO admin;",
|
||||
"GRANT DATABASE * TO admin;",
|
||||
"GRANT ALL PRIVILEGES TO user;",
|
||||
"GRANT DATABASE db1 TO user;",
|
||||
"GRANT ALL PRIVILEGES TO user2;",
|
||||
"GRANT DATABASE db2 TO user2;",
|
||||
"REVOKE DATABASE memgraph FROM user2;",
|
||||
"SET MAIN DATABASE db2 FOR user2",
|
||||
"GRANT ALL PRIVILEGES TO user3;",
|
||||
"GRANT DATABASE * TO user3;",
|
||||
"REVOKE DATABASE memgraph FROM user3;",
|
||||
"SET MAIN DATABASE db1 FOR user3",
|
||||
]
|
||||
|
||||
workloads:
|
||||
@ -63,7 +121,6 @@ workloads:
|
||||
binary: "tests/e2e/pytest_runner.sh"
|
||||
args: ["fine_grained_access/create_delete_filtering_tests.py"]
|
||||
<<: *create_delete_filtering_cluster
|
||||
|
||||
- name: "EdgeType filtering"
|
||||
binary: "tests/e2e/pytest_runner.sh"
|
||||
args: ["fine_grained_access/edge_type_filtering_tests.py"]
|
||||
@ -72,3 +129,7 @@ workloads:
|
||||
binary: "tests/e2e/pytest_runner.sh"
|
||||
args: ["fine_grained_access/path_filtering_tests.py"]
|
||||
<<: *path_filtering_cluster
|
||||
- name: "Show databases with users"
|
||||
binary: "tests/e2e/pytest_runner.sh"
|
||||
args: ["fine_grained_access/show_db.py"]
|
||||
<<: *show_databases_w_user
|
||||
|
@ -21,7 +21,14 @@ const driver = neo4j.driver(
|
||||
neo4j.auth.basic("", "")
|
||||
);
|
||||
|
||||
const neoSchema = new Neo4jGraphQL({ typeDefs, driver });
|
||||
const neoSchema = new Neo4jGraphQL({
|
||||
typeDefs, driver,
|
||||
config: {
|
||||
driverConfig: {
|
||||
database: "memgraph",
|
||||
},
|
||||
}
|
||||
});
|
||||
|
||||
neoSchema.getSchema().then((schema) => {
|
||||
const server = new ApolloServer({
|
||||
|
@ -9,12 +9,16 @@
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <fmt/core.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <mgclient.hpp>
|
||||
|
||||
#include "query/exceptions.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||
|
||||
@ -53,16 +57,55 @@ bool IsDiskStorageMode(std::unique_ptr<mg::Client> &client) {
|
||||
return false;
|
||||
}
|
||||
|
||||
void CleanDatabase() {
|
||||
auto client = GetClient();
|
||||
void CleanDatabase(std::unique_ptr<mg::Client> &client) {
|
||||
MG_ASSERT(client->Execute("MATCH (n) DETACH DELETE n;"));
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
void SetupCleanDB() {
|
||||
auto client = GetClient();
|
||||
MG_ASSERT(client->Execute("USE DATABASE memgraph;"));
|
||||
client->DiscardAll();
|
||||
try {
|
||||
client->Execute("DROP DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
} catch (const mg::ClientException &) {
|
||||
// In case clean doesn't exist
|
||||
}
|
||||
MG_ASSERT(client->Execute("CREATE DATABASE clean;"));
|
||||
client->DiscardAll();
|
||||
MG_ASSERT(client->Execute("USE DATABASE clean;"));
|
||||
client->DiscardAll();
|
||||
CleanDatabase(client);
|
||||
}
|
||||
|
||||
void SwitchToDB(const std::string &name, std::unique_ptr<mg::Client> &client) {
|
||||
MG_ASSERT(client->Execute(fmt::format("USE DATABASE {};", name)));
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
void SwitchToCleanDB(std::unique_ptr<mg::Client> &client) { SwitchToDB("clean", client); }
|
||||
|
||||
void SwitchToSameDB(std::unique_ptr<mg::Client> &main, std::unique_ptr<mg::Client> &client) {
|
||||
MG_ASSERT(main->Execute("SHOW DATABASES;"));
|
||||
auto dbs = main->FetchAll();
|
||||
MG_ASSERT(dbs, "Failed to show databases");
|
||||
for (const auto &elem : *dbs) {
|
||||
MG_ASSERT(elem.size(), "Show databases wrong output");
|
||||
const auto &active = elem[1].ValueString();
|
||||
if (active == "*") {
|
||||
const auto &name = elem[0].ValueString();
|
||||
SwitchToDB(std::string(name), client);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestSnapshotIsolation(std::unique_ptr<mg::Client> &client) {
|
||||
spdlog::info("Verifying SNAPSHOT ISOLATION");
|
||||
|
||||
auto creator = GetClient();
|
||||
SwitchToSameDB(client, creator);
|
||||
|
||||
MG_ASSERT(client->BeginTransaction());
|
||||
MG_ASSERT(creator->BeginTransaction());
|
||||
@ -89,13 +132,14 @@ void TestSnapshotIsolation(std::unique_ptr<mg::Client> &client) {
|
||||
"at a later point.",
|
||||
current_vertex_count, 0);
|
||||
MG_ASSERT(client->CommitTransaction());
|
||||
CleanDatabase();
|
||||
CleanDatabase(creator);
|
||||
}
|
||||
|
||||
void TestReadCommitted(std::unique_ptr<mg::Client> &client) {
|
||||
spdlog::info("Verifying READ COMMITTED");
|
||||
|
||||
auto creator = GetClient();
|
||||
SwitchToSameDB(client, creator);
|
||||
|
||||
MG_ASSERT(client->BeginTransaction());
|
||||
MG_ASSERT(creator->BeginTransaction());
|
||||
@ -121,13 +165,14 @@ void TestReadCommitted(std::unique_ptr<mg::Client> &client) {
|
||||
"from a committed transaction",
|
||||
current_vertex_count, vertex_count);
|
||||
MG_ASSERT(client->CommitTransaction());
|
||||
CleanDatabase();
|
||||
CleanDatabase(creator);
|
||||
}
|
||||
|
||||
void TestReadUncommitted(std::unique_ptr<mg::Client> &client) {
|
||||
spdlog::info("Verifying READ UNCOMMITTED");
|
||||
|
||||
auto creator = GetClient();
|
||||
SwitchToSameDB(client, creator);
|
||||
|
||||
MG_ASSERT(client->BeginTransaction());
|
||||
MG_ASSERT(creator->BeginTransaction());
|
||||
@ -152,18 +197,23 @@ void TestReadUncommitted(std::unique_ptr<mg::Client> &client) {
|
||||
"from a different transaction",
|
||||
current_vertex_count, vertex_count);
|
||||
MG_ASSERT(client->CommitTransaction());
|
||||
CleanDatabase();
|
||||
CleanDatabase(creator);
|
||||
}
|
||||
|
||||
inline constexpr std::array isolation_levels{std::pair{"SNAPSHOT ISOLATION", &TestSnapshotIsolation},
|
||||
std::pair{"READ COMMITTED", &TestReadCommitted},
|
||||
std::pair{"READ UNCOMMITTED", &TestReadUncommitted}};
|
||||
|
||||
void TestGlobalIsolationLevel(bool isDiskStorage) {
|
||||
void TestGlobalIsolationLevel(bool isDiskStorage, bool mdb = false) {
|
||||
spdlog::info("\n\n----Test global isolation levels----\n");
|
||||
auto first_client = GetClient();
|
||||
auto second_client = GetClient();
|
||||
|
||||
if (mdb) {
|
||||
SwitchToCleanDB(first_client);
|
||||
SwitchToCleanDB(second_client);
|
||||
}
|
||||
|
||||
for (const auto &[isolation_level, verification_function] : isolation_levels) {
|
||||
spdlog::info("--------------------------");
|
||||
|
||||
@ -183,11 +233,17 @@ void TestGlobalIsolationLevel(bool isDiskStorage) {
|
||||
}
|
||||
}
|
||||
|
||||
void TestSessionIsolationLevel(bool isDiskStorage) {
|
||||
void TestSessionIsolationLevel(bool isDiskStorage, bool mdb = false) {
|
||||
spdlog::info("\n\n----Test session isolation levels----\n");
|
||||
|
||||
auto global_client = GetClient();
|
||||
auto session_client = GetClient();
|
||||
|
||||
if (mdb) {
|
||||
SwitchToCleanDB(global_client);
|
||||
SwitchToCleanDB(session_client);
|
||||
}
|
||||
|
||||
for (const auto &[global_isolation_level, global_verification_function] : isolation_levels) {
|
||||
if (isDiskStorage && strcmp(global_isolation_level, "SNAPSHOT ISOLATION") != 0) {
|
||||
spdlog::info("Skipping for disk storage unsupported global isolation level {}", global_isolation_level);
|
||||
@ -218,11 +274,17 @@ void TestSessionIsolationLevel(bool isDiskStorage) {
|
||||
}
|
||||
|
||||
// Priority of applying the isolation level from highest priority NEXT -> SESSION -> GLOBAL
|
||||
void TestNextIsolationLevel(bool isDiskStorage) {
|
||||
void TestNextIsolationLevel(bool isDiskStorage, bool mdb = false) {
|
||||
spdlog::info("\n\n----Test next isolation levels----\n");
|
||||
|
||||
auto global_client = GetClient();
|
||||
auto session_client = GetClient();
|
||||
|
||||
if (mdb) {
|
||||
SwitchToCleanDB(global_client);
|
||||
SwitchToCleanDB(session_client);
|
||||
}
|
||||
|
||||
for (const auto &[global_isolation_level, global_verification_function] : isolation_levels) {
|
||||
if (isDiskStorage && strcmp(global_isolation_level, "SNAPSHOT ISOLATION") != 0) {
|
||||
spdlog::info("Skipping for disk storage unsupported global isolation level {}", global_isolation_level);
|
||||
@ -289,10 +351,21 @@ int main(int argc, char **argv) {
|
||||
auto client = GetClient();
|
||||
bool isDiskStorage = IsDiskStorageMode(client);
|
||||
client->DiscardAll();
|
||||
bool multiDB = false;
|
||||
|
||||
TestGlobalIsolationLevel(isDiskStorage);
|
||||
TestSessionIsolationLevel(isDiskStorage);
|
||||
TestNextIsolationLevel(isDiskStorage);
|
||||
|
||||
// MultiDB tests
|
||||
multiDB = true;
|
||||
spdlog::info("--------------------------");
|
||||
spdlog::info("---- RUNNING MULTI DB ----");
|
||||
spdlog::info("--------------------------");
|
||||
SetupCleanDB();
|
||||
TestGlobalIsolationLevel(isDiskStorage, multiDB);
|
||||
TestSessionIsolationLevel(isDiskStorage, multiDB);
|
||||
TestNextIsolationLevel(isDiskStorage, multiDB);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2021 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,9 +9,10 @@
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import mgclient
|
||||
import typing
|
||||
|
||||
import mgclient
|
||||
|
||||
|
||||
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
|
||||
cursor.execute(query, params)
|
||||
@ -24,6 +25,19 @@ def connect(**kwargs) -> mgclient.Connection:
|
||||
return connection
|
||||
|
||||
|
||||
def switch_db(cursor):
|
||||
execute_and_fetch_all(cursor, "USE DATABASE clean;")
|
||||
|
||||
|
||||
def create_multi_db(cursor):
|
||||
execute_and_fetch_all(cursor, "USE DATABASE memgraph;")
|
||||
try:
|
||||
execute_and_fetch_all(cursor, "DROP DATABASE clean;")
|
||||
except:
|
||||
pass
|
||||
execute_and_fetch_all(cursor, "CREATE DATABASE clean;")
|
||||
|
||||
|
||||
def reset_permissions(admin_cursor: mgclient.Cursor, create_index: bool = False):
|
||||
execute_and_fetch_all(admin_cursor, "REVOKE LABELS * FROM user;")
|
||||
execute_and_fetch_all(admin_cursor, "REVOKE EDGE_TYPES * FROM user;")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,15 +9,10 @@
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
from common import (
|
||||
connect,
|
||||
execute_and_fetch_all,
|
||||
mgclient,
|
||||
reset_create_delete_permissions,
|
||||
)
|
||||
import pytest
|
||||
from common import *
|
||||
|
||||
AUTHORIZATION_ERROR_IDENTIFIER = "AuthorizationError"
|
||||
|
||||
@ -28,55 +23,83 @@ create_edge_query = "MATCH (n:create_delete_label_1), (m:create_delete_label_2)
|
||||
delete_edge_query = "CALL create_delete.delete_edge() YIELD * RETURN *;"
|
||||
|
||||
|
||||
def test_can_not_create_vertex_when_given_nothing():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_create_vertex_when_given_nothing(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, create_vertex_query)
|
||||
|
||||
|
||||
def test_can_create_vertex_when_given_global_create_delete():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_create_vertex_when_given_global_create_delete(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
result = execute_and_fetch_all(test_cursor, create_vertex_query)
|
||||
|
||||
len(result[0][0]) == 1
|
||||
|
||||
|
||||
def test_can_not_create_vertex_when_given_global_read():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_create_vertex_when_given_global_read(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, create_vertex_query)
|
||||
|
||||
|
||||
def test_can_not_create_vertex_when_given_global_update():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_create_vertex_when_given_global_update(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :create_delete_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, create_vertex_query)
|
||||
|
||||
|
||||
def test_can_add_vertex_label_when_given_create_delete():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_add_vertex_label_when_given_create_delete(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(
|
||||
@ -85,14 +108,20 @@ def test_can_add_vertex_label_when_given_create_delete():
|
||||
)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_label_vertex_query)
|
||||
|
||||
assert "create_delete_label" in result[0][0]
|
||||
assert "new_create_delete_label" in result[0][0]
|
||||
|
||||
|
||||
def test_can_not_add_vertex_label_when_given_update():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_add_vertex_label_when_given_update(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(
|
||||
@ -100,12 +129,18 @@ def test_can_not_add_vertex_label_when_given_update():
|
||||
)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, set_label_vertex_query)
|
||||
|
||||
|
||||
def test_can_not_add_vertex_label_when_given_read():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_add_vertex_label_when_given_read(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(
|
||||
@ -113,118 +148,178 @@ def test_can_not_add_vertex_label_when_given_read():
|
||||
)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, set_label_vertex_query)
|
||||
|
||||
|
||||
def test_can_remove_vertex_label_when_given_create_delete():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_remove_vertex_label_when_given_create_delete(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :create_delete_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, remove_label_vertex_query)
|
||||
|
||||
assert result[0][0] != ":create_delete_label"
|
||||
|
||||
|
||||
def test_can_remove_vertex_label_when_given_global_create_delete():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_remove_vertex_label_when_given_global_create_delete(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, remove_label_vertex_query)
|
||||
|
||||
assert result[0][0] != ":create_delete_label"
|
||||
|
||||
|
||||
def test_can_not_remove_vertex_label_when_given_update():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_remove_vertex_label_when_given_update(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :create_delete_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, remove_label_vertex_query)
|
||||
|
||||
|
||||
def test_can_not_remove_vertex_label_when_given_global_update():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_remove_vertex_label_when_given_global_update(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, remove_label_vertex_query)
|
||||
|
||||
|
||||
def test_can_not_remove_vertex_label_when_given_read():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_remove_vertex_label_when_given_read(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :create_delete_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, remove_label_vertex_query)
|
||||
|
||||
|
||||
def test_can_not_remove_vertex_label_when_given_global_read():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_remove_vertex_label_when_given_global_read(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, remove_label_vertex_query)
|
||||
|
||||
|
||||
def test_can_not_create_edge_when_given_nothing():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_create_edge_when_given_nothing(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, create_edge_query)
|
||||
|
||||
|
||||
def test_can_not_create_edge_when_given_read():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_create_edge_when_given_read(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :new_create_delete_edge_type TO user")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, create_edge_query)
|
||||
|
||||
|
||||
def test_can_not_create_edge_when_given_update():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_create_edge_when_given_update(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :new_create_delete_edge_type TO user")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, create_edge_query)
|
||||
|
||||
|
||||
def test_can_create_edge_when_given_create_delete():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_create_edge_when_given_create_delete(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(
|
||||
@ -233,24 +328,36 @@ def test_can_create_edge_when_given_create_delete():
|
||||
)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
no_of_edges = execute_and_fetch_all(test_cursor, create_edge_query)
|
||||
|
||||
assert no_of_edges[0][0] == 2
|
||||
|
||||
|
||||
def test_can_not_delete_edge_when_given_nothing():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_delete_edge_when_given_nothing(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, delete_edge_query)
|
||||
|
||||
|
||||
def test_can_not_delete_edge_when_given_read():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_delete_edge_when_given_read(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(
|
||||
@ -259,13 +366,19 @@ def test_can_not_delete_edge_when_given_read():
|
||||
)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, delete_edge_query)
|
||||
|
||||
|
||||
def test_can_not_delete_edge_when_given_update():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_delete_edge_when_given_update(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(
|
||||
@ -274,13 +387,19 @@ def test_can_not_delete_edge_when_given_update():
|
||||
)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER):
|
||||
execute_and_fetch_all(test_cursor, delete_edge_query)
|
||||
|
||||
|
||||
def test_can_delete_edge_when_given_create_delete():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_delete_edge_when_given_create_delete(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_create_delete_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(
|
||||
@ -289,6 +408,8 @@ def test_can_delete_edge_when_given_create_delete():
|
||||
)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
|
||||
no_of_edges = execute_and_fetch_all(test_cursor, delete_edge_query)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,12 +9,11 @@
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
from typing import List
|
||||
|
||||
from common import connect, execute_and_fetch_all, reset_permissions
|
||||
import pytest
|
||||
from common import *
|
||||
|
||||
match_query = "MATCH (n) RETURN n;"
|
||||
match_by_id_query = "MATCH (n) WHERE ID(n) >= 0 RETURN n;"
|
||||
@ -105,11 +104,16 @@ def get_user_cursor():
|
||||
|
||||
|
||||
def execute_read_node_assertion(
|
||||
operation_case: List[str], queries: List[str], create_index: bool, expected_size: int
|
||||
operation_case: List[str], queries: List[str], create_index: bool, expected_size: int, switch: bool
|
||||
) -> None:
|
||||
admin_cursor = get_admin_cursor()
|
||||
user_cursor = get_user_cursor()
|
||||
|
||||
if switch:
|
||||
create_multi_db(admin_cursor)
|
||||
switch_db(admin_cursor)
|
||||
switch_db(user_cursor)
|
||||
|
||||
reset_permissions(admin_cursor, create_index)
|
||||
|
||||
for operation in operation_case:
|
||||
@ -120,7 +124,8 @@ def execute_read_node_assertion(
|
||||
assert len(results) == expected_size
|
||||
|
||||
|
||||
def test_can_read_node_when_authorized():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_node_when_authorized(switch):
|
||||
match_queries_without_index = [match_query, match_by_id_query]
|
||||
match_queries_with_index = [
|
||||
match_by_label_query,
|
||||
@ -132,14 +137,15 @@ def test_can_read_node_when_authorized():
|
||||
for expected_size, operation_case in zip(
|
||||
read_node_without_index_operation_cases_expected_size, read_node_without_index_operation_cases
|
||||
):
|
||||
execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size)
|
||||
execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size, switch)
|
||||
for expected_size, operation_case in zip(
|
||||
read_node_with_index_operation_cases_expected_sizes, read_node_with_index_operation_cases
|
||||
):
|
||||
execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size)
|
||||
execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size, switch)
|
||||
|
||||
|
||||
def test_can_not_read_node_when_authorized():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_read_node_when_authorized(switch):
|
||||
match_queries_without_index = [match_query, match_by_id_query]
|
||||
match_queries_with_index = [
|
||||
match_by_label_query,
|
||||
@ -151,11 +157,11 @@ def test_can_not_read_node_when_authorized():
|
||||
for expected_size, operation_case in zip(
|
||||
not_read_node_without_index_operation_cases_expected_sizes, not_read_node_without_index_operation_cases
|
||||
):
|
||||
execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size)
|
||||
execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size, switch)
|
||||
for expected_size, operation_case in zip(
|
||||
not_read_node_with_index_operation_cases_expexted_sizes, not_read_node_with_index_operation_cases
|
||||
):
|
||||
execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size)
|
||||
execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size, switch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -10,177 +10,260 @@
|
||||
# licenses/APL.txt.
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from common import connect, execute_and_fetch_all, reset_permissions
|
||||
from common import *
|
||||
|
||||
get_number_of_vertices_query = "CALL read.number_of_visible_nodes() YIELD nr_of_nodes RETURN nr_of_nodes;"
|
||||
get_number_of_edges_query = "CALL read.number_of_visible_edges() YIELD nr_of_edges RETURN nr_of_edges;"
|
||||
|
||||
|
||||
def test_can_read_vertex_through_c_api_when_given_grant_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_vertex_through_c_api_when_given_grant_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_read_vertex_through_c_api_when_given_update_grant_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_vertex_through_c_api_when_given_update_grant_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :read_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_read_vertex_through_c_api_when_given_create_delete_grant_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_vertex_through_c_api_when_given_create_delete_grant_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :read_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_not_read_vertex_through_c_api_when_given_nothing():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_read_vertex_through_c_api_when_given_nothing(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
|
||||
|
||||
assert result[0][0] == 0
|
||||
|
||||
|
||||
def test_can_not_read_vertex_through_c_api_when_given_deny_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_read_vertex_through_c_api_when_given_deny_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
|
||||
|
||||
assert result[0][0] == 0
|
||||
|
||||
|
||||
def test_can_read_partial_vertices_through_c_api_when_given_global_read_but_deny_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_partial_vertices_through_c_api_when_given_global_read_but_deny_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_read_partial_vertices_through_c_api_when_given_global_update_but_deny_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_partial_vertices_through_c_api_when_given_global_update_but_deny_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_read_partial_vertices_through_c_api_when_given_global_create_delete_but_deny_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_partial_vertices_through_c_api_when_given_global_create_delete_but_deny_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_read_edge_through_c_api_when_given_grant_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_edge_through_c_api_when_given_grant_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :read_edge_type TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_not_read_edge_through_c_api_when_given_deny_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_read_edge_through_c_api_when_given_deny_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON EDGE_TYPES :read_edge_type TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
|
||||
|
||||
assert result[0][0] == 0
|
||||
|
||||
|
||||
def test_can_read_edge_through_c_api_when_given_grant_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_edge_through_c_api_when_given_grant_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :read_edge_type TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_read_edge_through_c_api_when_given_update_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_edge_through_c_api_when_given_update_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :read_edge_type TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_read_edge_through_c_api_when_given_create_delete_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_read_edge_through_c_api_when_given_create_delete_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES :read_edge_type TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
|
||||
@ -188,13 +271,19 @@ def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
|
||||
|
||||
assert result[0][0] == 0
|
||||
|
||||
|
||||
def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
|
||||
@ -202,13 +291,19 @@ def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_ed
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
|
||||
|
||||
assert result[0][0] == 0
|
||||
|
||||
|
||||
def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_deny_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_deny_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;")
|
||||
@ -216,6 +311,8 @@ def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_den
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, get_number_of_edges_query)
|
||||
|
||||
assert result[0][0] == 0
|
||||
|
@ -38,6 +38,8 @@ BASIC_PRIVILEGES = [
|
||||
"MODULE_WRITE",
|
||||
"TRANSACTION_MANAGEMENT",
|
||||
"STORAGE_MODE",
|
||||
"MULTI_DATABASE_EDIT",
|
||||
"MULTI_DATABASE_USE",
|
||||
]
|
||||
|
||||
|
||||
@ -61,7 +63,7 @@ def test_lba_procedures_show_privileges_first_user():
|
||||
cursor = connect(username="Josip", password="").cursor()
|
||||
result = execute_and_fetch_all(cursor, "SHOW PRIVILEGES FOR Josip;")
|
||||
|
||||
assert len(result) == 32
|
||||
assert len(result) == 34
|
||||
|
||||
fine_privilege_results = [res for res in result if res[0] not in BASIC_PRIVILEGES]
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,107 +9,149 @@
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
from common import (
|
||||
connect,
|
||||
execute_and_fetch_all,
|
||||
reset_update_permissions,
|
||||
)
|
||||
import pytest
|
||||
from common import *
|
||||
|
||||
set_vertex_property_query = "MATCH (n:update_label) CALL update.set_property(n) YIELD * RETURN n.prop;"
|
||||
set_edge_property_query = "MATCH (n:update_label_1)-[r:update_edge_type]->(m:update_label_2) CALL update.set_property(r) YIELD * RETURN r.prop;"
|
||||
|
||||
|
||||
def test_can_not_update_vertex_when_given_read():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_update_vertex_when_given_read(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_update_vertex_when_given_update_grant_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_update_vertex_when_given_update_grant_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :update_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_update_vertex_when_given_create_delete_grant_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_update_vertex_when_given_create_delete_grant_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :update_label TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_update_vertex_when_given_update_global_grant_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_update_vertex_when_given_update_global_grant_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_update_vertex_when_given_create_delete_global_grant_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_update_vertex_when_given_create_delete_global_grant_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_not_update_vertex_when_denied_update_and_granted_global_update_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_update_vertex_when_denied_update_and_granted_global_update_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_not_update_vertex_when_denied_update_and_granted_global_create_delete_on_label():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_update_vertex_when_denied_update_and_granted_global_create_delete_on_label(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;")
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_vertex_property_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_update_edge_when_given_update_grant_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_update_edge_when_given_update_grant_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
|
||||
@ -117,13 +159,19 @@ def test_can_update_edge_when_given_update_grant_on_edge_type():
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :update_edge_type TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_not_update_edge_when_given_read_grant_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_update_edge_when_given_read_grant_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
|
||||
@ -131,13 +179,19 @@ def test_can_not_update_edge_when_given_read_grant_on_edge_type():
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :update_edge_type TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_update_edge_when_given_create_delete_grant_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_update_edge_when_given_create_delete_grant_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
|
||||
@ -145,13 +199,19 @@ def test_can_update_edge_when_given_create_delete_grant_on_edge_type():
|
||||
execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES :update_edge_type TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
|
||||
|
||||
assert result[0][0] == 2
|
||||
|
||||
|
||||
def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_update_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_update_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
|
||||
@ -160,13 +220,19 @@ def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_upd
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
||||
|
||||
def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_create_delete_on_edge_type():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_create_delete_on_edge_type(switch):
|
||||
admin_cursor = connect(username="admin", password="test").cursor()
|
||||
create_multi_db(admin_cursor)
|
||||
if switch:
|
||||
switch_db(admin_cursor)
|
||||
reset_update_permissions(admin_cursor)
|
||||
|
||||
execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;")
|
||||
@ -175,6 +241,8 @@ def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_cre
|
||||
execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES * TO user;")
|
||||
|
||||
test_cursor = connect(username="user", password="test").cursor()
|
||||
if switch:
|
||||
switch_db(test_cursor)
|
||||
result = execute_and_fetch_all(test_cursor, set_edge_property_query)
|
||||
|
||||
assert result[0][0] == 1
|
||||
|
@ -6,8 +6,10 @@ read_query_modules_cluster: &read_query_modules_cluster
|
||||
setup_queries:
|
||||
- "CREATE USER admin IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO admin"
|
||||
- "GRANT DATABASE * TO admin"
|
||||
- "CREATE USER user IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO user"
|
||||
- "GRANT DATABASE * TO user"
|
||||
validation_queries: []
|
||||
|
||||
update_query_modules_cluster: &update_query_modules_cluster
|
||||
@ -18,8 +20,10 @@ update_query_modules_cluster: &update_query_modules_cluster
|
||||
setup_queries:
|
||||
- "CREATE USER admin IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO admin"
|
||||
- "GRANT DATABASE * TO admin"
|
||||
- "CREATE USER user IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO user"
|
||||
- "GRANT DATABASE * TO user"
|
||||
validation_queries: []
|
||||
|
||||
show_privileges_cluster: &show_privileges_cluster
|
||||
@ -67,8 +71,10 @@ read_permission_queries: &read_permission_queries
|
||||
setup_queries:
|
||||
- "CREATE USER admin IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO admin"
|
||||
- "GRANT DATABASE * TO admin"
|
||||
- "CREATE USER user IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO user"
|
||||
- "GRANT DATABASE * TO user"
|
||||
validation_queries: []
|
||||
|
||||
create_delete_query_modules_cluster: &create_delete_query_modules_cluster
|
||||
@ -79,8 +85,10 @@ create_delete_query_modules_cluster: &create_delete_query_modules_cluster
|
||||
setup_queries:
|
||||
- "CREATE USER admin IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO admin;"
|
||||
- "GRANT DATABASE * TO admin"
|
||||
- "CREATE USER user IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO user;"
|
||||
- "GRANT DATABASE * TO user"
|
||||
validation_queries: []
|
||||
|
||||
update_permission_queries_cluster: &update_permission_queries_cluster
|
||||
@ -91,8 +99,10 @@ update_permission_queries_cluster: &update_permission_queries_cluster
|
||||
setup_queries:
|
||||
- "CREATE USER admin IDENTIFIED BY 'test';"
|
||||
- "GRANT ALL PRIVILEGES TO admin;"
|
||||
- "GRANT DATABASE * TO admin"
|
||||
- "CREATE USER user IDENTIFIED BY 'test'"
|
||||
- "GRANT ALL PRIVILEGES TO user;"
|
||||
- "GRANT DATABASE * TO user"
|
||||
validation_queries: []
|
||||
|
||||
workloads:
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,16 +9,29 @@
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import typing
|
||||
import mgclient
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import mgclient
|
||||
import pytest
|
||||
from common import execute_and_fetch_all, has_n_result_row
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_argument(connection, function_type):
|
||||
@pytest.fixture(scope="function")
|
||||
def multi_db(request, connection):
|
||||
cursor = connection.cursor()
|
||||
if request.param:
|
||||
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "USE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
pass
|
||||
yield connection
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_argument(multi_db, function_type):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
|
||||
result = execute_and_fetch_all(
|
||||
@ -31,9 +44,10 @@ def test_return_argument(connection, function_type):
|
||||
assert vertex.properties == {"id": 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_optional_argument(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
def test_return_optional_argument(multi_db, function_type):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
@ -44,9 +58,10 @@ def test_return_optional_argument(connection, function_type):
|
||||
assert result == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_optional_argument_no_arg(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
def test_return_optional_argument_no_arg(multi_db, function_type):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
@ -57,9 +72,10 @@ def test_return_optional_argument_no_arg(connection, function_type):
|
||||
assert result == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_add_two_numbers(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
def test_add_two_numbers(multi_db, function_type):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
@ -70,9 +86,10 @@ def test_add_two_numbers(connection, function_type):
|
||||
assert result_sum == 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_null(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
def test_return_null(multi_db, function_type):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
@ -82,9 +99,10 @@ def test_return_null(connection, function_type):
|
||||
assert result_null is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_too_many_arguments(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
def test_too_many_arguments(multi_db, function_type):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
# Should raise too many arguments
|
||||
with pytest.raises(mgclient.DatabaseError):
|
||||
@ -94,9 +112,10 @@ def test_too_many_arguments(connection, function_type):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_try_to_write(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
def test_try_to_write(multi_db, function_type):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
|
||||
# Should raise non mutable
|
||||
@ -106,9 +125,11 @@ def test_try_to_write(connection, function_type):
|
||||
f"MATCH (n) RETURN {function_type}_write.try_to_write(n, 'property', 1);",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_case_sensitivity(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
def test_case_sensitivity(multi_db, function_type):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
# Should raise function does not exist
|
||||
with pytest.raises(mgclient.DatabaseError):
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::SetUsageMessage("Memgraph E2E Memory Control");
|
||||
@ -34,6 +35,15 @@ int main(int argc, char **argv) {
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
|
||||
if (FLAGS_multi_db) {
|
||||
client->Execute("CREATE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("USE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
const auto *create_query = "UNWIND range(1, 50) as u CREATE (n {string: \"Some longer string\"}) RETURN n;";
|
||||
|
||||
memgraph::utils::Timer timer;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::SetUsageMessage("Memgraph E2E Memory Limit For Global Allocators");
|
||||
@ -31,6 +32,15 @@ int main(int argc, char **argv) {
|
||||
LOG_FATAL("Failed to connect!");
|
||||
}
|
||||
|
||||
if (FLAGS_multi_db) {
|
||||
client->Execute("CREATE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("USE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
bool result = client->Execute("CALL libglobal_memory_limit.procedure() YIELD *");
|
||||
MG_ASSERT(result == false);
|
||||
return 0;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::SetUsageMessage("Memgraph E2E Memory Limit For Global Allocators");
|
||||
@ -31,6 +32,16 @@ int main(int argc, char **argv) {
|
||||
if (!client) {
|
||||
LOG_FATAL("Failed to connect!");
|
||||
}
|
||||
|
||||
if (FLAGS_multi_db) {
|
||||
client->Execute("CREATE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("USE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
MG_ASSERT(client->Execute("CALL libglobal_memory_limit_proc.error() YIELD *"));
|
||||
MG_ASSERT(std::invoke([&] {
|
||||
try {
|
||||
|
@ -21,14 +21,31 @@ workloads:
|
||||
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
|
||||
<<: *template_cluster
|
||||
|
||||
- name: "Memory control multi database"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__control"
|
||||
args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"]
|
||||
<<: *template_cluster
|
||||
|
||||
- name: "Memory limit for modules upon loading"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc"
|
||||
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
|
||||
proc: "tests/e2e/memory/procedures/"
|
||||
<<: *template_cluster
|
||||
|
||||
- name: "Memory limit for modules upon loading multi database"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc"
|
||||
args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"]
|
||||
proc: "tests/e2e/memory/procedures/"
|
||||
<<: *template_cluster
|
||||
|
||||
- name: "Memory limit for modules inside a procedure"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc_proc"
|
||||
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
|
||||
proc: "tests/e2e/memory/procedures/"
|
||||
<<: *template_cluster
|
||||
|
||||
- name: "Memory limit for modules inside a procedure multi database"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc_proc"
|
||||
args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"]
|
||||
proc: "tests/e2e/memory/procedures/"
|
||||
<<: *template_cluster
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -21,6 +21,7 @@
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||
|
||||
namespace {
|
||||
auto GetClient() {
|
||||
@ -181,6 +182,15 @@ int main(int argc, char **argv) {
|
||||
mg::Client::Init();
|
||||
auto client = GetClient();
|
||||
|
||||
if (FLAGS_multi_db) {
|
||||
client->Execute("CREATE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("USE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
AssertQueryFails<mg::ClientException>(client, CreateModuleFileQuery("some.cpp", "some content"),
|
||||
"mg.create_module_file: The specified file isn't in the supported format.");
|
||||
|
||||
|
@ -12,3 +12,8 @@ workloads:
|
||||
binary: "tests/e2e/module_file_manager/memgraph__e2e__module_file_manager"
|
||||
args: ["--bolt-port", *bolt_port]
|
||||
<<: *template_cluster
|
||||
|
||||
- name: "Module File Manager multi database"
|
||||
binary: "tests/e2e/module_file_manager/memgraph__e2e__module_file_manager"
|
||||
args: ["--bolt-port", *bolt_port, "--multi-db", "true"]
|
||||
<<: *template_cluster
|
||||
|
@ -14,6 +14,19 @@ import typing
|
||||
import mgclient
|
||||
|
||||
|
||||
def switch_db(cursor):
|
||||
execute_and_fetch_all(cursor, "USE DATABASE clean;")
|
||||
|
||||
|
||||
def create_multi_db(cursor):
|
||||
execute_and_fetch_all(cursor, "USE DATABASE memgraph;")
|
||||
try:
|
||||
execute_and_fetch_all(cursor, "DROP DATABASE clean;")
|
||||
except:
|
||||
pass
|
||||
execute_and_fetch_all(cursor, "CREATE DATABASE clean;")
|
||||
|
||||
|
||||
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
|
||||
cursor.execute(query, params)
|
||||
return cursor.fetchall()
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -14,7 +14,7 @@ import os # To be removed
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from common import connect, execute_and_fetch_all
|
||||
from common import connect, create_multi_db, execute_and_fetch_all, switch_db
|
||||
|
||||
COMMON_PATH_PREFIX_TEST1 = "procedures/mage/test_module"
|
||||
COMMON_PATH_PREFIX_TEST2 = "procedures/new_test_module_utils"
|
||||
@ -76,9 +76,13 @@ def postprocess_functions(path1: str, path2: str):
|
||||
)
|
||||
|
||||
|
||||
def test_mg_load_reload_submodule_root_utils():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_mg_load_reload_submodule_root_utils(switch):
|
||||
"""Tests whether mg.load reloads content of some submodule code."""
|
||||
cursor = connect().cursor()
|
||||
if switch:
|
||||
create_multi_db(cursor)
|
||||
switch_db(cursor)
|
||||
# First do a simple experiment
|
||||
test_module_res = execute_and_fetch_all(cursor, "CALL new_test_module.test(10, 2) YIELD * RETURN *;")
|
||||
try:
|
||||
@ -101,9 +105,13 @@ def test_mg_load_reload_submodule_root_utils():
|
||||
execute_and_fetch_all(cursor, "CALL mg.load('new_test_module');")
|
||||
|
||||
|
||||
def test_mg_load_all_reload_submodule_root_utils():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_mg_load_all_reload_submodule_root_utils(switch):
|
||||
"""Tests whether mg.load_all reloads content of some submodule code"""
|
||||
cursor = connect().cursor()
|
||||
if switch:
|
||||
create_multi_db(cursor)
|
||||
switch_db(cursor)
|
||||
# First do a simple experiment
|
||||
test_module_res = execute_and_fetch_all(cursor, "CALL new_test_module.test(10, 2) YIELD * RETURN *;")
|
||||
try:
|
||||
@ -126,9 +134,13 @@ def test_mg_load_all_reload_submodule_root_utils():
|
||||
execute_and_fetch_all(cursor, "CALL mg.load_all();")
|
||||
|
||||
|
||||
def test_mg_load_reload_submodule():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_mg_load_reload_submodule(switch):
|
||||
"""Tests whether mg.load reloads content of some submodule code."""
|
||||
cursor = connect().cursor()
|
||||
if switch:
|
||||
create_multi_db(cursor)
|
||||
switch_db(cursor)
|
||||
# First do a simple experiment
|
||||
test_module_res = execute_and_fetch_all(cursor, "CALL test_module.test(10, 2) YIELD * RETURN *;")
|
||||
try:
|
||||
@ -151,9 +163,13 @@ def test_mg_load_reload_submodule():
|
||||
execute_and_fetch_all(cursor, "CALL mg.load('test_module');")
|
||||
|
||||
|
||||
def test_mg_load_all_reload_submodule():
|
||||
@pytest.mark.parametrize("switch", [False, True])
|
||||
def test_mg_load_all_reload_submodule(switch):
|
||||
"""Tests whether mg.load_all reloads content of some submodule code"""
|
||||
cursor = connect().cursor()
|
||||
if switch:
|
||||
create_multi_db(cursor)
|
||||
switch_db(cursor)
|
||||
# First do a simple experiment
|
||||
test_module_res = execute_and_fetch_all(cursor, "CALL test_module.test(10, 2) YIELD * RETURN *;")
|
||||
try:
|
||||
|
@ -12,7 +12,6 @@
|
||||
import typing
|
||||
|
||||
import mgclient
|
||||
import pytest
|
||||
|
||||
|
||||
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
|
||||
|
@ -12,7 +12,6 @@
|
||||
|
||||
import multiprocessing
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
@ -56,12 +55,28 @@ def test_self_transaction():
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_multitenant_transactions():
|
||||
"""Tests that show transactions work on another database"""
|
||||
test_cursor = connect().cursor()
|
||||
execute_and_fetch_all(test_cursor, "CREATE DATABASE testing")
|
||||
tx_connection = connect()
|
||||
tx_cursor = tx_connection.cursor()
|
||||
tx_process = multiprocessing.Process(
|
||||
target=process_function, args=(tx_cursor, ["USE DATABASE testing", "MATCH (n) RETURN n"])
|
||||
)
|
||||
tx_process.start()
|
||||
time.sleep(0.5)
|
||||
show_transactions_test(test_cursor, 1)
|
||||
# TODO Add SHOW TRANSACTIONS ON * that should return all transactions
|
||||
|
||||
|
||||
def test_admin_has_one_transaction():
|
||||
"""Creates admin and tests that he sees only one transaction."""
|
||||
# a_cursor is used for creating admin user, simulates main thread
|
||||
superadmin_cursor = connect().cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
|
||||
admin_cursor = connect(username="admin", password="").cursor()
|
||||
process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1))
|
||||
process.start()
|
||||
@ -74,6 +89,7 @@ def test_user_can_see_its_transaction():
|
||||
superadmin_cursor = connect().cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
|
||||
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user")
|
||||
user_cursor = connect(username="user", password="").cursor()
|
||||
@ -89,6 +105,7 @@ def test_explicit_transaction_output():
|
||||
superadmin_cursor = connect().cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
|
||||
admin_connection = connect(username="admin", password="")
|
||||
admin_cursor = admin_connection.cursor()
|
||||
# Admin starts running explicit transaction
|
||||
@ -114,8 +131,10 @@ def test_superadmin_cannot_see_admin_can_see_admin():
|
||||
superadmin_cursor = connect().cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin1")
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2")
|
||||
# Admin starts running infinite query
|
||||
admin_connection_1 = connect(username="admin1", password="")
|
||||
admin_cursor_1 = admin_connection_1.cursor()
|
||||
@ -153,6 +172,7 @@ def test_admin_sees_superadmin():
|
||||
superadmin_cursor = superadmin_connection.cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
|
||||
# Admin starts running infinite query
|
||||
process = multiprocessing.Process(
|
||||
target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
|
||||
@ -183,6 +203,7 @@ def test_admin_can_see_user_transaction():
|
||||
superadmin_cursor = connect().cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
|
||||
# Admin starts running infinite query
|
||||
admin_connection = connect(username="admin", password="")
|
||||
@ -220,8 +241,10 @@ def test_user_cannot_see_admin_transaction():
|
||||
superadmin_cursor = connect().cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin1")
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2")
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
|
||||
admin_connection_1 = connect(username="admin1", password="")
|
||||
admin_cursor_1 = admin_connection_1.cursor()
|
||||
@ -282,6 +305,7 @@ def test_admin_killing_multiple_non_existing_transactions():
|
||||
superadmin_cursor = connect().cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
|
||||
# Connect with admin
|
||||
admin_cursor = connect(username="admin", password="").cursor()
|
||||
transactions_id = ["'1'", "'2'", "'3'"]
|
||||
@ -298,6 +322,7 @@ def test_user_killing_some_transactions():
|
||||
superadmin_cursor = connect().cursor()
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
|
||||
execute_and_fetch_all(superadmin_cursor, "CREATE USER user1")
|
||||
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user1")
|
||||
|
||||
|
@ -24,11 +24,14 @@ def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}
|
||||
def connect(**kwargs) -> mgclient.Connection:
|
||||
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
|
||||
connection.autocommit = True
|
||||
execute_and_fetch_all(connection.cursor(), "USE DATABASE memgraph")
|
||||
try:
|
||||
execute_and_fetch_all(connection.cursor(), "DROP DATABASE clean")
|
||||
except:
|
||||
pass
|
||||
execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n")
|
||||
triggers_list = execute_and_fetch_all(connection.cursor(), "SHOW TRIGGERS;")
|
||||
for trigger in triggers_list:
|
||||
execute_and_fetch_all(connection.cursor(), f"DROP TRIGGER {trigger[0]}")
|
||||
execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n")
|
||||
yield connection
|
||||
for trigger in triggers_list:
|
||||
execute_and_fetch_all(connection.cursor(), f"DROP TRIGGER {trigger[0]}")
|
||||
execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -16,13 +16,25 @@ import pytest
|
||||
from common import connect, execute_and_fetch_all
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def multi_db(request, connect):
|
||||
cursor = connect.cursor()
|
||||
if request.param:
|
||||
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "USE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
pass
|
||||
yield connect
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("ba_commit", ["BEFORE COMMIT", "AFTER COMMIT"])
|
||||
def test_create_on_create(ba_commit, connect):
|
||||
def test_create_on_create(ba_commit, multi_db):
|
||||
"""
|
||||
Args:
|
||||
ba_commit (str): BEFORE OR AFTER commit
|
||||
"""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
QUERY_TRIGGER_CREATE = f"""
|
||||
CREATE TRIGGER CreateTriggerEdgesCount
|
||||
ON --> CREATE
|
||||
@ -30,6 +42,7 @@ def test_create_on_create(ba_commit, connect):
|
||||
EXECUTE
|
||||
CREATE (n:CreatedEdge {{count: size(createdEdges)}})
|
||||
"""
|
||||
|
||||
execute_and_fetch_all(cursor, QUERY_TRIGGER_CREATE)
|
||||
execute_and_fetch_all(cursor, "CREATE (n:Node {id: 1})")
|
||||
execute_and_fetch_all(cursor, "CREATE (n:Node {id: 2})")
|
||||
@ -50,14 +63,22 @@ def test_create_on_create(ba_commit, connect):
|
||||
# execute_and_fetch_all(cursor, "DROP TRIGGER CreateTriggerEdgesCount")
|
||||
# execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
|
||||
# check that there is no cross contamination between databases
|
||||
nodes = execute_and_fetch_all(cursor, "SHOW DATABASES")
|
||||
if len(nodes) == 2: # multi db mode
|
||||
execute_and_fetch_all(cursor, "USE DATABASE memgraph")
|
||||
created_edges = execute_and_fetch_all(cursor, "MATCH (n:CreatedEdge) RETURN n")
|
||||
assert len(created_edges) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("ba_commit", ["AFTER COMMIT", "BEFORE COMMIT"])
|
||||
def test_create_on_delete(ba_commit, connect):
|
||||
def test_create_on_delete(ba_commit, multi_db):
|
||||
"""
|
||||
Args:
|
||||
ba_commit (str): BEFORE OR AFTER commit
|
||||
"""
|
||||
cursor = connect.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
QUERY_TRIGGER_CREATE = f"""
|
||||
CREATE TRIGGER DeleteTriggerEdgesCount
|
||||
ON --> DELETE
|
||||
@ -102,7 +123,15 @@ def test_create_on_delete(ba_commit, connect):
|
||||
# execute_and_fetch_all(cursor, "DROP TRIGGER DeleteTriggerEdgesCount")
|
||||
# execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")``
|
||||
|
||||
# check that there is no cross contamination between databases
|
||||
nodes = execute_and_fetch_all(cursor, "SHOW DATABASES")
|
||||
if len(nodes) == 2: # multi db mode
|
||||
execute_and_fetch_all(cursor, "USE DATABASE memgraph")
|
||||
created_edges = execute_and_fetch_all(cursor, "MATCH (n:CreatedEdge) RETURN n")
|
||||
assert len(created_edges) == 0
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
@pytest.mark.parametrize("ba_commit", ["BEFORE COMMIT", "AFTER COMMIT"])
|
||||
def test_create_on_delete_explicit_transaction(ba_commit):
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,13 +9,25 @@
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import typing
|
||||
import mgclient
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import mgclient
|
||||
import pytest
|
||||
from common import execute_and_fetch_all, has_n_result_row
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def multi_db(request, connection):
|
||||
cursor = connection.cursor()
|
||||
if request.param:
|
||||
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "USE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
pass
|
||||
yield connection
|
||||
|
||||
|
||||
def create_subgraph(cursor):
|
||||
execute_and_fetch_all(cursor, "CREATE (n:Person {id: 1});")
|
||||
execute_and_fetch_all(cursor, "CREATE (n:Person {id: 2});")
|
||||
@ -41,8 +53,9 @@ def create_smaller_subgraph(cursor):
|
||||
execute_and_fetch_all(cursor, "MATCH (p:Person {id: 2}) MATCH (t:Team {id:6}) CREATE (p)-[:SUPPORTS]->(t);")
|
||||
|
||||
|
||||
def test_is_callable(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_is_callable(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
|
||||
@ -59,8 +72,9 @@ def test_is_callable(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_incorrect_graph_argument_placement(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_incorrect_graph_argument_placement(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
|
||||
@ -79,8 +93,9 @@ def test_incorrect_graph_argument_placement(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_get_vertices(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_get_vertices(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
|
||||
@ -97,8 +112,9 @@ def test_get_vertices(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_get_out_edges(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_get_out_edges(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
|
||||
@ -115,8 +131,9 @@ def test_get_out_edges(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_get_in_edges(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_get_in_edges(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
|
||||
@ -133,8 +150,9 @@ def test_get_in_edges(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_get_2_hop_edges(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_get_2_hop_edges(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
|
||||
@ -150,8 +168,9 @@ def test_get_2_hop_edges(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_get_out_edges_vertex_id(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_get_out_edges_vertex_id(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor=cursor)
|
||||
|
||||
@ -168,8 +187,9 @@ def test_get_out_edges_vertex_id(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_get_path_vertices(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_get_path_vertices(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
|
||||
@ -185,8 +205,9 @@ def test_subgraph_get_path_vertices(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_get_path_edges(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_get_path_edges(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
|
||||
@ -202,8 +223,9 @@ def test_subgraph_get_path_edges(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_get_path_vertices_in_subgraph(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_get_path_vertices_in_subgraph(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
|
||||
@ -218,8 +240,9 @@ def test_subgraph_get_path_vertices_in_subgraph(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_insert_vertex_get_vertices(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_insert_vertex_get_vertices(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
|
||||
@ -234,8 +257,9 @@ def test_subgraph_insert_vertex_get_vertices(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_insert_edge_get_vertex_out_edges(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_insert_edge_get_vertex_out_edges(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
|
||||
@ -250,8 +274,9 @@ def test_subgraph_insert_edge_get_vertex_out_edges(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
|
||||
@ -267,8 +292,9 @@ def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(connect
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_remove_edge_get_vertex_out_edges(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_remove_edge_get_vertex_out_edges(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
|
||||
@ -283,8 +309,9 @@ def test_subgraph_remove_edge_get_vertex_out_edges(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_remove_edge_not_in_subgraph_error(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_remove_edge_not_in_subgraph_error(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_subgraph(cursor)
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6)
|
||||
@ -299,8 +326,9 @@ def test_subgraph_remove_edge_not_in_subgraph_error(connection):
|
||||
)
|
||||
|
||||
|
||||
def test_subgraph_remove_vertex_and_out_edges_get_vertices(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_subgraph_remove_vertex_and_out_edges_get_vertices(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
|
||||
create_smaller_subgraph(cursor)
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 4)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2021 Memgraph Ltd.
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,17 +9,30 @@
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import typing
|
||||
import mgclient
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import mgclient
|
||||
import pytest
|
||||
from common import execute_and_fetch_all, has_one_result_row, has_n_result_row
|
||||
from common import execute_and_fetch_all, has_n_result_row, has_one_result_row
|
||||
|
||||
|
||||
def test_is_write(connection):
|
||||
@pytest.fixture(scope="function")
|
||||
def multi_db(request, connection):
|
||||
cursor = connection.cursor()
|
||||
if request.param:
|
||||
execute_and_fetch_all(cursor, "CREATE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "USE DATABASE clean")
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
pass
|
||||
yield connection
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_is_write(multi_db):
|
||||
is_write = 2
|
||||
result_order = "name, signature, is_write"
|
||||
cursor = connection.cursor()
|
||||
cursor = multi_db.cursor()
|
||||
for proc in execute_and_fetch_all(
|
||||
cursor,
|
||||
"CALL mg.procedures() YIELD * WITH name, signature, "
|
||||
@ -41,8 +54,9 @@ def test_is_write(connection):
|
||||
assert cursor.description[2].name == "is_write"
|
||||
|
||||
|
||||
def test_single_vertex(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_single_vertex(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")
|
||||
vertex = result[0][0]
|
||||
@ -93,8 +107,9 @@ def test_single_vertex(connection):
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
|
||||
|
||||
def test_single_edge(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_single_edge(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
|
||||
v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
|
||||
@ -134,8 +149,9 @@ def test_single_edge(connection):
|
||||
assert has_n_result_row(cursor, "MATCH ()-[e]->() RETURN e", 0)
|
||||
|
||||
|
||||
def test_detach_delete_vertex(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_detach_delete_vertex(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
|
||||
v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
|
||||
@ -156,8 +172,9 @@ def test_detach_delete_vertex(connection):
|
||||
assert has_one_result_row(cursor, f"MATCH (n) WHERE id(n) = {v2_id} RETURN n")
|
||||
|
||||
|
||||
def test_graph_mutability(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_graph_mutability(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
|
||||
v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id
|
||||
@ -193,8 +210,9 @@ def test_graph_mutability(connection):
|
||||
test_mutability(False)
|
||||
|
||||
|
||||
def test_log_message(connection):
|
||||
cursor = connection.cursor()
|
||||
@pytest.mark.parametrize("multi_db", [False, True], indirect=True)
|
||||
def test_log_message(multi_db):
|
||||
cursor = multi_db.cursor()
|
||||
success = execute_and_fetch_all(cursor, f"CALL read.log_message('message') YIELD success RETURN success")[0][0]
|
||||
assert (success) is True
|
||||
|
||||
|
@ -21,6 +21,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
DEFAULT_DB = "memgraph"
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
|
||||
@ -37,26 +38,24 @@ QUERIES = [
|
||||
("CREATE (n {name: $name})", {"name": 5, "leftover": 42}),
|
||||
("MATCH (n), (m) CREATE (n)-[:e {when: $when}]->(m)", {"when": 42}),
|
||||
("MATCH (n) RETURN n", {}),
|
||||
(
|
||||
"MATCH (n), (m {type: $type}) RETURN count(n), count(m)",
|
||||
{"type": "dadada"}
|
||||
),
|
||||
("MATCH (n), (m {type: $type}) RETURN count(n), count(m)", {"type": "dadada"}),
|
||||
(
|
||||
"MERGE (n) ON CREATE SET n.created = timestamp() "
|
||||
"ON MATCH SET n.lastSeen = timestamp() "
|
||||
"RETURN n.name, n.created, n.lastSeen",
|
||||
{}
|
||||
),
|
||||
(
|
||||
"MATCH (n {value: $value}) SET n.value = 0 RETURN n",
|
||||
{"value": "nandare!"}
|
||||
{},
|
||||
),
|
||||
("MATCH (n {value: $value}) SET n.value = 0 RETURN n", {"value": "nandare!"}),
|
||||
("MATCH (n), (m) SET n.value = m.value", {}),
|
||||
("MATCH (n {test: $test}) REMOVE n.value", {"test": 48}),
|
||||
("MATCH (n), (m) REMOVE n.value, m.value", {}),
|
||||
("CREATE INDEX ON :User (id)", {}),
|
||||
]
|
||||
|
||||
CREATE_DB_QUERIES = [
|
||||
("CREATE DATABASE clean", {}),
|
||||
]
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
||||
@ -65,6 +64,13 @@ def wait_for_server(port, delay=0.1):
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def gen_mt_queries(queries, db):
|
||||
out = []
|
||||
for query, params in queries:
|
||||
out.append((db, query, params))
|
||||
return out
|
||||
|
||||
|
||||
def execute_test(memgraph_binary, tester_binary):
|
||||
storage_directory = tempfile.TemporaryDirectory()
|
||||
memgraph_args = [
|
||||
@ -74,7 +80,8 @@ def execute_test(memgraph_binary, tester_binary):
|
||||
storage_directory.name,
|
||||
"--audit-enabled",
|
||||
"--log-file=memgraph.log",
|
||||
"--log-level=TRACE"]
|
||||
"--log-level=TRACE",
|
||||
]
|
||||
|
||||
# Start the memgraph binary
|
||||
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
||||
@ -90,17 +97,31 @@ def execute_test(memgraph_binary, tester_binary):
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
|
||||
def execute_queries(queries):
|
||||
for query, params in queries:
|
||||
for db, query, params in queries:
|
||||
print(query, params)
|
||||
args = [tester_binary, "--query", query,
|
||||
"--params-json", json.dumps(params)]
|
||||
args = [tester_binary, "--query", query, "--use-db", db, "--params-json", json.dumps(params)]
|
||||
subprocess.run(args).check_returncode()
|
||||
|
||||
# Test default db
|
||||
mt_queries = gen_mt_queries(QUERIES, DEFAULT_DB)
|
||||
|
||||
# Execute all queries
|
||||
print("\033[1;36m~~ Starting query execution ~~\033[0m")
|
||||
execute_queries(QUERIES)
|
||||
execute_queries(mt_queries)
|
||||
print("\033[1;36m~~ Finished query execution ~~\033[0m\n")
|
||||
|
||||
# Test new db
|
||||
print("\033[1;36m~~ Creating clean database ~~\033[0m")
|
||||
mt_queries2 = gen_mt_queries(CREATE_DB_QUERIES, DEFAULT_DB)
|
||||
execute_queries(mt_queries2)
|
||||
print("\033[1;36m~~ Finished creating clean database ~~\033[0m\n")
|
||||
|
||||
# Execute all queries on clean database
|
||||
mt_queries3 = gen_mt_queries(QUERIES, "clean")
|
||||
print("\033[1;36m~~ Starting query execution on clean database ~~\033[0m")
|
||||
execute_queries(mt_queries3)
|
||||
print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n")
|
||||
|
||||
# Shutdown the memgraph binary
|
||||
memgraph.terminate()
|
||||
|
||||
@ -109,26 +130,37 @@ def execute_test(memgraph_binary, tester_binary):
|
||||
# Verify the written log
|
||||
print("\033[1;36m~~ Starting log verification ~~\033[0m")
|
||||
with open(os.path.join(storage_directory.name, "audit", "audit.log")) as f:
|
||||
reader = csv.reader(f, delimiter=',', doublequote=False,
|
||||
escapechar='\\', lineterminator='\n',
|
||||
quotechar='"', quoting=csv.QUOTE_MINIMAL,
|
||||
skipinitialspace=False, strict=True)
|
||||
reader = csv.reader(
|
||||
f,
|
||||
delimiter=",",
|
||||
doublequote=False,
|
||||
escapechar="\\",
|
||||
lineterminator="\n",
|
||||
quotechar='"',
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
skipinitialspace=False,
|
||||
strict=True,
|
||||
)
|
||||
queries = []
|
||||
for line in reader:
|
||||
timestamp, address, username, query, params = line
|
||||
timestamp, address, username, database, query, params = line
|
||||
params = json.loads(params)
|
||||
queries.append((query, params))
|
||||
print(query, params)
|
||||
if query.startswith("USE DATABASE"):
|
||||
continue # Skip all databases switching queries
|
||||
queries.append((database, query, params))
|
||||
print(database, query, params)
|
||||
|
||||
assert queries == QUERIES, "Logged queries don't match " \
|
||||
"executed queries!"
|
||||
# Combine all queries executed
|
||||
all_queries = mt_queries
|
||||
all_queries += mt_queries2
|
||||
all_queries += mt_queries3
|
||||
assert queries == all_queries, "Logged queries don't match " "executed queries!"
|
||||
print("\033[1;36m~~ Finished log verification ~~\033[0m\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
|
||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
|
||||
"integration", "audit", "tester")
|
||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "audit", "tester")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,6 +9,7 @@
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <fmt/core.h>
|
||||
#include <gflags/gflags.h>
|
||||
|
||||
#include <json/json.hpp>
|
||||
@ -25,6 +26,7 @@ DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
|
||||
DEFINE_string(query, "", "Query to execute");
|
||||
DEFINE_string(params_json, "{}", "Params for the query");
|
||||
DEFINE_string(use_db, "memgraph", "Database to run the query against");
|
||||
|
||||
memgraph::communication::bolt::Value JsonToValue(const nlohmann::json &jv) {
|
||||
memgraph::communication::bolt::Value ret;
|
||||
@ -89,6 +91,7 @@ int main(int argc, char **argv) {
|
||||
memgraph::communication::bolt::Client client(context);
|
||||
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
client.Execute(fmt::format("USE DATABASE {}", FLAGS_use_db), {});
|
||||
client.Execute(FLAGS_query, JsonToValue(nlohmann::json::parse(FLAGS_params_json)).ValueMap());
|
||||
|
||||
return 0;
|
||||
|
@ -29,15 +29,8 @@ PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
|
||||
QUERIES = [
|
||||
# CREATE
|
||||
(
|
||||
"CREATE (n)",
|
||||
("CREATE",)
|
||||
),
|
||||
(
|
||||
"MATCH (n), (m) CREATE (n)-[:e]->(m)",
|
||||
("CREATE", "MATCH")
|
||||
),
|
||||
|
||||
("CREATE (n)", ("CREATE",)),
|
||||
("MATCH (n), (m) CREATE (n)-[:e]->(m)", ("CREATE", "MATCH")),
|
||||
# DELETE
|
||||
(
|
||||
"MATCH (n) DELETE n",
|
||||
@ -47,116 +40,43 @@ QUERIES = [
|
||||
"MATCH (n) DETACH DELETE n",
|
||||
("DELETE", "MATCH"),
|
||||
),
|
||||
|
||||
# MATCH
|
||||
(
|
||||
"MATCH (n) RETURN n",
|
||||
("MATCH",)
|
||||
),
|
||||
(
|
||||
"MATCH (n), (m) RETURN count(n), count(m)",
|
||||
("MATCH",)
|
||||
),
|
||||
|
||||
("MATCH (n) RETURN n", ("MATCH",)),
|
||||
("MATCH (n), (m) RETURN count(n), count(m)", ("MATCH",)),
|
||||
# MERGE
|
||||
(
|
||||
"MERGE (n) ON CREATE SET n.created = timestamp() "
|
||||
"ON MATCH SET n.lastSeen = timestamp() "
|
||||
"RETURN n.name, n.created, n.lastSeen",
|
||||
("MERGE",)
|
||||
("MERGE",),
|
||||
),
|
||||
|
||||
# SET
|
||||
(
|
||||
"MATCH (n) SET n.value = 0 RETURN n",
|
||||
("SET", "MATCH")
|
||||
),
|
||||
(
|
||||
"MATCH (n), (m) SET n.value = m.value",
|
||||
("SET", "MATCH")
|
||||
),
|
||||
|
||||
("MATCH (n) SET n.value = 0 RETURN n", ("SET", "MATCH")),
|
||||
("MATCH (n), (m) SET n.value = m.value", ("SET", "MATCH")),
|
||||
# REMOVE
|
||||
(
|
||||
"MATCH (n) REMOVE n.value",
|
||||
("REMOVE", "MATCH")
|
||||
),
|
||||
(
|
||||
"MATCH (n), (m) REMOVE n.value, m.value",
|
||||
("REMOVE", "MATCH")
|
||||
),
|
||||
|
||||
("MATCH (n) REMOVE n.value", ("REMOVE", "MATCH")),
|
||||
("MATCH (n), (m) REMOVE n.value, m.value", ("REMOVE", "MATCH")),
|
||||
# INDEX
|
||||
(
|
||||
"CREATE INDEX ON :User (id)",
|
||||
("INDEX",)
|
||||
),
|
||||
|
||||
("CREATE INDEX ON :User (id)", ("INDEX",)),
|
||||
# AUTH
|
||||
(
|
||||
"CREATE ROLE test_role",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"DROP ROLE test_role",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"SHOW ROLES",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"CREATE USER test_user",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"SET PASSWORD FOR test_user TO '1234'",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"DROP USER test_user",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"SHOW USERS",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"SET ROLE FOR test_user TO test_role",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"CLEAR ROLE FOR test_user",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"GRANT ALL PRIVILEGES TO test_user",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"DENY ALL PRIVILEGES TO test_user",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"REVOKE ALL PRIVILEGES FROM test_user",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"SHOW PRIVILEGES FOR test_user",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"SHOW ROLE FOR test_user",
|
||||
("AUTH",)
|
||||
),
|
||||
(
|
||||
"SHOW USERS FOR test_role",
|
||||
("AUTH",)
|
||||
),
|
||||
("CREATE ROLE test_role", ("AUTH",)),
|
||||
("DROP ROLE test_role", ("AUTH",)),
|
||||
("SHOW ROLES", ("AUTH",)),
|
||||
("CREATE USER test_user", ("AUTH",)),
|
||||
("SET PASSWORD FOR test_user TO '1234'", ("AUTH",)),
|
||||
("DROP USER test_user", ("AUTH",)),
|
||||
("SHOW USERS", ("AUTH",)),
|
||||
("SET ROLE FOR test_user TO test_role", ("AUTH",)),
|
||||
("CLEAR ROLE FOR test_user", ("AUTH",)),
|
||||
("GRANT ALL PRIVILEGES TO test_user", ("AUTH",)),
|
||||
("DENY ALL PRIVILEGES TO test_user", ("AUTH",)),
|
||||
("REVOKE ALL PRIVILEGES FROM test_user", ("AUTH",)),
|
||||
("SHOW PRIVILEGES FOR test_user", ("AUTH",)),
|
||||
("SHOW ROLE FOR test_user", ("AUTH",)),
|
||||
("SHOW USERS FOR test_role", ("AUTH",)),
|
||||
]
|
||||
|
||||
UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " \
|
||||
"contact your database administrator."
|
||||
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
@ -166,8 +86,16 @@ def wait_for_server(port, delay=0.1):
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def execute_tester(binary, queries, should_fail=False, failure_message="",
|
||||
username="", password="", check_failure=True):
|
||||
def execute_tester(
|
||||
binary,
|
||||
queries,
|
||||
should_fail=False,
|
||||
failure_message="",
|
||||
username="",
|
||||
password="",
|
||||
check_failure=True,
|
||||
connection_should_fail=False,
|
||||
):
|
||||
args = [binary, "--username", username, "--password", password]
|
||||
if should_fail:
|
||||
args.append("--should-fail")
|
||||
@ -175,6 +103,8 @@ def execute_tester(binary, queries, should_fail=False, failure_message="",
|
||||
args.extend(["--failure-message", failure_message])
|
||||
if check_failure:
|
||||
args.append("--check-failure")
|
||||
if connection_should_fail:
|
||||
args.append("--connection-should-fail")
|
||||
args.extend(queries)
|
||||
subprocess.run(args).check_returncode()
|
||||
|
||||
@ -200,18 +130,31 @@ def check_permissions(query_perms, user_perms):
|
||||
|
||||
def execute_test(memgraph_binary, tester_binary, checker_binary):
|
||||
storage_directory = tempfile.TemporaryDirectory()
|
||||
memgraph_args = [memgraph_binary,
|
||||
"--data-directory", storage_directory.name]
|
||||
memgraph_args = [memgraph_binary, "--data-directory", storage_directory.name]
|
||||
|
||||
def execute_admin_queries(queries):
|
||||
return execute_tester(tester_binary, queries, should_fail=False,
|
||||
check_failure=True, username="admin",
|
||||
password="admin")
|
||||
return execute_tester(
|
||||
tester_binary, queries, should_fail=False, check_failure=True, username="admin", password="admin"
|
||||
)
|
||||
|
||||
def execute_user_queries(queries, should_fail=False, failure_message="",
|
||||
check_failure=True):
|
||||
return execute_tester(tester_binary, queries, should_fail,
|
||||
failure_message, "user", "user", check_failure)
|
||||
def execute_user_queries(
|
||||
queries,
|
||||
should_fail=False,
|
||||
failure_message="",
|
||||
check_failure=True,
|
||||
username="user",
|
||||
connection_should_fail=False,
|
||||
):
|
||||
return execute_tester(
|
||||
tester_binary,
|
||||
queries,
|
||||
should_fail,
|
||||
failure_message,
|
||||
username,
|
||||
"user",
|
||||
check_failure,
|
||||
connection_should_fail,
|
||||
)
|
||||
|
||||
# Start the memgraph binary
|
||||
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
||||
@ -226,12 +169,33 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
|
||||
# Prepare the multi database environment
|
||||
execute_admin_queries(
|
||||
[
|
||||
"CREATE DATABASE db1",
|
||||
"CREATE DATABASE db2",
|
||||
]
|
||||
)
|
||||
|
||||
# Prepare all users
|
||||
execute_admin_queries([
|
||||
"CREATE USER ADmin IDENTIFIED BY 'admin'",
|
||||
"GRANT ALL PRIVILEGES TO admIN",
|
||||
"CREATE USER usEr IDENTIFIED BY 'user'",
|
||||
])
|
||||
execute_admin_queries(
|
||||
[
|
||||
"CREATE USER ADmin IDENTIFIED BY 'admin'",
|
||||
"GRANT ALL PRIVILEGES TO admIN",
|
||||
"GRANT DATABASE * TO admin",
|
||||
"CREATE USER usEr IDENTIFIED BY 'user'",
|
||||
"GRANT DATABASE db1 TO user",
|
||||
"GRANT DATABASE db2 TO user",
|
||||
"CREATE USER useR2 IDENTIFIED BY 'user'",
|
||||
"GRANT DATABASE db2 TO user2",
|
||||
"REVOKE DATABASE memgraph FROM user2",
|
||||
"SET MAIN DATABASE db2 FOR user2",
|
||||
"CREATE USER user3 IDENTIFIED BY 'user'",
|
||||
"GRANT ALL PRIVILEGES TO user3",
|
||||
"GRANT DATABASE * TO user3",
|
||||
"REVOKE DATABASE memgraph FROM user3",
|
||||
]
|
||||
)
|
||||
|
||||
# Find all existing permissions
|
||||
permissions = set()
|
||||
@ -241,14 +205,99 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
|
||||
|
||||
# Run the test with all combinations of permissions
|
||||
print("\033[1;36m~~ Starting query test ~~\033[0m")
|
||||
for db in ["memgraph", "db1"]:
|
||||
print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db))
|
||||
execute_user_queries(["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR)
|
||||
execute_admin_queries(["GRANT MULTI_DATABASE_USE TO User"])
|
||||
execute_user_queries(["USE DATABASE {}".format(db)], check_failure=False, failure_message=UNAUTHORIZED_ERROR)
|
||||
for mask in range(0, 2 ** len(permissions)):
|
||||
user_perms = get_permissions(permissions, mask)
|
||||
print("\033[1;34m~~ Checking queries with privileges: ", ", ".join(user_perms), " ~~\033[0m")
|
||||
admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer"]
|
||||
if len(user_perms) > 0:
|
||||
admin_queries.append("GRANT {} TO User".format(", ".join(user_perms)))
|
||||
execute_admin_queries(admin_queries)
|
||||
authorized, unauthorized = [], []
|
||||
for query, query_perms in QUERIES:
|
||||
if check_permissions(query_perms, user_perms):
|
||||
authorized.append(query)
|
||||
else:
|
||||
unauthorized.append(query)
|
||||
execute_user_queries(authorized, check_failure=False, failure_message=UNAUTHORIZED_ERROR)
|
||||
execute_user_queries(unauthorized, should_fail=True, failure_message=UNAUTHORIZED_ERROR)
|
||||
print("\033[1;36m~~ Finished query test ~~\033[0m\n")
|
||||
|
||||
# Run the user/role permissions test
|
||||
print("\033[1;36m~~ Starting permissions test ~~\033[0m")
|
||||
execute_admin_queries(
|
||||
[
|
||||
"CREATE ROLE roLe",
|
||||
"REVOKE ALL PRIVILEGES FROM uSeR",
|
||||
]
|
||||
)
|
||||
execute_checker(checker_binary, [])
|
||||
for db in ["memgraph", "db1"]:
|
||||
print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db))
|
||||
execute_user_queries(["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR)
|
||||
execute_admin_queries(["GRANT MULTI_DATABASE_USE TO User"])
|
||||
execute_user_queries(["USE DATABASE {}".format(db)], check_failure=False, failure_message=UNAUTHORIZED_ERROR)
|
||||
execute_admin_queries(["REVOKE MULTI_DATABASE_USE FROM User"])
|
||||
for user_perm in ["GRANT", "DENY", "REVOKE"]:
|
||||
for role_perm in ["GRANT", "DENY", "REVOKE"]:
|
||||
for mapped in [True, False]:
|
||||
print(
|
||||
"\033[1;34m~~ Checking permissions with user ",
|
||||
user_perm,
|
||||
", role ",
|
||||
role_perm,
|
||||
"user mapped to role:",
|
||||
mapped,
|
||||
" ~~\033[0m",
|
||||
)
|
||||
if mapped:
|
||||
execute_admin_queries(["SET ROLE FOR USER TO roLE"])
|
||||
else:
|
||||
execute_admin_queries(["CLEAR ROLE FOR user"])
|
||||
user_prep = "FROM" if user_perm == "REVOKE" else "TO"
|
||||
role_prep = "FROM" if role_perm == "REVOKE" else "TO"
|
||||
execute_admin_queries(
|
||||
[
|
||||
"{} MATCH {} user".format(user_perm, user_prep),
|
||||
"{} MATCH {} rOLe".format(role_perm, role_prep),
|
||||
]
|
||||
)
|
||||
expected = []
|
||||
perms = [user_perm, role_perm] if mapped else [user_perm]
|
||||
if "DENY" in perms:
|
||||
expected = ["MATCH", "DENY"]
|
||||
elif "GRANT" in perms:
|
||||
expected = ["MATCH", "GRANT"]
|
||||
if len(expected) > 0:
|
||||
details = []
|
||||
if user_perm == "GRANT":
|
||||
details.append("GRANTED TO USER")
|
||||
elif user_perm == "DENY":
|
||||
details.append("DENIED TO USER")
|
||||
if mapped:
|
||||
if role_perm == "GRANT":
|
||||
details.append("GRANTED TO ROLE")
|
||||
elif role_perm == "DENY":
|
||||
details.append("DENIED TO ROLE")
|
||||
expected.append(", ".join(details))
|
||||
execute_checker(checker_binary, expected)
|
||||
print("\033[1;36m~~ Finished permissions test ~~\033[0m\n")
|
||||
|
||||
# Check database access
|
||||
# user has access to every db (with global privileges) <- tested above
|
||||
# user2 has access only to db2 (and it set to default)
|
||||
# user3 has access only to db2, but the default db is set to default (shouldn't even connect)
|
||||
print("\033[1;36m~~ Checking privileges with custom default db ~~\033[0m\n")
|
||||
for mask in range(0, 2 ** len(permissions)):
|
||||
user_perms = get_permissions(permissions, mask)
|
||||
print("\033[1;34m~~ Checking queries with privileges: ",
|
||||
", ".join(user_perms), " ~~\033[0m")
|
||||
admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer"]
|
||||
print("\033[1;34m~~ Checking queries with privileges: ", ", ".join(user_perms), " ~~\033[0m")
|
||||
admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer2"]
|
||||
if len(user_perms) > 0:
|
||||
admin_queries.append(
|
||||
"GRANT {} TO User".format(", ".join(user_perms)))
|
||||
admin_queries.append("GRANT {} TO User2".format(", ".join(user_perms)))
|
||||
execute_admin_queries(admin_queries)
|
||||
authorized, unauthorized = [], []
|
||||
for query, query_perms in QUERIES:
|
||||
@ -256,55 +305,26 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
|
||||
authorized.append(query)
|
||||
else:
|
||||
unauthorized.append(query)
|
||||
execute_user_queries(authorized, check_failure=False,
|
||||
failure_message=UNAUTHORIZED_ERROR)
|
||||
execute_user_queries(unauthorized, should_fail=True,
|
||||
failure_message=UNAUTHORIZED_ERROR)
|
||||
print("\033[1;36m~~ Finished query test ~~\033[0m\n")
|
||||
execute_user_queries(authorized, check_failure=False, failure_message=UNAUTHORIZED_ERROR, username="user2")
|
||||
execute_user_queries(unauthorized, should_fail=True, failure_message=UNAUTHORIZED_ERROR, username="user2")
|
||||
print("\033[1;36m~~ Finished custom default db checks ~~\033[0m\n")
|
||||
|
||||
# Run the user/role permissions test
|
||||
print("\033[1;36m~~ Starting permissions test ~~\033[0m")
|
||||
execute_admin_queries([
|
||||
"CREATE ROLE roLe",
|
||||
"REVOKE ALL PRIVILEGES FROM uSeR",
|
||||
])
|
||||
execute_checker(checker_binary, [])
|
||||
for user_perm in ["GRANT", "DENY", "REVOKE"]:
|
||||
for role_perm in ["GRANT", "DENY", "REVOKE"]:
|
||||
for mapped in [True, False]:
|
||||
print("\033[1;34m~~ Checking permissions with user ",
|
||||
user_perm, ", role ", role_perm,
|
||||
"user mapped to role:", mapped, " ~~\033[0m")
|
||||
if mapped:
|
||||
execute_admin_queries(["SET ROLE FOR USER TO roLE"])
|
||||
else:
|
||||
execute_admin_queries(["CLEAR ROLE FOR user"])
|
||||
user_prep = "FROM" if user_perm == "REVOKE" else "TO"
|
||||
role_prep = "FROM" if role_perm == "REVOKE" else "TO"
|
||||
execute_admin_queries([
|
||||
"{} MATCH {} user".format(user_perm, user_prep),
|
||||
"{} MATCH {} rOLe".format(role_perm, role_prep)
|
||||
])
|
||||
expected = []
|
||||
perms = [user_perm, role_perm] if mapped else [user_perm]
|
||||
if "DENY" in perms:
|
||||
expected = ["MATCH", "DENY"]
|
||||
elif "GRANT" in perms:
|
||||
expected = ["MATCH", "GRANT"]
|
||||
if len(expected) > 0:
|
||||
details = []
|
||||
if user_perm == "GRANT":
|
||||
details.append("GRANTED TO USER")
|
||||
elif user_perm == "DENY":
|
||||
details.append("DENIED TO USER")
|
||||
if mapped:
|
||||
if role_perm == "GRANT":
|
||||
details.append("GRANTED TO ROLE")
|
||||
elif role_perm == "DENY":
|
||||
details.append("DENIED TO ROLE")
|
||||
expected.append(", ".join(details))
|
||||
execute_checker(checker_binary, expected)
|
||||
print("\033[1;36m~~ Finished permissions test ~~\033[0m\n")
|
||||
print("\033[1;36m~~ Checking connections and database switching ~~\033[0m\n")
|
||||
for db in ["memgraph", "db1"]:
|
||||
print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db))
|
||||
execute_admin_queries(["GRANT {} TO User2".format("MULTI_DATABASE_USE")])
|
||||
execute_user_queries(
|
||||
["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR, username="user2"
|
||||
)
|
||||
print("\033[1;36m~~ Running with user3 (shouldn't even connect) ~~\033[0m")
|
||||
execute_admin_queries(["GRANT {} TO User3".format("MULTI_DATABASE_USE")])
|
||||
execute_user_queries(
|
||||
["USE DATABASE db2"],
|
||||
connection_should_fail=True,
|
||||
failure_message="Couldn't communicate with the server!",
|
||||
username="user3",
|
||||
)
|
||||
print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n")
|
||||
|
||||
# Shutdown the memgraph binary
|
||||
memgraph.terminate()
|
||||
@ -313,10 +333,8 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
|
||||
|
||||
if __name__ == "__main__":
|
||||
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
|
||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
|
||||
"integration", "auth", "tester")
|
||||
checker_binary = os.path.join(PROJECT_DIR, "build", "tests",
|
||||
"integration", "auth", "checker")
|
||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "auth", "tester")
|
||||
checker_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "auth", "checker")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -9,6 +9,8 @@
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <regex>
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
|
||||
#include "communication/bolt/client.hpp"
|
||||
@ -23,6 +25,7 @@ DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
|
||||
DEFINE_bool(check_failure, false, "Set to true to enable failure checking.");
|
||||
DEFINE_bool(should_fail, false, "Set to true to expect a failure.");
|
||||
DEFINE_bool(connection_should_fail, false, "Set to true to expect a connection failure.");
|
||||
DEFINE_string(failure_message, "", "Set to the expected failure message.");
|
||||
|
||||
/**
|
||||
@ -40,7 +43,26 @@ int main(int argc, char **argv) {
|
||||
memgraph::communication::ClientContext context(FLAGS_use_ssl);
|
||||
memgraph::communication::bolt::Client client(context);
|
||||
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
std::regex re(FLAGS_failure_message);
|
||||
|
||||
try {
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
} catch (const memgraph::communication::bolt::ClientFatalException &e) {
|
||||
if (FLAGS_connection_should_fail) {
|
||||
if (!FLAGS_failure_message.empty() && !std::regex_match(e.what(), re)) {
|
||||
LOG_FATAL(
|
||||
"The connection should have failed with an error message of '{}'' but "
|
||||
"instead it failed with '{}'",
|
||||
FLAGS_failure_message, e.what());
|
||||
}
|
||||
return 0;
|
||||
} else {
|
||||
LOG_FATAL(
|
||||
"The connection shoudn't have failed but it failed with an "
|
||||
"error message '{}'",
|
||||
e.what());
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
std::string query(argv[i]);
|
||||
@ -48,7 +70,7 @@ int main(int argc, char **argv) {
|
||||
client.Execute(query, {});
|
||||
} catch (const memgraph::communication::bolt::ClientQueryException &e) {
|
||||
if (!FLAGS_check_failure) {
|
||||
if (!FLAGS_failure_message.empty() && e.what() == FLAGS_failure_message) {
|
||||
if (!FLAGS_failure_message.empty() && std::regex_match(e.what(), re)) {
|
||||
LOG_FATAL(
|
||||
"The query should have succeeded or failed with an error "
|
||||
"message that isn't equal to '{}' but it failed with that error "
|
||||
@ -58,7 +80,7 @@ int main(int argc, char **argv) {
|
||||
continue;
|
||||
}
|
||||
if (FLAGS_should_fail) {
|
||||
if (!FLAGS_failure_message.empty() && e.what() != FLAGS_failure_message) {
|
||||
if (!FLAGS_failure_message.empty() && !std::regex_match(e.what(), re)) {
|
||||
LOG_FATAL(
|
||||
"The query should have failed with an error message of '{}'' but "
|
||||
"instead it failed with '{}'",
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -22,6 +22,7 @@ DEFINE_int32(port, 7687, "Server port");
|
||||
DEFINE_string(username, "admin", "Username for the database");
|
||||
DEFINE_string(password, "admin", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
DEFINE_string(use_db, "memgraph", "Database to run the query against");
|
||||
|
||||
/**
|
||||
* Verifies that user 'user' has privileges that are given as positional
|
||||
@ -38,6 +39,7 @@ int main(int argc, char **argv) {
|
||||
memgraph::communication::bolt::Client client(context);
|
||||
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
client.Execute(fmt::format("USE DATABASE {}", FLAGS_use_db), {});
|
||||
|
||||
try {
|
||||
std::string query(argv[1]);
|
||||
|
@ -23,7 +23,7 @@ from typing import List
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
|
||||
UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " "contact your database administrator."
|
||||
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
@ -47,8 +47,10 @@ def execute_tester(
|
||||
subprocess.run(args).check_returncode()
|
||||
|
||||
|
||||
def execute_filtering(binary: str, queries: List[str], expected: int, username: str = "", password: str = "") -> None:
|
||||
args = [binary, "--username", username, "--password", password]
|
||||
def execute_filtering(
|
||||
binary: str, queries: List[str], expected: int, username: str = "", password: str = "", db: str = "memgraph"
|
||||
) -> None:
|
||||
args = [binary, "--username", username, "--password", password, "--use-db", db]
|
||||
|
||||
args.extend(queries)
|
||||
args.append(str(expected))
|
||||
@ -82,35 +84,48 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
|
||||
# Prepare all users
|
||||
execute_admin_queries(
|
||||
[
|
||||
"CREATE USER admin IDENTIFIED BY 'admin'",
|
||||
"GRANT ALL PRIVILEGES TO admin",
|
||||
"CREATE USER user IDENTIFIED BY 'user'",
|
||||
"GRANT ALL PRIVILEGES TO user",
|
||||
"GRANT LABELS :label1, :label2, :label3 TO user",
|
||||
"GRANT EDGE_TYPES :edgeType1, :edgeType2 TO user",
|
||||
"MERGE (l1:label1 {name: 'test1'})",
|
||||
"MERGE (l2:label2 {name: 'test2'})",
|
||||
"MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2)",
|
||||
"MERGE (l3:label3 {name: 'test3'})",
|
||||
"MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3)",
|
||||
"MERGE (mix:label3:label1 {name: 'test4'})",
|
||||
"MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix)",
|
||||
]
|
||||
)
|
||||
def setup_user():
|
||||
execute_admin_queries(
|
||||
[
|
||||
"CREATE USER admin IDENTIFIED BY 'admin'",
|
||||
"GRANT ALL PRIVILEGES TO admin",
|
||||
"CREATE USER user IDENTIFIED BY 'user'",
|
||||
"GRANT ALL PRIVILEGES TO user",
|
||||
"GRANT LABELS :label1, :label2, :label3 TO user",
|
||||
"GRANT EDGE_TYPES :edgeType1, :edgeType2 TO user",
|
||||
]
|
||||
)
|
||||
|
||||
def db_setup():
|
||||
execute_admin_queries(
|
||||
[
|
||||
"MERGE (l1:label1 {name: 'test1'})",
|
||||
"MERGE (l2:label2 {name: 'test2'})",
|
||||
"MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2)",
|
||||
"MERGE (l3:label3 {name: 'test3'})",
|
||||
"MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3)",
|
||||
"MERGE (mix:label3:label1 {name: 'test4'})",
|
||||
"MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix)",
|
||||
]
|
||||
)
|
||||
|
||||
db_setup() # default db setup
|
||||
execute_admin_queries(["CREATE DATABASE db1", "USE DATABASE db1"])
|
||||
db_setup() # db1 setup
|
||||
|
||||
# Run the test with all combinations of permissions
|
||||
print("\033[1;36m~~ Starting edge filtering test ~~\033[0m")
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 3, "user", "user")
|
||||
execute_admin_queries(["DENY EDGE_TYPES :edgeType1 TO user"])
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 2, "user", "user")
|
||||
execute_admin_queries(["GRANT EDGE_TYPES :edgeType1 TO user", "DENY LABELS :label3 TO user"])
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 1, "user", "user")
|
||||
execute_admin_queries(["DENY LABELS :label1 TO user"])
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user")
|
||||
execute_admin_queries(["REVOKE LABELS * FROM user", "REVOKE EDGE_TYPES * FROM user"])
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user")
|
||||
for db in ["memgraph", "db1"]:
|
||||
setup_user()
|
||||
# Run the test with all combinations of permissions
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 3, "user", "user", db)
|
||||
execute_admin_queries(["DENY EDGE_TYPES :edgeType1 TO user"])
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 2, "user", "user", db)
|
||||
execute_admin_queries(["GRANT EDGE_TYPES :edgeType1 TO user", "DENY LABELS :label3 TO user"])
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 1, "user", "user", db)
|
||||
execute_admin_queries(["DENY LABELS :label1 TO user"])
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user", db)
|
||||
execute_admin_queries(["REVOKE LABELS * FROM user", "REVOKE EDGE_TYPES * FROM user"])
|
||||
execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user", db)
|
||||
|
||||
print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n")
|
||||
|
||||
|
@ -14,10 +14,14 @@ while ! nc -z -w 1 127.0.0.1 7687; do
|
||||
sleep 0.5
|
||||
done
|
||||
|
||||
# Start the test.
|
||||
# Start the test on default db.
|
||||
$binary_dir/tests/integration/transactions/tester
|
||||
code=$?
|
||||
|
||||
# Start the test on another db.
|
||||
$binary_dir/tests/integration/transactions/tester --use-db db1
|
||||
code2=$?
|
||||
|
||||
# Shutdown the memgraph process.
|
||||
kill $pid
|
||||
wait $pid
|
||||
@ -30,4 +34,12 @@ if [ $code_mg -ne 0 ]; then
|
||||
fi
|
||||
|
||||
# Exit with the exitcode of the test.
|
||||
exit $code
|
||||
if [ $code -ne 0 ]; then
|
||||
echo "Default database tests failed!"
|
||||
exit $code
|
||||
fi
|
||||
|
||||
if [ $code2 -ne 0 ]; then
|
||||
echo "Non default database tests failed!"
|
||||
exit $code2
|
||||
fi
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -11,6 +11,7 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <fmt/core.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -24,12 +25,16 @@ DEFINE_int32(port, 7687, "Server port");
|
||||
DEFINE_string(username, "", "Username for the database");
|
||||
DEFINE_string(password, "", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
DEFINE_string(use_db, "memgraph", "Database to run the query against");
|
||||
|
||||
using namespace memgraph::communication::bolt;
|
||||
|
||||
class BoltClient : public ::testing::Test {
|
||||
protected:
|
||||
virtual void SetUp() { client_.Connect(endpoint_, FLAGS_username, FLAGS_password); }
|
||||
virtual void SetUp() {
|
||||
client_.Connect(endpoint_, FLAGS_username, FLAGS_password);
|
||||
Execute("CREATE DATABASE db1");
|
||||
}
|
||||
|
||||
virtual void TearDown() {}
|
||||
|
||||
@ -90,6 +95,15 @@ const std::string kCommitInvalid =
|
||||
"Transaction can't be committed because there was a previous error. Please "
|
||||
"invoke a rollback instead.";
|
||||
|
||||
TEST_F(BoltClient, SelectDB) { Execute(fmt::format("USE DATABASE {}", FLAGS_use_db)); }
|
||||
|
||||
TEST_F(BoltClient, SelectDBUnderTx) {
|
||||
EXPECT_TRUE(Execute("begin"));
|
||||
EXPECT_THROW(Execute("USE DATABASE memgraph", "Multi-database queries are not allowed in multicommand transactions."),
|
||||
ClientQueryException);
|
||||
EXPECT_FALSE(TransactionActive());
|
||||
}
|
||||
|
||||
TEST_F(BoltClient, CommitWithoutTransaction) {
|
||||
EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException);
|
||||
EXPECT_FALSE(TransactionActive());
|
||||
|
@ -36,7 +36,7 @@ int main(int argc, char *argv[]) {
|
||||
memgraph::query::Interpreter interpreter{&interpreter_context};
|
||||
|
||||
ResultStreamFaker stream(interpreter_context.db.get());
|
||||
auto [header, _, qid] = interpreter.Prepare(argv[1], {}, nullptr);
|
||||
auto [header, _1, qid, _2] = interpreter.Prepare(argv[1], {}, nullptr);
|
||||
stream.Header(header);
|
||||
auto summary = interpreter.PullAll(&stream);
|
||||
stream.Summary(summary);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user