Decoupling Interpreter from Storage (#1186)

Unique/global InterpreterContext that is Storage agnostic (has a reference to the DbmsHandler instead)

* InterpreterContext is no longer the owner of Storage
* New Database structure that handles Storage, Triggers, Streams
* Renamed SessinContextHandler to DbmsHandler and simplified the multi-tenant logic
* Added Gatekeeper and updated handlers to use it

---------

Co-authored-by: Gareth Lloyd <gareth.lloyd@memgraph.io>
This commit is contained in:
andrejtonev 2023-09-20 13:13:54 +02:00 committed by GitHub
parent 404cdf05d3
commit bce48361ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 2947 additions and 3389 deletions

3
.gitignore vendored
View File

@ -16,8 +16,7 @@
.ycm_extra_conf.pyc
.temp/
Testing/
build
build/
/build*/
release/examples/build
cmake-build-*
cmake/DownloadProject/

View File

@ -18,6 +18,7 @@ add_subdirectory(rpc)
add_subdirectory(license)
add_subdirectory(auth)
add_subdirectory(audit)
add_subdirectory(dbms)
add_subdirectory(flags)
string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)

View File

@ -54,7 +54,7 @@ class SessionException : public utils::BasicException {
* @tparam TOutputStream type of output stream that will be used
*/
template <typename TInputStream, typename TOutputStream>
class Session : public dbms::SessionInterface {
class Session {
public:
using TEncoder = Encoder<ChunkedEncoderBuffer<TOutputStream>>;
@ -208,8 +208,8 @@ class Session : public dbms::SessionInterface {
Version version_;
std::string GetDatabaseName() const override = 0;
std::string UUID() const final { return session_uuid_; }
virtual std::string GetDatabaseName() const = 0;
std::string UUID() const { return session_uuid_; }
private:
void ClientFailureInvalidData() {

View File

@ -110,11 +110,7 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
return std::shared_ptr<WebsocketSession>(new WebsocketSession(std::forward<Args>(args)...));
}
#ifdef MG_ENTERPRISE
~WebsocketSession() { session_context_->Delete(session_); }
#else
~WebsocketSession() = default;
#endif
WebsocketSession(const WebsocketSession &) = delete;
WebsocketSession &operator=(const WebsocketSession &) = delete;
@ -171,14 +167,15 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
: ws_(std::move(socket)),
strand_{boost::asio::make_strand(ws_.get_executor())},
output_stream_([this](const uint8_t *data, size_t len, bool /*have_more*/) { return Write(data, len); }),
session_{*session_context, endpoint, input_buffer_.read_end(), &output_stream_},
session_{session_context->ic, endpoint, input_buffer_.read_end(), &output_stream_, session_context->auth,
#ifdef MG_ENTERPRISE
session_context->audit_log
#endif
},
session_context_{session_context},
endpoint_{endpoint},
remote_endpoint_{ws_.next_layer().socket().remote_endpoint()},
service_name_{service_name} {
#ifdef MG_ENTERPRISE
session_context_->Register(session_);
#endif
}
void OnAccept(boost::beast::error_code ec) {
@ -286,11 +283,7 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
return std::shared_ptr<Session>(new Session(std::forward<Args>(args)...));
}
#ifdef MG_ENTERPRISE
~Session() { session_context_->Delete(session_); }
#else
~Session() = default;
#endif
Session(const Session &) = delete;
Session(Session &&) = delete;
@ -366,17 +359,17 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
: socket_(CreateSocket(std::move(socket), server_context)),
strand_{boost::asio::make_strand(GetExecutor())},
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
session_{*session_context, endpoint, input_buffer_.read_end(), &output_stream_},
session_{session_context->ic, endpoint, input_buffer_.read_end(), &output_stream_, session_context->auth,
#ifdef MG_ENTERPRISE
session_context->audit_log
#endif
},
session_context_{session_context},
endpoint_{endpoint},
remote_endpoint_{GetRemoteEndpoint()},
service_name_{service_name},
timeout_seconds_(inactivity_timeout_sec),
timeout_timer_(GetExecutor()) {
#ifdef MG_ENTERPRISE
// TODO Try to remove Register (see comment at SessionInterface declaration)
session_context_->Register(session_);
#endif
ExecuteForSocket([](auto &&socket) {
socket.lowest_layer().set_option(tcp::no_delay(true)); // enable PSH
socket.lowest_layer().set_option(boost::asio::socket_base::keep_alive(true)); // enable SO_KEEPALIVE

3
src/dbms/CMakeLists.txt Normal file
View File

@ -0,0 +1,3 @@
add_library(mg-dbms STATIC database.cpp)
target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query)

34
src/dbms/database.cpp Normal file
View File

@ -0,0 +1,34 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/database.hpp"
#include "storage/v2/disk/storage.hpp"
#include "storage/v2/inmemory/storage.hpp"
template struct memgraph::utils::Gatekeeper<memgraph::dbms::Database>;
namespace memgraph::dbms {
Database::Database(const storage::Config &config)
: trigger_store_(config.durability.storage_directory / "triggers"),
streams_{config.durability.storage_directory / "streams"} {
if (config.force_on_disk || utils::DirExists(config.disk.main_storage_directory)) {
storage_ = std::make_unique<storage::DiskStorage>(config);
} else {
storage_ = std::make_unique<storage::InMemoryStorage>(config);
}
}
void Database::SwitchToOnDisk() {
storage_ = std::make_unique<memgraph::storage::DiskStorage>(std::move(storage_->config_));
}
} // namespace memgraph::dbms

148
src/dbms/database.hpp Normal file
View File

@ -0,0 +1,148 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <algorithm>
#include <filesystem>
#include <iterator>
#include <memory>
#include <optional>
#include <string_view>
#include <unordered_map>
#include "query/cypher_query_interpreter.hpp"
#include "query/stream/streams.hpp"
#include "query/trigger.hpp"
#include "storage/v2/storage.hpp"
#include "utils/gatekeeper.hpp"
namespace memgraph::dbms {
/**
* @brief Class containing everything associated with a single Database
*
*/
class Database {
public:
/**
* @brief Construct a new Database object
*
* @param config storage configuration
*/
explicit Database(const storage::Config &config);
/**
* @brief Returns the raw storage pointer.
* @note Ideally everybody would be using an accessor
* TODO: Remove
*
* @return storage::Storage*
*/
storage::Storage *storage() { return storage_.get(); }
/**
* @brief Storage's Accessor
*
* @param override_isolation_level
* @return std::unique_ptr<storage::Storage::Accessor>
*/
std::unique_ptr<storage::Storage::Accessor> Access(
std::optional<storage::IsolationLevel> override_isolation_level = {}) {
return storage_->Access(override_isolation_level);
}
/**
* @brief Unique storage identified (name)
*
* @return const std::string&
*/
const std::string &id() const { return storage_->id(); }
/**
* @brief Returns the storage configuration
*
* @return const storage::Config&
*/
const storage::Config &config() const { return storage_->config_; }
/**
* @brief Get the storage mode
*
* @return storage::StorageMode
*/
storage::StorageMode GetStorageMode() const { return storage_->GetStorageMode(); }
/**
* @brief Get the storage info
*
* @return storage::StorageInfo
*/
storage::StorageInfo GetInfo() const { return storage_->GetInfo(); }
/**
* @brief Switch storage to OnDisk
*
*/
void SwitchToOnDisk();
/**
* @brief Returns the raw TriggerStore pointer
*
* @return query::TriggerStore*
*/
query::TriggerStore *trigger_store() { return &trigger_store_; }
/**
* @brief Returns the raw Streams pointer
*
* @return query::stream::Streams*
*/
query::stream::Streams *streams() { return &streams_; }
/**
* @brief Returns the raw ThreadPool pointer (used for after commit triggers)
*
* @return utils::ThreadPool*
*/
utils::ThreadPool *thread_pool() { return &after_commit_trigger_pool_; }
/**
* @brief Add task to the after commit trigger thread pool
*
* @param new_task
*/
void AddTask(std::function<void()> new_task) { after_commit_trigger_pool_.AddTask(std::move(new_task)); }
/**
* @brief Returns the PlanCache vector raw pointer
*
* @return utils::SkipList<query::PlanCacheEntry>*
*/
utils::SkipList<query::PlanCacheEntry> *plan_cache() { return &plan_cache_; }
private:
std::unique_ptr<storage::Storage> storage_; //!< Underlying storage
query::TriggerStore trigger_store_; //!< Triggers associated with the storage
utils::ThreadPool after_commit_trigger_pool_{1}; //!< Thread pool for executing after commit triggers
query::stream::Streams streams_; //!< Streams associated with the storage
// TODO: Move to a better place
utils::SkipList<query::PlanCacheEntry> plan_cache_; //!< Plan cache associated with the storage
};
} // namespace memgraph::dbms
extern template struct memgraph::utils::Gatekeeper<memgraph::dbms::Database>;
namespace memgraph::dbms {
using DatabaseAccess = memgraph::utils::Gatekeeper<memgraph::dbms::Database>::Accessor;
} // namespace memgraph::dbms

View File

@ -0,0 +1,97 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#ifdef MG_ENTERPRISE
#include <algorithm>
#include <filesystem>
#include <iterator>
#include <memory>
#include <optional>
#include <string_view>
#include <unordered_map>
#include "dbms/database.hpp"
#include "handler.hpp"
namespace memgraph::dbms {
/* NOTE
* The Database object is shared. All the higher-level function calls should be protected.
* Storage function calls should already be protected; add protection where needed.
*
* Current implementation uses a handler of Database objects. It owns them and gives
* Gatekeeper::Accessor to it. These guarantee that the object won't be
* destroyed unless no one is using it.
*/
/**Config
* @brief Multi-database storage handler
*
*/
class DatabaseHandler : public Handler<Database> {
public:
using HandlerT = Handler<Database>;
/**
* @brief Generate new storage associated with the passed name.
*
* @param name Name associating the new interpreter context
* @param config Storage configuration
* @return HandlerT::NewResult
*/
HandlerT::NewResult New(std::string_view name, storage::Config config) {
// Control that no one is using the same data directory
if (std::any_of(begin(), end(), [&](auto &elem) {
auto db_acc = elem.second.access();
MG_ASSERT(db_acc.has_value(), "Gatekeeper in invalid state");
return db_acc->get()->config().durability.storage_directory == config.durability.storage_directory;
})) {
spdlog::info("Tried to generate new storage using a claimed directory.");
return NewError::EXISTS;
}
config.name = name; // Set storage id via config
return HandlerT::New(std::piecewise_construct, name, config);
}
/**
* @brief All currently active storage.
*
* @return std::vector<std::string>
*/
std::vector<std::string> All() const {
std::vector<std::string> res;
res.reserve(std::distance(cbegin(), cend()));
std::for_each(cbegin(), cend(), [&](const auto &elem) { res.push_back(elem.first); });
return res;
}
/**
* @brief Get the associated storage's configuration
*
* @param name
* @return std::optional<storage::Config>
*/
std::optional<storage::Config> GetConfig(std::string_view name) {
auto db = Get(name);
if (db) {
return (*db)->config();
}
return std::nullopt;
}
};
} // namespace memgraph::dbms
#endif

390
src/dbms/dbms_handler.hpp Normal file
View File

@ -0,0 +1,390 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <algorithm>
#include <concepts>
#include <cstdint>
#include <filesystem>
#include <memory>
#include <mutex>
#include <optional>
#include <ostream>
#include <stdexcept>
#include <system_error>
#include <unordered_map>
#include "auth/auth.hpp"
#include "constants.hpp"
#include "dbms/database_handler.hpp"
#include "global.hpp"
#include "query/config.hpp"
#include "query/interpreter_context.hpp"
#include "spdlog/spdlog.h"
#include "storage/v2/durability/durability.hpp"
#include "storage/v2/durability/paths.hpp"
#include "utils/exceptions.hpp"
#include "utils/file.hpp"
#include "utils/logging.hpp"
#include "utils/result.hpp"
#include "utils/rw_lock.hpp"
#include "utils/synchronized.hpp"
#include "utils/uuid.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
using DeleteResult = utils::BasicResult<DeleteError>;
/**
* @brief Multi-database session contexts handler.
*/
class DbmsHandler {
public:
using LockT = utils::RWLock;
using NewResultT = utils::BasicResult<NewError, DatabaseAccess>;
struct Statistics {
uint64_t num_vertex; //!< Sum of vertexes in every database
uint64_t num_edges; //!< Sum of edges in every database
uint64_t num_databases; //! number of isolated databases
};
/**
* @brief Initialize the handler.
*
* @param configs storage and interpreter configurations
* @param auth pointer to the global authenticator
* @param recovery_on_startup restore databases (and its content) and authentication data
* @param delete_on_drop when dropping delete any associated directories on disk
*/
DbmsHandler(storage::Config config, auto *auth, bool recovery_on_startup, bool delete_on_drop)
: lock_{utils::RWLock::Priority::READ}, default_config_{std::move(config)}, delete_on_drop_(delete_on_drop) {
// TODO: Decouple storage config from dbms config
// TODO: Save individual db configs inside the kvstore and restore from there
storage::UpdatePaths(*default_config_, default_config_->durability.storage_directory / "databases");
const auto &db_dir = default_config_->durability.storage_directory;
const auto durability_dir = db_dir / ".durability";
utils::EnsureDirOrDie(db_dir);
utils::EnsureDirOrDie(durability_dir);
durability_ = std::make_unique<kvstore::KVStore>(durability_dir);
// Generate the default database
MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB.");
// Recover previous databases
if (recovery_on_startup) {
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue; // Already set
spdlog::info("Restoring database {}.", name);
MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name);
spdlog::info("Database {} restored.", name);
}
} else { // Clear databases from the durability list and auth
auto locked_auth = auth->Lock();
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue;
locked_auth->DeleteDatabase(name);
durability_->Delete(name);
}
}
}
/**
* @brief Create a new Database associated with the "name" database
*
* @param name name of the database
* @return NewResultT context on success, error on failure
*/
NewResultT New(const std::string &name) {
std::lock_guard<LockT> wr(lock_);
return New_(name, name);
}
/**
* @brief Get the context associated with the "name" database
*
* @param name
* @return DatabaseAccess
* @throw UnknownDatabaseException if database not found
*/
DatabaseAccess Get(std::string_view name) {
std::shared_lock<LockT> rd(lock_);
return Get_(name);
}
/**
* @brief Delete database.
*
* @param db_name database name
* @return DeleteResult error on failure
*/
DeleteResult Delete(const std::string &db_name) {
std::lock_guard<LockT> wr(lock_);
if (db_name == kDefaultDB) {
// MSG cannot delete the default db
return DeleteError::DEFAULT_DB;
}
const auto storage_path = StorageDir_(db_name);
if (!storage_path) return DeleteError::NON_EXISTENT;
// Check if db exists
try {
// Low level handlers
if (!db_handler_.Delete(db_name)) {
return DeleteError::USING;
}
} catch (utils::BasicException &) {
return DeleteError::NON_EXISTENT;
}
// Remove from durability list
if (durability_) durability_->Delete(db_name);
// Delete disk storage
if (delete_on_drop_) {
std::error_code ec;
(void)std::filesystem::remove_all(*storage_path, ec);
if (ec) {
spdlog::error("Failed to clean disk while deleting database \"{}\".", db_name);
defunct_dbs_.emplace(db_name);
return DeleteError::DISK_FAIL;
}
}
// Delete from defunct_dbs_ (in case a second delete call was successful)
defunct_dbs_.erase(db_name);
return {}; // Success
}
/**
* @brief Return all active databases.
*
* @return std::vector<std::string>
*/
std::vector<std::string> All() const {
std::shared_lock<LockT> rd(lock_);
return db_handler_.All();
}
/**
* @brief Return the number of vertex across all databases.
*
* @return uint64_t
*/
Statistics Info() {
// TODO: Handle overflow?
uint64_t nv = 0;
uint64_t ne = 0;
std::shared_lock<LockT> rd(lock_);
const uint64_t ndb = std::distance(db_handler_.cbegin(), db_handler_.cend());
for (auto &[_, db_gk] : db_handler_) {
auto db_acc_opt = db_gk.access();
if (!db_acc_opt) continue;
auto &db_acc = *db_acc_opt;
const auto &info = db_acc->GetInfo();
nv += info.vertex_count;
ne += info.edge_count;
}
return {nv, ne, ndb};
}
/**
* @brief Restore triggers for all currently defined databases.
* @note: Triggers can execute query procedures, so we need to reload the modules first and then the triggers
*
* @param ic global InterpreterContext
*/
void RestoreTriggers(query::InterpreterContext *ic) {
std::lock_guard<LockT> wr(lock_);
for (auto &[_, db_gk] : db_handler_) {
auto db_acc_opt = db_gk.access();
if (!db_acc_opt) continue;
auto &db_acc = *db_acc_opt;
spdlog::debug("Restoring trigger for database \"{}\"", db_acc->id());
auto storage_accessor = db_acc->Access();
auto dba = memgraph::query::DbAccessor{storage_accessor.get()};
db_acc->trigger_store()->RestoreTriggers(&ic->ast_cache, &dba, ic->config.query, ic->auth_checker);
}
}
/**
* @brief Restore streams of all currently defined databases.
* @note: Stream transformations are using modules, they have to be restored after the query modules are loaded.
*
* @param ic global InterpreterContext
*/
void RestoreStreams(query::InterpreterContext *ic) {
std::lock_guard<LockT> wr(lock_);
for (auto &[_, db_gk] : db_handler_) {
auto db_acc = db_gk.access();
if (!db_acc) continue;
auto *db = db_acc->get();
spdlog::debug("Restoring streams for database \"{}\"", db->id());
db->streams()->RestoreStreams(*db_acc, ic);
}
}
private:
/**
* @brief return the storage directory of the associated database
*
* @param name Database name
* @return std::optional<std::filesystem::path>
*/
std::optional<std::filesystem::path> StorageDir_(const std::string &name) {
const auto conf = db_handler_.GetConfig(name);
if (conf) {
return conf->durability.storage_directory;
}
spdlog::debug("Failed to find storage dir for database \"{}\"", name);
return {};
}
/**
* @brief Create a new Database associated with the "name" database
*
* @param name name of the database
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name) { return New_(name, name); }
/**
* @brief Create a new Database associated with the "name" database
*
* @param name name of the database
* @param storage_subdir undelying RocksDB directory
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name, std::filesystem::path storage_subdir) {
if (default_config_) {
auto config_copy = *default_config_;
storage::UpdatePaths(config_copy, default_config_->durability.storage_directory / storage_subdir);
return New_(name, config_copy);
}
spdlog::info("Trying to generate session context without any configurations.");
return NewError::NO_CONFIGS;
}
/**
* @brief Create a new Database associated with the "name" database
*
* @param name name of the database
* @param storage_config storage configuration
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name, storage::Config &storage_config) {
if (defunct_dbs_.contains(name)) {
spdlog::warn("Failed to generate database due to the unknown state of the previously defunct database \"{}\".",
name);
return NewError::DEFUNCT;
}
auto new_db = db_handler_.New(name, storage_config);
if (new_db.HasValue()) {
// Success
if (durability_) durability_->Put(name, "ok"); // TODO: Serialize the configuration?
return new_db.GetValue();
}
return new_db.GetError();
}
/**
* @brief Create a new Database associated with the default database
*
* @return NewResultT context on success, error on failure
*/
NewResultT NewDefault_() {
// Create the default DB in the root (this is how it was done pre multi-tenancy)
auto res = New_(kDefaultDB, "..");
if (res.HasValue()) {
// For back-compatibility...
// Recreate the dbms layout for the default db and symlink to the root
const auto dir = StorageDir_(kDefaultDB);
MG_ASSERT(dir, "Failed to find storage path.");
const auto main_dir = *dir / "databases" / kDefaultDB;
if (!std::filesystem::exists(main_dir)) {
std::filesystem::create_directory(main_dir);
}
// Force link on-disk directories
const auto conf = db_handler_.GetConfig(kDefaultDB);
MG_ASSERT(conf, "No configuration for the default database.");
const auto &tmp_conf = conf->disk;
std::vector<std::filesystem::path> to_link{
tmp_conf.main_storage_directory, tmp_conf.label_index_directory,
tmp_conf.label_property_index_directory, tmp_conf.unique_constraints_directory,
tmp_conf.name_id_mapper_directory, tmp_conf.id_name_mapper_directory,
tmp_conf.durability_directory, tmp_conf.wal_directory,
};
// Add in-memory paths
// Some directories are redundant (skip those)
const std::vector<std::string> skip{".lock", "audit_log", "auth", "databases", "internal_modules", "settings"};
for (auto const &item : std::filesystem::directory_iterator{*dir}) {
const auto dir_name = std::filesystem::relative(item.path(), item.path().parent_path());
if (std::find(skip.begin(), skip.end(), dir_name) != skip.end()) continue;
to_link.push_back(item.path());
}
// Symlink to root dir
for (auto const &item : to_link) {
const auto dir_name = std::filesystem::relative(item, item.parent_path());
const auto link = main_dir / dir_name;
const auto to = std::filesystem::relative(item, main_dir);
if (!std::filesystem::is_symlink(link) && !std::filesystem::exists(link)) {
std::filesystem::create_directory_symlink(to, link);
} else { // Check existing link
std::error_code ec;
const auto test_link = std::filesystem::read_symlink(link, ec);
if (ec || test_link != to) {
MG_ASSERT(false,
"Memgraph storage directory incompatible with new version.\n"
"Please use a clean directory or remove \"{}\" and try again.",
link.string());
}
}
}
}
return res;
}
/**
* @brief Get the DatabaseAccess for the database associated with the "name"
*
* @param name
* @return DatabaseAccess
* @throw UnknownDatabaseException if trying to get unknown database
*/
DatabaseAccess Get_(std::string_view name) {
auto db = db_handler_.Get(name);
if (db) {
return *db;
}
throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name);
}
// Should storage objects ever be deleted?
mutable LockT lock_; //!< protective lock
DatabaseHandler db_handler_; //!< multi-tenancy storage handler
std::optional<storage::Config> default_config_; //!< Storage configuration used when creating new databases
std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation)
std::set<std::string> defunct_dbs_; //!< Databases that are in an unknown state due to various failures
bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory
};
#endif
} // namespace memgraph::dbms

View File

@ -60,51 +60,4 @@ class UnknownDatabaseException : public utils::BasicException {
using utils::BasicException::BasicException;
};
/**
* @brief Session interface used by the DBMS to handle the the active sessions.
* @todo Try to remove this dependency from SessionContextHandler. OnDelete could be removed, as it only does an assert.
* OnChange could be removed if SetFor returned the pointer and the called then handled the OnChange execution.
* However, the interface is very useful to decouple the interpreter's query execution and the sessions themselves.
*/
class SessionInterface {
public:
SessionInterface() = default;
virtual ~SessionInterface() = default;
SessionInterface(const SessionInterface &) = default;
SessionInterface &operator=(const SessionInterface &) = default;
SessionInterface(SessionInterface &&) noexcept = default;
SessionInterface &operator=(SessionInterface &&) noexcept = default;
/**
* @brief Return the unique string identifying the session.
*
* @return std::string
*/
virtual std::string UUID() const = 0;
/**
* @brief Return the currently active database.
*
* @return std::string
*/
virtual std::string GetDatabaseName() const = 0;
#ifdef MG_ENTERPRISE
/**
* @brief Gets called on database change.
*
* @return SetForResult enum (SUCCESS, ALREADY_SET or FAIL)
*/
virtual dbms::SetForResult OnChange(const std::string &) = 0;
/**
* @brief Callback that gets called on database delete (drop).
*
* @return true on success
*/
virtual bool OnDelete(const std::string &) = 0;
#endif
};
} // namespace memgraph::dbms

View File

