diff --git a/include/mg_procedure.h b/include/mg_procedure.h index ba964a807..0bd831174 100644 --- a/include/mg_procedure.h +++ b/include/mg_procedure.h @@ -111,6 +111,22 @@ enum mgp_error mgp_global_aligned_alloc(size_t size_in_bytes, size_t alignment, /// The behavior is undefined if `ptr` is not a value returned from a prior /// mgp_global_alloc() or mgp_global_aligned_alloc(). void mgp_global_free(void *p); + +/// State of the graph database. +struct mgp_graph; + +/// Allocations are tracked only for master thread. If new threads are spawned +/// inside procedure, by calling following function +/// you can start tracking allocations for current thread too. This +/// is important if you need query memory limit to work +/// for given procedure or per procedure memory limit. +enum mgp_error mgp_track_current_thread_allocations(struct mgp_graph *graph); + +/// Once allocations are tracked for current thread, you need to stop tracking allocations +/// for given thread, before thread finishes with execution, or is detached. +/// Otherwise it might result in slowdown of system due to unnecessary tracking of +/// allocations. +enum mgp_error mgp_untrack_current_thread_allocations(struct mgp_graph *graph); ///@} /// @name Operations on mgp_value @@ -854,9 +870,6 @@ enum mgp_error mgp_edge_set_properties(struct mgp_edge *e, struct mgp_map *prope enum mgp_error mgp_edge_iter_properties(struct mgp_edge *e, struct mgp_memory *memory, struct mgp_properties_iterator **result); -/// State of the graph database. -struct mgp_graph; - /// Get the vertex corresponding to given ID, or NULL if no such vertex exists. /// Resulting vertex must be freed using mgp_vertex_destroy. /// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate the vertex. diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 5fe343ffb..dc4b21577 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -23,7 +23,7 @@ #include "glue/run_id.hpp" #include "helpers.hpp" #include "license/license_sender.hpp" -#include "memory/memory_control.hpp" +#include "memory/global_memory_control.hpp" #include "query/config.hpp" #include "query/discard_value_stream.hpp" #include "query/interpreter.hpp" @@ -512,6 +512,7 @@ int main(int argc, char **argv) { server.AwaitShutdown(); websocket_server.AwaitShutdown(); + memgraph::memory::UnsetHooks(); #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { metrics_server.AwaitShutdown(); diff --git a/src/memory/CMakeLists.txt b/src/memory/CMakeLists.txt index 4a669b9e6..aadbbe23c 100644 --- a/src/memory/CMakeLists.txt +++ b/src/memory/CMakeLists.txt @@ -1,6 +1,7 @@ set(memory_src_files new_delete.cpp - memory_control.cpp) + global_memory_control.cpp + query_memory_control.cpp) diff --git a/src/memory/memory_control.cpp b/src/memory/global_memory_control.cpp similarity index 67% rename from src/memory/memory_control.cpp rename to src/memory/global_memory_control.cpp index b3eeb6c26..128d5b046 100644 --- a/src/memory/memory_control.cpp +++ b/src/memory/global_memory_control.cpp @@ -9,12 +9,16 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include "memory_control.hpp" +#include +#include + +#include "global_memory_control.hpp" +#include "query_memory_control.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" #if USE_JEMALLOC -#include +#include "jemalloc/jemalloc.h" #endif namespace memgraph::memory { @@ -57,12 +61,24 @@ void *my_alloc(extent_hooks_t *extent_hooks, void *new_addr, size_t size, size_t // This needs to be before, to throw exception in case of too big alloc if (*commit) [[likely]] { memgraph::utils::total_memory_tracker.Alloc(static_cast(size)); + if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] { + auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread(); + if (memory_tracker != nullptr) [[likely]] { + memory_tracker->Alloc(static_cast(size)); + } + } } auto *ptr = old_hooks->alloc(extent_hooks, new_addr, size, alignment, zero, commit, arena_ind); if (ptr == nullptr) [[unlikely]] { if (*commit) { memgraph::utils::total_memory_tracker.Free(static_cast(size)); + if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] { + auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread(); + if (memory_tracker != nullptr) [[likely]] { + memory_tracker->Free(static_cast(size)); + } + } } return ptr; } @@ -79,6 +95,13 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo if (committed) [[likely]] { memgraph::utils::total_memory_tracker.Free(static_cast(size)); + + if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] { + auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread(); + if (memory_tracker != nullptr) [[likely]] { + memory_tracker->Free(static_cast(size)); + } + } } return false; @@ -87,6 +110,12 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo static void my_destroy(extent_hooks_t *extent_hooks, void *addr, size_t size, bool committed, unsigned arena_ind) { if (committed) [[likely]] { memgraph::utils::total_memory_tracker.Free(static_cast(size)); + if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] { + auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread(); + if (memory_tracker != nullptr) [[likely]] { + memory_tracker->Free(static_cast(size)); + } + } } old_hooks->destroy(extent_hooks, addr, size, committed, arena_ind); @@ -101,6 +130,12 @@ static bool my_commit(extent_hooks_t *extent_hooks, void *addr, size_t size, siz } memgraph::utils::total_memory_tracker.Alloc(static_cast(length)); + if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] { + auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread(); + if (memory_tracker != nullptr) [[likely]] { + memory_tracker->Alloc(static_cast(size)); + } + } return false; } @@ -115,6 +150,12 @@ static bool my_decommit(extent_hooks_t *extent_hooks, void *addr, size_t size, s } memgraph::utils::total_memory_tracker.Free(static_cast(length)); + if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] { + auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread(); + if (memory_tracker != nullptr) [[likely]] { + memory_tracker->Free(static_cast(size)); + } + } return false; } @@ -129,6 +170,13 @@ static bool my_purge_forced(extent_hooks_t *extent_hooks, void *addr, size_t siz } memgraph::utils::total_memory_tracker.Free(static_cast(length)); + if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] { + auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread(); + if (memory_tracker != nullptr) [[likely]] { + memory_tracker->Alloc(static_cast(size)); + } + } + return false; } @@ -153,6 +201,7 @@ void SetHooks() { } for (int i = 0; i < n_arenas; i++) { + GetQueriesMemoryControl().InitializeArenaCounter(i); std::string func_name = "arena." + std::to_string(i) + ".extent_hooks"; size_t hooks_len = sizeof(old_hooks); @@ -197,6 +246,45 @@ void SetHooks() { #endif } +void UnsetHooks() { +#if USE_JEMALLOC + + uint64_t allocated{0}; + uint64_t sz{sizeof(allocated)}; + + sz = sizeof(unsigned); + unsigned n_arenas{0}; + int err = mallctl("opt.narenas", (void *)&n_arenas, &sz, nullptr, 0); + + if (err) { + LOG_FATAL("Error setting default hooks for jemalloc arenas"); + } + + for (int i = 0; i < n_arenas; i++) { + GetQueriesMemoryControl().InitializeArenaCounter(i); + std::string func_name = "arena." + std::to_string(i) + ".extent_hooks"; + + MG_ASSERT(old_hooks); + MG_ASSERT(old_hooks->alloc); + MG_ASSERT(old_hooks->dalloc); + MG_ASSERT(old_hooks->destroy); + MG_ASSERT(old_hooks->commit); + MG_ASSERT(old_hooks->decommit); + MG_ASSERT(old_hooks->purge_forced); + MG_ASSERT(old_hooks->purge_lazy); + MG_ASSERT(old_hooks->split); + MG_ASSERT(old_hooks->merge); + + err = mallctl(func_name.c_str(), nullptr, nullptr, &old_hooks, sizeof(old_hooks)); + + if (err) { + LOG_FATAL("Error setting default hooks for jemalloc arena {}", i); + } + } + +#endif +} + void PurgeUnusedMemory() { #if USE_JEMALLOC mallctl("arena." STRINGIFY(MALLCTL_ARENAS_ALL) ".purge", nullptr, nullptr, nullptr, 0); diff --git a/src/memory/memory_control.hpp b/src/memory/global_memory_control.hpp similarity index 97% rename from src/memory/memory_control.hpp rename to src/memory/global_memory_control.hpp index 471acf774..dedce0cc9 100644 --- a/src/memory/memory_control.hpp +++ b/src/memory/global_memory_control.hpp @@ -17,5 +17,6 @@ namespace memgraph::memory { void PurgeUnusedMemory(); void SetHooks(); +void UnsetHooks(); } // namespace memgraph::memory diff --git a/src/memory/query_memory_control.cpp b/src/memory/query_memory_control.cpp new file mode 100644 index 000000000..d44ad1a83 --- /dev/null +++ b/src/memory/query_memory_control.cpp @@ -0,0 +1,140 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "query_memory_control.hpp" +#include "utils/exceptions.hpp" +#include "utils/logging.hpp" +#include "utils/memory_tracker.hpp" +#include "utils/rw_spin_lock.hpp" + +#if USE_JEMALLOC +#include "jemalloc/jemalloc.h" +#endif + +namespace memgraph::memory { + +#if USE_JEMALLOC + +unsigned QueriesMemoryControl::GetArenaForThread() { + unsigned thread_arena{0}; + size_t size_thread_arena = sizeof(thread_arena); + int err = mallctl("thread.arena", &thread_arena, &size_thread_arena, nullptr, 0); + if (err) { + LOG_FATAL("Can't get arena for thread."); + } + return thread_arena; +} + +void QueriesMemoryControl::AddTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_add(1); } + +void QueriesMemoryControl::RemoveTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_sub(1); } + +void QueriesMemoryControl::UpdateThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) { + auto accessor = thread_id_to_transaction_id.access(); + accessor.insert({thread_id, transaction_id}); +} + +void QueriesMemoryControl::EraseThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) { + auto accessor = thread_id_to_transaction_id.access(); + auto elem = accessor.find(thread_id); + MG_ASSERT(elem != accessor.end() && elem->transaction_id == transaction_id); + accessor.remove(thread_id); +} + +utils::MemoryTracker *QueriesMemoryControl::GetTrackerCurrentThread() { + auto thread_id_to_transaction_id_accessor = thread_id_to_transaction_id.access(); + + // we might be just constructing mapping between thread id and transaction id + // so we miss this allocation + auto thread_id_to_transaction_id_elem = thread_id_to_transaction_id_accessor.find(std::this_thread::get_id()); + if (thread_id_to_transaction_id_elem == thread_id_to_transaction_id_accessor.end()) { + return nullptr; + } + + auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access(); + auto transaction_id_to_tracker = + transaction_id_to_tracker_accessor.find(thread_id_to_transaction_id_elem->transaction_id); + return &transaction_id_to_tracker->tracker; +} + +void QueriesMemoryControl::CreateTransactionIdTracker(uint64_t transaction_id, size_t inital_limit) { + auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access(); + + auto [elem, result] = transaction_id_to_tracker_accessor.insert({transaction_id, utils::MemoryTracker{}}); + + elem->tracker.SetMaximumHardLimit(inital_limit); + elem->tracker.SetHardLimit(inital_limit); +} + +bool QueriesMemoryControl::CheckTransactionIdTrackerExists(uint64_t transaction_id) { + auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access(); + return transaction_id_to_tracker_accessor.contains(transaction_id); +} + +bool QueriesMemoryControl::EraseTransactionIdTracker(uint64_t transaction_id) { + auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access(); + auto removed = transaction_id_to_tracker.access().remove(transaction_id); + return removed; +} + +bool QueriesMemoryControl::IsArenaTracked(unsigned arena_ind) { + return arena_tracking[arena_ind].load(std::memory_order_acquire) != 0; +} + +void QueriesMemoryControl::InitializeArenaCounter(unsigned arena_ind) { + arena_tracking[arena_ind].store(0, std::memory_order_relaxed); +} + +#endif + +void StartTrackingCurrentThreadTransaction(uint64_t transaction_id) { +#if USE_JEMALLOC + GetQueriesMemoryControl().UpdateThreadToTransactionId(std::this_thread::get_id(), transaction_id); + GetQueriesMemoryControl().AddTrackingOnArena(QueriesMemoryControl::GetArenaForThread()); +#endif +} + +void StopTrackingCurrentThreadTransaction(uint64_t transaction_id) { +#if USE_JEMALLOC + GetQueriesMemoryControl().EraseThreadToTransactionId(std::this_thread::get_id(), transaction_id); + GetQueriesMemoryControl().RemoveTrackingOnArena(QueriesMemoryControl::GetArenaForThread()); +#endif +} + +void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit) { +#if USE_JEMALLOC + if (GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) { + return; + } + GetQueriesMemoryControl().CreateTransactionIdTracker(transaction_id, limit); + +#endif +} + +void TryStopTrackingOnTransaction(uint64_t transaction_id) { +#if USE_JEMALLOC + if (!GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) { + return; + } + GetQueriesMemoryControl().EraseTransactionIdTracker(transaction_id); +#endif +} + +} // namespace memgraph::memory diff --git a/src/memory/query_memory_control.hpp b/src/memory/query_memory_control.hpp new file mode 100644 index 000000000..491dd8c36 --- /dev/null +++ b/src/memory/query_memory_control.hpp @@ -0,0 +1,141 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +#pragma once + +#include +#include +#include +#include + +#include "utils/memory_tracker.hpp" +#include "utils/skip_list.hpp" + +namespace memgraph::memory { + +#if USE_JEMALLOC + +// Track memory allocations per query. +// Multiple threads can allocate inside one transaction. +// If user forgets to unregister tracking for that thread before it dies, it will continue to +// track allocations for that arena indefinitely. +// As multiple queries can be executed inside one transaction, one by one (multi-transaction) +// it is necessary to restart tracking at the beginning of new query for that transaction. +class QueriesMemoryControl { + public: + /* + Arena stats + */ + + static unsigned GetArenaForThread(); + + // Add counter on threads allocating inside arena + void AddTrackingOnArena(unsigned); + + // Remove counter on threads allocating in arena + void RemoveTrackingOnArena(unsigned); + + // Are any threads using current arena for allocations + // Multiple threads can allocate inside one arena + bool IsArenaTracked(unsigned); + + // Initialize arena counter + void InitializeArenaCounter(unsigned); + + /* + Transaction id <-> tracker + */ + + // Create new tracker for transaction_id with initial limit + void CreateTransactionIdTracker(uint64_t, size_t); + + // Check if tracker for given transaction id exists + bool CheckTransactionIdTrackerExists(uint64_t); + + // Remove current tracker for transaction_id + bool EraseTransactionIdTracker(uint64_t); + + /* + Thread handlings + */ + + // Map thread to transaction with given id + // This way we can know which thread belongs to which transaction + // and get correct tracker for given transaction + void UpdateThreadToTransactionId(const std::thread::id &, uint64_t); + + // Remove tracking of thread from transaction. + // Important to reset if one thread gets reused for different transaction + void EraseThreadToTransactionId(const std::thread::id &, uint64_t); + + // C-API functionality for thread to transaction mapping + void UpdateThreadToTransactionId(const char *, uint64_t); + + // C-API functionality for thread to transaction unmapping + void EraseThreadToTransactionId(const char *, uint64_t); + + // Get tracker to current thread if exists, otherwise return + // nullptr. This can happen only if tracker is still + // being constructed. + utils::MemoryTracker *GetTrackerCurrentThread(); + + private: + std::unordered_map> arena_tracking; + + struct ThreadIdToTransactionId { + std::thread::id thread_id; + uint64_t transaction_id; + + bool operator<(const ThreadIdToTransactionId &other) const { return thread_id < other.thread_id; } + bool operator==(const ThreadIdToTransactionId &other) const { return thread_id == other.thread_id; } + + bool operator<(const std::thread::id other) const { return thread_id < other; } + bool operator==(const std::thread::id other) const { return thread_id == other; } + }; + + struct TransactionIdToTracker { + uint64_t transaction_id; + utils::MemoryTracker tracker; + + bool operator<(const TransactionIdToTracker &other) const { return transaction_id < other.transaction_id; } + bool operator==(const TransactionIdToTracker &other) const { return transaction_id == other.transaction_id; } + + bool operator<(uint64_t other) const { return transaction_id < other; } + bool operator==(uint64_t other) const { return transaction_id == other; } + }; + + utils::SkipList thread_id_to_transaction_id; + utils::SkipList transaction_id_to_tracker; +}; + +inline QueriesMemoryControl &GetQueriesMemoryControl() { + static QueriesMemoryControl queries_memory_control_; + return queries_memory_control_; +} + +#endif + +// API function call for to start tracking current thread for given transaction. +// Does nothing if jemalloc is not enabled +void StartTrackingCurrentThreadTransaction(uint64_t transaction_id); + +// API function call for to stop tracking current thread for given transaction. +// Does nothing if jemalloc is not enabled +void StopTrackingCurrentThreadTransaction(uint64_t transaction_id); + +// API function call for try to create tracker for transaction and set it to given limit. +// Does nothing if jemalloc is not enabled. Does nothing if tracker already exists +void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit); + +// API function call to stop tracking for given transaction. +// Does nothing if jemalloc is not enabled. Does nothing if tracker doesn't exist +void TryStopTrackingOnTransaction(uint64_t transaction_id); + +} // namespace memgraph::memory diff --git a/src/query/db_accessor.cpp b/src/query/db_accessor.cpp index cf914b4e3..0250ab695 100644 --- a/src/query/db_accessor.cpp +++ b/src/query/db_accessor.cpp @@ -21,6 +21,10 @@ namespace memgraph::query { SubgraphDbAccessor::SubgraphDbAccessor(query::DbAccessor db_accessor, Graph *graph) : db_accessor_(db_accessor), graph_(graph) {} +void SubgraphDbAccessor::TrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); } + +void SubgraphDbAccessor::UntrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); } + storage::PropertyId SubgraphDbAccessor::NameToProperty(const std::string_view name) { return db_accessor_.NameToProperty(name); } diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index 3725d9848..d6114edaf 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -17,6 +17,7 @@ #include #include +#include "memory/query_memory_control.hpp" #include "query/exceptions.hpp" #include "storage/v2/edge_accessor.hpp" #include "storage/v2/id_types.hpp" @@ -372,6 +373,16 @@ class DbAccessor final { void FinalizeTransaction() { accessor_->FinalizeTransaction(); } + void TrackCurrentThreadAllocations() { + memgraph::memory::StartTrackingCurrentThreadTransaction(*accessor_->GetTransactionId()); + } + + void UntrackCurrentThreadAllocations() { + memgraph::memory::StopTrackingCurrentThreadTransaction(*accessor_->GetTransactionId()); + } + + std::optional GetTransactionId() { return accessor_->GetTransactionId(); } + VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); } VerticesIterable Vertices(storage::View view, storage::LabelId label) { @@ -640,6 +651,14 @@ class SubgraphDbAccessor final { static SubgraphDbAccessor *MakeSubgraphDbAccessor(DbAccessor *db_accessor, Graph *graph); + void TrackThreadAllocations(const char *thread_id); + + void TrackCurrentThreadAllocations(); + + void UntrackThreadAllocations(const char *thread_id); + + void UntrackCurrentThreadAllocations(); + storage::PropertyId NameToProperty(std::string_view name); storage::LabelId NameToLabel(std::string_view name); diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 503613612..a55cba115 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -39,7 +40,8 @@ #include "flags/run_time_configurable.hpp" #include "glue/communication.hpp" #include "license/license.hpp" -#include "memory/memory_control.hpp" +#include "memory/global_memory_control.hpp" +#include "memory/query_memory_control.hpp" #include "query/config.hpp" #include "query/constants.hpp" #include "query/context.hpp" @@ -1283,6 +1285,24 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &par std::optional PullPlan::Pull(AnyStream *stream, std::optional n, const std::vector &output_symbols, std::map *summary) { + std::optional transaction_id = ctx_.db_accessor->GetTransactionId(); + MG_ASSERT(transaction_id.has_value()); + + if (memory_limit_) { + memgraph::memory::TryStartTrackingOnTransaction(*transaction_id, *memory_limit_); + memgraph::memory::StartTrackingCurrentThreadTransaction(*transaction_id); + } + utils::OnScopeExit> reset_query_limit{ + [memory_limit = memory_limit_, transaction_id = *transaction_id]() { + if (memory_limit) { + // Stopping tracking of transaction occurs in interpreter::pull + // Exception can occur so we need to handle that case there. + // We can't stop tracking here as there can be multiple pulls + // so we need to take care of that after everything was pulled + memgraph::memory::StopTrackingCurrentThreadTransaction(transaction_id); + } + }}; + // Set up temporary memory for a single Pull. Initial memory comes from the // stack. 256 KiB should fit on the stack and should be more than enough for a // single `Pull`. @@ -1306,13 +1326,7 @@ std::optional PullPlan::Pull(AnyStream *strea pool_memory.emplace(kMaxBlockPerChunks, 1024, &monotonic_memory, &resource_with_exception); } - std::optional maybe_limited_resource; - if (memory_limit_) { - maybe_limited_resource.emplace(&*pool_memory, *memory_limit_); - ctx_.evaluation_context.memory = &*maybe_limited_resource; - } else { - ctx_.evaluation_context.memory = &*pool_memory; - } + ctx_.evaluation_context.memory = &*pool_memory; // Returns true if a result was pulled. const auto pull_result = [&]() -> bool { return cursor_->Pull(frame_, ctx_); }; @@ -1379,6 +1393,7 @@ std::optional PullPlan::Pull(AnyStream *strea } cursor_->Shutdown(); ctx_.profile_execution_time = execution_time_; + return GetStatsWithTotalTime(ctx_); } diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 970f43961..66231059d 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -16,6 +16,7 @@ #include #include "dbms/database.hpp" +#include "memory/query_memory_control.hpp" #include "query/auth_checker.hpp" #include "query/auth_query_handler.hpp" #include "query/config.hpp" @@ -402,6 +403,9 @@ std::map Interpreter::Pull(TStream *result_stream, std: // If the query finished executing, we have received a value which tells // us what to do after. if (maybe_res) { + if (current_transaction_) { + memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_); + } // Save its summary maybe_summary.emplace(std::move(query_execution->summary)); if (!query_execution->notifications.empty()) { @@ -440,9 +444,15 @@ std::map Interpreter::Pull(TStream *result_stream, std: } } } catch (const ExplicitTransactionUsageException &) { + if (current_transaction_) { + memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_); + } query_execution.reset(nullptr); throw; } catch (const utils::BasicException &) { + if (current_transaction_) { + memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_); + } // Trigger first failed query metrics::FirstFailedQuery(); memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery); diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 9841f318a..2a657aeb3 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -3596,3 +3596,15 @@ mgp_error mgp_log(const mgp_log_level log_level, const char *output) { throw std::invalid_argument{fmt::format("Invalid log level: {}", log_level)}; }); } + +mgp_error mgp_track_current_thread_allocations(mgp_graph *graph) { + return WrapExceptions([&]() { + std::visit([](auto *db_accessor) -> void { db_accessor->TrackCurrentThreadAllocations(); }, graph->impl); + }); +} + +mgp_error mgp_untrack_current_thread_allocations(mgp_graph *graph) { + return WrapExceptions([&]() { + std::visit([](auto *db_accessor) -> void { db_accessor->UntrackCurrentThreadAllocations(); }, graph->impl); + }); +} diff --git a/src/utils/memory_tracker.cpp b/src/utils/memory_tracker.cpp index fa9910e5e..774faf72c 100644 --- a/src/utils/memory_tracker.cpp +++ b/src/utils/memory_tracker.cpp @@ -89,6 +89,13 @@ void MemoryTracker::TryRaiseHardLimit(const int64_t limit) { ; } +void MemoryTracker::ResetTrackings() { + hard_limit_.store(0, std::memory_order_relaxed); + peak_.store(0, std::memory_order_relaxed); + amount_.store(0, std::memory_order_relaxed); + maximum_hard_limit_ = 0; +} + void MemoryTracker::SetMaximumHardLimit(const int64_t limit) { if (maximum_hard_limit_ < 0) { spdlog::warn("Invalid maximum hard limit."); diff --git a/src/utils/memory_tracker.hpp b/src/utils/memory_tracker.hpp index 3502368fc..0385d4517 100644 --- a/src/utils/memory_tracker.hpp +++ b/src/utils/memory_tracker.hpp @@ -12,6 +12,7 @@ #pragma once #include +#include #include "utils/exceptions.hpp" @@ -41,9 +42,20 @@ class MemoryTracker final { MemoryTracker() = default; ~MemoryTracker() = default; + MemoryTracker(MemoryTracker &&other) noexcept + : amount_(other.amount_.load(std::memory_order_acquire)), + peak_(other.peak_.load(std::memory_order_acquire)), + hard_limit_(other.hard_limit_.load(std::memory_order_acquire)), + maximum_hard_limit_(other.maximum_hard_limit_) { + other.maximum_hard_limit_ = 0; + other.amount_.store(0, std::memory_order_acquire); + other.peak_.store(0, std::memory_order_acquire); + other.hard_limit_.store(0, std::memory_order_acquire); + } + MemoryTracker(const MemoryTracker &) = delete; MemoryTracker &operator=(const MemoryTracker &) = delete; - MemoryTracker(MemoryTracker &&) = delete; + MemoryTracker &operator=(MemoryTracker &&) = delete; void Alloc(int64_t size); @@ -59,6 +71,8 @@ class MemoryTracker final { void TryRaiseHardLimit(int64_t limit); void SetMaximumHardLimit(int64_t limit); + void ResetTrackings(); + // By creating an object of this class, every allocation in its scope that goes over // the set hard limit produces an OutOfMemoryException. class OutOfMemoryExceptionEnabler final { diff --git a/tests/e2e/memgraph.py b/tests/e2e/memgraph.py index 2b80c2f62..a65bed2ed 100755 --- a/tests/e2e/memgraph.py +++ b/tests/e2e/memgraph.py @@ -21,6 +21,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) BUILD_DIR = os.path.join(PROJECT_DIR, "build") MEMGRAPH_BINARY = os.path.join(BUILD_DIR, "memgraph") +SIGNAL_SIGTERM = 15 def wait_for_server(port, delay=0.01): @@ -133,7 +134,7 @@ class MemgraphInstanceRunner: pid = self.proc_mg.pid try: - os.kill(pid, 15) # 15 is the signal number for SIGTERM + os.kill(pid, SIGNAL_SIGTERM) except os.OSError: assert False diff --git a/tests/e2e/memory/CMakeLists.txt b/tests/e2e/memory/CMakeLists.txt index c127d0d55..d06885483 100644 --- a/tests/e2e/memory/CMakeLists.txt +++ b/tests/e2e/memory/CMakeLists.txt @@ -11,10 +11,21 @@ target_link_libraries(memgraph__e2e__memory__limit_global_alloc gflags mgclient add_executable(memgraph__e2e__memory__limit_global_alloc_proc memory_limit_global_alloc_proc.cpp) target_link_libraries(memgraph__e2e__memory__limit_global_alloc_proc gflags mgclient mg-utils mg-io Threads::Threads) +add_executable(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread query_memory_limit_proc_multi_thread.cpp) +target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread gflags mgclient mg-utils mg-io Threads::Threads) + +add_executable(memgraph__e2e__memory__limit_query_alloc_create query_memory_limit_create.cpp) +target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create gflags mgclient mg-utils mg-io) + +add_executable(memgraph__e2e__memory__limit_query_alloc_proc query_memory_limit_proc.cpp) +target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc gflags mgclient mg-utils mg-io) + +add_executable(memgraph__e2e__memory__limit_query_alloc_create_multi_thread query_memory_limit_multi_thread.cpp) +target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create_multi_thread gflags mgclient mg-utils mg-io Threads::Threads) + add_executable(memgraph__e2e__memory__limit_delete memory_limit_delete.cpp) target_link_libraries(memgraph__e2e__memory__limit_delete gflags mgclient mg-utils mg-io) - add_executable(memgraph__e2e__memory__limit_accumulation memory_limit_accumulation.cpp) target_link_libraries(memgraph__e2e__memory__limit_accumulation gflags mgclient mg-utils mg-io) diff --git a/tests/e2e/memory/procedures/CMakeLists.txt b/tests/e2e/memory/procedures/CMakeLists.txt index 21201e59b..4ea9db247 100644 --- a/tests/e2e/memory/procedures/CMakeLists.txt +++ b/tests/e2e/memory/procedures/CMakeLists.txt @@ -3,3 +3,13 @@ target_include_directories(global_memory_limit PRIVATE ${CMAKE_SOURCE_DIR}/inclu add_library(global_memory_limit_proc SHARED global_memory_limit_proc.c) target_include_directories(global_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include) + + +add_library(query_memory_limit_proc_multi_thread SHARED query_memory_limit_proc_multi_thread.cpp) +target_include_directories(query_memory_limit_proc_multi_thread PRIVATE ${CMAKE_SOURCE_DIR}/include) +target_link_libraries(query_memory_limit_proc_multi_thread mg-utils) + + +add_library(query_memory_limit_proc SHARED query_memory_limit_proc.cpp) +target_include_directories(query_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include) +target_link_libraries(query_memory_limit_proc mg-utils) diff --git a/tests/e2e/memory/procedures/query_memory_limit_proc.cpp b/tests/e2e/memory/procedures/query_memory_limit_proc.cpp new file mode 100644 index 000000000..767f5c480 --- /dev/null +++ b/tests/e2e/memory/procedures/query_memory_limit_proc.cpp @@ -0,0 +1,70 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mg_procedure.h" +#include "utils/on_scope_exit.hpp" + +enum mgp_error Alloc(void *ptr) { + const size_t mb_size_268 = 1 << 28; + + return mgp_global_alloc(mb_size_268, (void **)(&ptr)); +} + +void Regular(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) { + mgp::MemoryDispatcherGuard guard{memory}; + const auto arguments = mgp::List(args); + const auto record_factory = mgp::RecordFactory(result); + + try { + void *ptr{nullptr}; + + memgraph::utils::OnScopeExit> cleanup{[&ptr]() { + if (nullptr == ptr) { + return; + } + mgp_global_free(ptr); + }}; + + const enum mgp_error alloc_err = Alloc(ptr); + auto new_record = record_factory.NewRecord(); + new_record.Insert("allocated", alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE); + } catch (std::exception &e) { + record_factory.SetErrorMessage(e.what()); + } +} + +extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) { + try { + mgp::MemoryDispatcherGuard mdg{memory}; + + AddProcedure(Regular, std::string("regular").c_str(), mgp::ProcedureType::Read, {}, + {mgp::Return(std::string("allocated").c_str(), mgp::Type::Bool)}, module, memory); + + } catch (const std::exception &e) { + return 1; + } + + return 0; +} + +extern "C" int mgp_shutdown_module() { return 0; } diff --git a/tests/e2e/memory/procedures/query_memory_limit_proc_multi_thread.cpp b/tests/e2e/memory/procedures/query_memory_limit_proc_multi_thread.cpp new file mode 100644 index 000000000..ffc509ff3 --- /dev/null +++ b/tests/e2e/memory/procedures/query_memory_limit_proc_multi_thread.cpp @@ -0,0 +1,101 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mg_procedure.h" +#include "utils/on_scope_exit.hpp" + +enum mgp_error Alloc(void *ptr) { + const size_t mb_size_268 = 1 << 28; + + return mgp_global_alloc(mb_size_268, (void **)(&ptr)); +} + +// change communication between threads with feature and promise +std::atomic num_allocations{0}; +std::vector ptrs_; + +void AllocFunc(mgp_graph *graph) { + [[maybe_unused]] const enum mgp_error tracking_error = mgp_track_current_thread_allocations(graph); + void *ptr = nullptr; + + ptrs_.emplace_back(ptr); + try { + enum mgp_error alloc_err { mgp_error::MGP_ERROR_NO_ERROR }; + alloc_err = Alloc(ptr); + if (alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + num_allocations.fetch_add(1, std::memory_order_relaxed); + } + if (alloc_err != mgp_error::MGP_ERROR_NO_ERROR) { + assert(false); + } + } catch (const std::exception &e) { + assert(false); + } + + [[maybe_unused]] const enum mgp_error untracking_error = mgp_untrack_current_thread_allocations(graph); +} + +void DualThread(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) { + mgp::MemoryDispatcherGuard guard{memory}; + const auto arguments = mgp::List(args); + const auto record_factory = mgp::RecordFactory(result); + num_allocations.store(0, std::memory_order_relaxed); + try { + std::vector threads; + + for (int i = 0; i < 2; i++) { + threads.emplace_back(AllocFunc, memgraph_graph); + } + + for (int i = 0; i < 2; i++) { + threads[i].join(); + } + for (void *ptr : ptrs_) { + if (ptr != nullptr) { + mgp_global_free(ptr); + } + } + + auto new_record = record_factory.NewRecord(); + + new_record.Insert("allocated_all", num_allocations.load(std::memory_order_relaxed) == 2); + } catch (std::exception &e) { + record_factory.SetErrorMessage(e.what()); + } +} + +extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) { + try { + mgp::memory = memory; + + AddProcedure(DualThread, std::string("dual_thread").c_str(), mgp::ProcedureType::Read, {}, + {mgp::Return(std::string("allocated_all").c_str(), mgp::Type::Bool)}, module, memory); + + } catch (const std::exception &e) { + return 1; + } + + return 0; +} + +extern "C" int mgp_shutdown_module() { return 0; } diff --git a/tests/e2e/memory/query_memory_limit_create.cpp b/tests/e2e/memory/query_memory_limit_create.cpp new file mode 100644 index 000000000..f9bcae133 --- /dev/null +++ b/tests/e2e/memory/query_memory_limit_create.cpp @@ -0,0 +1,65 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include + +#include "utils/logging.hpp" +#include "utils/timer.hpp" + +DEFINE_uint64(bolt_port, 7687, "Bolt port"); +DEFINE_bool(multi_db, false, "Run test in multi db environment"); + +int main(int argc, char **argv) { + google::SetUsageMessage("Memgraph E2E Memory Control"); + gflags::ParseCommandLineFlags(&argc, &argv, true); + memgraph::logging::RedirectToStderr(); + + mg::Client::Init(); + + auto client = + mg::Client::Connect({.host = "127.0.0.1", .port = static_cast(FLAGS_bolt_port), .use_ssl = false}); + if (!client) { + LOG_FATAL("Failed to connect!"); + } + + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + + if (FLAGS_multi_db) { + client->Execute("CREATE DATABASE clean;"); + client->DiscardAll(); + client->Execute("USE DATABASE clean;"); + client->DiscardAll(); + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + } + + const auto *create_query = + "UNWIND range(1, 50000) as u CREATE (n {string: 'Some longer string'}) RETURN n QUERY MEMORY LIMIT 30MB;"; + + try { + client->Execute(create_query); + [[maybe_unused]] auto results = client->FetchAll(); + if (results->empty()) { + assert(true); + return 0; + } + } catch (const mg::TransientException & /*unused*/) { + spdlog::info("Memgraph is out of memory"); + assert(true); + return 0; + } + MG_ASSERT(false, "Query should have failed!"); + + return 0; +} diff --git a/tests/e2e/memory/query_memory_limit_multi_thread.cpp b/tests/e2e/memory/query_memory_limit_multi_thread.cpp new file mode 100644 index 000000000..81f0bd5d8 --- /dev/null +++ b/tests/e2e/memory/query_memory_limit_multi_thread.cpp @@ -0,0 +1,100 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include + +#include +#include +#include +#include + +#include "utils/logging.hpp" +#include "utils/timer.hpp" + +DEFINE_uint64(bolt_port, 7687, "Bolt port"); +DEFINE_bool(multi_db, false, "Run test in multi db environment"); + +static constexpr int kNumberClients{2}; + +void Func(std::promise promise) { + auto client = + mg::Client::Connect({.host = "127.0.0.1", .port = static_cast(FLAGS_bolt_port), .use_ssl = false}); + if (!client) { + LOG_FATAL("Failed to connect!"); + } + const auto *create_query = + "FOREACH(i in range(1, 300000) | CREATE (n: Node {string: 'Some longer string'})) QUERY MEMORY LIMIT 100MB;"; + bool err{false}; + try { + client->Execute(create_query); + [[maybe_unused]] auto results = client->FetchAll(); + if (results->empty()) { + err = true; + } + } catch (const std::exception &e) { + spdlog::info("Good: Exception occured", e.what()); + err = true; + } + promise.set_value_at_thread_exit(err); +} + +int main(int argc, char **argv) { + google::SetUsageMessage("Memgraph E2E Memory Control"); + gflags::ParseCommandLineFlags(&argc, &argv, true); + memgraph::logging::RedirectToStderr(); + + mg::Client::Init(); + + { + auto client = + mg::Client::Connect({.host = "127.0.0.1", .port = static_cast(FLAGS_bolt_port), .use_ssl = false}); + if (!client) { + LOG_FATAL("Failed to connect!"); + } + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + + if (FLAGS_multi_db) { + client->Execute("CREATE DATABASE clean;"); + client->DiscardAll(); + client->Execute("USE DATABASE clean;"); + client->DiscardAll(); + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + } + } + + std::vector> my_promises; + std::vector> my_futures; + for (int i = 0; i < 4; i++) { + my_promises.push_back(std::promise()); + my_futures.emplace_back(my_promises.back().get_future()); + } + + std::vector my_threads; + + for (int i = 0; i < kNumberClients; i++) { + my_threads.emplace_back(Func, std::move(my_promises[i])); + } + + for (int i = 0; i < kNumberClients; i++) { + auto value = my_futures[i].get(); + MG_ASSERT(value, "Error should have happend in thread"); + } + + for (int i = 0; i < kNumberClients; i++) { + my_threads[i].join(); + } + + return 0; +} diff --git a/tests/e2e/memory/query_memory_limit_proc.cpp b/tests/e2e/memory/query_memory_limit_proc.cpp new file mode 100644 index 000000000..39baf4d8e --- /dev/null +++ b/tests/e2e/memory/query_memory_limit_proc.cpp @@ -0,0 +1,66 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include +#include +#include + +#include "utils/logging.hpp" +#include "utils/timer.hpp" + +DEFINE_uint64(bolt_port, 7687, "Bolt port"); +DEFINE_uint64(timeout, 120, "Timeout seconds"); +DEFINE_bool(multi_db, false, "Run test in multi db environment"); + +int main(int argc, char **argv) { + google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators"); + gflags::ParseCommandLineFlags(&argc, &argv, true); + memgraph::logging::RedirectToStderr(); + + mg::Client::Init(); + + auto client = + mg::Client::Connect({.host = "127.0.0.1", .port = static_cast(FLAGS_bolt_port), .use_ssl = false}); + if (!client) { + LOG_FATAL("Failed to connect!"); + } + + if (FLAGS_multi_db) { + client->Execute("CREATE DATABASE clean;"); + client->DiscardAll(); + client->Execute("USE DATABASE clean;"); + client->DiscardAll(); + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + } + + MG_ASSERT( + client->Execute("CALL libquery_memory_limit_proc.regular() YIELD allocated RETURN " + "allocated QUERY MEMORY LIMIT 250MB")); + bool error{false}; + try { + auto result_rows = client->FetchAll(); + if (result_rows) { + auto row = *result_rows->begin(); + error = row[0].ValueBool() == false; + } + + } catch (const std::exception &e) { + error = true; + } + + MG_ASSERT(error, "Error should have happend"); + + return 0; +} diff --git a/tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp b/tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp new file mode 100644 index 000000000..5a5ec94f0 --- /dev/null +++ b/tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp @@ -0,0 +1,66 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include +#include +#include + +#include "utils/logging.hpp" +#include "utils/timer.hpp" + +DEFINE_uint64(bolt_port, 7687, "Bolt port"); +DEFINE_uint64(timeout, 120, "Timeout seconds"); +DEFINE_bool(multi_db, false, "Run test in multi db environment"); + +int main(int argc, char **argv) { + google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators"); + gflags::ParseCommandLineFlags(&argc, &argv, true); + memgraph::logging::RedirectToStderr(); + + mg::Client::Init(); + + auto client = + mg::Client::Connect({.host = "127.0.0.1", .port = static_cast(FLAGS_bolt_port), .use_ssl = false}); + if (!client) { + LOG_FATAL("Failed to connect!"); + } + + if (FLAGS_multi_db) { + client->Execute("CREATE DATABASE clean;"); + client->DiscardAll(); + client->Execute("USE DATABASE clean;"); + client->DiscardAll(); + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + } + + MG_ASSERT( + client->Execute("CALL libquery_memory_limit_proc_multi_thread.dual_thread() YIELD allocated_all RETURN " + "allocated_all QUERY MEMORY LIMIT 500MB")); + bool error{false}; + try { + auto result_rows = client->FetchAll(); + if (result_rows) { + auto row = *result_rows->begin(); + error = row[0].ValueBool() == false; + } + + } catch (const std::exception &e) { + error = true; + } + + MG_ASSERT(error, "Error should have happend"); + + return 0; +} diff --git a/tests/e2e/memory/workloads.yaml b/tests/e2e/memory/workloads.yaml index aa0079bee..a23527c85 100644 --- a/tests/e2e/memory/workloads.yaml +++ b/tests/e2e/memory/workloads.yaml @@ -23,6 +23,20 @@ disk_cluster: &disk_cluster - "STORAGE MODE ON_DISK_TRANSACTIONAL" validation_queries: [] +args_query_limit: &args_query_limit + - "--bolt-port" + - *bolt_port + - "--storage-gc-cycle-sec=180" + - "--log-level=TRACE" + +in_memory_query_limit_cluster: &in_memory_query_limit_cluster + cluster: + main: + args: *args_query_limit + log_file: "memory-e2e.log" + setup_queries: [] + validation_queries: [] + args_450_MiB_limit: &args_450_MiB_limit - "--bolt-port" - *bolt_port @@ -95,6 +109,27 @@ workloads: proc: "tests/e2e/memory/procedures/" <<: *disk_cluster + - name: "Memory control query limit proc" + binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc" + proc: "tests/e2e/memory/procedures/" + args: ["--bolt-port", *bolt_port] + <<: *in_memory_query_limit_cluster + + - name: "Memory control query limit proc multi thread" + binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc_multi_thread" + args: ["--bolt-port", *bolt_port, "--timeout", "180"] + proc: "tests/e2e/memory/procedures/" + <<: *in_memory_query_limit_cluster + + - name: "Memory control query limit create" + binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create" + args: ["--bolt-port", *bolt_port] + <<: *in_memory_query_limit_cluster + + - name: "Memory control query limit create multi thread" + binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create_multi_thread" + args: ["--bolt-port", *bolt_port] + <<: *in_memory_query_limit_cluster - name: "Memory control for detach delete" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete" args: ["--bolt-port", *bolt_port] diff --git a/tests/integration/audit/runner.py b/tests/integration/audit/runner.py index 466c91d9a..01f9f53ec 100755 --- a/tests/integration/audit/runner.py +++ b/tests/integration/audit/runner.py @@ -24,6 +24,7 @@ import time DEFAULT_DB = "memgraph" SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) +SIGNAL_SIGTERM = 15 QUERIES = [ ("MATCH (n) DELETE n", {}), @@ -92,9 +93,13 @@ def execute_test(memgraph_binary, tester_binary): # Register cleanup function @atexit.register def cleanup(): - if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + + time.sleep(1) def execute_queries(queries): for db, query, params in queries: @@ -122,10 +127,12 @@ def execute_test(memgraph_binary, tester_binary): execute_queries(mt_queries3) print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n") - # Shutdown the memgraph binary - memgraph.terminate() - - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) # Verify the written log print("\033[1;36m~~ Starting log verification ~~\033[0m") diff --git a/tests/integration/auth/runner.py b/tests/integration/auth/runner.py index 9c4ab8ca7..a74b19a4e 100755 --- a/tests/integration/auth/runner.py +++ b/tests/integration/auth/runner.py @@ -21,6 +21,7 @@ import time SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) +SIGNAL_SIGTERM = 15 # When you create a new permission just add a testcase to this list (a tuple # of query, touple of required permissions) and the test will automatically @@ -166,8 +167,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): @atexit.register def cleanup(): if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) # Prepare the multi database environment execute_admin_queries( @@ -327,8 +332,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n") # Shutdown the memgraph binary - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) if __name__ == "__main__": diff --git a/tests/integration/durability/runner.py b/tests/integration/durability/runner.py index dd8c41456..8a62f7e3b 100755 --- a/tests/integration/durability/runner.py +++ b/tests/integration/durability/runner.py @@ -20,7 +20,6 @@ import sys import tempfile import time - SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) TESTS_DIR = os.path.join(SCRIPT_DIR, "tests") @@ -31,6 +30,8 @@ WAL_FILE_NAME = "wal.bin" DUMP_SNAPSHOT_FILE_NAME = "expected_snapshot.cypher" DUMP_WAL_FILE_NAME = "expected_wal.cypher" +SIGNAL_SIGTERM = 15 + def wait_for_server(port, delay=0.1): cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)] @@ -40,7 +41,7 @@ def wait_for_server(port, delay=0.1): def sorted_content(file_path): - with open(file_path, 'r') as fin: + with open(file_path, "r") as fin: return sorted(list(map(lambda x: x.strip(), fin.readlines()))) @@ -52,32 +53,27 @@ def list_to_string(data): return ret -def execute_test( - memgraph_binary, - dump_binary, - test_directory, - test_type, - write_expected): - assert test_type in ["SNAPSHOT", "WAL"], \ - "Test type should be either 'SNAPSHOT' or 'WAL'." - print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m" - .format(os.path.relpath(test_directory, TESTS_DIR), test_type)) +def execute_test(memgraph_binary, dump_binary, test_directory, test_type, write_expected): + assert test_type in ["SNAPSHOT", "WAL"], "Test type should be either 'SNAPSHOT' or 'WAL'." + print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m".format(os.path.relpath(test_directory, TESTS_DIR), test_type)) working_data_directory = tempfile.TemporaryDirectory() if test_type == "SNAPSHOT": snapshots_dir = os.path.join(working_data_directory.name, "snapshots") os.makedirs(snapshots_dir) - shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), - snapshots_dir) + shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), snapshots_dir) else: wal_dir = os.path.join(working_data_directory.name, "wal") os.makedirs(wal_dir) shutil.copy(os.path.join(test_directory, WAL_FILE_NAME), wal_dir) - memgraph_args = [memgraph_binary, - "--storage-recover-on-startup", - "--storage-properties-on-edges", - "--data-directory", working_data_directory.name] + memgraph_args = [ + memgraph_binary, + "--storage-recover-on-startup", + "--storage-properties-on-edges", + "--data-directory", + working_data_directory.name, + ] # Start the memgraph binary memgraph = subprocess.Popen(memgraph_args) @@ -89,8 +85,12 @@ def execute_test( @atexit.register def cleanup(): if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) # Execute `database dump` dump_output_file = tempfile.NamedTemporaryFile() @@ -98,28 +98,31 @@ def execute_test( subprocess.run(dump_args, stdout=dump_output_file, check=True) # Shutdown the memgraph binary - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) dump_file_name = DUMP_SNAPSHOT_FILE_NAME if test_type == "SNAPSHOT" else DUMP_WAL_FILE_NAME if write_expected: - with open(dump_output_file.name, 'r') as dump: + with open(dump_output_file.name, "r") as dump: queries_got = dump.readlines() # Write dump files expected_dump_file = os.path.join(test_directory, dump_file_name) - with open(expected_dump_file, 'w') as expected: + with open(expected_dump_file, "w") as expected: expected.writelines(queries_got) else: # Compare dump files expected_dump_file = os.path.join(test_directory, dump_file_name) - assert os.path.exists(expected_dump_file), \ - "Could not find expected dump path {}".format(expected_dump_file) + assert os.path.exists(expected_dump_file), "Could not find expected dump path {}".format(expected_dump_file) queries_got = sorted_content(dump_output_file.name) queries_expected = sorted_content(expected_dump_file) - assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \ - "{}".format(list_to_string(queries_got), - list_to_string(queries_expected)) + assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format( + list_to_string(queries_got), list_to_string(queries_expected) + ) print("\033[1;32m~~ Test successful ~~\033[0m\n") @@ -141,15 +144,17 @@ def find_test_directories(directory): continue snapshot_file = os.path.join(test_dir_path, SNAPSHOT_FILE_NAME) wal_file = os.path.join(test_dir_path, WAL_FILE_NAME) - dump_snapshot_file = os.path.join( - test_dir_path, DUMP_SNAPSHOT_FILE_NAME) + dump_snapshot_file = os.path.join(test_dir_path, DUMP_SNAPSHOT_FILE_NAME) dump_wal_file = os.path.join(test_dir_path, DUMP_WAL_FILE_NAME) - if (os.path.isfile(snapshot_file) and os.path.isfile(dump_snapshot_file) - and os.path.isfile(wal_file) and os.path.isfile(dump_wal_file)): + if ( + os.path.isfile(snapshot_file) + and os.path.isfile(dump_snapshot_file) + and os.path.isfile(wal_file) + and os.path.isfile(dump_wal_file) + ): test_dirs.append(test_dir_path) else: - raise Exception("Missing data in test directory '{}'" - .format(test_dir_path)) + raise Exception("Missing data in test directory '{}'".format(test_dir_path)) return test_dirs @@ -161,26 +166,15 @@ if __name__ == "__main__": parser.add_argument("--memgraph", default=memgraph_binary) parser.add_argument("--dump", default=dump_binary) parser.add_argument( - '--write-expected', - action='store_true', - help='Overwrite the expected cypher with results from current run') + "--write-expected", action="store_true", help="Overwrite the expected cypher with results from current run" + ) args = parser.parse_args() test_directories = find_test_directories(TESTS_DIR) assert len(test_directories) > 0, "No tests have been found!" for test_directory in test_directories: - execute_test( - args.memgraph, - args.dump, - test_directory, - "SNAPSHOT", - args.write_expected) - execute_test( - args.memgraph, - args.dump, - test_directory, - "WAL", - args.write_expected) + execute_test(args.memgraph, args.dump, test_directory, "SNAPSHOT", args.write_expected) + execute_test(args.memgraph, args.dump, test_directory, "WAL", args.write_expected) sys.exit(0) diff --git a/tests/integration/env_variable_check/runner.py b/tests/integration/env_variable_check/runner.py index 64cc18365..acd26709e 100644 --- a/tests/integration/env_variable_check/runner.py +++ b/tests/integration/env_variable_check/runner.py @@ -22,6 +22,7 @@ from typing import List SCRIPT_DIR = Path(__file__).absolute() PROJECT_DIR = SCRIPT_DIR.parents[3] +SIGNAL_SIGTERM = 15 def wait_for_server(port, delay=0.1): @@ -68,8 +69,12 @@ def execute_with_user(queries): def cleanup(memgraph): if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True): diff --git a/tests/integration/fine_grained_access/runner.py b/tests/integration/fine_grained_access/runner.py index 6f284aa39..590cf1409 100644 --- a/tests/integration/fine_grained_access/runner.py +++ b/tests/integration/fine_grained_access/runner.py @@ -24,6 +24,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\." +SIGNAL_SIGTERM = 15 def wait_for_server(port, delay=0.1): @@ -80,8 +81,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str @atexit.register def cleanup(): if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) # Prepare all users def setup_user(): @@ -130,8 +135,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n") # Shutdown the memgraph binary - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) if __name__ == "__main__": diff --git a/tests/integration/flag_check/runner.py b/tests/integration/flag_check/runner.py index 5306f1b80..1a3999871 100644 --- a/tests/integration/flag_check/runner.py +++ b/tests/integration/flag_check/runner.py @@ -21,6 +21,7 @@ from typing import List SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) +SIGNAL_SIGTERM = 15 def wait_for_server(port: int, delay: float = 0.1) -> float: @@ -86,8 +87,12 @@ def execute_without_user( def cleanup(memgraph: subprocess): if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False + time.sleep(1) def test_without_any_files(tester_binary: str, memgraph_args: List[str]): diff --git a/tests/integration/ldap/runner.py b/tests/integration/ldap/runner.py index 6b0446690..8fc3af913 100755 --- a/tests/integration/ldap/runner.py +++ b/tests/integration/ldap/runner.py @@ -23,6 +23,8 @@ import time SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) +SIGNAL_SIGTERM = 15 + CONFIG_TEMPLATE = """ server: host: "127.0.0.1" @@ -52,8 +54,7 @@ def wait_for_server(port, delay=0.1): time.sleep(delay) -def execute_tester(binary, queries, username="", password="", - auth_should_fail=False, query_should_fail=False): +def execute_tester(binary, queries, username="", password="", auth_should_fail=False, query_should_fail=False): if password == "": password = username args = [binary, "--username", username, "--password", password] @@ -76,18 +77,14 @@ class Memgraph: def start(self, **kwargs): self.stop() self._storage_directory = tempfile.TemporaryDirectory() - self._auth_module = os.path.join(self._storage_directory.name, - "ldap.py") - self._auth_config = os.path.join(self._storage_directory.name, - "ldap.yaml") - script_file = os.path.join(PROJECT_DIR, "src", "auth", - "reference_modules", "ldap.py") + self._auth_module = os.path.join(self._storage_directory.name, "ldap.py") + self._auth_config = os.path.join(self._storage_directory.name, "ldap.yaml") + script_file = os.path.join(PROJECT_DIR, "src", "auth", "reference_modules", "ldap.py") virtualenv_bin = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3") with open(script_file) as fin: data = fin.read() data = data.replace("/usr/bin/python3", virtualenv_bin) - data = data.replace("/etc/memgraph/auth/ldap.yaml", - self._auth_config) + data = data.replace("/etc/memgraph/auth/ldap.yaml", self._auth_config) with open(self._auth_module, "w") as fout: fout.write(data) os.chmod(self._auth_module, stat.S_IRWXU | stat.S_IRWXG) @@ -106,10 +103,13 @@ class Memgraph: } with open(self._auth_config, "w") as f: f.write(CONFIG_TEMPLATE.format(**config)) - args = [self._binary, - "--data-directory", self._storage_directory.name, - "--auth-module-executable", - kwargs.pop("module_executable", self._auth_module)] + args = [ + self._binary, + "--data-directory", + self._storage_directory.name, + "--auth-module-executable", + kwargs.pop("module_executable", self._auth_module), + ] for key, value in kwargs.items(): ldap_key = "--auth-module-" + key.replace("_", "-") if isinstance(value, bool): @@ -119,26 +119,27 @@ class Memgraph: args.append(value) self._process = subprocess.Popen(args) time.sleep(0.1) - assert self._process.poll() is None, "Memgraph process died " \ - "prematurely!" + assert self._process.poll() is None, "Memgraph process died " "prematurely!" wait_for_server(7687) def stop(self, check=True): if self._process is None: return 0 - self._process.terminate() - exitcode = self._process.wait() - self._process = None - if check: - assert exitcode == 0, "Memgraph process didn't exit cleanly!" - return exitcode + pid = self._process.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + if check: + assert False + return -1 + time.sleep(1) + return 0 def initialize_test(memgraph, tester_binary, **kwargs): memgraph.start(module_executable="") - execute_tester(tester_binary, - ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"]) + execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"]) check_login = kwargs.pop("check_login", True) memgraph.restart(**kwargs) if check_login: @@ -170,18 +171,15 @@ def test_role_mapping(memgraph, tester_binary): initialize_test(memgraph, tester_binary) execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, [], "bob") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True) execute_tester(tester_binary, [], "carol") - execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", - query_should_fail=True) + execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True) execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root") execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol") execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave") @@ -192,15 +190,13 @@ def test_role_mapping(memgraph, tester_binary): def test_role_removal(memgraph, tester_binary): initialize_test(memgraph, tester_binary) execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.restart(manage_roles=False) execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.stop() @@ -229,28 +225,22 @@ def test_user_is_role(memgraph, tester_binary): def test_user_permissions_persistancy(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, - ["CREATE USER alice", "GRANT MATCH TO alice"], "root") + execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() def test_role_permissions_persistancy(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() def test_only_authentication(memgraph, tester_binary): initialize_test(memgraph, tester_binary, manage_roles=False) - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.stop() @@ -267,22 +257,16 @@ def test_wrong_suffix(memgraph, tester_binary): def test_suffix_with_spaces(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, - suffix=", ou= people, dc = memgraph, dc = com") - execute_tester(tester_binary, - ["CREATE USER alice", "GRANT MATCH TO alice"], "root") + initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com") + execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() def test_role_mapping_wrong_root_dn(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, - root_dn="ou=invalid,dc=memgraph,dc=com") - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com") + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.restart() execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() @@ -290,11 +274,8 @@ def test_role_mapping_wrong_root_dn(memgraph, tester_binary): def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): initialize_test(memgraph, tester_binary, root_objectclass="person") - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.restart() execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() @@ -302,11 +283,8 @@ def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): initialize_test(memgraph, tester_binary, user_attribute="cn") - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.restart() execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() @@ -314,8 +292,7 @@ def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): def test_wrong_password(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "root", password="sudo", - auth_should_fail=True) + execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") memgraph.stop() @@ -326,12 +303,10 @@ def test_password_persistancy(memgraph, tester_binary): execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo") execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") memgraph.restart() - execute_tester(tester_binary, [], "root", password="sudo", - auth_should_fail=True) + execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") memgraph.restart(module_executable="") - execute_tester(tester_binary, [], "root", password="sudo", - auth_should_fail=True) + execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") memgraph.stop() @@ -339,33 +314,25 @@ def test_password_persistancy(memgraph, tester_binary): def test_user_multiple_roles(memgraph, tester_binary): initialize_test(memgraph, tester_binary, check_login=False) memgraph.restart() - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", - query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) memgraph.restart(manage_roles=False) - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", - query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) memgraph.restart(manage_roles=False, root_dn="") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", - query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) memgraph.stop() def test_starttls_failure(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, encryption="starttls", - check_login=False) + initialize_test(memgraph, tester_binary, encryption="starttls", check_login=False) execute_tester(tester_binary, [], "root", auth_should_fail=True) memgraph.stop() def test_ssl_failure(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, encryption="ssl", - check_login=False) + initialize_test(memgraph, tester_binary, encryption="ssl", check_login=False) execute_tester(tester_binary, [], "root", auth_should_fail=True) memgraph.stop() @@ -375,22 +342,19 @@ def test_ssl_failure(memgraph, tester_binary): if __name__ == "__main__": memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") - tester_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "ldap", "tester") + tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "ldap", "tester") parser = argparse.ArgumentParser() parser.add_argument("--memgraph", default=memgraph_binary) parser.add_argument("--tester", default=tester_binary) - parser.add_argument("--openldap-dir", - default=os.path.join(SCRIPT_DIR, "openldap-2.4.47")) + parser.add_argument("--openldap-dir", default=os.path.join(SCRIPT_DIR, "openldap-2.4.47")) args = parser.parse_args() # Setup Memgraph handler memgraph = Memgraph(args.memgraph) # Start the slapd binary - slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), - "-h", "ldap://127.0.0.1:1389/", "-d", "0"] + slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), "-h", "ldap://127.0.0.1:1389/", "-d", "0"] slapd = subprocess.Popen(slapd_args) time.sleep(0.1) assert slapd.poll() is None, "slapd process died prematurely!" @@ -409,8 +373,7 @@ if __name__ == "__main__": if slapd_stat != 0: print("slapd process didn't exit cleanly!") - assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " \ - "(memgraph, slapd) crashed!" + assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " "(memgraph, slapd) crashed!" # Execute tests names = sorted(globals().keys()) diff --git a/tests/integration/mg_import_csv/runner.py b/tests/integration/mg_import_csv/runner.py index afaa58e28..4bd54dce8 100755 --- a/tests/integration/mg_import_csv/runner.py +++ b/tests/integration/mg_import_csv/runner.py @@ -18,12 +18,13 @@ import subprocess import sys import tempfile import time -import yaml +import yaml SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) BUILD_DIR = os.path.join(BASE_DIR, "build") +SIGNAL_SIGTERM = 15 def wait_for_server(port, delay=0.1): @@ -46,17 +47,14 @@ def list_to_string(data): def verify_lifetime(memgraph_binary, mg_import_csv_binary): - print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " - "memgraph is running ~~\033[0m") + print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " "memgraph is running ~~\033[0m") storage_directory = tempfile.TemporaryDirectory() # Generate common args - common_args = ["--data-directory", storage_directory.name, - "--storage-properties-on-edges=false"] + common_args = ["--data-directory", storage_directory.name, "--storage-properties-on-edges=false"] # Start the memgraph binary - memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \ - common_args + memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args memgraph = subprocess.Popen(list(map(str, memgraph_args))) time.sleep(0.1) assert memgraph.poll() is None, "Memgraph process died prematurely!" @@ -66,47 +64,52 @@ def verify_lifetime(memgraph_binary, mg_import_csv_binary): @atexit.register def cleanup(): if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) # Execute mg_import_csv. - mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + \ - common_args + mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + common_args ret = subprocess.run(mg_import_csv_args) # Check the return code if ret.returncode == 0: - raise Exception( - "The importer was able to run while memgraph was running!") + raise Exception("The importer was able to run while memgraph was running!") # Shutdown the memgraph binary - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) print("\033[1;32m~~ Test successful ~~\033[0m\n") -def execute_test(name, test_path, test_config, memgraph_binary, - mg_import_csv_binary, tester_binary, write_expected): +def execute_test(name, test_path, test_config, memgraph_binary, mg_import_csv_binary, tester_binary, write_expected): print("\033[1;36m~~ Executing test", name, "~~\033[0m") storage_directory = tempfile.TemporaryDirectory() # Verify test configuration - if ("import_should_fail" not in test_config and - "expected" not in test_config) or \ - ("import_should_fail" in test_config and - "expected" in test_config): - raise Exception("The test should specify either 'import_should_fail' " - "or 'expected'!") + if ("import_should_fail" not in test_config and "expected" not in test_config) or ( + "import_should_fail" in test_config and "expected" in test_config + ): + raise Exception("The test should specify either 'import_should_fail' " "or 'expected'!") expected_path = test_config.pop("expected", "") import_should_fail = test_config.pop("import_should_fail", False) # Generate common args properties_on_edges = bool(test_config.pop("properties_on_edges", False)) - common_args = ["--data-directory", storage_directory.name, - "--storage-properties-on-edges=" + - str(properties_on_edges).lower()] + common_args = [ + "--data-directory", + storage_directory.name, + "--storage-properties-on-edges=" + str(properties_on_edges).lower(), + ] # Generate mg_import_csv args using flags specified in the test mg_import_csv_args = [mg_import_csv_binary] + common_args @@ -125,19 +128,16 @@ def execute_test(name, test_path, test_config, memgraph_binary, if import_should_fail: if ret.returncode == 0: - raise Exception("The import should have failed, but it " - "succeeded instead!") + raise Exception("The import should have failed, but it " "succeeded instead!") else: print("\033[1;32m~~ Test successful ~~\033[0m\n") return else: if ret.returncode != 0: - raise Exception("The import should have succeeded, but it " - "failed instead!") + raise Exception("The import should have succeeded, but it " "failed instead!") # Start the memgraph binary - memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \ - common_args + memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args memgraph = subprocess.Popen(list(map(str, memgraph_args))) time.sleep(0.1) assert memgraph.poll() is None, "Memgraph process died prematurely!" @@ -147,21 +147,29 @@ def execute_test(name, test_path, test_config, memgraph_binary, @atexit.register def cleanup(): if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) # Get the contents of the database - queries_got = extract_rows(subprocess.run( - [tester_binary], stdout=subprocess.PIPE, - check=True).stdout.decode("utf-8")) + queries_got = extract_rows( + subprocess.run([tester_binary], stdout=subprocess.PIPE, check=True).stdout.decode("utf-8") + ) # Shutdown the memgraph binary - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) if write_expected: - with open(os.path.join(test_path, expected_path), 'w') as expected: - expected.write('\n'.join(queries_got)) + with open(os.path.join(test_path, expected_path), "w") as expected: + expected.write("\n".join(queries_got)) else: if expected_path: @@ -173,18 +181,16 @@ def execute_test(name, test_path, test_config, memgraph_binary, # Verify the queries queries_expected.sort() queries_got.sort() - assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \ - "{}".format(list_to_string(queries_got), - list_to_string(queries_expected)) + assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format( + list_to_string(queries_got), list_to_string(queries_expected) + ) print("\033[1;32m~~ Test successful ~~\033[0m\n") if __name__ == "__main__": memgraph_binary = os.path.join(BUILD_DIR, "memgraph") - mg_import_csv_binary = os.path.join( - BUILD_DIR, "src", "mg_import_csv") - tester_binary = os.path.join( - BUILD_DIR, "tests", "integration", "mg_import_csv", "tester") + mg_import_csv_binary = os.path.join(BUILD_DIR, "src", "mg_import_csv") + tester_binary = os.path.join(BUILD_DIR, "tests", "integration", "mg_import_csv", "tester") parser = argparse.ArgumentParser() parser.add_argument("--memgraph", default=memgraph_binary) @@ -193,7 +199,8 @@ if __name__ == "__main__": parser.add_argument( "--write-expected", action="store_true", - help="Overwrite the expected values with the results of the current run") + help="Overwrite the expected values with the results of the current run", + ) args = parser.parse_args() # First test whether the CSV importer can be started while the main @@ -211,7 +218,8 @@ if __name__ == "__main__": testcases = yaml.safe_load(f) for test_config in testcases: test_name = name + "/" + test_config.pop("name") - execute_test(test_name, test_path, test_config, args.memgraph, - args.mg_import_csv, args.tester, args.write_expected) + execute_test( + test_name, test_path, test_config, args.memgraph, args.mg_import_csv, args.tester, args.write_expected + ) sys.exit(0) diff --git a/tests/integration/run_time_settings/runner.py b/tests/integration/run_time_settings/runner.py index 32e20d018..3d03077af 100755 --- a/tests/integration/run_time_settings/runner.py +++ b/tests/integration/run_time_settings/runner.py @@ -23,6 +23,7 @@ from typing import List SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) +SIGNAL_SIGTERM = 15 def wait_for_server(port: int, delay: float = 0.1) -> float: @@ -91,8 +92,12 @@ def check_config(tester_binary: str, flag: str, value: str) -> None: def cleanup(memgraph: subprocess): if memgraph.poll() is None: - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) def run_test(tester_binary: str, memgraph_args: List[str], server_name: str, query_tx: str): diff --git a/tests/integration/storage_mode/runner.py b/tests/integration/storage_mode/runner.py index 9a3182149..995d6d834 100644 --- a/tests/integration/storage_mode/runner.py +++ b/tests/integration/storage_mode/runner.py @@ -22,6 +22,8 @@ assertion_queries = [ f"MATCH (n)-[e]->(m) WITH count(e) as cnt RETURN assert(cnt={len(edge_queries)});", ] +SIGNAL_SIGTERM = 15 + def wait_for_server(port, delay=0.1): cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)] @@ -40,9 +42,12 @@ def prepare_memgraph(memgraph_args): def terminate_memgraph(memgraph): - memgraph.terminate() - time.sleep(0.1) - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) def execute_tester( @@ -90,8 +95,12 @@ def execute_test_analytical_mode(memgraph_binary: str, tester_binary: str) -> No execute_queries(assertion_queries) - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_binary: str) -> None: @@ -135,8 +144,12 @@ def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_bi execute_queries(assertion_queries) print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n") - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_binary: str) -> None: @@ -177,8 +190,12 @@ def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_bi execute_queries(assertion_queries) print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n") - memgraph.terminate() - assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) if __name__ == "__main__": diff --git a/tests/macro_benchmark/databases.py b/tests/macro_benchmark/databases.py index 439464c85..bf1896adc 100644 --- a/tests/macro_benchmark/databases.py +++ b/tests/macro_benchmark/databases.py @@ -11,12 +11,13 @@ import logging import os +import shutil import subprocess +import tempfile +import time from argparse import ArgumentParser from collections import defaultdict -import tempfile -import shutil -import time + from common import get_absolute_path, set_cpus try: @@ -36,13 +37,12 @@ class Memgraph: """ Knows how to start and stop memgraph. """ + def __init__(self, args, num_workers): self.log = logging.getLogger("MemgraphRunner") argp = ArgumentParser("MemgraphArgumentParser") - argp.add_argument("--runner-bin", - default=get_absolute_path("memgraph", "build")) - argp.add_argument("--port", default="7687", - help="Database and client port") + argp.add_argument("--runner-bin", default=get_absolute_path("memgraph", "build")) + argp.add_argument("--port", default="7687", help="Database and client port") argp.add_argument("--data-directory", default=None) argp.add_argument("--storage-snapshot-on-exit", action="store_true") argp.add_argument("--storage-recover-on-startup", action="store_true") @@ -55,8 +55,7 @@ class Memgraph: def start(self): self.log.info("start") - database_args = ["--bolt-port", self.args.port, - "--query-execution-timeout-sec", "0"] + database_args = ["--bolt-port", self.args.port, "--query-execution-timeout-sec", "0"] if self.num_workers: database_args += ["--bolt-num-workers", str(self.num_workers)] if self.args.data_directory: @@ -82,15 +81,13 @@ class Neo: """ Knows how to start and stop neo4j. """ + def __init__(self, args, config): self.log = logging.getLogger("NeoRunner") argp = ArgumentParser("NeoArgumentParser") - argp.add_argument("--runner-bin", default=get_absolute_path( - "neo4j/bin/neo4j", "libs")) - argp.add_argument("--port", default="7687", - help="Database and client port") - argp.add_argument("--http-port", default="7474", - help="Database and client port") + argp.add_argument("--runner-bin", default=get_absolute_path("neo4j/bin/neo4j", "libs")) + argp.add_argument("--port", default="7687", help="Database and client port") + argp.add_argument("--http-port", default="7474", help="Database and client port") self.log.info("Initializing Runner with arguments %r", args) self.args, _ = argp.parse_known_args(args) self.config = config @@ -105,24 +102,22 @@ class Neo: self.neo4j_home_path = tempfile.mkdtemp(dir="/dev/shm") try: - os.symlink(os.path.join(get_absolute_path("neo4j", "libs"), "lib"), - os.path.join(self.neo4j_home_path, "lib")) + os.symlink( + os.path.join(get_absolute_path("neo4j", "libs"), "lib"), os.path.join(self.neo4j_home_path, "lib") + ) neo4j_conf_dir = os.path.join(self.neo4j_home_path, "conf") neo4j_conf_file = os.path.join(neo4j_conf_dir, "neo4j.conf") os.mkdir(neo4j_conf_dir) shutil.copyfile(self.config, neo4j_conf_file) with open(neo4j_conf_file, "a") as f: - f.write("\ndbms.connector.bolt.listen_address=:" + - self.args.port + "\n") - f.write("\ndbms.connector.http.listen_address=:" + - self.args.http_port + "\n") + f.write("\ndbms.connector.bolt.listen_address=:" + self.args.port + "\n") + f.write("\ndbms.connector.http.listen_address=:" + self.args.http_port + "\n") # environment cwd = os.path.dirname(self.args.runner_bin) env = {"NEO4J_HOME": self.neo4j_home_path} - self.database_bin.run(self.args.runner_bin, args=["console"], - env=env, timeout=600, cwd=cwd) + self.database_bin.run(self.args.runner_bin, args=["console"], env=env, timeout=600, cwd=cwd) except: shutil.rmtree(self.neo4j_home_path) raise Exception("Couldn't run Neo4j!")