@ -18,21 +18,21 @@
#include <unordered_map>
#include "global.hpp"
#include "utils/exceptions.hpp"
#include "utils/gatekeeper.hpp"
#include "utils/result.hpp"
#include "utils/sync_ptr.hpp"
namespace memgraph::dbms {
/**
* @brief Generic multi-database content handler.
*
* @tparam TContext
* @tparam TConfig
* @tparam T
*/
template <typename TContext, typename TConfig>
template <typename T>
class Handler {
public:
using NewResult = utils::BasicResult<NewError, std::shared_ptr<TContext>>;
using NewResult = utils::BasicResult<NewError, typename utils::Gatekeeper<T>::Accessor>;
/**
* @brief Empty Handler constructor.
@ -43,67 +43,65 @@ class Handler {
/**
* @brief Generate a new context and corresponding configuration.
*
* @tparam T1 Variadic template of context constructor arguments
* @tparam T2 Variadic template of config constructor arguments
* @param name Name associated with the new context/config pair
* @param args1 Arguments passed (as a tuple) to the context constructor
* @param args2 Arguments passed (as a tuple) to the config constructor
* @tparam Args Variadic template of constructor arguments of T
* @param name Name associated with the new T
* @param args Arguments passed to the constructor of T
* @return NewResult
*/
template <typename... T1, typename... T2>
NewResult New(std::string name, std::tuple<T1...> args1, std::tuple<T2...> args2) {
return New_(name, args1, args2, std::make_index_sequence<sizeof...(T1)>{},
std::make_index_sequence<sizeof...(T2)>{});
template <typename... Args>
NewResult New(std::piecewise_construct_t /* marker */, std::string_view name, Args... args) {
// Make sure the emplace will succeed, since we don't want to create temporary objects that could break something
if (!Has(name)) {
auto [itr, _] = items_.emplace(std::piecewise_construct, std::forward_as_tuple(name),
std::forward_as_tuple(std::forward<Args>(args)...));
auto db_acc = itr->second.access();
if (db_acc) return std::move(*db_acc);
return NewError::DEFUNCT;
}
spdlog::info("Item with name \"{}\" already exists.", name);
return NewError::EXISTS;
}
/**
* @brief Get pointer to context.
*
* @param name Name associated with the wanted context
* @return std::optional<std::shared_ptr<TContext>>
* @return std::optional<typename utils::Gatekeeper<T>::Accessor>
*/
std::optional<std::shared_ptr<TContext>> Get(const std::string &name) {
std::optional<typename utils::Gatekeeper<T>::Accessor> Get(std::string_view name) {
if (auto search = items_.find(name); search != items_.end()) {
return search->second.get();
return search->second.access();
}
return {};
return std::nullopt;
}
/**
* @brief Get the config.
* @brief Delete the context associated with the name.
*
* @param name Name associated with the wanted config
* @return std::optional<TConfig>
*/
std::optional<TConfig> GetConfig(const std::string &name) const {
if (auto search = items_.find(name); search != items_.end()) {
return search->second.config();
}
return {};
}
/**
* @brief Delete the context/config pair associated with the name.
*
* @param name Name associated with the context/config pair to delete
* @param name Name associated with the context to delete
* @return true on success
* @throw BasicException
*/
bool Delete(const std::string &name) {
if (auto itr = items_.find(name); itr != items_.end()) {
itr->second.DestroyAndSync();
items_.erase(itr);
return true;
auto db_acc = itr->second.access();
if (db_acc && db_acc->try_delete()) {
db_acc->reset();
items_.erase(itr);
return true;
}
return false;
}
return false;
throw utils::BasicException("Unknown item \"{}\".", name);
}
/**
* @brief Check if a name is already used.
*
* @param name Name to check
* @return true if a context/config pair is already associated with the name
* @return true if a T is already associated with the name
*/
bool Has(const std::string &name) const { return items_.find(name) != items_.end(); }
bool Has(std::string_view name) const { return items_.find(name) != items_.end(); }
auto begin() { return items_.begin(); }
auto end() { return items_.end(); }
@ -112,31 +110,16 @@ class Handler {
auto cbegin() const { return items_.cbegin(); }
auto cend() const { return items_.cend(); }
private:
/**
* @brief Lower level handler that hides some ugly code.
*
* @tparam T1 Variadic template of context constructor arguments
* @tparam T2 Variadic template of config constructor arguments
* @tparam I1 List of indexes associated with the first tuple
* @tparam I2 List of indexes associated with the second tuple
*/
template <typename... T1, typename... T2, std::size_t... I1, std::size_t... I2>
NewResult New_(std::string name, std::tuple<T1...> &args1, std::tuple<T2...> &args2,
std::integer_sequence<std::size_t, I1...> /*not-used*/,
std::integer_sequence<std::size_t, I2...> /*not-used*/) {
// Make sure the emplace will succeed, since we don't want to create temporary objects that could break something
if (!Has(name)) {
auto [itr, _] = items_.emplace(std::piecewise_construct, std::forward_as_tuple(name),
std::forward_as_tuple(TConfig{std::forward<T1>(std::get<I1>(args1))...},
std::forward<T2>(std::get<I2>(args2))...));
return itr->second.get();
}
spdlog::info("Item with name \"{}\" already exists.", name);
return NewError::EXISTS;
}
struct string_hash {
using is_transparent = void;
[[nodiscard]] size_t operator()(const char *s) const { return std::hash<std::string_view>{}(s); }
[[nodiscard]] size_t operator()(std::string_view s) const { return std::hash<std::string_view>{}(s); }
[[nodiscard]] size_t operator()(const std::string &s) const { return std::hash<std::string>{}(s); }
};
std::unordered_map<std::string, utils::SyncPtr<TContext, TConfig>> items_; //!< map to all active items
private:
std::unordered_map<std::string, utils::Gatekeeper<T>, string_hash, std::equal_to<>>
items_; //!< map to all active items
};
} // namespace memgraph::dbms

View File

@ -1,106 +0,0 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#ifdef MG_ENTERPRISE
#include "global.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
#include "storage/v2/storage.hpp"
#include "handler.hpp"
namespace memgraph::dbms {
/**
* @brief Simple class that adds useful information to the query's InterpreterContext
*
* @tparam T Multi-database handler type
*/
template <typename T>
class ExpandedInterpContext : public query::InterpreterContext {
public:
template <typename... TArgs>
explicit ExpandedInterpContext(T &ref, TArgs &&...args)
: query::InterpreterContext(std::forward<TArgs>(args)...), sc_handler_(ref) {}
T &sc_handler_; //!< Multi-database/SessionContext handler (used in some queries)
};
/**
* @brief Simple structure that expands on the query's InterpreterConfig
*
*/
struct ExpandedInterpConfig {
storage::Config storage_config; //!< Storage configuration
query::InterpreterConfig interp_config; //!< Interpreter configuration
};
/**
* @brief Multi-database interpreter context handler
*
* @tparam TSCHandler High-level multi-database/SessionContext handler type
*/
template <typename TSCHandler>
class InterpContextHandler : public Handler<ExpandedInterpContext<TSCHandler>, ExpandedInterpConfig> {
public:
using InterpContextT = ExpandedInterpContext<TSCHandler>;
using HandlerT = Handler<InterpContextT, ExpandedInterpConfig>;
/**
* @brief Generate a new interpreter context associated with the passed name.
*
* @param name Name associating the new interpreter context
* @param sc_handler Multi-database/SessionContext handler used (some queries might use it)
* @param db Storage associated with the interpreter context
* @param config Interpreter's configuration
* @param dir Directory used by the interpreter
* @param auth_handler AuthQueryHandler used
* @param auth_checker AuthChecker used
* @return HandlerT::NewResult
*/
typename HandlerT::NewResult New(const std::string &name, TSCHandler &sc_handler, storage::Config storage_config,
const query::InterpreterConfig &interpreter_config,
query::AuthQueryHandler &auth_handler, query::AuthChecker &auth_checker) {
// Check if compatible with the existing interpreters
if (std::any_of(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) {
const auto &config = elem.second.config().storage_config;
return config.durability.storage_directory == storage_config.durability.storage_directory;
})) {
spdlog::info("Tried to generate a new context using claimed directory and/or storage.");
return NewError::EXISTS;
}
const auto dir = storage_config.durability.storage_directory;
storage_config.name = name; // Set storage id via config
return HandlerT::New(
name, std::forward_as_tuple(storage_config, interpreter_config),
std::forward_as_tuple(sc_handler, storage_config, interpreter_config, dir, &auth_handler, &auth_checker));
}
/**
* @brief All currently active storage.
*
* @return std::vector<std::string>
*/
std::vector<std::string> All() const {
std::vector<std::string> res;
res.reserve(std::distance(HandlerT::cbegin(), HandlerT::cend()));
std::for_each(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) { res.push_back(elem.first); });
return res;
}
};
} // namespace memgraph::dbms
#endif

View File

@ -1,61 +0,0 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/auth.hpp"
#include "query/interpreter.hpp"
#include "storage/v2/storage.hpp"
#include "utils/synchronized.hpp"
#if MG_ENTERPRISE
#include "audit/log.hpp"
#endif
namespace memgraph::dbms {
/**
* @brief Structure encapsulating storage and interpreter context.
*
* @note Each session contains a copy.
*/
struct SessionContext {
// Explicit constructor here to ensure that pointers to all objects are
// supplied.
SessionContext(std::shared_ptr<memgraph::query::InterpreterContext> interpreter_context, std::string run,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
#ifdef MG_ENTERPRISE
,
memgraph::audit::Log *audit_log
#endif
)
: interpreter_context(interpreter_context),
run_id(run),
auth(auth)
#ifdef MG_ENTERPRISE
,
audit_log(audit_log)
#endif
{
}
std::shared_ptr<memgraph::query::InterpreterContext> interpreter_context;
std::string run_id;
// std::shared_ptr<AuthContext> auth_context;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log;
#endif
};
} // namespace memgraph::dbms

View File

@ -1,603 +0,0 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <algorithm>
#include <concepts>
#include <cstdint>
#include <filesystem>
#include <memory>
#include <mutex>
#include <optional>
#include <ostream>
#include <stdexcept>
#include <system_error>
#include <unordered_map>
#include "constants.hpp"
#include "global.hpp"
#include "interp_handler.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
#include "session_context.hpp"
#include "spdlog/spdlog.h"
#include "storage/v2/durability/durability.hpp"
#include "storage/v2/durability/paths.hpp"
#include "utils/exceptions.hpp"
#include "utils/file.hpp"
#include "utils/logging.hpp"
#include "utils/result.hpp"
#include "utils/rw_lock.hpp"
#include "utils/synchronized.hpp"
#include "utils/uuid.hpp"
#include "handler.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
using DeleteResult = utils::BasicResult<DeleteError>;
/**
* @brief Multi-database session contexts handler.
*/
class SessionContextHandler {
public:
using StorageT = storage::Storage;
using StorageConfigT = storage::Config;
using LockT = utils::RWLock;
using NewResultT = utils::BasicResult<NewError, SessionContext>;
struct Config {
StorageConfigT storage_config; //!< Storage configuration
query::InterpreterConfig interp_config; //!< Interpreter context configuration
std::function<void(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *,
std::unique_ptr<query::AuthQueryHandler> &, std::unique_ptr<query::AuthChecker> &)>
glue_auth;
};
struct Statistics {
uint64_t num_vertex; //!< Sum of vertexes in every database
uint64_t num_edges; //!< Sum of edges in every database
uint64_t num_databases; //! number of isolated databases
};
/**
* @brief Initialize the handler.
*
* @param audit_log pointer to the audit logger (ENTERPRISE only)
* @param configs storage and interpreter configurations
* @param recovery_on_startup restore databases (and its content) and authentication data
*/
SessionContextHandler(memgraph::audit::Log &audit_log, Config configs, bool recovery_on_startup, bool delete_on_drop)
: lock_{utils::RWLock::Priority::READ},
default_configs_(configs),
run_id_{utils::GenerateUUID()},
audit_log_(&audit_log),
delete_on_drop_(delete_on_drop) {
const auto &root = configs.storage_config.durability.storage_directory;
utils::EnsureDirOrDie(root);
// Verify that the user that started the process is the same user that is
// the owner of the storage directory.
storage::durability::VerifyStorageDirectoryOwnerAndProcessUserOrDie(root);
// Create the lock file and open a handle to it. This will crash the
// database if it can't open the file for writing or if any other process is
// holding the file opened.
lock_file_path_ = root / ".lock";
lock_file_handle_.Open(lock_file_path_, utils::OutputFile::Mode::OVERWRITE_EXISTING);
MG_ASSERT(lock_file_handle_.AcquireLock(),
"Couldn't acquire lock on the storage directory {}"
"!\nAnother Memgraph process is currently running with the same "
"storage directory, please stop it first before starting this "
"process!",
root);
// TODO: Figure out if this is needed/wanted
// Clear auth database since we are not recovering
// if (!recovery_on_startup) {
// const auto &auth_dir = root / "auth";
// // Backup if auth present
// if (utils::DirExists(auth_dir)) {
// auto backup_dir = root / storage::durability::kBackupDirectory;
// std::error_code error_code;
// utils::EnsureDirOrDie(backup_dir);
// std::error_code ec;
// const auto now = std::chrono::system_clock::now();
// std::ostringstream os;
// os << now.time_since_epoch().count();
// std::filesystem::rename(auth_dir, backup_dir / ("auth-" + os.str()), ec);
// MG_ASSERT(!ec, "Couldn't backup auth directory because of: {}", ec.message());
// spdlog::warn(
// "Since Memgraph was not supposed to recover on startup the authentication files will be "
// "overwritten. To prevent important data loss, Memgraph has stored those files into .backup directory "
// "inside the storage directory.");
// }
// // Clear
// if (std::filesystem::exists(auth_dir)) {
// std::filesystem::remove_all(auth_dir);
// }
// }
// Lazy initialization of auth_
auth_ = std::make_unique<utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock>>(root / "auth");
configs.glue_auth(auth_.get(), auth_handler_, auth_checker_);
// TODO: Decouple storage config from dbms config
// TODO: Save individual db configs inside the kvstore and restore from there
storage::UpdatePaths(default_configs_->storage_config,
default_configs_->storage_config.durability.storage_directory / "databases");
const auto &db_dir = default_configs_->storage_config.durability.storage_directory;
const auto durability_dir = db_dir / ".durability";
utils::EnsureDirOrDie(db_dir);
utils::EnsureDirOrDie(durability_dir);
durability_ = std::make_unique<kvstore::KVStore>(durability_dir);
// Generate the default database
MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB.");
// Recover previous databases
if (recovery_on_startup) {
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue; // Already set
spdlog::info("Restoring database {}.", name);
MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name);
spdlog::info("Database {} restored.", name);
}
} else { // Clear databases from the durability list and auth
auto locked_auth = auth_->Lock();
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue;
locked_auth->DeleteDatabase(name);
durability_->Delete(name);
}
}
}
void Shutdown() {
for (auto &ic : interp_handler_) memgraph::query::Shutdown(ic.second.get().get());
}
/**
* @brief Create a new SessionContext associated with the "name" database
*
* @param name name of the database
* @return NewResultT context on success, error on failure
*/
NewResultT New(const std::string &name) {
std::lock_guard<LockT> wr(lock_);
return New_(name, name);
}
/**
* @brief Get the context associated with the "name" database
*
* @param name
* @return SessionContext
* @throw UnknownDatabaseException if getting unknown database
*/
SessionContext Get(const std::string &name) {
std::shared_lock<LockT> rd(lock_);
return Get_(name);
}
/**
* @brief Set the undelying database for a particular session.
*
* @param uuid unique session identifier
* @param db_name unique database name
* @return SetForResult enum
* @throws UnknownDatabaseException, UnknownSessionException or anything OnChange throws
*/
SetForResult SetFor(const std::string &uuid, const std::string &db_name) {
std::shared_lock<LockT> rd(lock_);
(void)Get_(
db_name); // throws if db doesn't exist (TODO: Better to pass it via OnChange - but injecting dependency)
try {
auto &s = sessions_.at(uuid);
return s.OnChange(db_name);
} catch (std::out_of_range &) {
throw UnknownSessionException("Unknown session \"{}\"", uuid);
}
}
/**
* @brief Set the undelying database from a session itself. SessionContext handler.
*
* @param db_name unique database name
* @param handler function that gets called in place with the appropriate SessionContext
* @return SetForResult enum
*/
template <typename THandler>
requires std::invocable<THandler, SessionContext> SetForResult SetInPlace(const std::string &db_name,
THandler handler) {
std::shared_lock<LockT> rd(lock_);
return handler(Get_(db_name));
}
/**
* @brief Call void handler under a shared lock.
*
* @param handler function that gets called in place
*/
template <typename THandler>
requires std::invocable<THandler>
void CallInPlace(THandler handler) {
std::shared_lock<LockT> rd(lock_);
handler();
}
/**
* @brief Register an active session (used to handle callbacks).
*
* @param session
* @return true on success
*/
bool Register(SessionInterface &session) {
std::lock_guard<LockT> wr(lock_);
auto [_, success] = sessions_.emplace(session.UUID(), session);
return success;
}
/**
* @brief Delete a session.
*
* @param session
*/
bool Delete(const SessionInterface &session) {
std::lock_guard<LockT> wr(lock_);
return sessions_.erase(session.UUID()) > 0;
}
/**
* @brief Delete database.
*
* @param db_name database name
* @return DeleteResult error on failure
*/
DeleteResult Delete(const std::string &db_name) {
std::lock_guard<LockT> wr(lock_);
if (db_name == kDefaultDB) {
// MSG cannot delete the default db
return DeleteError::DEFAULT_DB;
}
// Check if db exists
try {
auto sc = Get_(db_name);
// Check if a session is using the db
if (!sc.interpreter_context->interpreters->empty()) {
return DeleteError::USING;
}
} catch (UnknownDatabaseException &) {
return DeleteError::NON_EXISTENT;
}
// High level handlers
for (auto &[_, s] : sessions_) {
if (!s.OnDelete(db_name)) {
spdlog::error("Partial failure while deleting database \"{}\".", db_name);
defunct_dbs_.emplace(db_name);
return DeleteError::FAIL;
}
}
// Low level handlers
const auto storage_path = StorageDir_(db_name);
MG_ASSERT(storage_path, "Missing storage for {}", db_name);
if (!interp_handler_.Delete(db_name)) {
spdlog::error("Partial failure while deleting database \"{}\".", db_name);
defunct_dbs_.emplace(db_name);
return DeleteError::FAIL;
}
// Remove from auth
auth_->Lock()->DeleteDatabase(db_name);
// Remove from durability list
if (durability_) durability_->Delete(db_name);
// Delete disk storage
if (delete_on_drop_) {
std::error_code ec;
(void)std::filesystem::remove_all(*storage_path, ec);
if (ec) {
spdlog::error("Failed to clean disk while deleting database \"{}\".", db_name);
defunct_dbs_.emplace(db_name);
return DeleteError::DISK_FAIL;
}
}
// Delete from defunct_dbs_ (in case a second delete call was successful)
defunct_dbs_.erase(db_name);
return {}; // Success
}
/**
* @brief Set the default configurations.
*
* @param configs storage, interpreter and authorization configurations
*/
void SetDefaultConfigs(Config configs) {
std::lock_guard<LockT> wr(lock_);
default_configs_ = configs;
}
/**
* @brief Get the default configurations.
*
* @return std::optional<Config>
*/
std::optional<Config> GetDefaultConfigs() const {
std::shared_lock<LockT> rd(lock_);
return default_configs_;
}
/**
* @brief Return all active databases.
*
* @return std::vector<std::string>
*/
std::vector<std::string> All() const {
std::shared_lock<LockT> rd(lock_);
return interp_handler_.All();
}
/**
* @brief Return the number of vertex across all databases.
*
* @return uint64_t
*/
Statistics Info() const {
// TODO: Handle overflow
uint64_t nv = 0;
uint64_t ne = 0;
std::shared_lock<LockT> rd(lock_);
const uint64_t ndb = std::distance(interp_handler_.cbegin(), interp_handler_.cend());
for (const auto &ic : interp_handler_) {
const auto &info = ic.second.get()->db->GetInfo();
nv += info.vertex_count;
ne += info.edge_count;
}
return {nv, ne, ndb};
}
/**
* @brief Return the currently active database for a particular session.
*
* @param uuid session's unique identifier
* @return std::string name of the database
* @throw
*/
std::string Current(const std::string &uuid) const {
std::shared_lock<LockT> rd(lock_);
return sessions_.at(uuid).GetDatabaseName();
}
/**
* @brief Restore triggers for all currently defined databases.
* @note: Triggers can execute query procedures, so we need to reload the modules first and then the triggers
*/
void RestoreTriggers() {
std::lock_guard<LockT> wr(lock_);
for (auto &ic_itr : interp_handler_) {
auto ic = ic_itr.second.get();
spdlog::debug("Restoring trigger for database \"{}\"", ic->db->id());
auto storage_accessor = ic->db->Access();
auto dba = memgraph::query::DbAccessor{storage_accessor.get()};
ic->trigger_store.RestoreTriggers(&ic->ast_cache, &dba, ic->config.query, ic->auth_checker);
}
}
/**
* @brief Restore streams of all currently defined databases.
* @note: Stream transformations are using modules, they have to be restored after the query modules are loaded.
*/
void RestoreStreams() {
std::lock_guard<LockT> wr(lock_);
for (auto &ic_itr : interp_handler_) {
auto ic = ic_itr.second.get();
spdlog::debug("Restoring streams for database \"{}\"", ic->db->id());
ic->streams.RestoreStreams();
}
}
private:
std::optional<std::filesystem::path> StorageDir_(const std::string &name) const {
const auto conf = interp_handler_.GetConfig(name);
if (conf) {
return conf->storage_config.durability.storage_directory;
}
spdlog::debug("Failed to find storage dir for database \"{}\"", name);
return {};
}
/**
* @brief Create a new SessionContext associated with the "name" database
*
* @param name name of the database
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name) { return New_(name, name); }
/**
* @brief Create a new SessionContext associated with the "name" database
*
* @param name name of the database
* @param storage_subdir undelying RocksDB directory
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name, std::filesystem::path storage_subdir) {
if (default_configs_) {
auto storage = default_configs_->storage_config;
storage::UpdatePaths(storage, storage.durability.storage_directory / storage_subdir);
return New_(name, storage, default_configs_->interp_config);
}
spdlog::info("Trying to generate session context without any configurations.");
return NewError::NO_CONFIGS;
}
/**
* @brief Create a new SessionContext associated with the "name" database
*
* @param name name of the database
* @param storage_config storage configuration
* @param inter_config interpreter configuration
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const std::string &name, StorageConfigT &storage_config, query::InterpreterConfig &inter_config/*,
const std::string &ah_flags*/) {
MG_ASSERT(auth_handler_, "No high level AuthQueryHandler has been supplied.");
MG_ASSERT(auth_checker_, "No high level AuthChecker has been supplied.");
if (defunct_dbs_.contains(name)) {
spdlog::warn("Failed to generate database due to the unknown state of the previously defunct database \"{}\".",
name);
return NewError::DEFUNCT;
}
auto new_interp = interp_handler_.New(name, *this, storage_config, inter_config, *auth_handler_, *auth_checker_);
if (new_interp.HasValue()) {
// Success
if (durability_) durability_->Put(name, "ok");
return SessionContext{new_interp.GetValue(), run_id_, auth_.get(), audit_log_};
}
return new_interp.GetError();
}
/**
* @brief Create a new SessionContext associated with the default database
*
* @return NewResultT context on success, error on failure
*/
NewResultT NewDefault_() {
// Create the default DB in the root (this is how it was done pre multi-tenancy)
auto res = New_(kDefaultDB, "..");
if (res.HasValue()) {
// For back-compatibility...
// Recreate the dbms layout for the default db and symlink to the root
const auto dir = StorageDir_(kDefaultDB);
MG_ASSERT(dir, "Failed to find storage path.");
const auto main_dir = *dir / "databases" / kDefaultDB;
if (!std::filesystem::exists(main_dir)) {
std::filesystem::create_directory(main_dir);
}
// Force link on-disk directories
const auto conf = interp_handler_.GetConfig(kDefaultDB);
MG_ASSERT(conf, "No configuration for the default database.");
const auto &tmp_conf = conf->storage_config.disk;
std::vector<std::filesystem::path> to_link{
tmp_conf.main_storage_directory, tmp_conf.label_index_directory,
tmp_conf.label_property_index_directory, tmp_conf.unique_constraints_directory,
tmp_conf.name_id_mapper_directory, tmp_conf.id_name_mapper_directory,
tmp_conf.durability_directory, tmp_conf.wal_directory,
};
// Add in-memory paths
// Some directories are redundant (skip those)
const std::vector<std::string> skip{".lock", "audit_log", "auth", "databases", "internal_modules", "settings"};
for (auto const &item : std::filesystem::directory_iterator{*dir}) {
const auto dir_name = std::filesystem::relative(item.path(), item.path().parent_path());
if (std::find(skip.begin(), skip.end(), dir_name) != skip.end()) continue;
to_link.push_back(item.path());
}
// Symlink to root dir
for (auto const &item : to_link) {
const auto dir_name = std::filesystem::relative(item, item.parent_path());
const auto link = main_dir / dir_name;
const auto to = std::filesystem::relative(item, main_dir);
if (!std::filesystem::is_symlink(link) && !std::filesystem::exists(link)) {
std::filesystem::create_directory_symlink(to, link);
} else { // Check existing link
std::error_code ec;
const auto test_link = std::filesystem::read_symlink(link, ec);
if (ec || test_link != to) {
MG_ASSERT(false,
"Memgraph storage directory incompatible with new version.\n"
"Please use a clean directory or remove \"{}\" and try again.",
link.string());
}
}
}
}
return res;
}
/**
* @brief Get the context associated with the "name" database
*
* @param name
* @return SessionContext
* @throw UnknownDatabaseException if trying to get unknown database
*/
SessionContext Get_(const std::string &name) {
auto interp = interp_handler_.Get(name);
if (interp) {
return SessionContext{*interp, run_id_, auth_.get(), audit_log_};
}
throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name);
}
// Should storage objects ever be deleted?
mutable LockT lock_; //!< protective lock
std::filesystem::path lock_file_path_; //!< Lock file protecting the main storage
utils::OutputFile lock_file_handle_; //!< Handler the lock (crash if already open)
InterpContextHandler<SessionContextHandler> interp_handler_; //!< multi-tenancy interpreter handler
// AuthContextHandler auth_handler_; //!< multi-tenancy authorization handler (currently we use a single global
// auth)
std::unique_ptr<utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock>> auth_;
std::unique_ptr<query::AuthQueryHandler> auth_handler_;
std::unique_ptr<query::AuthChecker> auth_checker_;
std::optional<Config> default_configs_; //!< default storage and interpreter configurations
const std::string run_id_; //!< run's unique identifier (auto generated)
memgraph::audit::Log *audit_log_; //!< pointer to the audit logger
std::unordered_map<std::string, SessionInterface &> sessions_; //!< map of active/registered sessions
std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation)
std::set<std::string> defunct_dbs_; //!< Databases that are in an unknown state due to various failures
bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory
public:
static SessionContextHandler &ExtractSCH(query::InterpreterContext *interpreter_context) {
return static_cast<typename decltype(interp_handler_)::InterpContextT *>(interpreter_context)->sc_handler_;
}
};
#else
/**
* @brief Initialize the handler.
*
* @param auth pointer to the authenticator
* @param configs storage and interpreter configurations
*/
static inline SessionContext Init(storage::Config &storage_config, query::InterpreterConfig &interp_config,
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth,
query::AuthQueryHandler *auth_handler, query::AuthChecker *auth_checker) {
MG_ASSERT(auth, "Passed a nullptr auth");
MG_ASSERT(auth_handler, "Passed a nullptr auth_handler");
MG_ASSERT(auth_checker, "Passed a nullptr auth_checker");
storage_config.name = kDefaultDB;
auto interp_context = std::make_shared<query::InterpreterContext>(
storage_config, interp_config, storage_config.durability.storage_directory, auth_handler, auth_checker);
MG_ASSERT(interp_context, "Failed to construct main interpret context.");
return SessionContext{interp_context, utils::GenerateUUID(), auth};
}
#endif
} // namespace memgraph::dbms

View File

@ -1,4 +1,11 @@
add_library(mg-glue STATIC )
target_sources(mg-glue PRIVATE auth.cpp auth_checker.cpp auth_handler.cpp communication.cpp SessionHL.cpp ServerT.cpp MonitoringServerT.cpp)
target_sources(mg-glue PRIVATE auth.cpp
auth_checker.cpp
auth_handler.cpp
communication.cpp
SessionHL.cpp
ServerT.cpp
MonitoringServerT.cpp
run_id.cpp)
target_link_libraries(mg-glue mg-query mg-auth mg-audit)
target_precompile_headers(mg-glue INTERFACE auth_checker.hpp auth_handler.hpp)

View File

@ -10,5 +10,4 @@
// licenses/APL.txt.
#include "glue/MonitoringServerT.hpp"
template class memgraph::communication::http::Server<
memgraph::http::MetricsRequestHandler<memgraph::dbms::SessionContext>, memgraph::dbms::SessionContext>;
template class memgraph::communication::http::Server<memgraph::http::MetricsRequestHandler, memgraph::storage::Storage>;

View File

@ -11,15 +11,14 @@
#pragma once
#include "communication/http/server.hpp"
#include "dbms/session_context.hpp"
#include "http_handlers/metrics.hpp"
#include "storage/v2/storage.hpp"
extern template class memgraph::communication::http::Server<
memgraph::http::MetricsRequestHandler<memgraph::dbms::SessionContext>, memgraph::dbms::SessionContext>;
extern template class memgraph::communication::http::Server<memgraph::http::MetricsRequestHandler,
memgraph::storage::Storage>;
namespace memgraph::glue {
using MonitoringServerT =
memgraph::communication::http::Server<memgraph::http::MetricsRequestHandler<memgraph::dbms::SessionContext>,
memgraph::dbms::SessionContext>;
memgraph::communication::http::Server<memgraph::http::MetricsRequestHandler, memgraph::storage::Storage>;
} // namespace memgraph::glue

View File

@ -10,8 +10,4 @@
// licenses/APL.txt.
#include "glue/ServerT.hpp"
#ifdef MG_ENTERPRISE
template class memgraph::communication::v2::Server<memgraph::glue::SessionHL, memgraph::dbms::SessionContextHandler>;
#else
template class memgraph::communication::v2::Server<memgraph::glue::SessionHL, memgraph::dbms::SessionContext>;
#endif
template class memgraph::communication::v2::Server<memgraph::glue::SessionHL, Context>;

View File

@ -12,24 +12,35 @@
#include "communication/v2/server.hpp"
#include "glue/SessionHL.hpp"
#include "utils/synchronized.hpp"
#ifdef MG_ENTERPRISE
#include "dbms/session_context_handler.hpp"
#else
#include "dbms/session_context.hpp"
namespace memgraph::query {
struct InterpreterContext;
}
#if MG_ENTERPRISE
namespace memgraph::audit {
class Log;
}
#endif
#ifdef MG_ENTERPRISE
extern template class memgraph::communication::v2::Server<memgraph::glue::SessionHL,
memgraph::dbms::SessionContextHandler>;
#else
extern template class memgraph::communication::v2::Server<memgraph::glue::SessionHL, memgraph::dbms::SessionContext>;
namespace memgraph::auth {
class Auth;
}
namespace memgraph::utils {
class WritePrioritizedRWLock;
}
struct Context {
memgraph::query::InterpreterContext *ic;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
#if MG_ENTERPRISE
memgraph::audit::Log *audit_log;
#endif
};
extern template class memgraph::communication::v2::Server<memgraph::glue::SessionHL, Context>;
namespace memgraph::glue {
#ifdef MG_ENTERPRISE
using ServerT = memgraph::communication::v2::Server<memgraph::glue::SessionHL, memgraph::dbms::SessionContextHandler>;
#else
using ServerT = memgraph::communication::v2::Server<memgraph::glue::SessionHL, memgraph::dbms::SessionContext>;
#endif
using ServerT = memgraph::communication::v2::Server<memgraph::glue::SessionHL, Context>;
} // namespace memgraph::glue

View File

@ -9,19 +9,21 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "glue/SessionHL.hpp"
#include <optional>
#include "gflags/gflags.h"
#include "audit/log.hpp"
#include "dbms/constants.hpp"
#include "flags/run_time_configurable.hpp"
#include "glue/SessionHL.hpp"
#include "glue/auth_checker.hpp"
#include "glue/communication.hpp"
#include "glue/run_id.hpp"
#include "license/license.hpp"
#include "query/discard_value_stream.hpp"
#include "query/interpreter_context.hpp"
#include "utils/spin_lock.hpp"
#include "gflags/gflags.h"
namespace memgraph::metrics {
extern const Event ActiveBoltSessions;
} // namespace memgraph::metrics
@ -47,14 +49,14 @@ auto ToQueryExtras(const memgraph::communication::bolt::Value &extra) -> memgrap
class TypedValueResultStreamBase {
public:
explicit TypedValueResultStreamBase(memgraph::query::InterpreterContext *interpreterContext);
explicit TypedValueResultStreamBase(memgraph::storage::Storage *storage);
std::vector<memgraph::communication::bolt::Value> DecodeValues(
const std::vector<memgraph::query::TypedValue> &values) const;
private:
protected:
// NOTE: Needed only for ToBoltValue conversions
memgraph::query::InterpreterContext *interpreter_context_;
memgraph::storage::Storage *storage_;
};
/// Wrapper around TEncoder which converts TypedValue to Value
@ -62,8 +64,8 @@ class TypedValueResultStreamBase {
template <typename TEncoder>
class TypedValueResultStream : public TypedValueResultStreamBase {
public:
TypedValueResultStream(TEncoder *encoder, memgraph::query::InterpreterContext *ic)
: TypedValueResultStreamBase{ic}, encoder_(encoder) {}
TypedValueResultStream(TEncoder *encoder, memgraph::storage::Storage *storage)
: TypedValueResultStreamBase{storage}, encoder_(encoder) {}
void Result(const std::vector<memgraph::query::TypedValue> &values) { encoder_->MessageRecord(DecodeValues(values)); }
@ -76,7 +78,7 @@ std::vector<memgraph::communication::bolt::Value> TypedValueResultStreamBase::De
std::vector<memgraph::communication::bolt::Value> decoded_values;
decoded_values.reserve(values.size());
for (const auto &v : values) {
auto maybe_value = memgraph::glue::ToBoltValue(v, *interpreter_context_->db, memgraph::storage::View::NEW);
auto maybe_value = memgraph::glue::ToBoltValue(v, *storage_, memgraph::storage::View::NEW);
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case memgraph::storage::Error::DELETED_OBJECT:
@ -93,33 +95,13 @@ std::vector<memgraph::communication::bolt::Value> TypedValueResultStreamBase::De
}
return decoded_values;
}
TypedValueResultStreamBase::TypedValueResultStreamBase(memgraph::query::InterpreterContext *interpreterContext)
: interpreter_context_(interpreterContext) {}
TypedValueResultStreamBase::TypedValueResultStreamBase(memgraph::storage::Storage *storage) : storage_(storage) {}
namespace memgraph::glue {
#ifdef MG_ENTERPRISE
void SessionHL::UpdateAndDefunct(const std::string &db_name) {
UpdateAndDefunct(ContextWrapper(sc_handler_.Get(db_name)));
}
void SessionHL::UpdateAndDefunct(ContextWrapper &&cntxt) {
defunct_.emplace(std::move(current_));
Update(std::forward<ContextWrapper>(cntxt));
defunct_->Defunct();
}
void SessionHL::Update(const std::string &db_name) {
ContextWrapper tmp(sc_handler_.Get(db_name));
Update(std::move(tmp));
}
void SessionHL::Update(ContextWrapper &&cntxt) {
current_ = std::move(cntxt);
interpreter_ = current_.interp();
interpreter_->in_explicit_db_ = in_explicit_db_;
interpreter_context_ = current_.interpreter_context();
}
void SessionHL::MultiDatabaseAuth(const std::string &db) {
if (user_ && !AuthChecker::IsUserAuthorized(*user_, {}, db)) {
inline static void MultiDatabaseAuth(const std::optional<auth::User> &user, std::string_view db) {
if (user && !AuthChecker::IsUserAuthorized(*user, {}, std::string(db))) {
throw memgraph::communication::bolt::ClientError(
"You are not authorized on the database \"{}\"! Please contact your database administrator.", db);
}
@ -130,23 +112,13 @@ std::string SessionHL::GetDefaultDB() {
}
return memgraph::dbms::kDefaultDB;
}
bool SessionHL::OnDelete(const std::string &db_name) {
MG_ASSERT(current_.interpreter_context()->db->id() != db_name && (!defunct_ || defunct_->defunct()),
"Trying to delete a database while still in use.");
return true;
}
memgraph::dbms::SetForResult SessionHL::OnChange(const std::string &db_name) {
MultiDatabaseAuth(db_name);
if (db_name != current_.interpreter_context()->db->id()) {
UpdateAndDefunct(db_name); // Done during Pull, so we cannot just replace the current db
return memgraph::dbms::SetForResult::SUCCESS;
}
return memgraph::dbms::SetForResult::ALREADY_SET;
}
#endif
std::string SessionHL::GetDatabaseName() const { return interpreter_context_->db->id(); }
std::string SessionHL::GetDatabaseName() const {
if (!interpreter_.db_acc_) return "";
const auto *db = interpreter_.db_acc_->get();
return db->id();
}
std::optional<std::string> SessionHL::GetServerNameForInit() {
auto locked_name = flags::run_time::bolt_server_name_.Lock();
@ -154,30 +126,29 @@ std::optional<std::string> SessionHL::GetServerNameForInit() {
}
bool SessionHL::Authenticate(const std::string &username, const std::string &password) {
auto locked_auth = auth_->Lock();
if (!locked_auth->HasUsers()) {
return true;
}
user_ = locked_auth->Authenticate(username, password);
#ifdef MG_ENTERPRISE
if (user_.has_value()) {
const auto &db = user_->db_access().GetDefault();
// Check if the underlying database needs to be updated
if (db != current_.interpreter_context()->db->id()) {
const auto &res = sc_handler_.SetFor(UUID(), db);
return res == memgraph::dbms::SetForResult::SUCCESS || res == memgraph::dbms::SetForResult::ALREADY_SET;
bool res = true;
{
auto locked_auth = auth_->Lock();
if (locked_auth->HasUsers()) {
user_ = locked_auth->Authenticate(username, password);
res = user_.has_value();
}
}
#ifdef MG_ENTERPRISE
// Start off with the default database
interpreter_.SetCurrentDB(GetDefaultDB());
#endif
return user_.has_value();
implicit_db_.emplace(GetDatabaseName());
return res;
}
void SessionHL::Abort() { interpreter_->Abort(); }
void SessionHL::Abort() { interpreter_.Abort(); }
std::map<std::string, memgraph::communication::bolt::Value> SessionHL::Discard(std::optional<int> n,
std::optional<int> qid) {
try {
memgraph::query::DiscardValueResultStream stream;
return DecodeSummary(interpreter_->Pull(&stream, n, qid));
return DecodeSummary(interpreter_.Pull(&stream, n, qid));
} catch (const memgraph::query::QueryException &e) {
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
@ -187,15 +158,18 @@ std::map<std::string, memgraph::communication::bolt::Value> SessionHL::Discard(s
std::map<std::string, memgraph::communication::bolt::Value> SessionHL::Pull(SessionHL::TEncoder *encoder,
std::optional<int> n,
std::optional<int> qid) {
// TODO: Update once interpreter can handle non-database queries (db_acc will be nullopt)
auto *db = interpreter_.db_acc_->get();
try {
TypedValueResultStream<TEncoder> stream(encoder, interpreter_context_);
return DecodeSummary(interpreter_->Pull(&stream, n, qid));
TypedValueResultStream<TEncoder> stream(encoder, db->storage());
return DecodeSummary(interpreter_.Pull(&stream, n, qid));
} catch (const memgraph::query::QueryException &e) {
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
}
}
std::pair<std::vector<std::string>, std::optional<int>> SessionHL::Interpret(
const std::string &query, const std::map<std::string, memgraph::communication::bolt::Value> &params,
const std::map<std::string, memgraph::communication::bolt::Value> &extra) {
@ -209,16 +183,18 @@ std::pair<std::vector<std::string>, std::optional<int>> SessionHL::Interpret(
}
#ifdef MG_ENTERPRISE
// TODO: Update once interpreter can handle non-database queries (db_acc will be nullopt)
auto *db = interpreter_.db_acc_->get();
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query,
memgraph::storage::PropertyValue(params_pv), interpreter_context_->db->id());
memgraph::storage::PropertyValue(params_pv), db->id());
}
#endif
try {
auto result = interpreter_->Prepare(query, params_pv, username, ToQueryExtras(extra), UUID());
auto result = interpreter_.Prepare(query, params_pv, username, ToQueryExtras(extra), UUID());
const std::string db_name = result.db ? *result.db : "";
if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges, db_name)) {
interpreter_->Abort();
interpreter_.Abort();
if (db_name.empty()) {
throw memgraph::communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact your database administrator.");
@ -238,10 +214,10 @@ std::pair<std::vector<std::string>, std::optional<int>> SessionHL::Interpret(
throw memgraph::communication::bolt::ClientError(e.what());
}
}
void SessionHL::RollbackTransaction() { interpreter_->RollbackTransaction(); }
void SessionHL::CommitTransaction() { interpreter_->CommitTransaction(); }
void SessionHL::RollbackTransaction() { interpreter_.RollbackTransaction(); }
void SessionHL::CommitTransaction() { interpreter_.CommitTransaction(); }
void SessionHL::BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &extra) {
interpreter_->BeginTransaction(ToQueryExtras(extra));
interpreter_.BeginTransaction(ToQueryExtras(extra));
}
void SessionHL::Configure(const std::map<std::string, memgraph::communication::bolt::Value> &run_time_info) {
#ifdef MG_ENTERPRISE
@ -254,108 +230,68 @@ void SessionHL::Configure(const std::map<std::string, memgraph::communication::b
throw memgraph::communication::bolt::ClientError("Malformed database name.");
}
db = db_info.ValueString();
update = db != current_.interpreter_context()->db->id();
const auto &current = GetDatabaseName();
update = db != current;
if (!in_explicit_db_) implicit_db_.emplace(current); // Still not in an explicit database, save for recovery
in_explicit_db_ = true;
// NOTE: Once in a transaction, the drivers stop explicitly sending the db and count on using it until commit
} else if (in_explicit_db_ && !interpreter_->in_explicit_transaction_) { // Just on a switch
db = GetDefaultDB();
update = db != current_.interpreter_context()->db->id();
} else if (in_explicit_db_ && !interpreter_.in_explicit_transaction_) { // Just on a switch
if (implicit_db_) {
db = *implicit_db_;
} else {
db = GetDefaultDB();
}
update = db != GetDatabaseName();
in_explicit_db_ = false;
}
// Check if the underlying database needs to be updated
if (update) {
sc_handler_.SetInPlace(db, [this](auto new_sc) mutable {
const auto &db_name = new_sc.interpreter_context->db->id();
MultiDatabaseAuth(db_name);
try {
Update(ContextWrapper(new_sc));
return memgraph::dbms::SetForResult::SUCCESS;
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw memgraph::communication::bolt::ClientError("No database named \"{}\" found!", db_name);
}
});
MultiDatabaseAuth(user_, db);
interpreter_.SetCurrentDB(db);
}
#endif
}
SessionHL::~SessionHL() { memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions); }
SessionHL::SessionHL(
SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context,
const memgraph::communication::v2::ServerEndpoint &endpoint,
memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
#ifdef MG_ENTERPRISE
memgraph::dbms::SessionContextHandler &sc_handler,
#else
memgraph::dbms::SessionContext sc,
,
memgraph::audit::Log *audit_log
#endif
const memgraph::communication::v2::ServerEndpoint &endpoint, memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream, const std::string &default_db) // NOLINT
)
: Session<memgraph::communication::v2::InputStream, memgraph::communication::v2::OutputStream>(input_stream,
output_stream),
interpreter_context_(interpreter_context),
interpreter_(interpreter_context_),
#ifdef MG_ENTERPRISE
sc_handler_(sc_handler),
current_(sc_handler_.Get(default_db)),
#else
current_(sc),
#endif
interpreter_context_(current_.interpreter_context()),
interpreter_(current_.interp()),
auth_(current_.auth()),
#ifdef MG_ENTERPRISE
audit_log_(current_.audit_log()),
audit_log_(audit_log),
#endif
auth_(auth),
endpoint_(endpoint),
run_id_(current_.run_id()) {
implicit_db_(dbms::kDefaultDB) {
// Metrics update
memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions);
#ifdef MG_ENTERPRISE
interpreter_.OnChangeCB([&](std::string_view db_name) { MultiDatabaseAuth(user_, db_name); });
#endif
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); });
}
/// ContextWrapper
ContextWrapper::ContextWrapper(memgraph::dbms::SessionContext sc)
: session_context(sc),
interpreter(std::make_unique<memgraph::query::Interpreter>(session_context.interpreter_context.get())),
defunct_(false) {
session_context.interpreter_context->interpreters.WithLock(
[this](auto &interpreters) { interpreters.insert(interpreter.get()); });
SessionHL::~SessionHL() {
memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions);
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); });
}
ContextWrapper::~ContextWrapper() { Defunct(); }
void ContextWrapper::Defunct() {
if (!defunct_) {
session_context.interpreter_context->interpreters.WithLock(
[this](auto &interpreters) { interpreters.erase(interpreter.get()); });
defunct_ = true;
}
}
ContextWrapper::ContextWrapper(ContextWrapper &&in) noexcept
: session_context(std::move(in.session_context)), interpreter(std::move(in.interpreter)), defunct_(in.defunct_) {
in.defunct_ = true;
}
ContextWrapper &ContextWrapper::operator=(ContextWrapper &&in) noexcept {
if (this != &in) {
Defunct();
session_context = std::move(in.session_context);
interpreter = std::move(in.interpreter);
defunct_ = in.defunct_;
in.defunct_ = true;
}
return *this;
}
memgraph::query::InterpreterContext *ContextWrapper::interpreter_context() {
return session_context.interpreter_context.get();
}
memgraph::query::Interpreter *ContextWrapper::interp() { return interpreter.get(); }
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *ContextWrapper::auth()
const {
return session_context.auth;
}
std::string ContextWrapper::run_id() const { return session_context.run_id; }
bool ContextWrapper::defunct() const { return defunct_; }
#ifdef MG_ENTERPRISE
memgraph::audit::Log *ContextWrapper::audit_log() const { return session_context.audit_log; }
#endif
std::map<std::string, memgraph::communication::bolt::Value> SessionHL::DecodeSummary(
const std::map<std::string, memgraph::query::TypedValue> &summary) {
// TODO: Update once interpreter can handle non-database queries (db_acc will be nullopt)
auto *db = interpreter_.db_acc_->get();
std::map<std::string, memgraph::communication::bolt::Value> decoded_summary;
for (const auto &kv : summary) {
auto maybe_value = ToBoltValue(kv.second, *interpreter_context_->db, memgraph::storage::View::NEW);
auto maybe_value = ToBoltValue(kv.second, *db->storage(), memgraph::storage::View::NEW);
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case memgraph::storage::Error::DELETED_OBJECT:
@ -372,14 +308,7 @@ std::map<std::string, memgraph::communication::bolt::Value> SessionHL::DecodeSum
// This is sent with every query, instead of only on bolt init inside
// communication/bolt/v1/states/init.hpp because neo4jdriver does not
// read the init message.
if (auto run_id = run_id_; run_id) {
decoded_summary.emplace("run_id", *run_id);
}
// Clean up previous session (session gets defunct when switching between databases)
if (defunct_) {
defunct_.reset();
}
decoded_summary.emplace("run_id", memgraph::glue::run_id_);
return decoded_summary;
}

View File

@ -10,56 +10,28 @@
// licenses/APL.txt.
#pragma once
#include "audit/log.hpp"
#include "auth/auth.hpp"
#include "communication/v2/server.hpp"
#include "communication/v2/session.hpp"
#include "dbms/session_context.hpp"
#ifdef MG_ENTERPRISE
#include "dbms/session_context_handler.hpp"
#else
#include "dbms/session_context.hpp"
#endif
#include "dbms/database.hpp"
#include "query/interpreter.hpp"
namespace memgraph::glue {
struct ContextWrapper {
explicit ContextWrapper(memgraph::dbms::SessionContext sc);
~ContextWrapper();
ContextWrapper(const ContextWrapper &) = delete;
ContextWrapper &operator=(const ContextWrapper &) = delete;
ContextWrapper(ContextWrapper &&in) noexcept;
ContextWrapper &operator=(ContextWrapper &&in) noexcept;
void Defunct();
memgraph::query::InterpreterContext *interpreter_context();
memgraph::query::Interpreter *interp();
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth() const;
std::string run_id() const;
bool defunct() const;
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log() const;
#endif
private:
memgraph::dbms::SessionContext session_context;
std::unique_ptr<memgraph::query::Interpreter> interpreter;
bool defunct_;
};
class SessionHL final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream> {
public:
SessionHL(
SessionHL(memgraph::query::InterpreterContext *interpreter_context,
const memgraph::communication::v2::ServerEndpoint &endpoint,
memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
#ifdef MG_ENTERPRISE
memgraph::dbms::SessionContextHandler &sc_handler,
#else
memgraph::dbms::SessionContext sc,
,
memgraph::audit::Log *audit_log
#endif
const memgraph::communication::v2::ServerEndpoint &endpoint,
memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::OutputStream *output_stream,
const std::string &default_db = memgraph::dbms::kDefaultDB);
);
~SessionHL() override;
@ -92,14 +64,8 @@ class SessionHL final : public memgraph::communication::bolt::Session<memgraph::
void Abort() override;
// Called during Init
// During Init, the user cannot choose the landing DB (switch is done during query execution)
bool Authenticate(const std::string &username, const std::string &password) override;
#ifdef MG_ENTERPRISE
memgraph::dbms::SetForResult OnChange(const std::string &db_name) override;
bool OnDelete(const std::string &db_name) override;
#endif
std::optional<std::string> GetServerNameForInit() override;
std::string GetDatabaseName() const override;
@ -108,54 +74,23 @@ class SessionHL final : public memgraph::communication::bolt::Session<memgraph::
std::map<std::string, memgraph::communication::bolt::Value> DecodeSummary(
const std::map<std::string, memgraph::query::TypedValue> &summary);
#ifdef MG_ENTERPRISE
/**
* @brief Update setup to the new database.
*
* @param db_name name of the target database
* @throws UnknownDatabaseException if handler cannot get it
*/
void UpdateAndDefunct(const std::string &db_name);
void UpdateAndDefunct(ContextWrapper &&cntxt);
void Update(const std::string &db_name);
void Update(ContextWrapper &&cntxt);
/**
* @brief Authenticate user on passed database.
*
* @param db database to check against
* @throws bolt::ClientError when user is not authorized
*/
void MultiDatabaseAuth(const std::string &db);
/**
* @brief Get the user's default database
*
* @return std::string
*/
std::string GetDefaultDB();
#endif
#ifdef MG_ENTERPRISE
memgraph::dbms::SessionContextHandler &sc_handler_;
#endif
ContextWrapper current_;
std::optional<ContextWrapper> defunct_;
memgraph::query::InterpreterContext *interpreter_context_;
memgraph::query::Interpreter *interpreter_;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
memgraph::query::Interpreter interpreter_;
std::optional<memgraph::auth::User> user_;
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log_;
bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata
#endif
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
memgraph::communication::v2::ServerEndpoint endpoint_;
// NOTE: run_id should be const but that complicates code a lot.
std::optional<std::string> run_id_;
std::optional<std::string> implicit_db_;
};
} // namespace memgraph::glue

View File

@ -406,6 +406,14 @@ bool AuthQueryHandler::SetMainDatabase(const std::string &db, const std::string
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void AuthQueryHandler::DeleteDatabase(std::string_view db) {
try {
auth_->Lock()->DeleteDatabase(std::string(db));
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
#endif
bool AuthQueryHandler::DropRole(const std::string &rolename) {

View File

@ -17,7 +17,7 @@
#include "auth_global.hpp"
#include "glue/auth.hpp"
#include "license/license.hpp"
#include "query/interpreter.hpp"
#include "query/auth_query_handler.hpp"
#include "utils/string.hpp"
namespace memgraph::glue {
@ -45,6 +45,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) override;
bool SetMainDatabase(const std::string &db, const std::string &username) override;
void DeleteDatabase(std::string_view db) override;
#endif
bool CreateRole(const std::string &rolename) override;

15
src/glue/run_id.cpp Normal file
View File

@ -0,0 +1,15 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "glue/run_id.hpp"
#include "utils/uuid.hpp"
const std::string memgraph::glue::run_id_ = memgraph::utils::GenerateUUID();

16
src/glue/run_id.hpp Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
namespace memgraph::glue {
extern const std::string run_id_;
} // namespace memgraph::glue

View File

@ -47,10 +47,9 @@ struct MetricsResponse {
std::vector<std::tuple<std::string, std::string, uint64_t>> event_histograms{};
};
template <typename TSessionContext>
class MetricsService {
public:
explicit MetricsService(TSessionContext *session_context) : db_(session_context->interpreter_context->db.get()) {}
explicit MetricsService(storage::Storage *storage) : db_(storage) {}
nlohmann::json GetMetricsJSON() {
auto response = GetMetrics();
@ -98,7 +97,7 @@ class MetricsService {
return metrics_response;
}
auto GetEventCounters() {
inline static std::vector<std::tuple<std::string, std::string, uint64_t>> GetEventCounters() {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<std::tuple<std::string, std::string, uint64_t>> event_counters{};
event_counters.reserve(memgraph::metrics::CounterEnd());
@ -111,7 +110,7 @@ class MetricsService {
return event_counters;
}
auto GetEventGauges() {
inline static std::vector<std::tuple<std::string, std::string, uint64_t>> GetEventGauges() {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<std::tuple<std::string, std::string, uint64_t>> event_gauges{};
event_gauges.reserve(memgraph::metrics::GaugeEnd());
@ -124,7 +123,7 @@ class MetricsService {
return event_gauges;
}
auto GetEventHistograms() {
inline static std::vector<std::tuple<std::string, std::string, uint64_t>> GetEventHistograms() {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<std::tuple<std::string, std::string, uint64_t>> event_histograms{};
@ -143,10 +142,11 @@ class MetricsService {
}
};
template <typename TSessionContext>
// TODO: Should this be inside Database?
// Raw pointer could be dangerous
class MetricsRequestHandler final {
public:
explicit MetricsRequestHandler(TSessionContext *session_context) : service_(session_context) {
explicit MetricsRequestHandler(storage::Storage *storage) : service_(storage) {
spdlog::info("Basic request handler started!");
}
@ -208,6 +208,6 @@ class MetricsRequestHandler final {
}
private:
MetricsService<TSessionContext> service_;
MetricsService service_;
};
} // namespace memgraph::http

View File

@ -9,22 +9,22 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "flags/run_time_configurable.hpp"
#ifndef MG_ENTERPRISE
#include "dbms/session_context_handler.hpp"
#endif
#include "audit/log.hpp"
#include "communication/websocket/auth.hpp"
#include "communication/websocket/server.hpp"
#include "dbms/constants.hpp"
#include "flags/all.hpp"
#include "flags/run_time_configurable.hpp"
#include "glue/MonitoringServerT.hpp"
#include "glue/ServerT.hpp"
#include "glue/auth_checker.hpp"
#include "glue/auth_handler.hpp"
#include "glue/run_id.hpp"
#include "helpers.hpp"
#include "license/license_sender.hpp"
#include "query/config.hpp"
#include "query/discard_value_stream.hpp"
#include "query/interpreter.hpp"
#include "query/procedure/callable_alias_mapper.hpp"
#include "query/procedure/module.hpp"
#include "query/procedure/py_module.hpp"
@ -36,13 +36,18 @@
#include "utils/terminate_handler.hpp"
#include "version.hpp"
#include "dbms/dbms_handler.hpp"
#include "query/auth_query_handler.hpp"
#include "query/interpreter_context.hpp"
constexpr const char *kMgUser = "MEMGRAPH_USER";
constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD";
constexpr const char *kMgPassfile = "MEMGRAPH_PASSFILE";
void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, std::string cypherl_file_path,
memgraph::audit::Log *audit_log = nullptr) {
memgraph::query::Interpreter interpreter(&ctx);
// TODO: move elsewhere so that we can remove need of interpreter.hpp
void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, memgraph::dbms::DatabaseAccess &db_acc,
std::string cypherl_file_path, memgraph::audit::Log *audit_log = nullptr) {
memgraph::query::Interpreter interpreter(&ctx, db_acc);
std::ifstream file(cypherl_file_path);
if (!file.is_open()) {
@ -198,6 +203,22 @@ int main(int argc, char **argv) {
auto data_directory = std::filesystem::path(FLAGS_data_directory);
memgraph::utils::EnsureDirOrDie(data_directory);
// Verify that the user that started the process is the same user that is
// the owner of the storage directory.
memgraph::storage::durability::VerifyStorageDirectoryOwnerAndProcessUserOrDie(data_directory);
// Create the lock file and open a handle to it. This will crash the
// database if it can't open the file for writing or if any other process is
// holding the file opened.
memgraph::utils::OutputFile lock_file_handle;
lock_file_handle.Open(data_directory / ".lock", memgraph::utils::OutputFile::Mode::OVERWRITE_EXISTING);
MG_ASSERT(lock_file_handle.AcquireLock(),
"Couldn't acquire lock on the storage directory {}"
"!\nAnother Memgraph process is currently running with the same "
"storage directory, please stop it first before starting this "
"process!",
data_directory);
const auto memory_limit = memgraph::flags::GetMemoryLimit();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
spdlog::info("Memory limit in config is set to {}", memgraph::utils::GetReadableSize(memory_limit));
@ -320,60 +341,63 @@ int main(int argc, char **argv) {
}
};
#ifdef MG_ENTERPRISE
// SessionContext handler (multi-tenancy)
memgraph::dbms::SessionContextHandler sc_handler(audit_log, {db_config, interp_config, auth_glue},
FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup,
FLAGS_storage_delete_on_drop);
// Just for current support... TODO remove
auto session_context = sc_handler.Get(memgraph::dbms::kDefaultDB);
#else
// WIP
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth_{data_directory /
"auth"};
std::unique_ptr<memgraph::query::AuthQueryHandler> auth_handler;
std::unique_ptr<memgraph::query::AuthChecker> auth_checker;
auth_glue(&auth_, auth_handler, auth_checker);
auto session_context = memgraph::dbms::Init(db_config, interp_config, &auth_, auth_handler.get(), auth_checker.get());
#ifdef MG_ENTERPRISE
memgraph::dbms::DbmsHandler new_handler(db_config, &auth_, FLAGS_data_recovery_on_startup,
FLAGS_storage_delete_on_drop);
auto db_acc = new_handler.Get(memgraph::dbms::kDefaultDB);
memgraph::query::InterpreterContext interpreter_context_(interp_config, &new_handler, auth_handler.get(),
auth_checker.get());
#else
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gatekeeper{db_config};
auto db_acc_opt = db_gatekeeper.access();
MG_ASSERT(db_acc_opt, "Failed to access the main database");
auto &db_acc = *db_acc_opt;
memgraph::query::InterpreterContext interpreter_context_(interp_config, nullptr, auth_handler.get(),
auth_checker.get());
#endif
auto *auth = session_context.auth;
auto &interpreter_context = *session_context.interpreter_context; // TODO remove
MG_ASSERT(db_acc, "Failed to access the main database");
memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(memgraph::flags::ParseQueryModulesDirectory(),
FLAGS_data_directory);
memgraph::query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories();
memgraph::query::procedure::gCallableAliasMapper.LoadMapping(FLAGS_query_callable_mappings_path);
// TODO Make multi-tenant
if (!FLAGS_init_file.empty()) {
spdlog::info("Running init file...");
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
InitFromCypherlFile(interpreter_context, FLAGS_init_file, &audit_log);
InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_file, &audit_log);
} else {
InitFromCypherlFile(interpreter_context, FLAGS_init_file);
InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_file);
}
#else
InitFromCypherlFile(interpreter_context, FLAGS_init_file);
InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_file);
#endif
}
#ifdef MG_ENTERPRISE
sc_handler.RestoreTriggers();
sc_handler.RestoreStreams();
new_handler.RestoreTriggers(&interpreter_context_);
new_handler.RestoreStreams(&interpreter_context_);
#else
{
// Triggers can execute query procedures, so we need to reload the modules first and then
// the triggers
auto storage_accessor = interpreter_context.db->Access();
auto storage_accessor = db_acc->Access();
auto dba = memgraph::query::DbAccessor{storage_accessor.get()};
interpreter_context.trigger_store.RestoreTriggers(
&interpreter_context.ast_cache, &dba, interpreter_context.config.query, interpreter_context.auth_checker);
db_acc->trigger_store()->RestoreTriggers(&interpreter_context_.ast_cache, &dba, interpreter_context_.config.query,
interpreter_context_.auth_checker);
}
// As the Stream transformations are using modules, they have to be restored after the query modules are loaded.
interpreter_context.streams.RestoreStreams();
db_acc->streams()->RestoreStreams(db_acc, &interpreter_context_);
#endif
ServerContext context;
@ -389,29 +413,31 @@ int main(int argc, char **argv) {
auto server_endpoint = memgraph::communication::v2::ServerEndpoint{
boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast<uint16_t>(FLAGS_bolt_port)};
#ifdef MG_ENTERPRISE
memgraph::glue::ServerT server(server_endpoint, &sc_handler, &context, FLAGS_bolt_session_inactivity_timeout,
service_name, FLAGS_bolt_num_workers);
Context session_context{&interpreter_context_, &auth_, &audit_log};
#else
Context session_context{&interpreter_context_, &auth_};
#endif
memgraph::glue::ServerT server(server_endpoint, &session_context, &context, FLAGS_bolt_session_inactivity_timeout,
service_name, FLAGS_bolt_num_workers);
#endif
const auto machine_id = memgraph::utils::GetMachineId();
const auto run_id = session_context.run_id; // For current compatibility
// Setup telemetry
static constexpr auto telemetry_server{"https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/"};
std::optional<memgraph::telemetry::Telemetry> telemetry;
if (FLAGS_telemetry_enabled) {
telemetry.emplace(telemetry_server, data_directory / "telemetry", run_id, machine_id, std::chrono::minutes(10));
telemetry.emplace(telemetry_server, data_directory / "telemetry", memgraph::glue::run_id_, machine_id,
std::chrono::minutes(10));
#ifdef MG_ENTERPRISE
telemetry->AddCollector("storage", [&sc_handler]() -> nlohmann::json {
const auto &info = sc_handler.Info();
telemetry->AddCollector("storage", [&new_handler]() -> nlohmann::json {
const auto &info = new_handler.Info();
return {{"vertices", info.num_vertex}, {"edges", info.num_edges}, {"databases", info.num_databases}};
});
#else
telemetry->AddCollector("storage", [&interpreter_context]() -> nlohmann::json {
auto info = interpreter_context.db->GetInfo();
telemetry->AddCollector("storage", [gk = &db_gatekeeper]() -> nlohmann::json {
auto db_acc = gk->access();
MG_ASSERT(db_acc, "Failed to get access to the default database");
auto info = db_acc->get()->GetInfo();
return {{"vertices", info.vertex_count}, {"edges", info.edge_count}};
});
#endif
@ -427,67 +453,46 @@ int main(int argc, char **argv) {
return memgraph::query::plan::CallProcedure::GetAndResetCounters();
});
}
memgraph::license::LicenseInfoSender license_info_sender(telemetry_server, run_id, machine_id, memory_limit,
memgraph::license::LicenseInfoSender license_info_sender(telemetry_server, memgraph::glue::run_id_, machine_id,
memory_limit,
memgraph::license::global_license_checker.GetLicenseInfo());
memgraph::communication::websocket::SafeAuth websocket_auth{auth};
memgraph::communication::websocket::SafeAuth websocket_auth{&auth_};
memgraph::communication::websocket::Server websocket_server{
{FLAGS_monitoring_address, static_cast<uint16_t>(FLAGS_monitoring_port)}, &context, websocket_auth};
memgraph::flags::AddLoggerSink(websocket_server.GetLoggingSink());
memgraph::glue::MonitoringServerT metrics_server{
{FLAGS_metrics_address, static_cast<uint16_t>(FLAGS_metrics_port)}, &session_context, &context};
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
// Handler for regular termination signals
auto shutdown = [&metrics_server, &websocket_server, &server, &sc_handler] {
// Server needs to be shutdown first and then the database. This prevents
// a race condition when a transaction is accepted during server shutdown.
server.Shutdown();
// After the server is notified to stop accepting and processing
// connections we tell the execution engine to stop processing all pending
// queries.
sc_handler.Shutdown();
// TODO: Make multi-tenant
memgraph::glue::MonitoringServerT metrics_server{
{FLAGS_metrics_address, static_cast<uint16_t>(FLAGS_metrics_port)}, db_acc->storage(), &context};
#endif
websocket_server.Shutdown();
metrics_server.Shutdown();
};
InitSignalHandlers(shutdown);
} else {
// Handler for regular termination signals
auto shutdown = [&websocket_server, &server, &interpreter_context] {
// Server needs to be shutdown first and then the database. This prevents
// a race condition when a transaction is accepted during server shutdown.
server.Shutdown();
// After the server is notified to stop accepting and processing
// connections we tell the execution engine to stop processing all pending
// queries.
memgraph::query::Shutdown(&interpreter_context);
websocket_server.Shutdown();
};
InitSignalHandlers(shutdown);
}
#else
// Handler for regular termination signals
auto shutdown = [&websocket_server, &server, &interpreter_context] {
auto shutdown = [
#ifdef MG_ENTERPRISE
&metrics_server,
#endif
&websocket_server, &server, &interpreter_context_] {
// Server needs to be shutdown first and then the database. This prevents
// a race condition when a transaction is accepted during server shutdown.
server.Shutdown();
// After the server is notified to stop accepting and processing
// connections we tell the execution engine to stop processing all pending
// queries.
memgraph::query::Shutdown(&interpreter_context);
interpreter_context_.Shutdown();
websocket_server.Shutdown();
#ifdef MG_ENTERPRISE
metrics_server.Shutdown();
#endif
};
InitSignalHandlers(shutdown);
#endif
// Release the temporary database access
db_acc.reset();
// Startup the main server
MG_ASSERT(server.Start(), "Couldn't start the Bolt server!");
websocket_server.Start();
@ -500,13 +505,16 @@ int main(int argc, char **argv) {
if (!FLAGS_init_data_file.empty()) {
spdlog::info("Running init data file.");
#ifdef MG_ENTERPRISE
auto db_acc = new_handler.Get(memgraph::dbms::kDefaultDB);
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
InitFromCypherlFile(interpreter_context, FLAGS_init_data_file, &audit_log);
InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_data_file, &audit_log);
} else {
InitFromCypherlFile(interpreter_context, FLAGS_init_data_file);
InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_data_file);
}
#else
InitFromCypherlFile(interpreter_context, FLAGS_init_data_file);
auto db_acc_2 = db_gatekeeper.access();
MG_ASSERT(db_acc_2, "Failed to gain access to the main database");
InitFromCypherlFile(interpreter_context_, *db_acc_2, FLAGS_init_data_file);
#endif
}

View File

@ -36,7 +36,9 @@ set(mg_query_sources
trigger_context.cpp
typed_value.cpp
graph.cpp
db_accessor.cpp)
db_accessor.cpp
auth_query_handler.cpp
interpreter_context.cpp)
add_library(mg-query STATIC ${mg_query_sources})
target_include_directories(mg-query PUBLIC ${CMAKE_SOURCE_DIR}/include)
@ -51,7 +53,8 @@ target_link_libraries(mg-query PUBLIC dl
mg-kvstore
mg-memory
mg::csv
mg-flags)
mg-flags
mg-dbms)
if(NOT "${MG_PYTHON_PATH}" STREQUAL "")
set(Python3_ROOT_DIR "${MG_PYTHON_PATH}")
endif()

View File

@ -11,7 +11,11 @@
#pragma once
#include "query/db_accessor.hpp"
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "query/frontend/ast/ast.hpp"
#include "storage/v2/id_types.hpp"
@ -19,17 +23,19 @@ namespace memgraph::query {
class FineGrainedAuthChecker;
class DbAccessor;
class AuthChecker {
public:
virtual ~AuthChecker() = default;
[[nodiscard]] virtual bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges,
const std::vector<AuthQuery::Privilege> &privileges,
const std::string &db_name) const = 0;
#ifdef MG_ENTERPRISE
[[nodiscard]] virtual std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *db_accessor) const = 0;
const std::string &username, const DbAccessor *db_accessor) const = 0;
virtual void ClearCache() const = 0;
#endif
@ -39,75 +45,73 @@ class FineGrainedAuthChecker {
public:
virtual ~FineGrainedAuthChecker() = default;
[[nodiscard]] virtual bool Has(const query::VertexAccessor &vertex, memgraph::storage::View view,
query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
[[nodiscard]] virtual bool Has(const VertexAccessor &vertex, memgraph::storage::View view,
AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
[[nodiscard]] virtual bool Has(const query::EdgeAccessor &edge,
query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
[[nodiscard]] virtual bool Has(const EdgeAccessor &edge,
AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
[[nodiscard]] virtual bool Has(const std::vector<memgraph::storage::LabelId> &labels,
query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
[[nodiscard]] virtual bool Has(const memgraph::storage::EdgeTypeId &edge_type,
query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
[[nodiscard]] virtual bool HasGlobalPrivilegeOnVertices(
memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
[[nodiscard]] virtual bool HasGlobalPrivilegeOnEdges(
memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0;
};
class AllowEverythingFineGrainedAuthChecker final : public query::FineGrainedAuthChecker {
class AllowEverythingFineGrainedAuthChecker final : public FineGrainedAuthChecker {
public:
bool Has(const VertexAccessor & /*vertex*/, const memgraph::storage::View /*view*/,
const query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
return true;
}
bool Has(const memgraph::query::EdgeAccessor & /*edge*/,
const query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
bool Has(const EdgeAccessor & /*edge*/,
const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
return true;
}
bool Has(const std::vector<memgraph::storage::LabelId> & /*labels*/,
const query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
return true;
}
bool Has(const memgraph::storage::EdgeTypeId & /*edge_type*/,
const query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
return true;
}
bool HasGlobalPrivilegeOnVertices(
const memgraph::query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
bool HasGlobalPrivilegeOnVertices(const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
return true;
}
bool HasGlobalPrivilegeOnEdges(
const memgraph::query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
bool HasGlobalPrivilegeOnEdges(const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override {
return true;
}
}; // namespace memgraph::query
};
#endif
class AllowEverythingAuthChecker final : public query::AuthChecker {
class AllowEverythingAuthChecker final : public AuthChecker {
public:
bool IsUserAuthorized(const std::optional<std::string> & /*username*/,
const std::vector<query::AuthQuery::Privilege> & /*privileges*/,
const std::vector<AuthQuery::Privilege> & /*privileges*/,
const std::string & /*db*/) const override {
return true;
}
#ifdef MG_ENTERPRISE
std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker(const std::string & /*username*/,
const query::DbAccessor * /*dba*/) const override {
const DbAccessor * /*dba*/) const override {
return std::make_unique<AllowEverythingFineGrainedAuthChecker>();
}
void ClearCache() const override {}
#endif
}; // namespace memgraph::query
};
} // namespace memgraph::query

View File

@ -0,0 +1,12 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/auth_query_handler.hpp"

View File

@ -0,0 +1,126 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <optional>
#include <string>
#include <string_view>
#include <vector>
#include "query/frontend/ast/ast.hpp" // overkill
#include "query/typed_value.hpp"
namespace memgraph::query {
class AuthQueryHandler {
public:
AuthQueryHandler() = default;
virtual ~AuthQueryHandler() = default;
AuthQueryHandler(const AuthQueryHandler &) = delete;
AuthQueryHandler(AuthQueryHandler &&) = delete;
AuthQueryHandler &operator=(const AuthQueryHandler &) = delete;
AuthQueryHandler &operator=(AuthQueryHandler &&) = delete;
/// Return false if the user already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password) = 0;
/// Return false if the user does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0;
#ifdef MG_ENTERPRISE
/// Return true if access revoked successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0;
/// Return true if access granted successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0;
/// Returns database access rights for the user
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) = 0;
/// Return true if main database set successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool SetMainDatabase(const std::string &db, const std::string &username) = 0;
/// Delete database from all users
/// @throw QueryRuntimeException if an error ocurred.
virtual void DeleteDatabase(std::string_view db) = 0;
#endif
/// Return false if the role already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateRole(const std::string &rolename) = 0;
/// Return false if the role does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<memgraph::query::TypedValue> GetUsernames() = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<memgraph::query::TypedValue> GetRolenames() = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::optional<std::string> GetRolenameForUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetRole(const std::string &username, const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void ClearRole(const std::string &username) = 0;
virtual std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void GrantPrivilege(
const std::string &user_or_role, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges
#ifdef MG_ENTERPRISE
,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&label_privileges,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RevokePrivilege(
const std::string &user_or_role, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges
#ifdef MG_ENTERPRISE
,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&label_privileges,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) = 0;
};
} // namespace memgraph::query

View File

@ -10,6 +10,8 @@
// licenses/APL.txt.
#include "query/cypher_query_interpreter.hpp"
#include "query/frontend/ast/cypher_main_visitor.hpp"
#include "query/frontend/opencypher/parser.hpp"
// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(query_cost_planner, true, "Use the cost-estimating query planner.");

View File

@ -12,8 +12,6 @@
#pragma once
#include "query/config.hpp"
#include "query/frontend/ast/cypher_main_visitor.hpp"
#include "query/frontend/opencypher/parser.hpp"
#include "query/frontend/semantic/required_privileges.hpp"
#include "query/frontend/semantic/symbol_generator.hpp"
#include "query/frontend/stripped.hpp"

View File

@ -158,6 +158,11 @@ class ExplicitTransactionUsageException : public QueryRuntimeException {
using QueryRuntimeException::QueryRuntimeException;
};
class DatabaseContextRequiredException : public QueryRuntimeException {
public:
using QueryRuntimeException::QueryRuntimeException;
};
class WriteVertexOperationInEdgeImportModeException : public QueryException {
public:
WriteVertexOperationInEdgeImportModeException()

File diff suppressed because it is too large Load Diff

View File

@ -15,7 +15,9 @@
#include <gflags/gflags.h>
#include "dbms/database.hpp"
#include "query/auth_checker.hpp"
#include "query/auth_query_handler.hpp"
#include "query/config.hpp"
#include "query/context.hpp"
#include "query/cypher_query_interpreter.hpp"
@ -53,106 +55,11 @@ extern const Event FailedQuery;
namespace memgraph::query {
struct InterpreterContext;
inline constexpr size_t kExecutionMemoryBlockSize = 1UL * 1024UL * 1024UL;
inline constexpr size_t kExecutionPoolMaxBlockSize = 1024UL; // 2 ^ 10
class AuthQueryHandler {
public:
AuthQueryHandler() = default;
virtual ~AuthQueryHandler() = default;
AuthQueryHandler(const AuthQueryHandler &) = delete;
AuthQueryHandler(AuthQueryHandler &&) = delete;
AuthQueryHandler &operator=(const AuthQueryHandler &) = delete;
AuthQueryHandler &operator=(AuthQueryHandler &&) = delete;
/// Return false if the user already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password) = 0;
/// Return false if the user does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0;
#ifdef MG_ENTERPRISE
/// Return true if access revoked successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0;
/// Return true if access granted successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0;
/// Returns database access rights for the user
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) = 0;
/// Return true if main database set successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool SetMainDatabase(const std::string &db, const std::string &username) = 0;
#endif
/// Return false if the role already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateRole(const std::string &rolename) = 0;
/// Return false if the role does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<TypedValue> GetUsernames() = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<TypedValue> GetRolenames() = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::optional<std::string> GetRolenameForUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetRole(const std::string &username, const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void ClearRole(const std::string &username) = 0;
virtual std::vector<std::vector<TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void GrantPrivilege(
const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges
#ifdef MG_ENTERPRISE
,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&label_privileges,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void DenyPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RevokePrivilege(
const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges
#ifdef MG_ENTERPRISE
,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&label_privileges,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) = 0;
};
enum class QueryHandlerResult { COMMIT, ABORT, NOTHING };
class ReplicationQueryHandler {
@ -232,57 +139,10 @@ struct QueryExtras {
std::optional<int64_t> tx_timeout;
};
class Interpreter;
/**
* Holds data shared between multiple `Interpreter` instances (which might be
* running concurrently).
*
*/
/// TODO: andi decouple in a separate file why here?
struct InterpreterContext {
explicit InterpreterContext(storage::Config storage_config, InterpreterConfig interpreter_config,
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr,
query::AuthChecker *ac = nullptr);
InterpreterContext(std::unique_ptr<storage::Storage> &&db, InterpreterConfig interpreter_config,
const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr,
query::AuthChecker *ac = nullptr);
std::unique_ptr<storage::Storage> db;
// ANTLR has singleton instance that is shared between threads. It is
// protected by locks inside of ANTLR. Unfortunately, they are not protected
// in a very good way. Once we have ANTLR version without race conditions we
// can remove this lock. This will probably never happen since ANTLR
// developers introduce more bugs in each version. Fortunately, we have
// cache so this lock probably won't impact performance much...
utils::SpinLock antlr_lock;
std::optional<double> tsc_frequency{utils::GetTSCFrequency()};
std::atomic<bool> is_shutting_down{false};
AuthQueryHandler *auth;
AuthChecker *auth_checker;
utils::SkipList<QueryCacheEntry> ast_cache;
utils::SkipList<PlanCacheEntry> plan_cache;
TriggerStore trigger_store;
utils::ThreadPool after_commit_trigger_pool{1};
const InterpreterConfig config;
query::stream::Streams streams;
utils::Synchronized<std::unordered_set<Interpreter *>, utils::SpinLock> interpreters;
};
/// Function that is used to tell all active interpreters that they should stop
/// their ongoing execution.
inline void Shutdown(InterpreterContext *context) { context->is_shutting_down.store(true, std::memory_order_release); }
class Interpreter final {
public:
explicit Interpreter(InterpreterContext *interpreter_context);
Interpreter(InterpreterContext *interpreter_context);
Interpreter(InterpreterContext *interpreter_context, memgraph::dbms::DatabaseAccess db);
Interpreter(const Interpreter &) = delete;
Interpreter &operator=(const Interpreter &) = delete;
Interpreter(Interpreter &&) = delete;
@ -303,6 +163,14 @@ class Interpreter final {
std::shared_ptr<utils::AsyncTimer> explicit_transaction_timer_{};
std::optional<std::map<std::string, storage::PropertyValue>> metadata_{}; //!< User defined transaction metadata
std::optional<memgraph::dbms::DatabaseAccess> db_acc_; // Current db (TODO: expand to support multiple)
#ifdef MG_ENTERPRISE
void SetCurrentDB(std::string_view db_name);
void SetCurrentDB(memgraph::dbms::DatabaseAccess new_db);
void OnChangeCB(auto cb) { on_change_.emplace(cb); }
#endif
/**
* Prepare a query for execution.
*
@ -435,6 +303,8 @@ class Interpreter final {
// To avoid this, we use unique_ptr with which we manualy control construction
// and deletion of a single query execution, i.e. when a query finishes,
// we reset the corresponding unique_ptr.
// TODO Figure out how this would work for multi-database
// Exists only during a single transaction (for now should be okay as is)
std::vector<std::unique_ptr<QueryExecution>> query_executions_;
// all queries that are run as part of the current transaction
utils::Synchronized<std::vector<std::string>, utils::SpinLock> transaction_queries_;
@ -462,6 +332,8 @@ class Interpreter final {
return std::count_if(query_executions_.begin(), query_executions_.end(),
[](const auto &execution) { return execution && execution->prepared_query; });
}
std::optional<std::function<void(std::string_view)>> on_change_{};
};
class TransactionQueueQueryHandler {
@ -475,13 +347,9 @@ class TransactionQueueQueryHandler {
TransactionQueueQueryHandler(TransactionQueueQueryHandler &&) = default;
TransactionQueueQueryHandler &operator=(TransactionQueueQueryHandler &&) = default;
static std::vector<std::vector<TypedValue>> ShowTransactions(const std::unordered_set<Interpreter *> &interpreters,
const std::optional<std::string> &username,
bool hasTransactionManagementPrivilege);
static std::vector<std::vector<TypedValue>> KillTransactions(
InterpreterContext *interpreter_context, const std::vector<std::string> &maybe_kill_transaction_ids,
const std::optional<std::string> &username, bool hasTransactionManagementPrivilege);
static std::vector<std::vector<TypedValue>> ShowTransactions(
const std::unordered_set<Interpreter *> &interpreters, const std::optional<std::string> &username,
bool hasTransactionManagementPrivilege, std::optional<memgraph::dbms::DatabaseAccess> &filter_db_acc);
};
template <typename TStream>

View File

@ -0,0 +1,73 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/interpreter_context.hpp"
#include "query/interpreter.hpp"
namespace memgraph::query {
std::vector<std::vector<TypedValue>> InterpreterContext::KillTransactions(
std::vector<std::string> maybe_kill_transaction_ids, const std::optional<std::string> &username,
bool hasTransactionManagementPrivilege, Interpreter &calling_interpreter) {
auto not_found_midpoint = maybe_kill_transaction_ids.end();
// Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed
// TERMINATE and SHOW TRANSACTIONS are mutually exclusive
interpreters.WithLock([&not_found_midpoint, &maybe_kill_transaction_ids, username, hasTransactionManagementPrivilege,
filter_db_acc = &calling_interpreter.db_acc_](const auto &interpreters) {
for (Interpreter *interpreter : interpreters) {
TransactionStatus alive_status = TransactionStatus::ACTIVE;
// if it is just checking kill, commit and abort should wait for the end of the check
// The only way to start checking if the transaction will get killed is if the transaction_status is
// active
if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) {
continue;
}
bool killed = false;
utils::OnScopeExit clean_status([interpreter, &killed]() {
if (killed) {
interpreter->transaction_status_.store(TransactionStatus::TERMINATED, std::memory_order_release);
} else {
interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
}
});
if (interpreter->db_acc_ != *filter_db_acc) continue;
std::optional<uint64_t> intr_trans = interpreter->GetTransactionId();
if (!intr_trans.has_value()) continue;
auto transaction_id = std::to_string(intr_trans.value());
auto it = std::find(maybe_kill_transaction_ids.begin(), not_found_midpoint, transaction_id);
if (it != not_found_midpoint) {
// update the maybe_kill_transaction_ids (partitioning not found + killed)
--not_found_midpoint;
std::iter_swap(it, not_found_midpoint);
if (interpreter->username_ == username || hasTransactionManagementPrivilege) {
killed = true; // Note: this is used by the above `clean_status` (OnScopeExit)
spdlog::warn("Transaction {} successfully killed", transaction_id);
} else {
spdlog::warn("Not enough rights to kill the transaction");
}
}
}
});
std::vector<std::vector<TypedValue>> results;
for (auto it = maybe_kill_transaction_ids.begin(); it != not_found_midpoint; ++it) {
results.push_back({TypedValue(*it), TypedValue(false)});
spdlog::warn("Transaction {} not found", *it);
}
for (auto it = not_found_midpoint; it != maybe_kill_transaction_ids.end(); ++it) {
results.push_back({TypedValue(*it), TypedValue(true)});
}
return results;
}
} // namespace memgraph::query

View File

@ -0,0 +1,85 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <optional>
#include <string>
#include <unordered_set>
#include <vector>
#include "query/config.hpp"
#include "query/cypher_query_interpreter.hpp"
#include "query/typed_value.hpp"
#include "utils/gatekeeper.hpp"
#include "utils/skip_list.hpp"
#include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
class DbmsHandler;
#else
class Database;
#endif
} // namespace memgraph::dbms
namespace memgraph::query {
class AuthQueryHandler;
class AuthChecker;
class Interpreter;
/**
* Holds data shared between multiple `Interpreter` instances (which might be
* running concurrently).
*
*/
struct InterpreterContext {
#ifdef MG_ENTERPRISE
InterpreterContext(InterpreterConfig interpreter_config, memgraph::dbms::DbmsHandler *db_handler,
AuthQueryHandler *ah = nullptr, AuthChecker *ac = nullptr);
#else
InterpreterContext(InterpreterConfig interpreter_config,
memgraph::utils::Gatekeeper<memgraph::dbms::Database> *db_gatekeeper,
query::AuthQueryHandler *ah = nullptr, query::AuthChecker *ac = nullptr);
#endif
#ifdef MG_ENTERPRISE
memgraph::dbms::DbmsHandler *db_handler;
#else
memgraph::utils::Gatekeeper<memgraph::dbms::Database> *db_gatekeeper;
#endif
// Internal
const InterpreterConfig config;
std::atomic<bool> is_shutting_down{false}; // TODO: Do we even need this, since there is a global one also
memgraph::utils::SkipList<QueryCacheEntry> ast_cache;
// GLOBAL
AuthQueryHandler *auth;
AuthChecker *auth_checker;
// Used to check active transactions
// TODO: Have a way to read the current database
memgraph::utils::Synchronized<std::unordered_set<Interpreter *>, memgraph::utils::SpinLock> interpreters;
/// Function that is used to tell all active interpreters that they should stop
/// their ongoing execution.
void Shutdown() { is_shutting_down.store(true, std::memory_order_release); }
std::vector<std::vector<TypedValue>> KillTransactions(std::vector<std::string> maybe_kill_transaction_ids,
const std::optional<std::string> &username,
bool hasTransactionManagementPrivilege,
Interpreter &calling_interpreter);
};
} // namespace memgraph::query

View File

@ -12,6 +12,7 @@
#include "query/procedure/module.hpp"
#include <filesystem>
#include <fstream>
#include <optional>
extern "C" {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -51,8 +51,8 @@ void KafkaStream::Stop() { consumer_->Stop(); }
bool KafkaStream::IsRunning() const { return consumer_->IsRunning(); }
void KafkaStream::Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit,
const ConsumerFunction<integrations::kafka::Message> &consumer_function) const {
consumer_->Check(timeout, batch_limit, consumer_function);
ConsumerFunction<integrations::kafka::Message> consumer_function) const {
consumer_->Check(timeout, batch_limit, std::move(consumer_function));
}
utils::BasicResult<std::string> KafkaStream::SetStreamOffset(const int64_t offset) {
@ -115,8 +115,8 @@ void PulsarStream::StartWithLimit(uint64_t batch_limit, std::optional<std::chron
void PulsarStream::Stop() { consumer_->Stop(); }
bool PulsarStream::IsRunning() const { return consumer_->IsRunning(); }
void PulsarStream::Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit,
const ConsumerFunction<Message> &consumer_function) const {
consumer_->Check(timeout, batch_limit, consumer_function);
ConsumerFunction<Message> consumer_function) const {
consumer_->Check(timeout, batch_limit, std::move(consumer_function));
}
namespace {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -41,7 +41,7 @@ struct KafkaStream {
bool IsRunning() const;
void Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit,
const ConsumerFunction<Message> &consumer_function) const;
ConsumerFunction<Message> consumer_function) const;
utils::BasicResult<std::string> SetStreamOffset(int64_t offset);
@ -77,7 +77,7 @@ struct PulsarStream {
bool IsRunning() const;
void Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit,
const ConsumerFunction<Message> &consumer_function) const;
ConsumerFunction<Message> consumer_function) const;
private:
using Consumer = integrations::pulsar::Consumer;

View File

@ -18,6 +18,8 @@
#include <spdlog/spdlog.h>
#include <json/json.hpp>
#include "dbms/database.hpp"
#include "dbms/dbms_handler.hpp"
#include "integrations/constants.hpp"
#include "mg_procedure.h"
#include "query/db_accessor.hpp"
@ -161,10 +163,7 @@ void from_json(const nlohmann::json &data, StreamStatus<TStream> &status) {
from_json(data, status.info);
}
Streams::Streams(InterpreterContext *interpreter_context, std::filesystem::path directory)
: interpreter_context_(interpreter_context), storage_(std::move(directory)) {
RegisterProcedures();
}
Streams::Streams(std::filesystem::path directory) : storage_(std::move(directory)) { RegisterProcedures(); }
void Streams::RegisterProcedures() {
RegisterKafkaProcedures();
@ -448,11 +447,12 @@ void Streams::RegisterPulsarProcedures() {
}
}
template <Stream TStream>
template <Stream TStream, typename TDbAccess>
void Streams::Create(const std::string &stream_name, typename TStream::StreamInfo info,
std::optional<std::string> owner) {
std::optional<std::string> owner, TDbAccess db_acc, InterpreterContext *ic) {
auto locked_streams = streams_.Lock();
auto it = CreateConsumer<TStream>(*locked_streams, stream_name, std::move(info), std::move(owner));
auto it = CreateConsumer<TStream, TDbAccess>(*locked_streams, stream_name, std::move(info), std::move(owner),
std::move(db_acc), ic);
try {
std::visit(
@ -467,30 +467,38 @@ void Streams::Create(const std::string &stream_name, typename TStream::StreamInf
}
}
template void Streams::Create<KafkaStream>(const std::string &stream_name, KafkaStream::StreamInfo info,
std::optional<std::string> owner);
template void Streams::Create<PulsarStream>(const std::string &stream_name, PulsarStream::StreamInfo info,
std::optional<std::string> owner);
template void Streams::Create<KafkaStream, dbms::DatabaseAccess>(const std::string &stream_name,
KafkaStream::StreamInfo info,
std::optional<std::string> owner,
dbms::DatabaseAccess db, InterpreterContext *ic);
template void Streams::Create<PulsarStream, dbms::DatabaseAccess>(const std::string &stream_name,
PulsarStream::StreamInfo info,
std::optional<std::string> owner,
dbms::DatabaseAccess db, InterpreterContext *ic);
template <Stream TStream>
template <Stream TStream, typename TDbAccess>
Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std::string &stream_name,
typename TStream::StreamInfo stream_info,
std::optional<std::string> owner) {
std::optional<std::string> owner, TDbAccess db_acc,
InterpreterContext *interpreter_context) {
if (map.contains(stream_name)) {
throw StreamsException{"Stream already exists with name '{}'", stream_name};
}
auto *memory_resource = utils::NewDeleteResource();
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name,
auto consumer_function = [interpreter_context, memory_resource, stream_name,
transformation_name = stream_info.common_info.transformation_name, owner = owner,
interpreter = std::make_shared<Interpreter>(interpreter_context_),
interpreter = std::make_shared<Interpreter>(interpreter_context, std::move(db_acc)),
result = mgp_result{nullptr, memory_resource},
total_retries = interpreter_context_->config.stream_transaction_conflict_retries,
retry_interval = interpreter_context_->config.stream_transaction_retry_interval](
total_retries = interpreter_context->config.stream_transaction_conflict_retries,
retry_interval = interpreter_context->config.stream_transaction_retry_interval](
const std::vector<typename TStream::Message> &messages) mutable {
auto accessor = interpreter_context->db->Access();
// register new interpreter into interpreter_context_
#ifdef MG_ENTERPRISE
interpreter->OnChangeCB([](auto) { return false; }); // Disable database change
#endif
auto accessor = interpreter->db_acc_->get()->Access();
// register new interpreter into interpreter_context
interpreter_context->interpreters->insert(interpreter.get());
utils::OnScopeExit interpreter_cleanup{
[interpreter_context, interpreter]() { interpreter_context->interpreters->erase(interpreter.get()); }};
@ -552,7 +560,8 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
return insert_result.first;
}
void Streams::RestoreStreams() {
template <typename TDbAccess>
void Streams::RestoreStreams(TDbAccess db, InterpreterContext *ic) {
spdlog::info("Loading streams...");
auto locked_streams_map = streams_.Lock();
MG_ASSERT(locked_streams_map->empty(), "Cannot restore streams when some streams already exist!");
@ -563,8 +572,8 @@ void Streams::RestoreStreams() {
return fmt::format("Failed to load stream '{}', because: {} caused by {}", stream_name, message, nested_message);
};
const auto create_consumer = [&, &stream_name = stream_name, this]<typename T>(StreamStatus<T> status,
auto &&stream_json_data) {
const auto create_consumer = [&, &stream_name = stream_name]<typename T>(StreamStatus<T> status,
auto &&stream_json_data) {
try {
stream_json_data.get_to(status);
} catch (const nlohmann::json::type_error &exception) {
@ -577,7 +586,8 @@ void Streams::RestoreStreams() {
MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name);
try {
auto it = CreateConsumer<T>(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner));
auto it = CreateConsumer<T>(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner),
db, ic);
if (status.is_running) {
std::visit(
[&](const auto &stream_data) {
@ -613,6 +623,8 @@ void Streams::RestoreStreams() {
}
}
template void Streams::RestoreStreams<dbms::DatabaseAccess>(dbms::DatabaseAccess db, InterpreterContext *ic);
void Streams::Drop(const std::string &stream_name) {
auto locked_streams = streams_.Lock();
@ -722,7 +734,9 @@ std::vector<StreamStatus<>> Streams::GetStreamInfo() const {
return result;
}
TransformationResult Streams::Check(const std::string &stream_name, std::optional<std::chrono::milliseconds> timeout,
template <typename TDbAccess>
TransformationResult Streams::Check(const std::string &stream_name, TDbAccess db_acc,
std::optional<std::chrono::milliseconds> timeout,
std::optional<uint64_t> batch_limit) const {
std::optional locked_streams{streams_.ReadLock()};
auto it = GetStream(**locked_streams, stream_name);
@ -739,10 +753,9 @@ TransformationResult Streams::Check(const std::string &stream_name, std::optiona
mgp_result result{nullptr, memory_resource};
TransformationResult test_result;
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, &stream_name,
&transformation_name = transformation_name, &result,
&test_result]<typename T>(const std::vector<T> &messages) mutable {
auto accessor = interpreter_context->db->Access();
auto consumer_function = [&db_acc, memory_resource, &stream_name, &transformation_name = transformation_name,
&result, &test_result]<typename T>(const std::vector<T> &messages) mutable {
auto accessor = db_acc->Access();
CallCustomTransformation(transformation_name, messages, result, *accessor, *memory_resource, stream_name);
auto result_row = std::vector<TypedValue>();
@ -774,4 +787,9 @@ TransformationResult Streams::Check(const std::string &stream_name, std::optiona
it->second);
}
template TransformationResult Streams::Check<dbms::DatabaseAccess>(const std::string &stream_name,
dbms::DatabaseAccess db_acc,
std::optional<std::chrono::milliseconds> timeout,
std::optional<uint64_t> batch_limit) const;
} // namespace memgraph::query::stream

View File

@ -81,13 +81,14 @@ class Streams final {
///
/// @param interpreter_context context to use to run the result of transformations
/// @param directory a directory path to store the persisted streams metadata
Streams(InterpreterContext *interpreter_context, std::filesystem::path directory);
explicit Streams(std::filesystem::path directory);
/// Restores the streams from the persisted metadata.
/// The restoration is done in a best effort manner, therefore no exception is thrown on failure, but the error is
/// logged. If a stream was running previously, then after restoration it will be started.
/// This function should only be called when there are no existing streams.
void RestoreStreams();
template <typename TDbAccess>
void RestoreStreams(TDbAccess db, InterpreterContext *interpreter_context);
/// Creates a new import stream.
/// The create implies connecting to the server to get metadata necessary to initialize the stream. This
@ -97,8 +98,9 @@ class Streams final {
/// @param stream_info the necessary informations needed to create the Kafka consumer and transform the messages
///
/// @throws StreamsException if the stream with the same name exists or if the creation of Kafka consumer fails
template <Stream TStream>
void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional<std::string> owner);
template <Stream TStream, typename TDbAccess>
void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional<std::string> owner,
TDbAccess db, InterpreterContext *interpreter_context);
/// Deletes an existing stream and all the data that was persisted.
///
@ -161,7 +163,8 @@ class Streams final {
/// @throws StreamsException if the stream doesn't exist
/// @throws ConsumerRunningException if the consumer is already running
/// @throws ConsumerCheckFailedException if the transformation function throws any std::exception during processing
TransformationResult Check(const std::string &stream_name,
template <typename TDbAccess>
TransformationResult Check(const std::string &stream_name, TDbAccess db,
std::optional<std::chrono::milliseconds> timeout = std::nullopt,
std::optional<uint64_t> batch_limit = std::nullopt) const;
@ -180,9 +183,10 @@ class Streams final {
using StreamsMap = std::unordered_map<std::string, StreamDataVariant>;
using SynchronizedStreamsMap = utils::Synchronized<StreamsMap, utils::WritePrioritizedRWLock>;
template <Stream TStream>
template <Stream TStream, typename TDbAccess>
StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name,
typename TStream::StreamInfo stream_info, std::optional<std::string> owner);
typename TStream::StreamInfo stream_info, std::optional<std::string> owner,
TDbAccess db, InterpreterContext *interpreter_context);
template <Stream TStream>
void Persist(StreamStatus<TStream> &&status) {
@ -196,7 +200,6 @@ class Streams final {
void RegisterKafkaProcedures();
void RegisterPulsarProcedures();
InterpreterContext *interpreter_context_;
kvstore::KVStore storage_;
SynchronizedStreamsMap streams_;

View File

@ -78,6 +78,7 @@ struct Config {
} disk;
std::string name;
bool force_on_disk{false};
};
static inline void UpdatePaths(Config &config, const std::filesystem::path &storage_dir) {

213
src/utils/gatekeeper.hpp Normal file
View File

@ -0,0 +1,213 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cassert>
#include <condition_variable>
#include <cstdint>
#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <optional>
#include <type_traits>
#include <utility>
#include <vector>
namespace memgraph::utils {
struct run_t {};
struct not_run_t {};
template <typename Ret>
struct EvalResult;
template <>
struct EvalResult<void> {
template <typename Func, typename T>
EvalResult(run_t /* marker */, Func &&func, T &arg) : was_run{true} {
std::invoke(std::forward<Func>(func), arg);
}
EvalResult(not_run_t /* marker */) : was_run{false} {}
~EvalResult() = default;
EvalResult(EvalResult const &) = delete;
EvalResult(EvalResult &&) = delete;
EvalResult &operator=(EvalResult const &) = delete;
EvalResult &operator=(EvalResult &&) = delete;
explicit operator bool() const { return was_run; }
private:
bool was_run;
};
template <typename Ret>
struct EvalResult {
template <typename Func, typename T>
EvalResult(run_t /* marker */, Func &&func, T &arg) : return_result{std::invoke(std::forward<Func>(func), arg)} {}
EvalResult(not_run_t /* marker */) {}
~EvalResult() = default;
EvalResult(EvalResult const &) = delete;
EvalResult(EvalResult &&) = delete;
EvalResult &operator=(EvalResult const &) = delete;
EvalResult &operator=(EvalResult &&) = delete;
explicit operator bool() const { return return_result.has_value(); }
constexpr const Ret &value() const & { return return_result.value(); }
constexpr Ret &value() & { return return_result.value(); }
constexpr Ret &&value() && { return return_result.value(); }
constexpr const Ret &&value() const && { return return_result.value(); }
private:
std::optional<Ret> return_result = std::nullopt;
};
template <typename Func, typename T>
EvalResult(run_t, Func &&, T &) -> EvalResult<std::invoke_result_t<Func, T &>>;
template <typename T>
struct Gatekeeper {
template <typename... Args>
explicit Gatekeeper(Args &&...args) : value_{std::forward<Args>(args)...} {}
Gatekeeper(Gatekeeper const &) = delete;
Gatekeeper(Gatekeeper &&) noexcept = delete;
Gatekeeper &operator=(Gatekeeper const &) = delete;
Gatekeeper &operator=(Gatekeeper &&) = delete;
struct Accessor {
friend Gatekeeper;
private:
explicit Accessor(Gatekeeper *owner) : owner_{owner} { ++owner_->count_; }
public:
Accessor(Accessor const &other) : owner_{other.owner_} {
if (owner_) {
auto guard = std::unique_lock{owner_->mutex_};
++owner_->count_;
}
};
Accessor(Accessor &&other) noexcept : owner_{std::exchange(other.owner_, nullptr)} {};
Accessor &operator=(Accessor const &other) {
// no change assignment
if (owner_ == other.owner_) {
return *this;
}
// gain ownership
if (other.owner_) {
auto guard = std::unique_lock{other.owner_->mutex_};
++other.owner_->count_;
}
// reliquish ownership
if (owner_) {
auto guard = std::unique_lock{owner_->mutex_};
--owner_->count_;
}
// correct owner
owner_ = other.owner_;
return *this;
};
Accessor &operator=(Accessor &&other) noexcept {
// self assignment
if (&other == this) return *this;
// reliquish ownership
if (owner_) {
auto guard = std::unique_lock{owner_->mutex_};
--owner_->count_;
}
// correct owners
owner_ = std::exchange(other.owner_, nullptr);
return *this;
}
~Accessor() { reset(); }
auto get() -> T * { return std::addressof(*owner_->value_); }
auto get() const -> const T * { return std::addressof(*owner_->value_); }
T *operator->() { return std::addressof(*owner_->value_); }
const T *operator->() const { return std::addressof(*owner_->value_); }
template <typename Func>
[[nodiscard]] auto try_exclusively(Func &&func) -> EvalResult<std::invoke_result_t<Func, T &>> {
// Prevent new access
auto guard = std::unique_lock{owner_->mutex_};
// Only invoke if we have exclusive access
if (owner_->count_ != 1) {
return {not_run_t{}};
}
// Invoke and hold result in wrapper type
return {run_t{}, std::forward<Func>(func), *owner_->value_};
}
// Completely invalidated the accessor if return true
[[nodiscard]] bool try_delete(std::chrono::milliseconds timeout = std::chrono::milliseconds(100)) {
// Prevent new access
auto guard = std::unique_lock{owner_->mutex_};
if (!owner_->cv_.wait_for(guard, timeout, [this] { return owner_->count_ == 1; })) {
return false;
}
// Delete value
owner_->value_ = std::nullopt;
return true;
}
explicit operator bool() const { return owner_ != nullptr; }
void reset() {
if (owner_) {
{
auto guard = std::unique_lock{owner_->mutex_};
--owner_->count_;
}
owner_->cv_.notify_all();
}
owner_ = nullptr;
}
friend bool operator==(Accessor const &lhs, Accessor const &rhs) { return lhs.owner_ == rhs.owner_; }
private:
Gatekeeper *owner_ = nullptr;
};
std::optional<Accessor> access() {
auto guard = std::unique_lock{mutex_};
if (value_) {
return Accessor{this};
}
return std::nullopt;
}
~Gatekeeper() {
// wait for count to drain to 0
auto lock = std::unique_lock{mutex_};
cv_.wait(lock, [this] { return count_ == 0; });
}
private:
std::optional<T> value_;
uint64_t count_ = 0;
std::mutex mutex_; // TODO change to something cheaper?
std::condition_variable cv_;
};
} // namespace memgraph::utils

View File

@ -1,189 +0,0 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <condition_variable>
#include <functional>
#include <iostream>
#include <memory>
#include "utils/exceptions.hpp"
namespace memgraph::utils {
/**
* @brief
*
* @tparam TContext
* @tparam TConfig
*/
template <typename TContext, typename TConfig = void>
struct SyncPtr {
/**
* @brief Construct a new synched pointer.
*
* @tparam TArgs variable templates used by the TContext constructor
* @param config Additional metadata associated with context
* @param args Arguments to pass to TContext constructor
*/
template <typename... TArgs>
explicit SyncPtr(TConfig config, TArgs &&...args)
: timeout_{1000}, config_{config}, ptr_{new TContext(std::forward<TArgs>(args)...), [this](TContext *ptr) {
this->OnDelete(ptr);
}} {}
~SyncPtr() = default;
SyncPtr(const SyncPtr &) = delete;
SyncPtr &operator=(const SyncPtr &) = delete;
SyncPtr(SyncPtr &&) noexcept = delete;
SyncPtr &operator=(SyncPtr &&) noexcept = delete;
/**
* @brief Destroy the synched pointer and wait for all copies to get destroyed.
*
*/
void DestroyAndSync() {
ptr_.reset();
SyncOnDelete();
}
/**
* @brief Get (copy) the underlying shared pointer.
*
* @return std::shared_ptr<TContext>
*/
std::shared_ptr<TContext> get() { return ptr_; }
std::shared_ptr<const TContext> get() const { return ptr_; }
/**
* @brief Return saved configuration (metadata)
*
* @return TConfig
*/
TConfig config() { return config_; }
const TConfig &config() const { return config_; }
void timeout(const std::chrono::milliseconds to) { timeout_ = to; }
std::chrono::milliseconds timeout() const { return timeout_; }
private:
/**
* @brief Block until OnDelete gets called.
*
*/
void SyncOnDelete() {
std::unique_lock<std::mutex> lock(in_use_mtx_);
if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) {
throw utils::BasicException("Syncronization timeout!");
}
}
/**
* @brief Custom destructor used to sync the shared_ptr release.
*
* @param p Pointer to the undelying object.
*/
void OnDelete(TContext *p) {
delete p;
{
std::lock_guard<std::mutex> lock(in_use_mtx_);
in_use_ = false;
}
in_use_cv_.notify_all();
}
bool in_use_{true}; //!< Flag used to signal sync
mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync
mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync
std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms
TConfig config_; //!< Additional metadata associated with the context
std::shared_ptr<TContext> ptr_; //!< Pointer being synced
};
template <typename TContext>
class SyncPtr<TContext, void> {
public:
/**
* @brief Construct a new synched pointer.
*
* @tparam TArgs variable templates used by the TContext constructor
* @param config Additional metadata associated with context
* @param args Arguments to pass to TContext constructor
*/
template <typename... TArgs>
explicit SyncPtr(TArgs &&...args)
: timeout_{1000}, ptr_{new TContext(std::forward<TArgs>(args)...), [this](TContext *ptr) {
this->OnDelete(ptr);
}} {}
~SyncPtr() = default;
SyncPtr(const SyncPtr &) = delete;
SyncPtr &operator=(const SyncPtr &) = delete;
SyncPtr(SyncPtr &&) noexcept = delete;
SyncPtr &operator=(SyncPtr &&) noexcept = delete;
/**
* @brief Destroy the synched pointer and wait for all copies to get destroyed.
*
*/
void DestroyAndSync() {
ptr_.reset();
SyncOnDelete();
}
/**
* @brief Get (copy) the underlying shared pointer.
*
* @return std::shared_ptr<TContext>
*/
std::shared_ptr<TContext> get() { return ptr_; }
std::shared_ptr<const TContext> get() const { return ptr_; }
void timeout(const std::chrono::milliseconds to) { timeout_ = to; }
std::chrono::milliseconds timeout() const { return timeout_; }
private:
/**
* @brief Block until OnDelete gets called.
*
*/
void SyncOnDelete() {
std::unique_lock<std::mutex> lock(in_use_mtx_);
if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) {
throw utils::BasicException("Syncronization timeout!");
}
}
/**
* @brief Custom destructor used to sync the shared_ptr release.
*
* @param p Pointer to the undelying object.
*/
void OnDelete(TContext *p) {
delete p;
{
std::lock_guard<std::mutex> lock(in_use_mtx_);
in_use_ = false;
}
in_use_cv_.notify_all();
}
bool in_use_{true}; //!< Flag used to signal sync
mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync
mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync
std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms
std::shared_ptr<TContext> ptr_; //!< Pointer being synced
};
} // namespace memgraph::utils

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -20,12 +20,14 @@ extern "C" {
namespace memgraph::utils {
uint64_t ReadTSC() { return rdtsc(); }
std::optional<double> GetTSCFrequency() {
bool IsAvailableTSC() {
// init is only needed for fetching frequency
static auto result = std::invoke([] { return rdtsc_init(); });
return result == 0 ? std::optional{rdtsc_get_tsc_hz()} : std::nullopt;
static bool available = [] { return rdtsc_init() == 0; }(); // iile
return available;
}
std::optional<double> GetTSCFrequency() { return IsAvailableTSC() ? std::optional{rdtsc_get_tsc_hz()} : std::nullopt; }
TSCTimer::TSCTimer(std::optional<double> frequency) : frequency_(frequency) {
if (!frequency_) return;
start_value_ = utils::ReadTSC();

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,6 +18,8 @@ namespace memgraph::utils {
// TSC stands for Time-Stamp Counter
bool IsAvailableTSC();
uint64_t ReadTSC();
std::optional<double> GetTSCFrequency();

View File

@ -10,28 +10,35 @@
// licenses/APL.txt.
#include <benchmark/benchmark.h>
#include <memory>
#include "communication/result_stream_faker.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/inmemory/storage.hpp"
#include "storage/v2/isolation_level.hpp"
#include "utils/logging.hpp"
class ExpansionBenchFixture : public benchmark::Fixture {
protected:
std::optional<memgraph::query::InterpreterContext> interpreter_context;
std::optional<memgraph::query::Interpreter> interpreter;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "expansion-benchmark"};
std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk{memgraph::storage::Config{
.durability.storage_directory = data_directory, .disk.main_storage_directory = data_directory / "disk"}};
void SetUp(const benchmark::State &state) override {
interpreter_context.emplace(memgraph::storage::Config{}, memgraph::query::InterpreterConfig{}, data_directory);
auto *db = interpreter_context->db.get();
auto db_acc_opt = db_gk->access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr);
auto label = db->NameToLabel("Starting");
auto label = db_acc->storage()->NameToLabel("Starting");
{
auto dba = db->Access();
auto dba = db_acc->Access();
for (int i = 0; i < state.range(0); i++) dba->CreateVertex();
// the fixed part is one vertex expanding to 1000 others
@ -45,14 +52,15 @@ class ExpansionBenchFixture : public benchmark::Fixture {
MG_ASSERT(!dba->Commit().HasError());
}
MG_ASSERT(!db->CreateIndex(label).HasError());
MG_ASSERT(!db_acc->storage()->CreateIndex(label).HasError());
interpreter.emplace(&*interpreter_context);
interpreter.emplace(&*interpreter_context, std::move(db_acc));
}
void TearDown(const benchmark::State &) override {
interpreter = std::nullopt;
interpreter_context = std::nullopt;
db_gk.reset();
std::filesystem::remove_all(data_directory);
}
};
@ -61,7 +69,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Match)(benchmark::State &state) {
auto query = "MATCH (s:Starting) return s";
while (state.KeepRunning()) {
ResultStreamFaker results(interpreter_context->db.get());
ResultStreamFaker results(interpreter->db_acc_->get()->storage());
interpreter->Prepare(query, {}, nullptr);
interpreter->PullAll(&results);
}
@ -76,7 +84,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Expand)(benchmark::State &state) {
auto query = "MATCH (s:Starting) WITH s MATCH (s)--(d) RETURN count(d)";
while (state.KeepRunning()) {
ResultStreamFaker results(interpreter_context->db.get());
ResultStreamFaker results(interpreter->db_acc_->get()->storage());
interpreter->Prepare(query, {}, nullptr);
interpreter->PullAll(&results);
}

View File

@ -13,6 +13,7 @@
#include "license/license.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "storage/v2/config.hpp"
#include "storage/v2/inmemory/storage.hpp"
#include "storage/v2/isolation_level.hpp"
@ -31,11 +32,15 @@ int main(int argc, char *argv[]) {
memgraph::utils::OnScopeExit([&data_directory] { std::filesystem::remove_all(data_directory); });
memgraph::license::global_license_checker.EnableTesting();
memgraph::query::InterpreterContext interpreter_context{memgraph::storage::Config{},
memgraph::query::InterpreterConfig{}, data_directory};
memgraph::query::Interpreter interpreter{&interpreter_context};
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk(memgraph::storage::Config{
.durability.storage_directory = data_directory, .disk.main_storage_directory = data_directory / "disk"});
auto db_acc_opt = db_gk.access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr);
memgraph::query::Interpreter interpreter{&interpreter_context, db_acc};
ResultStreamFaker stream(interpreter_context.db.get());
ResultStreamFaker stream(db_acc->storage());
auto [header, _1, qid, _2] = interpreter.Prepare(argv[1], {}, nullptr);
stream.Header(header);
auto summary = interpreter.PullAll(&stream);

View File

@ -257,9 +257,6 @@ target_link_libraries(${test_prefix}utils_signals mg-utils)
add_unit_test(utils_string.cpp)
target_link_libraries(${test_prefix}utils_string mg-utils)
add_unit_test(utils_sync_ptr.cpp)
target_link_libraries(${test_prefix}utils_sync_ptr)
add_unit_test(utils_synchronized.cpp)
target_link_libraries(${test_prefix}utils_synchronized mg-utils)
@ -404,17 +401,9 @@ target_link_libraries(${test_prefix}monitoring mg-communication Boost::headers)
# Test multi-database
if(MG_ENTERPRISE)
# add_unit_test(dbms_storage.cpp)
# target_link_libraries(${test_prefix}dbms_storage mg-storage-v2 mg-query mg-glue)
add_unit_test(dbms_database.cpp)
target_link_libraries(${test_prefix}dbms_database mg-storage-v2 mg-query mg-glue mg-dbms)
add_unit_test(dbms_interp.cpp)
target_link_libraries(${test_prefix}dbms_interp mg-query)
endif()
# add_unit_test(dbms_auth.cpp)
# target_link_libraries(${test_prefix}dbms_auth mg-glue)
if(MG_ENTERPRISE)
add_unit_test_with_custom_main(dbms_sc_handler.cpp)
target_link_libraries(${test_prefix}dbms_sc_handler mg-query mg-audit mg-glue)
add_unit_test_with_custom_main(dbms_handler.cpp)
target_link_libraries(${test_prefix}dbms_handler mg-query mg-auth mg-glue mg-dbms)
endif()

View File

@ -119,13 +119,6 @@ class TestSession final : public Session<TestInputStream, TestOutputStream> {
void Configure(const std::map<std::string, memgraph::communication::bolt::Value> &) override {}
std::string GetDatabaseName() const override { return ""; }
#ifdef MG_ENTERPRISE
memgraph::dbms::SetForResult OnChange(const std::string &db_name) override {
return memgraph::dbms::SetForResult::SUCCESS;
}
bool OnDelete(const std::string &) override { return true; }
#endif
void TestHook_ShouldAbort() { should_abort_ = true; }
private:

View File

@ -0,0 +1,224 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <filesystem>
#include "dbms/database_handler.hpp"
#include "dbms/global.hpp"
#include "license/license.hpp"
#include "query_plan_common.hpp"
#include "storage/v2/view.hpp"
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_database"};
memgraph::storage::Config default_conf(std::string name = "") {
return {.durability = {.storage_directory = storage_directory / name,
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / name / "disk"}};
}
class DBMS_Database : public ::testing::Test {
protected:
void SetUp() override { Clear(); }
void TearDown() override { Clear(); }
private:
void Clear() {
if (std::filesystem::exists(storage_directory)) {
std::filesystem::remove_all(storage_directory);
}
}
};
#ifdef MG_ENTERPRISE
TEST_F(DBMS_Database, New) {
memgraph::dbms::DatabaseHandler db_handler;
{ ASSERT_FALSE(db_handler.GetConfig("db1")); }
{ // With custom config
memgraph::storage::Config db_config{
.durability = {.storage_directory = storage_directory / "db2",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
auto db2 = db_handler.New("db2", db_config);
ASSERT_TRUE(db2.HasValue() && db2.GetValue());
ASSERT_TRUE(std::filesystem::exists(storage_directory / "db2"));
}
{
// With default config
auto db3 = db_handler.New("db3", default_conf("db3"));
ASSERT_TRUE(db3.HasValue() && db3.GetValue());
ASSERT_TRUE(std::filesystem::exists(storage_directory / "db3"));
auto db4 = db_handler.New("db4", default_conf("four"));
ASSERT_TRUE(db4.HasValue() && db4.GetValue());
ASSERT_TRUE(std::filesystem::exists(storage_directory / "four"));
auto db5 = db_handler.New("db5", default_conf("db3"));
ASSERT_TRUE(db5.HasError() && db5.GetError() == memgraph::dbms::NewError::EXISTS);
}
auto all = db_handler.All();
std::sort(all.begin(), all.end());
ASSERT_EQ(all.size(), 3);
ASSERT_EQ(all[0], "db2");
ASSERT_EQ(all[1], "db3");
ASSERT_EQ(all[2], "db4");
}
TEST_F(DBMS_Database, Get) {
memgraph::dbms::DatabaseHandler db_handler;
auto db1 = db_handler.New("db1", default_conf("db1"));
auto db2 = db_handler.New("db2", default_conf("db2"));
auto db3 = db_handler.New("db3", default_conf("db3"));
ASSERT_TRUE(db1.HasValue());
ASSERT_TRUE(db2.HasValue());
ASSERT_TRUE(db3.HasValue());
auto get_db1 = db_handler.Get("db1");
auto get_db2 = db_handler.Get("db2");
auto get_db3 = db_handler.Get("db3");
ASSERT_TRUE(get_db1 && *get_db1 == db1.GetValue());
ASSERT_TRUE(get_db2 && *get_db2 == db2.GetValue());
ASSERT_TRUE(get_db3 && *get_db3 == db3.GetValue());
ASSERT_FALSE(db_handler.Get("db123"));
ASSERT_FALSE(db_handler.Get("db2 "));
ASSERT_FALSE(db_handler.Get(" db3"));
}
TEST_F(DBMS_Database, Delete) {
memgraph::dbms::DatabaseHandler db_handler;
auto db1 = db_handler.New("db1", default_conf("db1"));
auto db2 = db_handler.New("db2", default_conf("db2"));
auto db3 = db_handler.New("db3", default_conf("db3"));
ASSERT_TRUE(db1.HasValue());
ASSERT_TRUE(db2.HasValue());
ASSERT_TRUE(db3.HasValue());
{
// Release accessor to storage
db1.GetValue().reset();
// Delete from handler
ASSERT_TRUE(db_handler.Delete("db1"));
ASSERT_FALSE(db_handler.Get("db1"));
auto all = db_handler.All();
std::sort(all.begin(), all.end());
ASSERT_EQ(all.size(), 2);
ASSERT_EQ(all[0], "db2");
ASSERT_EQ(all[1], "db3");
}
{
ASSERT_THROW(db_handler.Delete("db0"), memgraph::utils::BasicException);
ASSERT_THROW(db_handler.Delete("db1"), memgraph::utils::BasicException);
auto all = db_handler.All();
std::sort(all.begin(), all.end());
ASSERT_EQ(all.size(), 2);
ASSERT_EQ(all[0], "db2");
ASSERT_EQ(all[1], "db3");
}
}
TEST_F(DBMS_Database, DeleteAndRecover) {
memgraph::license::global_license_checker.EnableTesting();
memgraph::dbms::DatabaseHandler db_handler;
{
auto db1 = db_handler.New("db1", default_conf("db1"));
auto db2 = db_handler.New("db2", default_conf("db2"));
memgraph::storage::Config conf_w_snap{
.durability = {.storage_directory = storage_directory / "db3",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL,
.snapshot_on_exit = true},
.disk = {.main_storage_directory = storage_directory / "db3" / "disk"}};
auto db3 = db_handler.New("db3", conf_w_snap);
ASSERT_TRUE(db1.HasValue());
ASSERT_TRUE(db2.HasValue());
ASSERT_TRUE(db3.HasValue());
// Add data to graphs
{
auto storage_dba = db1.GetValue()->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
memgraph::query::VertexAccessor v1{dba.InsertVertex()};
memgraph::query::VertexAccessor v2{dba.InsertVertex()};
ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l11")).HasValue());
ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l12")).HasValue());
ASSERT_FALSE(dba.Commit().HasError());
}
{
auto storage_dba = db3.GetValue()->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
memgraph::query::VertexAccessor v1{dba.InsertVertex()};
memgraph::query::VertexAccessor v2{dba.InsertVertex()};
memgraph::query::VertexAccessor v3{dba.InsertVertex()};
ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l31")).HasValue());
ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l32")).HasValue());
ASSERT_TRUE(v3.AddLabel(dba.NameToLabel("l33")).HasValue());
ASSERT_FALSE(dba.Commit().HasError());
}
}
// Delete from handler
ASSERT_TRUE(db_handler.Delete("db1"));
ASSERT_TRUE(db_handler.Delete("db2"));
ASSERT_TRUE(db_handler.Delete("db3"));
{
// Recover graphs (only db3)
auto db1 = db_handler.New("db1", default_conf("db1"));
auto db2 = db_handler.New("db2", default_conf("db2"));
memgraph::storage::Config conf_w_rec{
.durability = {.storage_directory = storage_directory / "db3",
.recover_on_startup = true,
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "db3" / "disk"}};
auto db3 = db_handler.New("db3", conf_w_rec);
// Check content
{
// Empty
auto storage_dba = db1.GetValue()->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
ASSERT_EQ(dba.VerticesCount(), 0);
}
{
// Empty
auto storage_dba = db2.GetValue()->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
ASSERT_EQ(dba.VerticesCount(), 0);
}
{
// Full
auto storage_dba = db3.GetValue()->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
ASSERT_EQ(dba.VerticesCount(), 3);
}
}
}
#endif

193
tests/unit/dbms_handler.cpp Normal file
View File

@ -0,0 +1,193 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/auth_query_handler.hpp"
#ifdef MG_ENTERPRISE
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <filesystem>
#include <system_error>
#include "dbms/constants.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/global.hpp"
#include "glue/auth_checker.hpp"
#include "glue/auth_handler.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
// Global
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler"};
static memgraph::storage::Config storage_conf;
std::unique_ptr<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>> auth;
// Let this be global so we can test it different states throughout
class TestEnvironment : public ::testing::Environment {
public:
static memgraph::dbms::DbmsHandler *get() { return ptr_.get(); }
void SetUp() override {
// Setup config
memgraph::storage::UpdatePaths(storage_conf, storage_directory);
storage_conf.durability.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL;
// Clean storage directory (running multiple parallel test, run only if the first process)
if (std::filesystem::exists(storage_directory)) {
memgraph::utils::OutputFile lock_file_handle_;
lock_file_handle_.Open(storage_directory / ".lock", memgraph::utils::OutputFile::Mode::OVERWRITE_EXISTING);
if (lock_file_handle_.AcquireLock()) {
std::filesystem::remove_all(storage_directory);
}
}
auth =
std::make_unique<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>>(
storage_directory / "auth");
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf, auth.get(), false, true);
}
void TearDown() override {
ptr_.reset();
auth.reset();
}
static std::unique_ptr<memgraph::dbms::DbmsHandler> ptr_;
};
std::unique_ptr<memgraph::dbms::DbmsHandler> TestEnvironment::ptr_ = nullptr;
class DBMS_Handler : public testing::Test {};
using DBMS_HandlerDeath = DBMS_Handler;
TEST(DBMS_Handler, Init) {
// Check that the default db has been created successfully
std::vector<std::string> dirs = {"snapshots", "streams", "triggers", "wal"};
for (const auto &dir : dirs)
ASSERT_TRUE(std::filesystem::exists(storage_directory / dir)) << (storage_directory / dir);
const auto db_path = storage_directory / "databases" / memgraph::dbms::kDefaultDB;
ASSERT_TRUE(std::filesystem::exists(db_path));
for (const auto &dir : dirs) {
std::error_code ec;
const auto test_link = std::filesystem::read_symlink(db_path / dir, ec);
ASSERT_TRUE(!ec) << ec.message();
ASSERT_EQ(test_link, "../../" + dir);
}
}
TEST(DBMS_Handler, New) {
auto &dbms = *TestEnvironment::get();
{
const auto all = dbms.All();
ASSERT_EQ(all.size(), 1);
ASSERT_EQ(all[0], memgraph::dbms::kDefaultDB);
}
{
auto db1 = dbms.New("db1");
ASSERT_TRUE(db1.HasValue());
ASSERT_TRUE(db1.GetValue());
ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "db1"));
ASSERT_TRUE(db1.GetValue()->storage() != nullptr);
ASSERT_TRUE(db1.GetValue()->streams() != nullptr);
ASSERT_TRUE(db1.GetValue()->trigger_store() != nullptr);
ASSERT_TRUE(db1.GetValue()->thread_pool() != nullptr);
const auto all = dbms.All();
ASSERT_EQ(all.size(), 2);
ASSERT_TRUE(std::find(all.begin(), all.end(), memgraph::dbms::kDefaultDB) != all.end());
ASSERT_TRUE(std::find(all.begin(), all.end(), "db1") != all.end());
}
{
// Fail if name exists
auto db2 = dbms.New("db1");
ASSERT_TRUE(db2.HasError() && db2.GetError() == memgraph::dbms::NewError::EXISTS);
}
{
auto db3 = dbms.New("db3");
ASSERT_TRUE(db3.HasValue());
ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "db3"));
ASSERT_TRUE(db3.GetValue()->storage() != nullptr);
ASSERT_TRUE(db3.GetValue()->streams() != nullptr);
ASSERT_TRUE(db3.GetValue()->trigger_store() != nullptr);
ASSERT_TRUE(db3.GetValue()->thread_pool() != nullptr);
const auto all = dbms.All();
ASSERT_EQ(all.size(), 3);
ASSERT_TRUE(std::find(all.begin(), all.end(), "db3") != all.end());
}
}
TEST(DBMS_Handler, Get) {
auto &dbms = *TestEnvironment::get();
auto default_db = dbms.Get(memgraph::dbms::kDefaultDB);
ASSERT_TRUE(default_db);
ASSERT_TRUE(default_db->storage() != nullptr);
ASSERT_TRUE(default_db->streams() != nullptr);
ASSERT_TRUE(default_db->trigger_store() != nullptr);
ASSERT_TRUE(default_db->thread_pool() != nullptr);
ASSERT_ANY_THROW(dbms.Get("non-existent"));
auto db1 = dbms.Get("db1");
ASSERT_TRUE(db1);
ASSERT_TRUE(db1->storage() != nullptr);
ASSERT_TRUE(db1->streams() != nullptr);
ASSERT_TRUE(db1->trigger_store() != nullptr);
ASSERT_TRUE(db1->thread_pool() != nullptr);
auto db3 = dbms.Get("db3");
ASSERT_TRUE(db3);
ASSERT_TRUE(db3->storage() != nullptr);
ASSERT_TRUE(db3->streams() != nullptr);
ASSERT_TRUE(db3->trigger_store() != nullptr);
ASSERT_TRUE(db3->thread_pool() != nullptr);
}
TEST(DBMS_Handler, Delete) {
auto &dbms = *TestEnvironment::get();
auto db1_acc = dbms.Get("db1"); // Holds access to database
{
auto del = dbms.Delete(memgraph::dbms::kDefaultDB);
ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::DEFAULT_DB);
}
{
auto del = dbms.Delete("non-existent");
ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT);
}
{
// db1_acc is using db1
auto del = dbms.Delete("db1");
ASSERT_TRUE(del.HasError());
ASSERT_TRUE(del.GetError() == memgraph::dbms::DeleteError::USING);
}
{
// Reset db1_acc (releases access) so delete will succeed
db1_acc.reset();
ASSERT_FALSE(db1_acc);
auto del = dbms.Delete("db1");
ASSERT_FALSE(del.HasError()) << (int)del.GetError();
auto del2 = dbms.Delete("db1");
ASSERT_TRUE(del2.HasError() && del2.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT);
}
{
auto del = dbms.Delete("db3");
ASSERT_FALSE(del.HasError());
ASSERT_FALSE(std::filesystem::exists(storage_directory / "databases" / "db3"));
}
}
int main(int argc, char *argv[]) {
::testing::InitGoogleTest(&argc, argv);
// gtest takes ownership of the TestEnvironment ptr - we don't delete it.
::testing::AddGlobalTestEnvironment(new TestEnvironment);
return RUN_ALL_TESTS();
}
#endif

View File

@ -1,431 +0,0 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#ifdef MG_ENTERPRISE
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <filesystem>
#include "dbms/global.hpp"
#include "dbms/interp_handler.hpp"
#include "query/auth_checker.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/interpreter.hpp"
class TestAuthHandler : public memgraph::query::AuthQueryHandler {
public:
TestAuthHandler() = default;
bool CreateUser(const std::string & /*username*/, const std::optional<std::string> & /*password*/) override {
return true;
}
bool DropUser(const std::string & /*username*/) override { return true; }
void SetPassword(const std::string & /*username*/, const std::optional<std::string> & /*password*/) override {}
bool RevokeDatabaseFromUser(const std::string & /*db*/, const std::string & /*username*/) override { return true; }
bool GrantDatabaseToUser(const std::string & /*db*/, const std::string & /*username*/) override { return true; }
bool SetMainDatabase(const std::string & /*db*/, const std::string & /*username*/) override { return true; }
std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string & /*user*/) override {
return {};
}
bool CreateRole(const std::string & /*rolename*/) override { return true; }
bool DropRole(const std::string & /*rolename*/) override { return true; }
std::vector<memgraph::query::TypedValue> GetUsernames() override { return {}; }
std::vector<memgraph::query::TypedValue> GetRolenames() override { return {}; }
std::optional<std::string> GetRolenameForUser(const std::string & /*username*/) override { return {}; }
std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string & /*rolename*/) override { return {}; }
void SetRole(const std::string &username, const std::string & /*rolename*/) override {}
void ClearRole(const std::string &username) override {}
std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string & /*user_or_role*/) override {
return {};
}
void GrantPrivilege(
const std::string & /*user_or_role*/, const std::vector<memgraph::query::AuthQuery::Privilege> & /*privileges*/,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
& /*label_privileges*/,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
& /*edge_type_privileges*/) override {}
void DenyPrivilege(const std::string & /*user_or_role*/,
const std::vector<memgraph::query::AuthQuery::Privilege> & /*privileges*/) override {}
void RevokePrivilege(
const std::string & /*user_or_role*/, const std::vector<memgraph::query::AuthQuery::Privilege> & /*privileges*/,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
& /*label_privileges*/,
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
& /*edge_type_privileges*/) override {}
};
class TestAuthChecker : public memgraph::query::AuthChecker {
public:
bool IsUserAuthorized(const std::optional<std::string> & /*username*/,
const std::vector<memgraph::query::AuthQuery::Privilege> & /*privileges*/,
const std::string & /*db*/) const override {
return true;
}
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker(
const std::string & /*username*/, const memgraph::query::DbAccessor * /*db_accessor*/) const override {
return {};
}
void ClearCache() const override {}
};
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_interp"};
memgraph::query::InterpreterConfig default_conf{};
memgraph::storage::Config default_storage_conf(std::string name = "") {
return {.durability = {.storage_directory = storage_directory / name,
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / name / "disk"}};
}
class TestHandler {
} test_handler;
class DBMS_Interp : public ::testing::Test {
protected:
void SetUp() override { Clear(); }
void TearDown() override { Clear(); }
private:
void Clear() {
if (std::filesystem::exists(storage_directory)) {
std::filesystem::remove_all(storage_directory);
}
}
};
TEST_F(DBMS_Interp, New) {
memgraph::dbms::InterpContextHandler<TestHandler> ih;
memgraph::storage::Config db_conf{
.durability = {.storage_directory = storage_directory,
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
TestAuthHandler ah;
TestAuthChecker ac;
{
// Clean initialization
auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac);
ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr);
ASSERT_TRUE(std::filesystem::exists(storage_directory / "triggers"));
ASSERT_TRUE(std::filesystem::exists(storage_directory / "streams"));
ASSERT_NE(ic1.GetValue()->db, nullptr);
ASSERT_EQ(&ic1.GetValue()->sc_handler_, &test_handler);
ASSERT_EQ(ih.GetConfig("ic1")->storage_config.durability.storage_directory, storage_directory);
}
{
memgraph::storage::Config db_conf2{
.durability = {.storage_directory = storage_directory,
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
// Try to override data directory
auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac);
ASSERT_TRUE(ic2.HasError() && ic2.GetError() == memgraph::dbms::NewError::EXISTS);
}
{
memgraph::storage::Config db_conf3{
.durability = {.storage_directory = storage_directory / "ic3",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
// Try to override the name "ic1"
auto ic3 = ih.New("ic1", test_handler, db_conf3, default_conf, ah, ac);
ASSERT_TRUE(ic3.HasError() && ic3.GetError() == memgraph::dbms::NewError::EXISTS);
}
{
// Another clean initialization
memgraph::storage::Config db_conf4{
.durability = {.storage_directory = storage_directory / "ic4",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
auto ic4 = ih.New("ic4", test_handler, db_conf4, default_conf, ah, ac);
ASSERT_TRUE(ic4.HasValue() && ic4.GetValue() != nullptr);
ASSERT_TRUE(std::filesystem::exists(storage_directory / "ic4" / "triggers"));
ASSERT_TRUE(std::filesystem::exists(storage_directory / "ic4" / "streams"));
ASSERT_EQ(&ic4.GetValue()->sc_handler_, &test_handler);
ASSERT_EQ(ih.GetConfig("ic4")->storage_config.durability.storage_directory, storage_directory / "ic4");
}
}
TEST_F(DBMS_Interp, Get) {
memgraph::dbms::InterpContextHandler<TestHandler> ih;
TestAuthHandler ah;
TestAuthChecker ac;
memgraph::storage::Config db_conf{
.durability = {.storage_directory = storage_directory / "ic1",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac);
ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr);
auto ic1_get = ih.Get("ic1");
ASSERT_TRUE(ic1_get && *ic1_get == ic1.GetValue());
memgraph::storage::Config db_conf2{
.durability = {.storage_directory = storage_directory / "ic2",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac);
ASSERT_TRUE(ic2.HasValue() && ic2.GetValue() != nullptr);
auto ic2_get = ih.Get("ic2");
ASSERT_TRUE(ic2_get && *ic2_get == ic2.GetValue());
ASSERT_FALSE(ih.Get("aa"));
ASSERT_FALSE(ih.Get("ic1 "));
ASSERT_FALSE(ih.Get("ic21"));
ASSERT_FALSE(ih.Get(" ic2"));
}
TEST_F(DBMS_Interp, Delete) {
memgraph::dbms::InterpContextHandler<TestHandler> ih;
TestAuthHandler ah;
TestAuthChecker ac;
memgraph::storage::Config db_conf{
.durability = {.storage_directory = storage_directory / "ic1",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
{
auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac);
ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr);
}
memgraph::storage::Config db_conf2{
.durability = {.storage_directory = storage_directory / "ic2",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
{
auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac);
ASSERT_TRUE(ic2.HasValue() && ic2.GetValue() != nullptr);
}
ASSERT_TRUE(ih.Delete("ic1"));
ASSERT_FALSE(ih.Get("ic1"));
ASSERT_FALSE(ih.Delete("ic1"));
ASSERT_FALSE(ih.Delete("ic3"));
}
/**
*
*
*
*
*
* Test storage (previous StorageHandler, now handled via InterpretContext)
*
*
*
*
*
*/
TEST_F(DBMS_Interp, StorageNew) {
memgraph::dbms::InterpContextHandler<TestHandler> ih;
TestAuthHandler ah;
TestAuthChecker ac;
{ ASSERT_FALSE(ih.GetConfig("db1")); }
{
// With custom config
memgraph::storage::Config db_config{
.durability = {.storage_directory = storage_directory / "db2",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "disk"}};
auto db2 = ih.New("db2", test_handler, db_config, default_conf, ah, ac);
ASSERT_TRUE(db2.HasValue() && db2.GetValue() != nullptr);
ASSERT_TRUE(std::filesystem::exists(storage_directory / "db2"));
}
{
// With default config
auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac);
ASSERT_TRUE(db3.HasValue() && db3.GetValue() != nullptr);
ASSERT_TRUE(std::filesystem::exists(storage_directory / "db3"));
auto db4 = ih.New("db4", test_handler, default_storage_conf("four"), default_conf, ah, ac);
ASSERT_TRUE(db4.HasValue() && db4.GetValue() != nullptr);
ASSERT_TRUE(std::filesystem::exists(storage_directory / "four"));
auto db5 = ih.New("db5", test_handler, default_storage_conf("db3"), default_conf, ah, ac);
ASSERT_TRUE(db5.HasError() && db5.GetError() == memgraph::dbms::NewError::EXISTS);
}
auto all = ih.All();
std::sort(all.begin(), all.end());
ASSERT_EQ(all.size(), 3);
ASSERT_EQ(all[0], "db2");
ASSERT_EQ(all[1], "db3");
ASSERT_EQ(all[2], "db4");
}
TEST_F(DBMS_Interp, StorageGet) {
memgraph::dbms::InterpContextHandler<TestHandler> ih;
TestAuthHandler ah;
TestAuthChecker ac;
auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac);
auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac);
auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac);
ASSERT_TRUE(db1.HasValue());
ASSERT_TRUE(db2.HasValue());
ASSERT_TRUE(db3.HasValue());
auto get_db1 = ih.Get("db1");
auto get_db2 = ih.Get("db2");
auto get_db3 = ih.Get("db3");
ASSERT_TRUE(get_db1 && *get_db1 == db1.GetValue());
ASSERT_TRUE(get_db2 && *get_db2 == db2.GetValue());
ASSERT_TRUE(get_db3 && *get_db3 == db3.GetValue());
ASSERT_FALSE(ih.Get("db123"));
ASSERT_FALSE(ih.Get("db2 "));
ASSERT_FALSE(ih.Get(" db3"));
}
TEST_F(DBMS_Interp, StorageDelete) {
memgraph::dbms::InterpContextHandler<TestHandler> ih;
TestAuthHandler ah;
TestAuthChecker ac;
auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac);
auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac);
auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac);
ASSERT_TRUE(db1.HasValue());
ASSERT_TRUE(db2.HasValue());
ASSERT_TRUE(db3.HasValue());
{
// Release pointer to storage
db1.GetValue().reset();
// Delete from handler
ASSERT_TRUE(ih.Delete("db1"));
ASSERT_FALSE(ih.Get("db1"));
auto all = ih.All();
std::sort(all.begin(), all.end());
ASSERT_EQ(all.size(), 2);
ASSERT_EQ(all[0], "db2");
ASSERT_EQ(all[1], "db3");
}
{
ASSERT_FALSE(ih.Delete("db0"));
ASSERT_FALSE(ih.Delete("db1"));
auto all = ih.All();
std::sort(all.begin(), all.end());
ASSERT_EQ(all.size(), 2);
ASSERT_EQ(all[0], "db2");
ASSERT_EQ(all[1], "db3");
}
}
TEST_F(DBMS_Interp, StorageDeleteAndRecover) {
// memgraph::license::global_license_checker.EnableTesting();
memgraph::dbms::InterpContextHandler<TestHandler> ih;
TestAuthHandler ah;
TestAuthChecker ac;
{
auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac);
auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac);
memgraph::storage::Config conf_w_snap{
.durability = {.storage_directory = storage_directory / "db3",
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL,
.snapshot_on_exit = true},
.disk = {.main_storage_directory = storage_directory / "db3" / "disk"}};
auto db3 = ih.New("db3", test_handler, conf_w_snap, default_conf, ah, ac);
ASSERT_TRUE(db1.HasValue());
ASSERT_TRUE(db2.HasValue());
ASSERT_TRUE(db3.HasValue());
// Add data to graphs
{
auto storage_dba = db1.GetValue()->db->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
memgraph::query::VertexAccessor v1{dba.InsertVertex()};
memgraph::query::VertexAccessor v2{dba.InsertVertex()};
ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l11")).HasValue());
ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l12")).HasValue());
ASSERT_FALSE(dba.Commit().HasError());
}
{
auto storage_dba = db3.GetValue()->db->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
memgraph::query::VertexAccessor v1{dba.InsertVertex()};
memgraph::query::VertexAccessor v2{dba.InsertVertex()};
memgraph::query::VertexAccessor v3{dba.InsertVertex()};
ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l31")).HasValue());
ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l32")).HasValue());
ASSERT_TRUE(v3.AddLabel(dba.NameToLabel("l33")).HasValue());
ASSERT_FALSE(dba.Commit().HasError());
}
}
// Delete from handler
ASSERT_TRUE(ih.Delete("db1"));
ASSERT_TRUE(ih.Delete("db2"));
ASSERT_TRUE(ih.Delete("db3"));
{
// Recover graphs (only db3)
auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac);
auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac);
memgraph::storage::Config conf_w_rec{
.durability = {.storage_directory = storage_directory / "db3",
.recover_on_startup = true,
.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL},
.disk = {.main_storage_directory = storage_directory / "db3" / "disk"}};
auto db3 = ih.New("db3", test_handler, conf_w_rec, default_conf, ah, ac);
// Check content
{
// Empty
auto storage_dba = db1.GetValue()->db->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
ASSERT_EQ(dba.VerticesCount(), 0);
}
{
// Empty
auto storage_dba = db2.GetValue()->db->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
ASSERT_EQ(dba.VerticesCount(), 0);
}
{
// Full
auto storage_dba = db3.GetValue()->db->Access();
memgraph::query::DbAccessor dba{storage_dba.get()};
ASSERT_EQ(dba.VerticesCount(), 3);
}
}
}
#endif

View File

@ -1,343 +0,0 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <system_error>
#include "query/interpreter.hpp"
#ifdef MG_ENTERPRISE
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <filesystem>
#include "dbms/constants.hpp"
#include "dbms/global.hpp"
#include "dbms/session_context_handler.hpp"
#include "glue/auth_checker.hpp"
#include "glue/auth_handler.hpp"
#include "query/config.hpp"
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_sc_handler"};
static memgraph::storage::Config storage_conf;
memgraph::query::InterpreterConfig interp_conf;
// Global
memgraph::audit::Log audit_log{storage_directory / "audit", 100, 1000};
class TestInterface : public memgraph::dbms::SessionInterface {
public:
TestInterface(std::string name, auto on_change, auto on_delete) : id_(id++), db_(name) {
on_change_ = on_change;
on_delete_ = on_delete;
}
std::string UUID() const override { return std::to_string(id_); }
std::string GetDatabaseName() const override { return db_; }
memgraph::dbms::SetForResult OnChange(const std::string &name) override { return on_change_(name); }
bool OnDelete(const std::string &name) override { return on_delete_(name); }
static int id;
int id_;
std::string db_;
std::function<memgraph::dbms::SetForResult(const std::string &)> on_change_;
std::function<bool(const std::string &)> on_delete_;
};
int TestInterface::id{0};
// Let this be global so we can test it different states throughout
class TestEnvironment : public ::testing::Environment {
public:
static memgraph::dbms::SessionContextHandler *get() { return ptr_.get(); }
void SetUp() override {
// Setup config
memgraph::storage::UpdatePaths(storage_conf, storage_directory);
storage_conf.durability.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL;
// Clean storage directory (running multiple parallel test, run only if the first process)
if (std::filesystem::exists(storage_directory)) {
memgraph::utils::OutputFile lock_file_handle_;
lock_file_handle_.Open(storage_directory / ".lock", memgraph::utils::OutputFile::Mode::OVERWRITE_EXISTING);
if (lock_file_handle_.AcquireLock()) {
std::filesystem::remove_all(storage_directory);
}
}
ptr_ = std::make_unique<memgraph::dbms::SessionContextHandler>(
audit_log,
memgraph::dbms::SessionContextHandler::Config{
storage_conf, interp_conf,
[](memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
std::unique_ptr<memgraph::query::AuthQueryHandler> &ah,
std::unique_ptr<memgraph::query::AuthChecker> &ac) {
// Glue high level auth implementations to the query side
ah = std::make_unique<memgraph::glue::AuthQueryHandler>(auth, "");
ac = std::make_unique<memgraph::glue::AuthChecker>(auth);
}},
false, true);
}
void TearDown() override { ptr_.reset(); }
static std::unique_ptr<memgraph::dbms::SessionContextHandler> ptr_;
};
std::unique_ptr<memgraph::dbms::SessionContextHandler> TestEnvironment::ptr_ = nullptr;
class DBMS_Handler : public testing::Test {};
using DBMS_HandlerDeath = DBMS_Handler;
TEST(DBMS_Handler, Init) {
// Check that the default db has been created successfully
std::vector<std::string> dirs = {"snapshots", "streams", "triggers", "wal"};
for (const auto &dir : dirs)
ASSERT_TRUE(std::filesystem::exists(storage_directory / dir)) << (storage_directory / dir);
const auto db_path = storage_directory / "databases" / memgraph::dbms::kDefaultDB;
ASSERT_TRUE(std::filesystem::exists(db_path));
for (const auto &dir : dirs) {
std::error_code ec;
const auto test_link = std::filesystem::read_symlink(db_path / dir, ec);
ASSERT_TRUE(!ec) << ec.message();
ASSERT_EQ(test_link, "../../" + dir);
}
}
TEST(DBMS_HandlerDeath, InitSameDir) {
// This will be executed in a clean process (so the singleton will NOT be initalized)
(void)(::testing::GTEST_FLAG(death_test_style) = "threadsafe");
// NOTE: Init test has ran in another process (so holds the lock)
ASSERT_DEATH(
{
memgraph::dbms::SessionContextHandler sch(
audit_log,
{storage_conf, interp_conf,
[](memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
std::unique_ptr<memgraph::query::AuthQueryHandler> &ah,
std::unique_ptr<memgraph::query::AuthChecker> &ac) {
// Glue high level auth implementations to the query side
ah = std::make_unique<memgraph::glue::AuthQueryHandler>(auth, "");
ac = std::make_unique<memgraph::glue::AuthChecker>(auth);
}},
false, true);
},
R"(\b.*\b)");
}
TEST(DBMS_Handler, New) {
auto &sch = *TestEnvironment::get();
{
const auto all = sch.All();
ASSERT_EQ(all.size(), 1);
ASSERT_EQ(all[0], memgraph::dbms::kDefaultDB);
}
{
auto sc1 = sch.New("sc1");
ASSERT_TRUE(sc1.HasValue());
ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "sc1"));
ASSERT_TRUE(sc1.GetValue().interpreter_context->db != nullptr);
ASSERT_TRUE(sc1.GetValue().interpreter_context != nullptr);
ASSERT_TRUE(sc1.GetValue().audit_log != nullptr);
ASSERT_TRUE(sc1.GetValue().auth != nullptr);
const auto all = sch.All();
ASSERT_EQ(all.size(), 2);
ASSERT_TRUE(std::find(all.begin(), all.end(), memgraph::dbms::kDefaultDB) != all.end());
ASSERT_TRUE(std::find(all.begin(), all.end(), "sc1") != all.end());
}
{
// Fail if name exists
auto sc2 = sch.New("sc1");
ASSERT_TRUE(sc2.HasError() && sc2.GetError() == memgraph::dbms::NewError::EXISTS);
}
{
auto sc3 = sch.New("sc3");
ASSERT_TRUE(sc3.HasValue());
ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "sc3"));
ASSERT_TRUE(sc3.GetValue().interpreter_context->db != nullptr);
ASSERT_TRUE(sc3.GetValue().interpreter_context != nullptr);
ASSERT_TRUE(sc3.GetValue().audit_log != nullptr);
ASSERT_TRUE(sc3.GetValue().auth != nullptr);
const auto all = sch.All();
ASSERT_EQ(all.size(), 3);
ASSERT_TRUE(std::find(all.begin(), all.end(), "sc3") != all.end());
}
}
TEST(DBMS_Handler, Get) {
auto &sch = *TestEnvironment::get();
auto default_sc = sch.Get(memgraph::dbms::kDefaultDB);
ASSERT_TRUE(default_sc.interpreter_context->db != nullptr);
ASSERT_TRUE(default_sc.interpreter_context != nullptr);
ASSERT_TRUE(default_sc.audit_log != nullptr);
ASSERT_TRUE(default_sc.auth != nullptr);
ASSERT_ANY_THROW(sch.Get("non-existent"));
auto sc1 = sch.Get("sc1");
ASSERT_TRUE(sc1.interpreter_context->db != nullptr);
ASSERT_TRUE(sc1.interpreter_context != nullptr);
ASSERT_TRUE(sc1.audit_log != nullptr);
ASSERT_TRUE(sc1.auth != nullptr);
auto sc3 = sch.Get("sc3");
ASSERT_TRUE(sc3.interpreter_context->db != nullptr);
ASSERT_TRUE(sc3.interpreter_context != nullptr);
ASSERT_TRUE(sc3.audit_log != nullptr);
ASSERT_TRUE(sc3.auth != nullptr);
}
TEST(DBMS_Handler, SetFor) {
auto &sch = *TestEnvironment::get();
ASSERT_TRUE(sch.New("db1").HasValue());
bool ti0_on_change_ = false;
bool ti0_on_delete_ = false;
TestInterface ti0(
"memgraph",
[&ti0, &ti0_on_change_](const std::string &name) -> memgraph::dbms::SetForResult {
ti0_on_change_ = true;
if (name != ti0.db_) {
ti0.db_ = name;
return memgraph::dbms::SetForResult::SUCCESS;
}
return memgraph::dbms::SetForResult::ALREADY_SET;
},
[&](const std::string &name) -> bool {
ti0_on_delete_ = true;
return true;
});
bool ti1_on_change_ = false;
bool ti1_on_delete_ = false;
TestInterface ti1(
"db1",
[&](const std::string &name) -> memgraph::dbms::SetForResult {
ti1_on_change_ = true;
return memgraph::dbms::SetForResult::SUCCESS;
},
[&](const std::string &name) -> bool {
ti1_on_delete_ = true;
return true;
});
ASSERT_TRUE(sch.Register(ti0));
ASSERT_FALSE(sch.Register(ti0));
{
ASSERT_EQ(sch.SetFor("0", "db1"), memgraph::dbms::SetForResult::SUCCESS);
ASSERT_TRUE(ti0_on_change_);
ti0_on_change_ = false;
ASSERT_EQ(sch.SetFor("0", "db1"), memgraph::dbms::SetForResult::ALREADY_SET);
ASSERT_TRUE(ti0_on_change_);
ti0_on_change_ = false;
ASSERT_ANY_THROW(sch.SetFor(std::to_string(TestInterface::id), "db1")); // Session does not exist
ASSERT_ANY_THROW(sch.SetFor("1", "db1")); // Session not registered
ASSERT_ANY_THROW(sch.SetFor("0", "db2")); // No db2
ASSERT_EQ(sch.SetFor("0", "memgraph"), memgraph::dbms::SetForResult::SUCCESS);
ASSERT_TRUE(ti0_on_change_);
}
ASSERT_TRUE(sch.Delete(ti0));
ASSERT_FALSE(sch.Delete(ti1));
}
TEST(DBMS_Handler, Delete) {
auto &sch = *TestEnvironment::get();
bool ti0_on_change_ = false;
bool ti0_on_delete_ = false;
TestInterface ti0(
"memgraph",
[&](const std::string &name) -> memgraph::dbms::SetForResult {
ti0_on_change_ = true;
if (name != "sc3") return memgraph::dbms::SetForResult::SUCCESS;
return memgraph::dbms::SetForResult::FAIL;
},
[&](const std::string &name) -> bool {
ti0_on_delete_ = true;
return (name != "sc3");
});
bool ti1_on_change_ = false;
bool ti1_on_delete_ = false;
TestInterface ti1(
"sc1",
[&](const std::string &name) -> memgraph::dbms::SetForResult {
ti1_on_change_ = true;
ti1.db_ = name;
return memgraph::dbms::SetForResult::SUCCESS;
},
[&](const std::string &name) -> bool {
ti1_on_delete_ = true;
return ti1.db_ != name;
});
ASSERT_TRUE(sch.Register(ti0));
ASSERT_TRUE(sch.Register(ti1));
{
auto del = sch.Delete(memgraph::dbms::kDefaultDB);
ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::DEFAULT_DB);
}
{
auto del = sch.Delete("non-existent");
ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT);
}
{
// ti1 is using sc1
auto del = sch.Delete("sc1");
ASSERT_TRUE(del.HasError());
ASSERT_TRUE(del.GetError() == memgraph::dbms::DeleteError::FAIL);
}
{
// Delete ti1 so delete will succeed
ASSERT_EQ(sch.SetFor(ti1.UUID(), "memgraph"), memgraph::dbms::SetForResult::SUCCESS);
auto del = sch.Delete("sc1");
ASSERT_FALSE(del.HasError()) << (int)del.GetError();
auto del2 = sch.Delete("sc1");
ASSERT_TRUE(del2.HasError() && del2.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT);
}
{
// Using based on the active interpreters
auto new_sc = sch.New("sc1");
ASSERT_TRUE(new_sc.HasValue()) << (int)new_sc.GetError();
auto sc = sch.Get("sc1");
memgraph::query::Interpreter interpreter(sc.interpreter_context.get());
sc.interpreter_context->interpreters.WithLock([&](auto &interpreters) { interpreters.insert(&interpreter); });
auto del = sch.Delete("sc1");
ASSERT_TRUE(del.HasError());
ASSERT_EQ(del.GetError(), memgraph::dbms::DeleteError::USING);
sc.interpreter_context->interpreters.WithLock([&](auto &interpreters) { interpreters.erase(&interpreter); });
}
{
// Interpreter deactivated, so we should be able to delete
auto del = sch.Delete("sc1");
ASSERT_FALSE(del.HasError()) << (int)del.GetError();
}
{
ASSERT_TRUE(sch.Delete(ti0));
auto del = sch.Delete("sc3");
ASSERT_FALSE(del.HasError());
ASSERT_FALSE(std::filesystem::exists(storage_directory / "databases" / "sc3"));
}
ASSERT_TRUE(sch.Delete(ti1));
}
int main(int argc, char *argv[]) {
::testing::InitGoogleTest(&argc, argv);
// gtest takes ownership of the TestEnvironment ptr - we don't delete it.
::testing::AddGlobalTestEnvironment(new TestEnvironment);
return RUN_ALL_TESTS();
}
#endif

View File

@ -26,12 +26,14 @@
#include "query/config.hpp"
#include "query/exceptions.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/stream.hpp"
#include "query/typed_value.hpp"
#include "query_common.hpp"
#include "storage/v2/inmemory/storage.hpp"
#include "storage/v2/isolation_level.hpp"
#include "storage/v2/property_value.hpp"
#include "storage/v2/storage_mode.hpp"
#include "utils/logging.hpp"
namespace {
@ -49,20 +51,43 @@ auto ToEdgeList(const memgraph::communication::bolt::Value &v) {
// TODO: This is not a unit test, but tests/integration dir is chaotic at the
// moment. After tests refactoring is done, move/rename this.
constexpr auto kNoHandler = nullptr;
template <typename StorageType>
class InterpreterTest : public ::testing::Test {
public:
const std::string testSuite = "interpreter";
const std::string testSuiteCsv = "interpreter_csv";
std::filesystem::path data_directory = std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter";
InterpreterTest()
: data_directory(std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"),
interpreter_context(std::make_unique<StorageType>(disk_test_utils::GenerateOnDiskConfig(testSuite)), {},
data_directory) {
memgraph::flags::run_time::execution_timeout_sec_ = 600.0;
}
InterpreterTest() : interpreter_context({}, kNoHandler) { memgraph::flags::run_time::execution_timeout_sec_ = 600.0; }
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk{
[&]() {
memgraph::storage::Config config{};
config.durability.storage_directory = data_directory;
config.disk.main_storage_directory = config.durability.storage_directory / "disk";
if constexpr (std::is_same_v<StorageType, memgraph::storage::DiskStorage>) {
config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk;
config.force_on_disk = true;
}
return config;
}() // iile
};
memgraph::dbms::DatabaseAccess db{
[&]() {
auto db_acc_opt = db_gk.access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v<StorageType, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL),
"Wrong storage mode!");
return db_acc;
}() // iile
};
std::filesystem::path data_directory;
memgraph::query::InterpreterContext interpreter_context;
void TearDown() override {
@ -72,7 +97,7 @@ class InterpreterTest : public ::testing::Test {
}
}
InterpreterFaker default_interpreter{&interpreter_context};
InterpreterFaker default_interpreter{&interpreter_context, db};
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
return default_interpreter.Prepare(query, params);
@ -320,7 +345,7 @@ TYPED_TEST(InterpreterTest, Bfs) {
// Set up.
{
auto storage_dba = this->interpreter_context.db->Access();
auto storage_dba = this->db->Access();
memgraph::query::DbAccessor dba(storage_dba.get());
auto add_node = [&](int level, bool reachable) {
auto node = dba.InsertVertex();
@ -633,7 +658,7 @@ TYPED_TEST(InterpreterTest, UniqueConstraintTest) {
}
TYPED_TEST(InterpreterTest, ExplainQuery) {
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(this->db->plan_cache()->size(), 0U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U);
auto stream = this->Interpret("EXPLAIN MATCH (n) RETURN *;");
ASSERT_EQ(stream.GetHeader().size(), 1U);
@ -647,16 +672,16 @@ TYPED_TEST(InterpreterTest, ExplainQuery) {
++expected_it;
}
// We should have a plan cache for MATCH ...
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
// We should have AST cache for EXPLAIN ... and for inner MATCH ...
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
this->Interpret("MATCH (n) RETURN *;");
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
}
TYPED_TEST(InterpreterTest, ExplainQueryMultiplePulls) {
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(this->db->plan_cache()->size(), 0U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U);
auto [stream, qid] = this->Prepare("EXPLAIN MATCH (n) RETURN *;");
ASSERT_EQ(stream.GetHeader().size(), 1U);
@ -680,16 +705,16 @@ TYPED_TEST(InterpreterTest, ExplainQueryMultiplePulls) {
ASSERT_EQ(stream.GetResults()[2].size(), 1U);
EXPECT_EQ(stream.GetResults()[2].front().ValueString(), *expected_it);
// We should have a plan cache for MATCH ...
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
// We should have AST cache for EXPLAIN ... and for inner MATCH ...
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
this->Interpret("MATCH (n) RETURN *;");
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
}
TYPED_TEST(InterpreterTest, ExplainQueryInMulticommandTransaction) {
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(this->db->plan_cache()->size(), 0U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U);
this->Interpret("BEGIN");
auto stream = this->Interpret("EXPLAIN MATCH (n) RETURN *;");
@ -705,16 +730,16 @@ TYPED_TEST(InterpreterTest, ExplainQueryInMulticommandTransaction) {
++expected_it;
}
// We should have a plan cache for MATCH ...
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
// We should have AST cache for EXPLAIN ... and for inner MATCH ...
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
this->Interpret("MATCH (n) RETURN *;");
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
}
TYPED_TEST(InterpreterTest, ExplainQueryWithParams) {
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(this->db->plan_cache()->size(), 0U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U);
auto stream =
this->Interpret("EXPLAIN MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue(42)}});
@ -729,16 +754,16 @@ TYPED_TEST(InterpreterTest, ExplainQueryWithParams) {
++expected_it;
}
// We should have a plan cache for MATCH ...
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
// We should have AST cache for EXPLAIN ... and for inner MATCH ...
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
this->Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue("something else")}});
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
}
TYPED_TEST(InterpreterTest, ProfileQuery) {
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(this->db->plan_cache()->size(), 0U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U);
auto stream = this->Interpret("PROFILE MATCH (n) RETURN *;");
std::vector<std::string> expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"};
@ -752,16 +777,16 @@ TYPED_TEST(InterpreterTest, ProfileQuery) {
++expected_it;
}
// We should have a plan cache for MATCH ...
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
// We should have AST cache for PROFILE ... and for inner MATCH ...
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
this->Interpret("MATCH (n) RETURN *;");
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
}
TYPED_TEST(InterpreterTest, ProfileQueryMultiplePulls) {
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(this->db->plan_cache()->size(), 0U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U);
auto [stream, qid] = this->Prepare("PROFILE MATCH (n) RETURN *;");
std::vector<std::string> expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"};
@ -788,11 +813,11 @@ TYPED_TEST(InterpreterTest, ProfileQueryMultiplePulls) {
ASSERT_EQ(stream.GetResults()[2][0].ValueString(), *expected_it);
// We should have a plan cache for MATCH ...
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
// We should have AST cache for PROFILE ... and for inner MATCH ...
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
this->Interpret("MATCH (n) RETURN *;");
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
}
@ -803,7 +828,7 @@ TYPED_TEST(InterpreterTest, ProfileQueryInMulticommandTransaction) {
}
TYPED_TEST(InterpreterTest, ProfileQueryWithParams) {
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(this->db->plan_cache()->size(), 0U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U);
auto stream =
this->Interpret("PROFILE MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue(42)}});
@ -818,16 +843,16 @@ TYPED_TEST(InterpreterTest, ProfileQueryWithParams) {
++expected_it;
}
// We should have a plan cache for MATCH ...
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
// We should have AST cache for PROFILE ... and for inner MATCH ...
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
this->Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue("something else")}});
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
}
TYPED_TEST(InterpreterTest, ProfileQueryWithLiterals) {
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(this->db->plan_cache()->size(), 0U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U);
auto stream = this->Interpret("PROFILE UNWIND range(1, 1000) AS x CREATE (:Node {id: x});", {});
std::vector<std::string> expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"};
@ -841,11 +866,11 @@ TYPED_TEST(InterpreterTest, ProfileQueryWithLiterals) {
++expected_it;
}
// We should have a plan cache for UNWIND ...
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
// We should have AST cache for PROFILE ... and for inner UNWIND ...
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
this->Interpret("UNWIND range(42, 4242) AS x CREATE (:Node {id: x});", {});
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U);
}
@ -1071,7 +1096,7 @@ TYPED_TEST(InterpreterTest, CacheableQueries) {
SCOPED_TRACE("Cacheable query");
this->Interpret("RETURN 1");
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 1U);
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
}
{
@ -1080,7 +1105,7 @@ TYPED_TEST(InterpreterTest, CacheableQueries) {
// result signature could be changed
this->Interpret("CALL mg.load_all()");
EXPECT_EQ(this->interpreter_context.ast_cache.size(), 1U);
EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U);
EXPECT_EQ(this->db->plan_cache()->size(), 1U);
}
}
@ -1098,11 +1123,25 @@ TYPED_TEST(InterpreterTest, AllowLoadCsvConfig) {
"CREATE TRIGGER trigger ON CREATE BEFORE COMMIT EXECUTE LOAD CSV FROM 'file.csv' WITH HEADER AS row RETURN "
"row"};
memgraph::query::InterpreterContext csv_interpreter_context{
std::make_unique<TypeParam>(disk_test_utils::GenerateOnDiskConfig(this->testSuiteCsv)),
{.query = {.allow_load_csv = allow_load_csv}},
directory_manager.Path()};
InterpreterFaker interpreter_faker{&csv_interpreter_context};
memgraph::storage::Config config2{};
config2.durability.storage_directory = directory_manager.Path();
config2.disk.main_storage_directory = config2.durability.storage_directory / "disk";
if constexpr (std::is_same_v<TypeParam, memgraph::storage::DiskStorage>) {
config2.disk = disk_test_utils::GenerateOnDiskConfig(this->testSuiteCsv).disk;
config2.force_on_disk = true;
}
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk2(config2);
auto db_acc_opt = db_gk2.access();
ASSERT_TRUE(db_acc_opt) << "Failed to access db2";
auto &db_acc = *db_acc_opt;
ASSERT_TRUE(db_acc->GetStorageMode() == (std::is_same_v<TypeParam, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL))
<< "Wrong storage mode!";
memgraph::query::InterpreterContext csv_interpreter_context{{.query = {.allow_load_csv = allow_load_csv}}, nullptr};
InterpreterFaker interpreter_faker{&csv_interpreter_context, db_acc};
for (const auto &query : queries) {
if (allow_load_csv) {
SCOPED_TRACE(fmt::format("'{}' should not throw because LOAD CSV is allowed", query));

View File

@ -11,16 +11,17 @@
#include "communication/result_stream_faker.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
struct InterpreterFaker {
InterpreterFaker(memgraph::query::InterpreterContext *interpreter_context)
: interpreter_context(interpreter_context), interpreter(interpreter_context) {
InterpreterFaker(memgraph::query::InterpreterContext *interpreter_context, memgraph::dbms::DatabaseAccess db)
: interpreter_context(interpreter_context), interpreter(interpreter_context, db) {
interpreter_context->auth_checker = &auth_checker;
interpreter_context->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter); });
}
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
ResultStreamFaker stream(interpreter_context->db.get());
ResultStreamFaker stream(interpreter.db_acc_->get()->storage());
const auto [header, _1, qid, _2] = interpreter.Prepare(query, params, nullptr);
stream.Header(header);
return std::make_pair(std::move(stream), qid);

View File

@ -17,10 +17,13 @@
#include <vector>
#include "communication/result_stream_faker.hpp"
#include "dbms/database.hpp"
#include "disk_test_utils.hpp"
#include "query/config.hpp"
#include "query/dump.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/stream/streams.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/disk/storage.hpp"
#include "storage/v2/edge_accessor.hpp"
@ -205,9 +208,10 @@ DatabaseState GetState(memgraph::storage::Storage *db) {
return {vertices, edges, label_indices, label_property_indices, existence_constraints, unique_constraints};
}
auto Execute(memgraph::query::InterpreterContext *context, const std::string &query) {
memgraph::query::Interpreter interpreter(context);
ResultStreamFaker stream(context->db.get());
auto Execute(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db,
const std::string &query) {
memgraph::query::Interpreter interpreter(context, db);
ResultStreamFaker stream(db->storage());
auto [header, _1, qid, _2] = interpreter.Prepare(query, {}, nullptr);
stream.Header(header);
@ -275,14 +279,40 @@ class DumpTest : public ::testing::Test {
public:
const std::string testSuite = "query_dump";
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_dump_class"};
memgraph::query::InterpreterContext context{
std::make_unique<StorageType>(disk_test_utils::GenerateOnDiskConfig(testSuite)),
memgraph::query::InterpreterConfig{}, data_directory};
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk{
[&]() {
memgraph::storage::Config config{};
config.durability.storage_directory = data_directory;
config.disk.main_storage_directory = config.durability.storage_directory / "disk";
if constexpr (std::is_same_v<StorageType, memgraph::storage::DiskStorage>) {
config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk;
config.force_on_disk = true;
}
return config;
}() // iile
};
memgraph::dbms::DatabaseAccess db{
[&]() {
auto db_acc_opt = db_gk.access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v<StorageType, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL),
"Wrong storage mode!");
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext context{memgraph::query::InterpreterConfig{}, nullptr};
void TearDown() override {
if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) {
disk_test_utils::RemoveRocksDbDirs(testSuite);
}
std::filesystem::remove_all(data_directory);
}
};
@ -291,10 +321,10 @@ TYPED_TEST_CASE(DumpTest, StorageTypes);
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, EmptyGraph) {
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -304,16 +334,16 @@ TYPED_TEST(DumpTest, EmptyGraph) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, SingleVertex) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {}, {}, false);
ASSERT_FALSE(dba->Commit().HasError());
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -325,16 +355,16 @@ TYPED_TEST(DumpTest, SingleVertex) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, VertexWithSingleLabel) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {"Label1"}, {}, false);
ASSERT_FALSE(dba->Commit().HasError());
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -346,16 +376,16 @@ TYPED_TEST(DumpTest, VertexWithSingleLabel) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, VertexWithMultipleLabels) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {"Label1", "Label 2"}, {}, false);
ASSERT_FALSE(dba->Commit().HasError());
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -368,16 +398,16 @@ TYPED_TEST(DumpTest, VertexWithMultipleLabels) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, VertexWithSingleProperty) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {}, {{"prop", memgraph::storage::PropertyValue(42)}}, false);
ASSERT_FALSE(dba->Commit().HasError());
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -389,7 +419,7 @@ TYPED_TEST(DumpTest, VertexWithSingleProperty) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, MultipleVertices) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {}, {}, false);
CreateVertex(dba.get(), {}, {}, false);
CreateVertex(dba.get(), {}, {}, false);
@ -397,10 +427,10 @@ TYPED_TEST(DumpTest, MultipleVertices) {
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -412,7 +442,7 @@ TYPED_TEST(DumpTest, MultipleVertices) {
TYPED_TEST(DumpTest, PropertyValue) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
auto null_value = memgraph::storage::PropertyValue();
auto int_value = memgraph::storage::PropertyValue(13);
auto bool_value = memgraph::storage::PropertyValue(true);
@ -435,10 +465,10 @@ TYPED_TEST(DumpTest, PropertyValue) {
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -454,7 +484,7 @@ TYPED_TEST(DumpTest, PropertyValue) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, SingleEdge) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
auto u = CreateVertex(dba.get(), {}, {}, false);
auto v = CreateVertex(dba.get(), {}, {}, false);
CreateEdge(dba.get(), &u, &v, "EdgeType", {}, false);
@ -462,10 +492,10 @@ TYPED_TEST(DumpTest, SingleEdge) {
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -480,7 +510,7 @@ TYPED_TEST(DumpTest, SingleEdge) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, MultipleEdges) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
auto u = CreateVertex(dba.get(), {}, {}, false);
auto v = CreateVertex(dba.get(), {}, {}, false);
auto w = CreateVertex(dba.get(), {}, {}, false);
@ -491,10 +521,10 @@ TYPED_TEST(DumpTest, MultipleEdges) {
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -513,7 +543,7 @@ TYPED_TEST(DumpTest, MultipleEdges) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, EdgeWithProperties) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
auto u = CreateVertex(dba.get(), {}, {}, false);
auto v = CreateVertex(dba.get(), {}, {}, false);
CreateEdge(dba.get(), &u, &v, "EdgeType", {{"prop", memgraph::storage::PropertyValue(13)}}, false);
@ -521,10 +551,10 @@ TYPED_TEST(DumpTest, EdgeWithProperties) {
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -539,22 +569,24 @@ TYPED_TEST(DumpTest, EdgeWithProperties) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, IndicesKeys) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {"Label1", "Label 2"}, {{"p", memgraph::storage::PropertyValue(1)}}, false);
ASSERT_FALSE(dba->Commit().HasError());
}
ASSERT_FALSE(
this->context.db->CreateIndex(this->context.db->NameToLabel("Label1"), this->context.db->NameToProperty("prop"))
this->db->storage()
->CreateIndex(this->db->storage()->NameToLabel("Label1"), this->db->storage()->NameToProperty("prop"))
.HasError());
ASSERT_FALSE(
this->db->storage()
->CreateIndex(this->db->storage()->NameToLabel("Label 2"), this->db->storage()->NameToProperty("prop `"))
.HasError());
ASSERT_FALSE(this->context.db
->CreateIndex(this->context.db->NameToLabel("Label 2"), this->context.db->NameToProperty("prop `"))
.HasError());
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -567,21 +599,21 @@ TYPED_TEST(DumpTest, IndicesKeys) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, ExistenceConstraints) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {"L`abel 1"}, {{"prop", memgraph::storage::PropertyValue(1)}}, false);
ASSERT_FALSE(dba->Commit().HasError());
}
{
auto res = this->context.db->CreateExistenceConstraint(this->context.db->NameToLabel("L`abel 1"),
this->context.db->NameToProperty("prop"), {});
auto res = this->db->storage()->CreateExistenceConstraint(this->db->storage()->NameToLabel("L`abel 1"),
this->db->storage()->NameToProperty("prop"), {});
ASSERT_FALSE(res.HasError());
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -593,7 +625,7 @@ TYPED_TEST(DumpTest, ExistenceConstraints) {
TYPED_TEST(DumpTest, UniqueConstraints) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {"Label"},
{{"prop", memgraph::storage::PropertyValue(1)}, {"prop2", memgraph::storage::PropertyValue(2)}},
false);
@ -603,18 +635,18 @@ TYPED_TEST(DumpTest, UniqueConstraints) {
ASSERT_FALSE(dba->Commit().HasError());
}
{
auto res = this->context.db->CreateUniqueConstraint(
this->context.db->NameToLabel("Label"),
{this->context.db->NameToProperty("prop"), this->context.db->NameToProperty("prop2")}, {});
auto res = this->db->storage()->CreateUniqueConstraint(
this->db->storage()->NameToLabel("Label"),
{this->db->storage()->NameToProperty("prop"), this->db->storage()->NameToProperty("prop2")}, {});
ASSERT_TRUE(res.HasValue());
ASSERT_EQ(res.GetValue(), memgraph::storage::UniqueConstraints::CreationStatus::SUCCESS);
}
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -633,7 +665,7 @@ TYPED_TEST(DumpTest, UniqueConstraints) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
std::map<std::string, memgraph::storage::PropertyValue> prop1 = {
{"nested1", memgraph::storage::PropertyValue(1337)}, {"nested2", memgraph::storage::PropertyValue(3.14)}};
@ -644,15 +676,30 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) {
ASSERT_FALSE(dba->Commit().HasError());
}
auto data_directory = std::filesystem::temp_directory_path() / "MG_tests_unit_query_dump";
memgraph::query::InterpreterContext interpreter_context(std::make_unique<TypeParam>(),
memgraph::query::InterpreterConfig{}, data_directory);
memgraph::storage::Config config{};
config.durability.storage_directory = this->data_directory / "s1";
config.disk.main_storage_directory = config.durability.storage_directory / "disk";
if constexpr (std::is_same_v<TypeParam, memgraph::storage::DiskStorage>) {
config.disk = disk_test_utils::GenerateOnDiskConfig("query-dump-s1").disk;
config.force_on_disk = true;
}
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk(config);
auto db_acc_opt = db_gk.access();
ASSERT_TRUE(db_acc_opt) << "Failed to access db";
auto &db_acc = *db_acc_opt;
ASSERT_TRUE(db_acc->GetStorageMode() == (std::is_same_v<TypeParam, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL))
<< "Wrong storage mode!";
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr);
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -661,7 +708,7 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) {
for (const auto &item : results) {
ASSERT_EQ(item.size(), 1);
ASSERT_TRUE(item[0].IsString());
Execute(&interpreter_context, item[0].ValueString());
Execute(&interpreter_context, db_acc, item[0].ValueString());
}
}
}
@ -669,7 +716,7 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) {
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, CheckStateSimpleGraph) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
auto u = CreateVertex(dba.get(), {"Person"}, {{"name", memgraph::storage::PropertyValue("Ivan")}});
auto v = CreateVertex(dba.get(), {"Person"}, {{"name", memgraph::storage::PropertyValue("Josko")}});
auto w = CreateVertex(
@ -710,33 +757,48 @@ TYPED_TEST(DumpTest, CheckStateSimpleGraph) {
ASSERT_FALSE(dba->Commit().HasError());
}
{
auto ret = this->context.db->CreateExistenceConstraint(this->context.db->NameToLabel("Person"),
this->context.db->NameToProperty("name"), {});
auto ret = this->db->storage()->CreateExistenceConstraint(this->db->storage()->NameToLabel("Person"),
this->db->storage()->NameToProperty("name"), {});
ASSERT_FALSE(ret.HasError());
}
{
auto ret = this->context.db->CreateUniqueConstraint(this->context.db->NameToLabel("Person"),
{this->context.db->NameToProperty("name")}, {});
auto ret = this->db->storage()->CreateUniqueConstraint(this->db->storage()->NameToLabel("Person"),
{this->db->storage()->NameToProperty("name")}, {});
ASSERT_TRUE(ret.HasValue());
ASSERT_EQ(ret.GetValue(), memgraph::storage::UniqueConstraints::CreationStatus::SUCCESS);
}
ASSERT_FALSE(
this->context.db->CreateIndex(this->context.db->NameToLabel("Person"), this->context.db->NameToProperty("id"))
.HasError());
ASSERT_FALSE(this->context.db
->CreateIndex(this->context.db->NameToLabel("Person"),
this->context.db->NameToProperty("unexisting_property"))
ASSERT_FALSE(this->db->storage()
->CreateIndex(this->db->storage()->NameToLabel("Person"), this->db->storage()->NameToProperty("id"))
.HasError());
ASSERT_FALSE(this->db->storage()
->CreateIndex(this->db->storage()->NameToLabel("Person"),
this->db->storage()->NameToProperty("unexisting_property"))
.HasError());
const auto &db_initial_state = GetState(this->context.db.get());
auto data_directory = std::filesystem::temp_directory_path() / "MG_tests_unit_query_dump";
memgraph::query::InterpreterContext interpreter_context(std::make_unique<TypeParam>(),
memgraph::query::InterpreterConfig{}, data_directory);
const auto &db_initial_state = GetState(this->db->storage());
memgraph::storage::Config config{};
config.durability.storage_directory = this->data_directory / "s2";
config.disk.main_storage_directory = config.durability.storage_directory / "disk";
if constexpr (std::is_same_v<TypeParam, memgraph::storage::DiskStorage>) {
config.disk = disk_test_utils::GenerateOnDiskConfig("query-dump-s2").disk;
config.force_on_disk = true;
}
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk(config);
auto db_acc_opt = db_gk.access();
ASSERT_TRUE(db_acc_opt) << "Failed to access db";
auto &db_acc = *db_acc_opt;
ASSERT_TRUE(db_acc->GetStorageMode() == (std::is_same_v<TypeParam, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL))
<< "Wrong storage mode!";
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr);
{
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
{
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream);
}
@ -749,23 +811,23 @@ TYPED_TEST(DumpTest, CheckStateSimpleGraph) {
ASSERT_EQ(item.size(), 1);
ASSERT_TRUE(item[0].IsString());
spdlog::debug("Query: {}", item[0].ValueString());
Execute(&interpreter_context, item[0].ValueString());
Execute(&interpreter_context, db_acc, item[0].ValueString());
++i;
}
}
ASSERT_EQ(GetState(this->context.db.get()), db_initial_state);
ASSERT_EQ(GetState(this->db->storage()), db_initial_state);
}
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, ExecuteDumpDatabase) {
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {}, {}, false);
ASSERT_FALSE(dba->Commit().HasError());
}
{
auto stream = Execute(&this->context, "DUMP DATABASE");
auto stream = Execute(&this->context, this->db, "DUMP DATABASE");
const auto &header = stream.GetHeader();
const auto &results = stream.GetResults();
ASSERT_EQ(header.size(), 1U);
@ -784,11 +846,11 @@ TYPED_TEST(DumpTest, ExecuteDumpDatabase) {
class StatefulInterpreter {
public:
explicit StatefulInterpreter(memgraph::query::InterpreterContext *context)
: context_(context), interpreter_(context_) {}
explicit StatefulInterpreter(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db)
: context_(context), interpreter_(context_, db) {}
auto Execute(const std::string &query) {
ResultStreamFaker stream(context_->db.get());
ResultStreamFaker stream(interpreter_.db_acc_->get()->storage());
auto [header, _1, qid, _2] = interpreter_.Prepare(query, {}, nullptr);
stream.Header(header);
@ -799,18 +861,13 @@ class StatefulInterpreter {
}
private:
static const std::filesystem::path data_directory_;
memgraph::query::InterpreterContext *context_;
memgraph::query::Interpreter interpreter_;
};
const std::filesystem::path StatefulInterpreter::data_directory_{std::filesystem::temp_directory_path() /
"MG_tests_unit_query_dump_stateful"};
// NOLINTNEXTLINE(hicpp-special-member-functions)
TYPED_TEST(DumpTest, ExecuteDumpDatabaseInMulticommandTransaction) {
StatefulInterpreter interpreter(&this->context);
StatefulInterpreter interpreter(&this->context, this->db);
// Begin the transaction before the vertex is created.
interpreter.Execute("BEGIN");
@ -827,7 +884,7 @@ TYPED_TEST(DumpTest, ExecuteDumpDatabaseInMulticommandTransaction) {
// Create the vertex.
{
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
CreateVertex(dba.get(), {}, {}, false);
ASSERT_FALSE(dba->Commit().HasError());
}
@ -875,39 +932,41 @@ TYPED_TEST(DumpTest, MultiplePartialPulls) {
{
// Create indices
ASSERT_FALSE(
this->context.db->CreateIndex(this->context.db->NameToLabel("PERSON"), this->context.db->NameToProperty("name"))
this->db->storage()
->CreateIndex(this->db->storage()->NameToLabel("PERSON"), this->db->storage()->NameToProperty("name"))
.HasError());
ASSERT_FALSE(
this->db->storage()
->CreateIndex(this->db->storage()->NameToLabel("PERSON"), this->db->storage()->NameToProperty("surname"))
.HasError());
ASSERT_FALSE(this->context.db
->CreateIndex(this->context.db->NameToLabel("PERSON"), this->context.db->NameToProperty("surname"))
.HasError());
// Create existence constraints
{
auto res = this->context.db->CreateExistenceConstraint(this->context.db->NameToLabel("PERSON"),
this->context.db->NameToProperty("name"), {});
auto res = this->db->storage()->CreateExistenceConstraint(this->db->storage()->NameToLabel("PERSON"),
this->db->storage()->NameToProperty("name"), {});
ASSERT_FALSE(res.HasError());
}
{
auto res = this->context.db->CreateExistenceConstraint(this->context.db->NameToLabel("PERSON"),
this->context.db->NameToProperty("surname"), {});
auto res = this->db->storage()->CreateExistenceConstraint(this->db->storage()->NameToLabel("PERSON"),
this->db->storage()->NameToProperty("surname"), {});
ASSERT_FALSE(res.HasError());
}
// Create unique constraints
{
auto res = this->context.db->CreateUniqueConstraint(this->context.db->NameToLabel("PERSON"),
{this->context.db->NameToProperty("name")}, {});
auto res = this->db->storage()->CreateUniqueConstraint(this->db->storage()->NameToLabel("PERSON"),
{this->db->storage()->NameToProperty("name")}, {});
ASSERT_TRUE(res.HasValue());
ASSERT_EQ(res.GetValue(), memgraph::storage::UniqueConstraints::CreationStatus::SUCCESS);
}
{
auto res = this->context.db->CreateUniqueConstraint(this->context.db->NameToLabel("PERSON"),
{this->context.db->NameToProperty("surname")}, {});
auto res = this->db->storage()->CreateUniqueConstraint(this->db->storage()->NameToLabel("PERSON"),
{this->db->storage()->NameToProperty("surname")}, {});
ASSERT_TRUE(res.HasValue());
ASSERT_EQ(res.GetValue(), memgraph::storage::UniqueConstraints::CreationStatus::SUCCESS);
}
auto dba = this->context.db->Access();
auto dba = this->db->storage()->Access();
auto p1 = CreateVertex(dba.get(), {"PERSON"},
{{"name", memgraph::storage::PropertyValue("Person1")},
{"surname", memgraph::storage::PropertyValue("Unique1")}},
@ -935,9 +994,9 @@ TYPED_TEST(DumpTest, MultiplePartialPulls) {
ASSERT_FALSE(dba->Commit().HasError());
}
ResultStreamFaker stream(this->context.db.get());
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());
auto acc = this->context.db->Access();
auto acc = this->db->storage()->Access();
memgraph::query::DbAccessor dba(acc.get());
memgraph::query::PullPlanDump pullPlan{&dba};

View File

@ -23,6 +23,8 @@
#include "communication/result_stream_faker.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/stream/streams.hpp"
#include "storage/v2/inmemory/storage.hpp"
#include "storage/v2/storage.hpp"
@ -32,24 +34,48 @@ template <typename StorageType>
class QueryExecution : public testing::Test {
protected:
const std::string testSuite = "query_plan_edge_cases";
std::optional<memgraph::dbms::DatabaseAccess> db_acc_;
std::optional<memgraph::query::InterpreterContext> interpreter_context_;
std::optional<memgraph::query::Interpreter> interpreter_;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_plan_edge_cases"};
std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk{
[&]() {
memgraph::storage::Config config{};
config.durability.storage_directory = data_directory;
config.disk.main_storage_directory = config.durability.storage_directory / "disk";
if constexpr (std::is_same_v<StorageType, memgraph::storage::DiskStorage>) {
config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk;
config.force_on_disk = true;
}
return config;
}() // iile
};
void SetUp() {
interpreter_context_.emplace(std::make_unique<StorageType>(disk_test_utils::GenerateOnDiskConfig(testSuite)),
memgraph::query::InterpreterConfig{}, data_directory);
interpreter_.emplace(&*interpreter_context_);
auto db_acc_opt = db_gk->access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v<StorageType, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL),
"Wrong storage mode!");
db_acc_ = std::move(db_acc);
interpreter_context_.emplace(memgraph::query::InterpreterConfig{}, nullptr);
interpreter_.emplace(&*interpreter_context_, *db_acc_);
}
void TearDown() {
interpreter_ = std::nullopt;
interpreter_context_ = std::nullopt;
db_acc_.reset();
db_gk.reset();
if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) {
disk_test_utils::RemoveRocksDbDirs(testSuite);
}
std::filesystem::remove_all(data_directory);
}
/**
@ -58,7 +84,7 @@ class QueryExecution : public testing::Test {
* Return the query results.
*/
auto Execute(const std::string &query) {
ResultStreamFaker stream(this->interpreter_context_->db.get());
ResultStreamFaker stream(this->db_acc_->get()->storage());
auto [header, _1, qid, _2] = interpreter_->Prepare(query, {}, nullptr);
stream.Header(header);

View File

@ -22,6 +22,7 @@
#include "kafka_mock.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/stream/streams.hpp"
#include "storage/v2/disk/storage.hpp"
#include "storage/v2/inmemory/storage.hpp"
@ -74,9 +75,32 @@ class StreamsTestFixture : public ::testing::Test {
// Streams constructor.
// InterpreterContext::auth_checker_ is used in the Streams object, but only in the message processing part. Because
// these tests don't send any messages, the auth_checker_ pointer can be left as nullptr.
memgraph::query::InterpreterContext interpreter_context_{
std::make_unique<StorageType>(disk_test_utils::GenerateOnDiskConfig(testSuite)),
memgraph::query::InterpreterConfig{}, data_directory_};
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk{
[&]() {
memgraph::storage::Config config{};
config.durability.storage_directory = data_directory_;
config.disk.main_storage_directory = config.durability.storage_directory / "disk";
if constexpr (std::is_same_v<StorageType, memgraph::storage::DiskStorage>) {
config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk;
config.force_on_disk = true;
}
return config;
}() // iile
};
memgraph::dbms::DatabaseAccess db_{
[&]() {
auto db_acc_opt = db_gk.access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v<StorageType, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL),
"Wrong storage mode!");
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr};
std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"};
std::optional<StreamsTest> proxyStreams_;
@ -84,11 +108,12 @@ class StreamsTestFixture : public ::testing::Test {
if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) {
disk_test_utils::RemoveRocksDbDirs(testSuite);
}
std::filesystem::remove_all(data_directory_);
}
void ResetStreamsObject() {
proxyStreams_.emplace();
proxyStreams_->streams_.emplace(&interpreter_context_, streams_data_directory_);
proxyStreams_->streams_.emplace(streams_data_directory_);
}
void CheckStreamStatus(const StreamCheckData &check_data) {
@ -151,8 +176,8 @@ TYPED_TEST_CASE(StreamsTestFixture, StorageTypes);
TYPED_TEST(StreamsTestFixture, SimpleStreamManagement) {
auto check_data = this->CreateDefaultStreamCheckData();
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(check_data.name, check_data.info,
check_data.owner);
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
check_data.name, check_data.info, check_data.owner, this->db_, &this->interpreter_context_);
EXPECT_NO_FATAL_FAILURE(this->CheckStreamStatus(check_data));
EXPECT_NO_THROW(this->proxyStreams_->streams_->Start(check_data.name));
@ -178,12 +203,12 @@ TYPED_TEST(StreamsTestFixture, SimpleStreamManagement) {
TYPED_TEST(StreamsTestFixture, CreateAlreadyExisting) {
auto stream_info = this->CreateDefaultStreamInfo();
auto stream_name = GetDefaultStreamName();
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(stream_name, stream_info,
std::nullopt);
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_);
try {
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(stream_name, stream_info,
std::nullopt);
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_);
FAIL() << "Creating already existing stream should throw\n";
} catch (memgraph::query::stream::StreamsException &exception) {
EXPECT_EQ(exception.what(), fmt::format("Stream already exists with name '{}'", stream_name));
@ -194,8 +219,8 @@ TYPED_TEST(StreamsTestFixture, DropNotExistingStream) {
const auto stream_info = this->CreateDefaultStreamInfo();
const auto stream_name = GetDefaultStreamName();
const std::string not_existing_stream_name{"ThisDoesn'tExists"};
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(stream_name, stream_info,
std::nullopt);
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_);
try {
this->proxyStreams_->streams_->Drop(not_existing_stream_name);
@ -250,7 +275,7 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) {
// Reset the Streams object to trigger reloading
this->ResetStreamsObject();
EXPECT_TRUE(this->proxyStreams_->streams_->GetStreamInfo().empty());
this->proxyStreams_->streams_->RestoreStreams();
this->proxyStreams_->streams_->RestoreStreams(this->db_, &this->interpreter_context_);
EXPECT_EQ(stream_check_datas.size(), this->proxyStreams_->streams_->GetStreamInfo().size());
for (const auto &check_data : stream_check_datas) {
ASSERT_NO_FATAL_FAILURE(this->CheckStreamStatus(check_data));
@ -258,12 +283,12 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) {
}
};
this->proxyStreams_->streams_->RestoreStreams();
this->proxyStreams_->streams_->RestoreStreams(this->db_, &this->interpreter_context_);
EXPECT_TRUE(this->proxyStreams_->streams_->GetStreamInfo().empty());
for (auto &check_data : stream_check_datas) {
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
check_data.name, check_data.info, check_data.owner);
check_data.name, check_data.info, check_data.owner, this->db_, &this->interpreter_context_);
}
{
SCOPED_TRACE("After streams are created");
@ -299,13 +324,13 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) {
TYPED_TEST(StreamsTestFixture, CheckWithTimeout) {
const auto stream_info = this->CreateDefaultStreamInfo();
const auto stream_name = GetDefaultStreamName();
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(stream_name, stream_info,
std::nullopt);
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_);
std::chrono::milliseconds timeout{3000};
const auto start = std::chrono::steady_clock::now();
EXPECT_THROW(this->proxyStreams_->streams_->Check(stream_name, timeout, std::nullopt),
EXPECT_THROW(this->proxyStreams_->streams_->Check(stream_name, this->db_, timeout, std::nullopt),
memgraph::integrations::kafka::ConsumerCheckFailedException);
const auto end = std::chrono::steady_clock::now();
@ -325,7 +350,7 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidConfig) {
EXPECT_TRUE(message.find(kConfigValue) != std::string::npos) << message;
};
EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt),
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_),
memgraph::integrations::kafka::SettingCustomConfigFailed, checker);
}
@ -341,6 +366,6 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidCredentials) {
EXPECT_TRUE(message.find(kCredentialValue) == std::string::npos) << message;
};
EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt),
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_),
memgraph::integrations::kafka::SettingCustomConfigFailed, checker);
}

View File

@ -18,6 +18,7 @@
#include "interpreter_faker.hpp"
#include "query/exceptions.hpp"
#include "query/interpreter_context.hpp"
#include "storage/v2/inmemory/storage.hpp"
#include "storage/v2/isolation_level.hpp"
#include "storage/v2/storage_mode.hpp"
@ -68,10 +69,26 @@ INSTANTIATE_TEST_CASE_P(ParameterizedStorageModeTests, StorageModeTest, ::testin
class StorageModeMultiTxTest : public ::testing::Test {
protected:
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_storage_mode"};
memgraph::query::InterpreterContext interpreter_context{
std::make_unique<memgraph::storage::InMemoryStorage>(), {}, data_directory};
InterpreterFaker running_interpreter{&interpreter_context}, main_interpreter{&interpreter_context};
std::filesystem::path data_directory = []() {
const auto tmp = std::filesystem::temp_directory_path() / "MG_tests_unit_storage_mode";
std::filesystem::remove_all(tmp);
return tmp;
}(); // iile
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk{memgraph::storage::Config{
.durability.storage_directory = data_directory, .disk.main_storage_directory = data_directory / "disk"}};
memgraph::dbms::DatabaseAccess db{
[&]() {
auto db_acc_opt = db_gk.access();
auto &db_acc = *db_acc_opt;
MG_ASSERT(db_acc, "Failed to access db");
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext interpreter_context{{}, nullptr};
InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db};
};
TEST_F(StorageModeMultiTxTest, ModeSwitchInactiveTransaction) {
@ -87,11 +104,11 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchInactiveTransaction) {
while (!started) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL);
ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL);
main_interpreter.Interpret("STORAGE MODE IN_MEMORY_ANALYTICAL");
// should change state
ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL);
ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL);
// finish thread
running_thread.request_stop();
@ -100,7 +117,7 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchInactiveTransaction) {
TEST_F(StorageModeMultiTxTest, ModeSwitchActiveTransaction) {
// transactional state
ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL);
ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL);
main_interpreter.Interpret("BEGIN");
bool started = false;
@ -119,7 +136,7 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchActiveTransaction) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
// should not change still
ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL);
ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL);
main_interpreter.Interpret("COMMIT");
@ -127,7 +144,7 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchActiveTransaction) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
// should change state
ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL);
ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL);
// finish thread
running_thread.request_stop();
@ -135,11 +152,11 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchActiveTransaction) {
}
TEST_F(StorageModeMultiTxTest, ErrorChangeIsolationLevel) {
ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL);
ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL);
main_interpreter.Interpret("STORAGE MODE IN_MEMORY_ANALYTICAL");
// should change state
ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL);
ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL);
ASSERT_THROW(running_interpreter.Interpret("SET GLOBAL TRANSACTION ISOLATION LEVEL READ COMMITTED;"),
memgraph::query::IsolationLevelModificationInAnalyticsException);

View File

@ -19,6 +19,7 @@
#include "disk_test_utils.hpp"
#include "interpreter_faker.hpp"
#include "query/interpreter_context.hpp"
#include "storage/v2/inmemory/storage.hpp"
/*
@ -30,11 +31,38 @@ class TransactionQueueSimpleTest : public ::testing::Test {
protected:
const std::string testSuite = "transactin_queue";
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_transaction_queue_intr"};
memgraph::query::InterpreterContext interpreter_context{
std::make_unique<StorageType>(disk_test_utils::GenerateOnDiskConfig(testSuite)), {}, data_directory};
InterpreterFaker running_interpreter{&interpreter_context}, main_interpreter{&interpreter_context};
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk{
[&]() {
memgraph::storage::Config config{};
config.durability.storage_directory = data_directory;
config.disk.main_storage_directory = config.durability.storage_directory / "disk";
if constexpr (std::is_same_v<StorageType, memgraph::storage::DiskStorage>) {
config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk;
config.force_on_disk = true;
}
return config;
}() // iile
};
void TearDown() override { disk_test_utils::RemoveRocksDbDirs(testSuite); }
memgraph::dbms::DatabaseAccess db{
[&]() {
auto db_acc_opt = db_gk.access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v<StorageType, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL),
"Wrong storage mode!");
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext interpreter_context{{}, nullptr};
InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db};
void TearDown() override {
disk_test_utils::RemoveRocksDbDirs(testSuite);
std::filesystem::remove_all(data_directory);
}
};
using StorageTypes = ::testing::Types<memgraph::storage::InMemoryStorage, memgraph::storage::DiskStorage>;

View File

@ -22,6 +22,7 @@
#include "disk_test_utils.hpp"
#include "interpreter_faker.hpp"
#include "query/exceptions.hpp"
#include "query/interpreter_context.hpp"
#include "storage/v2/config.hpp"
#include "storage/v2/disk/storage.hpp"
#include "storage/v2/inmemory/storage.hpp"
@ -38,14 +39,39 @@ class TransactionQueueMultipleTest : public ::testing::Test {
const std::string testSuite = "transactin_queue_multiple";
std::filesystem::path data_directory{std::filesystem::temp_directory_path() /
"MG_tests_unit_transaction_queue_multiple_intr"};
memgraph::query::InterpreterContext interpreter_context{
std::make_unique<StorageType>(disk_test_utils::GenerateOnDiskConfig(testSuite)), {}, data_directory};
InterpreterFaker main_interpreter{&interpreter_context};
memgraph::utils::Gatekeeper<memgraph::dbms::Database> db_gk{
[&]() {
memgraph::storage::Config config{};
config.durability.storage_directory = data_directory;
config.disk.main_storage_directory = config.durability.storage_directory / "disk";
if constexpr (std::is_same_v<StorageType, memgraph::storage::DiskStorage>) {
config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk;
config.force_on_disk = true;
}
return config;
}() // iile
};
memgraph::dbms::DatabaseAccess db{
[&]() {
auto db_acc_opt = db_gk.access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v<StorageType, memgraph::storage::DiskStorage>
? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL),
"Wrong storage mode!");
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext interpreter_context{{}, nullptr};
InterpreterFaker main_interpreter{&interpreter_context, db};
std::vector<InterpreterFaker *> running_interpreters;
TransactionQueueMultipleTest() {
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
InterpreterFaker *faker = new InterpreterFaker(&interpreter_context);
InterpreterFaker *faker = new InterpreterFaker(&interpreter_context, db);
running_interpreters.push_back(faker);
}
}
@ -55,6 +81,7 @@ class TransactionQueueMultipleTest : public ::testing::Test {
delete running_interpreters[i];
}
disk_test_utils::RemoveRocksDbDirs(testSuite);
std::filesystem::remove_all(data_directory);
}
};

View File

@ -1,296 +0,0 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gtest/gtest.h>
#include <atomic>
#include <chrono>
#include <thread>
#include <utils/sync_ptr.hpp>
#include "utils/exceptions.hpp"
using namespace std::chrono_literals;
TEST(SyncPtr, Basic) {
std::atomic_bool alive{false};
struct Test {
Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; }
~Test() { alive_ = false; }
std::atomic_bool &alive_;
};
ASSERT_FALSE(alive);
memgraph::utils::SyncPtr<Test> sp(alive);
ASSERT_TRUE(alive);
auto sp_copy1 = sp.get();
auto sp_copy2 = sp.get();
sp_copy1.reset();
ASSERT_TRUE(alive);
sp_copy2.reset();
ASSERT_TRUE(alive);
sp.DestroyAndSync();
ASSERT_FALSE(alive);
}
TEST(SyncPtr, BasicWConfig) {
std::atomic_bool alive{false};
struct Test {
Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; }
~Test() { alive_ = false; }
std::atomic_bool &alive_;
};
struct TestConf {
TestConf(int i) : conf_(i) {}
int conf_;
};
ASSERT_FALSE(alive);
memgraph::utils::SyncPtr<Test, TestConf> sp(123, alive);
ASSERT_TRUE(alive);
ASSERT_EQ(sp.config().conf_, 123);
auto sp_copy1 = sp.get();
auto sp_copy2 = sp.get();
sp_copy1.reset();
ASSERT_TRUE(alive);
sp_copy2.reset();
ASSERT_TRUE(alive);
sp.DestroyAndSync();
ASSERT_FALSE(alive);
}
TEST(SyncPtr, Sync) {
std::atomic_bool alive{false};
struct Test {
Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; }
~Test() { alive_ = false; }
std::atomic_bool &alive_;
};
std::thread th;
ASSERT_FALSE(alive);
memgraph::utils::SyncPtr<Test> sp(alive);
{
using namespace std::chrono_literals;
sp.timeout(10000ms); // 10sec
ASSERT_TRUE(alive);
auto sp_copy1 = sp.get();
auto sp_copy2 = sp.get();
th = std::thread([&alive, p = sp.get()]() mutable {
// Wait for a second and then release the pointer
// SyncPtr will be destroyed in the mean time (and block)
std::this_thread::sleep_for(std::chrono::milliseconds(200));
ASSERT_TRUE(alive);
p.reset();
ASSERT_FALSE(alive);
});
}
ASSERT_TRUE(alive);
sp.DestroyAndSync();
th.join();
ASSERT_FALSE(alive);
}
TEST(SyncPtr, SyncWConfig) {
std::atomic_bool alive{false};
struct Test {
Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; }
~Test() { alive_ = false; }
std::atomic_bool &alive_;
};
struct TestConf {
TestConf(int i) : conf_(i) {}
int conf_;
};
std::thread th;
ASSERT_FALSE(alive);
memgraph::utils::SyncPtr<Test, TestConf> sp(456, alive);
{
using namespace std::chrono_literals;
sp.timeout(10000ms); // 10sec
ASSERT_TRUE(alive);
ASSERT_EQ(sp.config().conf_, 456);
auto sp_copy1 = sp.get();
auto sp_copy2 = sp.get();
th = std::thread([&alive, p = sp.get()]() mutable {
// Wait for a second and then release the pointer
// SyncPtr will be destroyed in the mean time (and block)
std::this_thread::sleep_for(std::chrono::milliseconds(200));
ASSERT_TRUE(alive);
p.reset();
ASSERT_FALSE(alive);
});
}
ASSERT_TRUE(alive);
sp.DestroyAndSync();
th.join();
ASSERT_FALSE(alive);
}
TEST(SyncPtr, Timeout100ms) {
std::atomic_bool alive{false};
struct Test {
Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; }
~Test() { alive_ = false; }
std::atomic_bool &alive_;
};
std::thread th;
ASSERT_FALSE(alive);
memgraph::utils::SyncPtr<Test> sp(alive);
using namespace std::chrono_literals;
sp.timeout(100ms);
ASSERT_TRUE(alive);
auto p = sp.get();
ASSERT_TRUE(alive);
auto start_100ms = std::chrono::system_clock::now();
ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException);
auto end_100ms = std::chrono::system_clock::now();
auto delta_100ms = std::chrono::duration_cast<std::chrono::milliseconds>(end_100ms - start_100ms).count();
ASSERT_NEAR(delta_100ms, 100, 100);
p.reset();
ASSERT_FALSE(alive);
}
TEST(SyncPtr, Timeout567ms) {
std::atomic_bool alive{false};
struct Test {
Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; }
~Test() { alive_ = false; }
std::atomic_bool &alive_;
};
std::thread th;
ASSERT_FALSE(alive);
memgraph::utils::SyncPtr<Test> sp(alive);
using namespace std::chrono_literals;
sp.timeout(567ms);
ASSERT_TRUE(alive);
auto p = sp.get();
ASSERT_TRUE(alive);
auto start = std::chrono::system_clock::now();
ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException);
auto end = std::chrono::system_clock::now();
auto delta_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
ASSERT_NEAR(delta_ms, 567, 100);
p.reset();
ASSERT_FALSE(alive);
}
TEST(SyncPtr, Timeout100msWConfig) {
std::atomic_bool alive{false};
struct Test {
Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; }
~Test() { alive_ = false; }
std::atomic_bool &alive_;
};
struct TestConf {
TestConf(int i) : conf_(i) {}
int conf_;
};
std::thread th;
ASSERT_FALSE(alive);
memgraph::utils::SyncPtr<Test, TestConf> sp(0, alive);
using namespace std::chrono_literals;
sp.timeout(100ms);
ASSERT_TRUE(alive);
auto p = sp.get();
ASSERT_TRUE(alive);
auto start_100ms = std::chrono::system_clock::now();
ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException);
auto end_100ms = std::chrono::system_clock::now();
auto delta_100ms = std::chrono::duration_cast<std::chrono::milliseconds>(end_100ms - start_100ms).count();
ASSERT_NEAR(delta_100ms, 100, 100);
p.reset();
ASSERT_FALSE(alive);
}
TEST(SyncPtr, Timeout567msWConfig) {
std::atomic_bool alive{false};
struct Test {
Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; }
~Test() { alive_ = false; }
std::atomic_bool &alive_;
};
struct TestConf {
TestConf(int i) : conf_(i) {}
int conf_;
};
std::thread th;
ASSERT_FALSE(alive);
memgraph::utils::SyncPtr<Test, TestConf> sp(2, alive);
using namespace std::chrono_literals;
sp.timeout(567ms);
ASSERT_TRUE(alive);
auto p = sp.get();
ASSERT_TRUE(alive);
auto start = std::chrono::system_clock::now();
ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException);
auto end = std::chrono::system_clock::now();
auto delta_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
ASSERT_NEAR(delta_ms, 567, 100);
p.reset();
ASSERT_FALSE(alive);
}