Use extent hooks for per query memory limit (#1340)

This commit is contained in:
Antonio Filipovic 2023-10-25 16:01:59 +02:00 committed by GitHub
parent 3d4d841753
commit a84f570c6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1275 additions and 267 deletions

View File

@ -111,6 +111,22 @@ enum mgp_error mgp_global_aligned_alloc(size_t size_in_bytes, size_t alignment,
/// The behavior is undefined if `ptr` is not a value returned from a prior /// The behavior is undefined if `ptr` is not a value returned from a prior
/// mgp_global_alloc() or mgp_global_aligned_alloc(). /// mgp_global_alloc() or mgp_global_aligned_alloc().
void mgp_global_free(void *p); void mgp_global_free(void *p);
/// State of the graph database.
struct mgp_graph;
/// Allocations are tracked only for master thread. If new threads are spawned
/// inside procedure, by calling following function
/// you can start tracking allocations for current thread too. This
/// is important if you need query memory limit to work
/// for given procedure or per procedure memory limit.
enum mgp_error mgp_track_current_thread_allocations(struct mgp_graph *graph);
/// Once allocations are tracked for current thread, you need to stop tracking allocations
/// for given thread, before thread finishes with execution, or is detached.
/// Otherwise it might result in slowdown of system due to unnecessary tracking of
/// allocations.
enum mgp_error mgp_untrack_current_thread_allocations(struct mgp_graph *graph);
///@} ///@}
/// @name Operations on mgp_value /// @name Operations on mgp_value
@ -854,9 +870,6 @@ enum mgp_error mgp_edge_set_properties(struct mgp_edge *e, struct mgp_map *prope
enum mgp_error mgp_edge_iter_properties(struct mgp_edge *e, struct mgp_memory *memory, enum mgp_error mgp_edge_iter_properties(struct mgp_edge *e, struct mgp_memory *memory,
struct mgp_properties_iterator **result); struct mgp_properties_iterator **result);
/// State of the graph database.
struct mgp_graph;
/// Get the vertex corresponding to given ID, or NULL if no such vertex exists. /// Get the vertex corresponding to given ID, or NULL if no such vertex exists.
/// Resulting vertex must be freed using mgp_vertex_destroy. /// Resulting vertex must be freed using mgp_vertex_destroy.
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate the vertex. /// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate the vertex.

View File

@ -23,7 +23,7 @@
#include "glue/run_id.hpp" #include "glue/run_id.hpp"
#include "helpers.hpp" #include "helpers.hpp"
#include "license/license_sender.hpp" #include "license/license_sender.hpp"
#include "memory/memory_control.hpp" #include "memory/global_memory_control.hpp"
#include "query/config.hpp" #include "query/config.hpp"
#include "query/discard_value_stream.hpp" #include "query/discard_value_stream.hpp"
#include "query/interpreter.hpp" #include "query/interpreter.hpp"
@ -512,6 +512,7 @@ int main(int argc, char **argv) {
server.AwaitShutdown(); server.AwaitShutdown();
websocket_server.AwaitShutdown(); websocket_server.AwaitShutdown();
memgraph::memory::UnsetHooks();
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
metrics_server.AwaitShutdown(); metrics_server.AwaitShutdown();

View File

@ -1,6 +1,7 @@
set(memory_src_files set(memory_src_files
new_delete.cpp new_delete.cpp
memory_control.cpp) global_memory_control.cpp
query_memory_control.cpp)

View File

@ -9,12 +9,16 @@
// by the Apache License, Version 2.0, included in the file // by the Apache License, Version 2.0, included in the file
// licenses/APL.txt. // licenses/APL.txt.
#include "memory_control.hpp" #include <atomic>
#include <cstdint>
#include "global_memory_control.hpp"
#include "query_memory_control.hpp"
#include "utils/logging.hpp" #include "utils/logging.hpp"
#include "utils/memory_tracker.hpp" #include "utils/memory_tracker.hpp"
#if USE_JEMALLOC #if USE_JEMALLOC
#include <jemalloc/jemalloc.h> #include "jemalloc/jemalloc.h"
#endif #endif
namespace memgraph::memory { namespace memgraph::memory {
@ -57,12 +61,24 @@ void *my_alloc(extent_hooks_t *extent_hooks, void *new_addr, size_t size, size_t
// This needs to be before, to throw exception in case of too big alloc // This needs to be before, to throw exception in case of too big alloc
if (*commit) [[likely]] { if (*commit) [[likely]] {
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(size)); memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(size));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Alloc(static_cast<int64_t>(size));
}
}
} }
auto *ptr = old_hooks->alloc(extent_hooks, new_addr, size, alignment, zero, commit, arena_ind); auto *ptr = old_hooks->alloc(extent_hooks, new_addr, size, alignment, zero, commit, arena_ind);
if (ptr == nullptr) [[unlikely]] { if (ptr == nullptr) [[unlikely]] {
if (*commit) { if (*commit) {
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size)); memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Free(static_cast<int64_t>(size));
}
}
} }
return ptr; return ptr;
} }
@ -79,6 +95,13 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo
if (committed) [[likely]] { if (committed) [[likely]] {
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size)); memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Free(static_cast<int64_t>(size));
}
}
} }
return false; return false;
@ -87,6 +110,12 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo
static void my_destroy(extent_hooks_t *extent_hooks, void *addr, size_t size, bool committed, unsigned arena_ind) { static void my_destroy(extent_hooks_t *extent_hooks, void *addr, size_t size, bool committed, unsigned arena_ind) {
if (committed) [[likely]] { if (committed) [[likely]] {
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size)); memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Free(static_cast<int64_t>(size));
}
}
} }
old_hooks->destroy(extent_hooks, addr, size, committed, arena_ind); old_hooks->destroy(extent_hooks, addr, size, committed, arena_ind);
@ -101,6 +130,12 @@ static bool my_commit(extent_hooks_t *extent_hooks, void *addr, size_t size, siz
} }
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(length)); memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(length));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Alloc(static_cast<int64_t>(size));
}
}
return false; return false;
} }
@ -115,6 +150,12 @@ static bool my_decommit(extent_hooks_t *extent_hooks, void *addr, size_t size, s
} }
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length)); memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Free(static_cast<int64_t>(size));
}
}
return false; return false;
} }
@ -129,6 +170,13 @@ static bool my_purge_forced(extent_hooks_t *extent_hooks, void *addr, size_t siz
} }
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length)); memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Alloc(static_cast<int64_t>(size));
}
}
return false; return false;
} }
@ -153,6 +201,7 @@ void SetHooks() {
} }
for (int i = 0; i < n_arenas; i++) { for (int i = 0; i < n_arenas; i++) {
GetQueriesMemoryControl().InitializeArenaCounter(i);
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks"; std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
size_t hooks_len = sizeof(old_hooks); size_t hooks_len = sizeof(old_hooks);
@ -197,6 +246,45 @@ void SetHooks() {
#endif #endif
} }
void UnsetHooks() {
#if USE_JEMALLOC
uint64_t allocated{0};
uint64_t sz{sizeof(allocated)};
sz = sizeof(unsigned);
unsigned n_arenas{0};
int err = mallctl("opt.narenas", (void *)&n_arenas, &sz, nullptr, 0);
if (err) {
LOG_FATAL("Error setting default hooks for jemalloc arenas");
}
for (int i = 0; i < n_arenas; i++) {
GetQueriesMemoryControl().InitializeArenaCounter(i);
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
MG_ASSERT(old_hooks);
MG_ASSERT(old_hooks->alloc);
MG_ASSERT(old_hooks->dalloc);
MG_ASSERT(old_hooks->destroy);
MG_ASSERT(old_hooks->commit);
MG_ASSERT(old_hooks->decommit);
MG_ASSERT(old_hooks->purge_forced);
MG_ASSERT(old_hooks->purge_lazy);
MG_ASSERT(old_hooks->split);
MG_ASSERT(old_hooks->merge);
err = mallctl(func_name.c_str(), nullptr, nullptr, &old_hooks, sizeof(old_hooks));
if (err) {
LOG_FATAL("Error setting default hooks for jemalloc arena {}", i);
}
}
#endif
}
void PurgeUnusedMemory() { void PurgeUnusedMemory() {
#if USE_JEMALLOC #if USE_JEMALLOC
mallctl("arena." STRINGIFY(MALLCTL_ARENAS_ALL) ".purge", nullptr, nullptr, nullptr, 0); mallctl("arena." STRINGIFY(MALLCTL_ARENAS_ALL) ".purge", nullptr, nullptr, nullptr, 0);

View File

@ -17,5 +17,6 @@ namespace memgraph::memory {
void PurgeUnusedMemory(); void PurgeUnusedMemory();
void SetHooks(); void SetHooks();
void UnsetHooks();
} // namespace memgraph::memory } // namespace memgraph::memory

View File

@ -0,0 +1,140 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <cstdint>
#include <iostream>
#include <optional>
#include <shared_mutex>
#include <thread>
#include <tuple>
#include <utility>
#include "query_memory_control.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/memory_tracker.hpp"
#include "utils/rw_spin_lock.hpp"
#if USE_JEMALLOC
#include "jemalloc/jemalloc.h"
#endif
namespace memgraph::memory {
#if USE_JEMALLOC
unsigned QueriesMemoryControl::GetArenaForThread() {
unsigned thread_arena{0};
size_t size_thread_arena = sizeof(thread_arena);
int err = mallctl("thread.arena", &thread_arena, &size_thread_arena, nullptr, 0);
if (err) {
LOG_FATAL("Can't get arena for thread.");
}
return thread_arena;
}
void QueriesMemoryControl::AddTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_add(1); }
void QueriesMemoryControl::RemoveTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_sub(1); }
void QueriesMemoryControl::UpdateThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) {
auto accessor = thread_id_to_transaction_id.access();
accessor.insert({thread_id, transaction_id});
}
void QueriesMemoryControl::EraseThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) {
auto accessor = thread_id_to_transaction_id.access();
auto elem = accessor.find(thread_id);
MG_ASSERT(elem != accessor.end() && elem->transaction_id == transaction_id);
accessor.remove(thread_id);
}
utils::MemoryTracker *QueriesMemoryControl::GetTrackerCurrentThread() {
auto thread_id_to_transaction_id_accessor = thread_id_to_transaction_id.access();
// we might be just constructing mapping between thread id and transaction id
// so we miss this allocation
auto thread_id_to_transaction_id_elem = thread_id_to_transaction_id_accessor.find(std::this_thread::get_id());
if (thread_id_to_transaction_id_elem == thread_id_to_transaction_id_accessor.end()) {
return nullptr;
}
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
auto transaction_id_to_tracker =
transaction_id_to_tracker_accessor.find(thread_id_to_transaction_id_elem->transaction_id);
return &transaction_id_to_tracker->tracker;
}
void QueriesMemoryControl::CreateTransactionIdTracker(uint64_t transaction_id, size_t inital_limit) {
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
auto [elem, result] = transaction_id_to_tracker_accessor.insert({transaction_id, utils::MemoryTracker{}});
elem->tracker.SetMaximumHardLimit(inital_limit);
elem->tracker.SetHardLimit(inital_limit);
}
bool QueriesMemoryControl::CheckTransactionIdTrackerExists(uint64_t transaction_id) {
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
return transaction_id_to_tracker_accessor.contains(transaction_id);
}
bool QueriesMemoryControl::EraseTransactionIdTracker(uint64_t transaction_id) {
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
auto removed = transaction_id_to_tracker.access().remove(transaction_id);
return removed;
}
bool QueriesMemoryControl::IsArenaTracked(unsigned arena_ind) {
return arena_tracking[arena_ind].load(std::memory_order_acquire) != 0;
}
void QueriesMemoryControl::InitializeArenaCounter(unsigned arena_ind) {
arena_tracking[arena_ind].store(0, std::memory_order_relaxed);
}
#endif
void StartTrackingCurrentThreadTransaction(uint64_t transaction_id) {
#if USE_JEMALLOC
GetQueriesMemoryControl().UpdateThreadToTransactionId(std::this_thread::get_id(), transaction_id);
GetQueriesMemoryControl().AddTrackingOnArena(QueriesMemoryControl::GetArenaForThread());
#endif
}
void StopTrackingCurrentThreadTransaction(uint64_t transaction_id) {
#if USE_JEMALLOC
GetQueriesMemoryControl().EraseThreadToTransactionId(std::this_thread::get_id(), transaction_id);
GetQueriesMemoryControl().RemoveTrackingOnArena(QueriesMemoryControl::GetArenaForThread());
#endif
}
void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit) {
#if USE_JEMALLOC
if (GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) {
return;
}
GetQueriesMemoryControl().CreateTransactionIdTracker(transaction_id, limit);
#endif
}
void TryStopTrackingOnTransaction(uint64_t transaction_id) {
#if USE_JEMALLOC
if (!GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) {
return;
}
GetQueriesMemoryControl().EraseTransactionIdTracker(transaction_id);
#endif
}
} // namespace memgraph::memory

View File

@ -0,0 +1,141 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstddef>
#include <cstdint>
#include <thread>
#include <unordered_map>
#include "utils/memory_tracker.hpp"
#include "utils/skip_list.hpp"
namespace memgraph::memory {
#if USE_JEMALLOC
// Track memory allocations per query.
// Multiple threads can allocate inside one transaction.
// If user forgets to unregister tracking for that thread before it dies, it will continue to
// track allocations for that arena indefinitely.
// As multiple queries can be executed inside one transaction, one by one (multi-transaction)
// it is necessary to restart tracking at the beginning of new query for that transaction.
class QueriesMemoryControl {
public:
/*
Arena stats
*/
static unsigned GetArenaForThread();
// Add counter on threads allocating inside arena
void AddTrackingOnArena(unsigned);
// Remove counter on threads allocating in arena
void RemoveTrackingOnArena(unsigned);
// Are any threads using current arena for allocations
// Multiple threads can allocate inside one arena
bool IsArenaTracked(unsigned);
// Initialize arena counter
void InitializeArenaCounter(unsigned);
/*
Transaction id <-> tracker
*/
// Create new tracker for transaction_id with initial limit
void CreateTransactionIdTracker(uint64_t, size_t);
// Check if tracker for given transaction id exists
bool CheckTransactionIdTrackerExists(uint64_t);
// Remove current tracker for transaction_id
bool EraseTransactionIdTracker(uint64_t);
/*
Thread handlings
*/
// Map thread to transaction with given id
// This way we can know which thread belongs to which transaction
// and get correct tracker for given transaction
void UpdateThreadToTransactionId(const std::thread::id &, uint64_t);
// Remove tracking of thread from transaction.
// Important to reset if one thread gets reused for different transaction
void EraseThreadToTransactionId(const std::thread::id &, uint64_t);
// C-API functionality for thread to transaction mapping
void UpdateThreadToTransactionId(const char *, uint64_t);
// C-API functionality for thread to transaction unmapping
void EraseThreadToTransactionId(const char *, uint64_t);
// Get tracker to current thread if exists, otherwise return
// nullptr. This can happen only if tracker is still
// being constructed.
utils::MemoryTracker *GetTrackerCurrentThread();
private:
std::unordered_map<unsigned, std::atomic<int>> arena_tracking;
struct ThreadIdToTransactionId {
std::thread::id thread_id;
uint64_t transaction_id;
bool operator<(const ThreadIdToTransactionId &other) const { return thread_id < other.thread_id; }
bool operator==(const ThreadIdToTransactionId &other) const { return thread_id == other.thread_id; }
bool operator<(const std::thread::id other) const { return thread_id < other; }
bool operator==(const std::thread::id other) const { return thread_id == other; }
};
struct TransactionIdToTracker {
uint64_t transaction_id;
utils::MemoryTracker tracker;
bool operator<(const TransactionIdToTracker &other) const { return transaction_id < other.transaction_id; }
bool operator==(const TransactionIdToTracker &other) const { return transaction_id == other.transaction_id; }
bool operator<(uint64_t other) const { return transaction_id < other; }
bool operator==(uint64_t other) const { return transaction_id == other; }
};
utils::SkipList<ThreadIdToTransactionId> thread_id_to_transaction_id;
utils::SkipList<TransactionIdToTracker> transaction_id_to_tracker;
};
inline QueriesMemoryControl &GetQueriesMemoryControl() {
static QueriesMemoryControl queries_memory_control_;
return queries_memory_control_;
}
#endif
// API function call for to start tracking current thread for given transaction.
// Does nothing if jemalloc is not enabled
void StartTrackingCurrentThreadTransaction(uint64_t transaction_id);
// API function call for to stop tracking current thread for given transaction.
// Does nothing if jemalloc is not enabled
void StopTrackingCurrentThreadTransaction(uint64_t transaction_id);
// API function call for try to create tracker for transaction and set it to given limit.
// Does nothing if jemalloc is not enabled. Does nothing if tracker already exists
void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit);
// API function call to stop tracking for given transaction.
// Does nothing if jemalloc is not enabled. Does nothing if tracker doesn't exist
void TryStopTrackingOnTransaction(uint64_t transaction_id);
} // namespace memgraph::memory

View File

@ -21,6 +21,10 @@ namespace memgraph::query {
SubgraphDbAccessor::SubgraphDbAccessor(query::DbAccessor db_accessor, Graph *graph) SubgraphDbAccessor::SubgraphDbAccessor(query::DbAccessor db_accessor, Graph *graph)
: db_accessor_(db_accessor), graph_(graph) {} : db_accessor_(db_accessor), graph_(graph) {}
void SubgraphDbAccessor::TrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); }
void SubgraphDbAccessor::UntrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); }
storage::PropertyId SubgraphDbAccessor::NameToProperty(const std::string_view name) { storage::PropertyId SubgraphDbAccessor::NameToProperty(const std::string_view name) {
return db_accessor_.NameToProperty(name); return db_accessor_.NameToProperty(name);
} }

View File

@ -17,6 +17,7 @@
#include <cppitertools/filter.hpp> #include <cppitertools/filter.hpp>
#include <cppitertools/imap.hpp> #include <cppitertools/imap.hpp>
#include "memory/query_memory_control.hpp"
#include "query/exceptions.hpp" #include "query/exceptions.hpp"
#include "storage/v2/edge_accessor.hpp" #include "storage/v2/edge_accessor.hpp"
#include "storage/v2/id_types.hpp" #include "storage/v2/id_types.hpp"
@ -372,6 +373,16 @@ class DbAccessor final {
void FinalizeTransaction() { accessor_->FinalizeTransaction(); } void FinalizeTransaction() { accessor_->FinalizeTransaction(); }
void TrackCurrentThreadAllocations() {
memgraph::memory::StartTrackingCurrentThreadTransaction(*accessor_->GetTransactionId());
}
void UntrackCurrentThreadAllocations() {
memgraph::memory::StopTrackingCurrentThreadTransaction(*accessor_->GetTransactionId());
}
std::optional<uint64_t> GetTransactionId() { return accessor_->GetTransactionId(); }
VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); } VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); }
VerticesIterable Vertices(storage::View view, storage::LabelId label) { VerticesIterable Vertices(storage::View view, storage::LabelId label) {
@ -640,6 +651,14 @@ class SubgraphDbAccessor final {
static SubgraphDbAccessor *MakeSubgraphDbAccessor(DbAccessor *db_accessor, Graph *graph); static SubgraphDbAccessor *MakeSubgraphDbAccessor(DbAccessor *db_accessor, Graph *graph);
void TrackThreadAllocations(const char *thread_id);
void TrackCurrentThreadAllocations();
void UntrackThreadAllocations(const char *thread_id);
void UntrackCurrentThreadAllocations();
storage::PropertyId NameToProperty(std::string_view name); storage::PropertyId NameToProperty(std::string_view name);
storage::LabelId NameToLabel(std::string_view name); storage::LabelId NameToLabel(std::string_view name);

View File

@ -26,6 +26,7 @@
#include <optional> #include <optional>
#include <stdexcept> #include <stdexcept>
#include <thread> #include <thread>
#include <tuple>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <variant> #include <variant>
@ -39,7 +40,8 @@
#include "flags/run_time_configurable.hpp" #include "flags/run_time_configurable.hpp"
#include "glue/communication.hpp" #include "glue/communication.hpp"
#include "license/license.hpp" #include "license/license.hpp"
#include "memory/memory_control.hpp" #include "memory/global_memory_control.hpp"
#include "memory/query_memory_control.hpp"
#include "query/config.hpp" #include "query/config.hpp"
#include "query/constants.hpp" #include "query/constants.hpp"
#include "query/context.hpp" #include "query/context.hpp"
@ -1283,6 +1285,24 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n, std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols, const std::vector<Symbol> &output_symbols,
std::map<std::string, TypedValue> *summary) { std::map<std::string, TypedValue> *summary) {
std::optional<uint64_t> transaction_id = ctx_.db_accessor->GetTransactionId();
MG_ASSERT(transaction_id.has_value());
if (memory_limit_) {
memgraph::memory::TryStartTrackingOnTransaction(*transaction_id, *memory_limit_);
memgraph::memory::StartTrackingCurrentThreadTransaction(*transaction_id);
}
utils::OnScopeExit<std::function<void()>> reset_query_limit{
[memory_limit = memory_limit_, transaction_id = *transaction_id]() {
if (memory_limit) {
// Stopping tracking of transaction occurs in interpreter::pull
// Exception can occur so we need to handle that case there.
// We can't stop tracking here as there can be multiple pulls
// so we need to take care of that after everything was pulled
memgraph::memory::StopTrackingCurrentThreadTransaction(transaction_id);
}
}};
// Set up temporary memory for a single Pull. Initial memory comes from the // Set up temporary memory for a single Pull. Initial memory comes from the
// stack. 256 KiB should fit on the stack and should be more than enough for a // stack. 256 KiB should fit on the stack and should be more than enough for a
// single `Pull`. // single `Pull`.
@ -1306,13 +1326,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *strea
pool_memory.emplace(kMaxBlockPerChunks, 1024, &monotonic_memory, &resource_with_exception); pool_memory.emplace(kMaxBlockPerChunks, 1024, &monotonic_memory, &resource_with_exception);
} }
std::optional<utils::LimitedMemoryResource> maybe_limited_resource;
if (memory_limit_) {
maybe_limited_resource.emplace(&*pool_memory, *memory_limit_);
ctx_.evaluation_context.memory = &*maybe_limited_resource;
} else {
ctx_.evaluation_context.memory = &*pool_memory; ctx_.evaluation_context.memory = &*pool_memory;
}
// Returns true if a result was pulled. // Returns true if a result was pulled.
const auto pull_result = [&]() -> bool { return cursor_->Pull(frame_, ctx_); }; const auto pull_result = [&]() -> bool { return cursor_->Pull(frame_, ctx_); };
@ -1379,6 +1393,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *strea
} }
cursor_->Shutdown(); cursor_->Shutdown();
ctx_.profile_execution_time = execution_time_; ctx_.profile_execution_time = execution_time_;
return GetStatsWithTotalTime(ctx_); return GetStatsWithTotalTime(ctx_);
} }

View File

@ -16,6 +16,7 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include "dbms/database.hpp" #include "dbms/database.hpp"
#include "memory/query_memory_control.hpp"
#include "query/auth_checker.hpp" #include "query/auth_checker.hpp"
#include "query/auth_query_handler.hpp" #include "query/auth_query_handler.hpp"
#include "query/config.hpp" #include "query/config.hpp"
@ -402,6 +403,9 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
// If the query finished executing, we have received a value which tells // If the query finished executing, we have received a value which tells
// us what to do after. // us what to do after.
if (maybe_res) { if (maybe_res) {
if (current_transaction_) {
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
}
// Save its summary // Save its summary
maybe_summary.emplace(std::move(query_execution->summary)); maybe_summary.emplace(std::move(query_execution->summary));
if (!query_execution->notifications.empty()) { if (!query_execution->notifications.empty()) {
@ -440,9 +444,15 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
} }
} }
} catch (const ExplicitTransactionUsageException &) { } catch (const ExplicitTransactionUsageException &) {
if (current_transaction_) {
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
}
query_execution.reset(nullptr); query_execution.reset(nullptr);
throw; throw;
} catch (const utils::BasicException &) { } catch (const utils::BasicException &) {
if (current_transaction_) {
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
}
// Trigger first failed query // Trigger first failed query
metrics::FirstFailedQuery(); metrics::FirstFailedQuery();
memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery); memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery);

View File

@ -3596,3 +3596,15 @@ mgp_error mgp_log(const mgp_log_level log_level, const char *output) {
throw std::invalid_argument{fmt::format("Invalid log level: {}", log_level)}; throw std::invalid_argument{fmt::format("Invalid log level: {}", log_level)};
}); });
} }
mgp_error mgp_track_current_thread_allocations(mgp_graph *graph) {
return WrapExceptions([&]() {
std::visit([](auto *db_accessor) -> void { db_accessor->TrackCurrentThreadAllocations(); }, graph->impl);
});
}
mgp_error mgp_untrack_current_thread_allocations(mgp_graph *graph) {
return WrapExceptions([&]() {
std::visit([](auto *db_accessor) -> void { db_accessor->UntrackCurrentThreadAllocations(); }, graph->impl);
});
}

View File

@ -89,6 +89,13 @@ void MemoryTracker::TryRaiseHardLimit(const int64_t limit) {
; ;
} }
void MemoryTracker::ResetTrackings() {
hard_limit_.store(0, std::memory_order_relaxed);
peak_.store(0, std::memory_order_relaxed);
amount_.store(0, std::memory_order_relaxed);
maximum_hard_limit_ = 0;
}
void MemoryTracker::SetMaximumHardLimit(const int64_t limit) { void MemoryTracker::SetMaximumHardLimit(const int64_t limit) {
if (maximum_hard_limit_ < 0) { if (maximum_hard_limit_ < 0) {
spdlog::warn("Invalid maximum hard limit."); spdlog::warn("Invalid maximum hard limit.");

View File

@ -12,6 +12,7 @@
#pragma once #pragma once
#include <atomic> #include <atomic>
#include <type_traits>
#include "utils/exceptions.hpp" #include "utils/exceptions.hpp"
@ -41,9 +42,20 @@ class MemoryTracker final {
MemoryTracker() = default; MemoryTracker() = default;
~MemoryTracker() = default; ~MemoryTracker() = default;
MemoryTracker(MemoryTracker &&other) noexcept
: amount_(other.amount_.load(std::memory_order_acquire)),
peak_(other.peak_.load(std::memory_order_acquire)),
hard_limit_(other.hard_limit_.load(std::memory_order_acquire)),
maximum_hard_limit_(other.maximum_hard_limit_) {
other.maximum_hard_limit_ = 0;
other.amount_.store(0, std::memory_order_acquire);
other.peak_.store(0, std::memory_order_acquire);
other.hard_limit_.store(0, std::memory_order_acquire);
}
MemoryTracker(const MemoryTracker &) = delete; MemoryTracker(const MemoryTracker &) = delete;
MemoryTracker &operator=(const MemoryTracker &) = delete; MemoryTracker &operator=(const MemoryTracker &) = delete;
MemoryTracker(MemoryTracker &&) = delete;
MemoryTracker &operator=(MemoryTracker &&) = delete; MemoryTracker &operator=(MemoryTracker &&) = delete;
void Alloc(int64_t size); void Alloc(int64_t size);
@ -59,6 +71,8 @@ class MemoryTracker final {
void TryRaiseHardLimit(int64_t limit); void TryRaiseHardLimit(int64_t limit);
void SetMaximumHardLimit(int64_t limit); void SetMaximumHardLimit(int64_t limit);
void ResetTrackings();
// By creating an object of this class, every allocation in its scope that goes over // By creating an object of this class, every allocation in its scope that goes over
// the set hard limit produces an OutOfMemoryException. // the set hard limit produces an OutOfMemoryException.
class OutOfMemoryExceptionEnabler final { class OutOfMemoryExceptionEnabler final {

View File

@ -21,6 +21,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
BUILD_DIR = os.path.join(PROJECT_DIR, "build") BUILD_DIR = os.path.join(PROJECT_DIR, "build")
MEMGRAPH_BINARY = os.path.join(BUILD_DIR, "memgraph") MEMGRAPH_BINARY = os.path.join(BUILD_DIR, "memgraph")
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.01): def wait_for_server(port, delay=0.01):
@ -133,7 +134,7 @@ class MemgraphInstanceRunner:
pid = self.proc_mg.pid pid = self.proc_mg.pid
try: try:
os.kill(pid, 15) # 15 is the signal number for SIGTERM os.kill(pid, SIGNAL_SIGTERM)
except os.OSError: except os.OSError:
assert False assert False

View File

@ -11,10 +11,21 @@ target_link_libraries(memgraph__e2e__memory__limit_global_alloc gflags mgclient
add_executable(memgraph__e2e__memory__limit_global_alloc_proc memory_limit_global_alloc_proc.cpp) add_executable(memgraph__e2e__memory__limit_global_alloc_proc memory_limit_global_alloc_proc.cpp)
target_link_libraries(memgraph__e2e__memory__limit_global_alloc_proc gflags mgclient mg-utils mg-io Threads::Threads) target_link_libraries(memgraph__e2e__memory__limit_global_alloc_proc gflags mgclient mg-utils mg-io Threads::Threads)
add_executable(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread query_memory_limit_proc_multi_thread.cpp)
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread gflags mgclient mg-utils mg-io Threads::Threads)
add_executable(memgraph__e2e__memory__limit_query_alloc_create query_memory_limit_create.cpp)
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create gflags mgclient mg-utils mg-io)
add_executable(memgraph__e2e__memory__limit_query_alloc_proc query_memory_limit_proc.cpp)
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc gflags mgclient mg-utils mg-io)
add_executable(memgraph__e2e__memory__limit_query_alloc_create_multi_thread query_memory_limit_multi_thread.cpp)
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create_multi_thread gflags mgclient mg-utils mg-io Threads::Threads)
add_executable(memgraph__e2e__memory__limit_delete memory_limit_delete.cpp) add_executable(memgraph__e2e__memory__limit_delete memory_limit_delete.cpp)
target_link_libraries(memgraph__e2e__memory__limit_delete gflags mgclient mg-utils mg-io) target_link_libraries(memgraph__e2e__memory__limit_delete gflags mgclient mg-utils mg-io)
add_executable(memgraph__e2e__memory__limit_accumulation memory_limit_accumulation.cpp) add_executable(memgraph__e2e__memory__limit_accumulation memory_limit_accumulation.cpp)
target_link_libraries(memgraph__e2e__memory__limit_accumulation gflags mgclient mg-utils mg-io) target_link_libraries(memgraph__e2e__memory__limit_accumulation gflags mgclient mg-utils mg-io)

View File

@ -3,3 +3,13 @@ target_include_directories(global_memory_limit PRIVATE ${CMAKE_SOURCE_DIR}/inclu
add_library(global_memory_limit_proc SHARED global_memory_limit_proc.c) add_library(global_memory_limit_proc SHARED global_memory_limit_proc.c)
target_include_directories(global_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include) target_include_directories(global_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
add_library(query_memory_limit_proc_multi_thread SHARED query_memory_limit_proc_multi_thread.cpp)
target_include_directories(query_memory_limit_proc_multi_thread PRIVATE ${CMAKE_SOURCE_DIR}/include)
target_link_libraries(query_memory_limit_proc_multi_thread mg-utils)
add_library(query_memory_limit_proc SHARED query_memory_limit_proc.cpp)
target_include_directories(query_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
target_link_libraries(query_memory_limit_proc mg-utils)

View File

@ -0,0 +1,70 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <cassert>
#include <exception>
#include <functional>
#include <mgp.hpp>
#include <mutex>
#include <sstream>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "mg_procedure.h"
#include "utils/on_scope_exit.hpp"
enum mgp_error Alloc(void *ptr) {
const size_t mb_size_268 = 1 << 28;
return mgp_global_alloc(mb_size_268, (void **)(&ptr));
}
void Regular(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
mgp::MemoryDispatcherGuard guard{memory};
const auto arguments = mgp::List(args);
const auto record_factory = mgp::RecordFactory(result);
try {
void *ptr{nullptr};
memgraph::utils::OnScopeExit<std::function<void(void)>> cleanup{[&ptr]() {
if (nullptr == ptr) {
return;
}
mgp_global_free(ptr);
}};
const enum mgp_error alloc_err = Alloc(ptr);
auto new_record = record_factory.NewRecord();
new_record.Insert("allocated", alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE);
} catch (std::exception &e) {
record_factory.SetErrorMessage(e.what());
}
}
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
try {
mgp::MemoryDispatcherGuard mdg{memory};
AddProcedure(Regular, std::string("regular").c_str(), mgp::ProcedureType::Read, {},
{mgp::Return(std::string("allocated").c_str(), mgp::Type::Bool)}, module, memory);
} catch (const std::exception &e) {
return 1;
}
return 0;
}
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,101 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <cassert>
#include <exception>
#include <functional>
#include <mgp.hpp>
#include <mutex>
#include <sstream>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "mg_procedure.h"
#include "utils/on_scope_exit.hpp"
enum mgp_error Alloc(void *ptr) {
const size_t mb_size_268 = 1 << 28;
return mgp_global_alloc(mb_size_268, (void **)(&ptr));
}
// change communication between threads with feature and promise
std::atomic<int> num_allocations{0};
std::vector<void *> ptrs_;
void AllocFunc(mgp_graph *graph) {
[[maybe_unused]] const enum mgp_error tracking_error = mgp_track_current_thread_allocations(graph);
void *ptr = nullptr;
ptrs_.emplace_back(ptr);
try {
enum mgp_error alloc_err { mgp_error::MGP_ERROR_NO_ERROR };
alloc_err = Alloc(ptr);
if (alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) {
num_allocations.fetch_add(1, std::memory_order_relaxed);
}
if (alloc_err != mgp_error::MGP_ERROR_NO_ERROR) {
assert(false);
}
} catch (const std::exception &e) {
assert(false);
}
[[maybe_unused]] const enum mgp_error untracking_error = mgp_untrack_current_thread_allocations(graph);
}
void DualThread(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
mgp::MemoryDispatcherGuard guard{memory};
const auto arguments = mgp::List(args);
const auto record_factory = mgp::RecordFactory(result);
num_allocations.store(0, std::memory_order_relaxed);
try {
std::vector<std::thread> threads;
for (int i = 0; i < 2; i++) {
threads.emplace_back(AllocFunc, memgraph_graph);
}
for (int i = 0; i < 2; i++) {
threads[i].join();
}
for (void *ptr : ptrs_) {
if (ptr != nullptr) {
mgp_global_free(ptr);
}
}
auto new_record = record_factory.NewRecord();
new_record.Insert("allocated_all", num_allocations.load(std::memory_order_relaxed) == 2);
} catch (std::exception &e) {
record_factory.SetErrorMessage(e.what());
}
}
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
try {
mgp::memory = memory;
AddProcedure(DualThread, std::string("dual_thread").c_str(), mgp::ProcedureType::Read, {},
{mgp::Return(std::string("allocated_all").c_str(), mgp::Type::Bool)}, module, memory);
} catch (const std::exception &e) {
return 1;
}
return 0;
}
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,65 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include <iostream>
#include <mgclient.hpp>
#include "utils/logging.hpp"
#include "utils/timer.hpp"
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Memory Control");
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::logging::RedirectToStderr();
mg::Client::Init();
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
const auto *create_query =
"UNWIND range(1, 50000) as u CREATE (n {string: 'Some longer string'}) RETURN n QUERY MEMORY LIMIT 30MB;";
try {
client->Execute(create_query);
[[maybe_unused]] auto results = client->FetchAll();
if (results->empty()) {
assert(true);
return 0;
}
} catch (const mg::TransientException & /*unused*/) {
spdlog::info("Memgraph is out of memory");
assert(true);
return 0;
}
MG_ASSERT(false, "Query should have failed!");
return 0;
}

View File

@ -0,0 +1,100 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <exception>
#include <future>
#include <thread>
#include <gflags/gflags.h>
#include <iostream>
#include <mgclient.hpp>
#include <utility>
#include "utils/logging.hpp"
#include "utils/timer.hpp"
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
static constexpr int kNumberClients{2};
void Func(std::promise<bool> promise) {
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
const auto *create_query =
"FOREACH(i in range(1, 300000) | CREATE (n: Node {string: 'Some longer string'})) QUERY MEMORY LIMIT 100MB;";
bool err{false};
try {
client->Execute(create_query);
[[maybe_unused]] auto results = client->FetchAll();
if (results->empty()) {
err = true;
}
} catch (const std::exception &e) {
spdlog::info("Good: Exception occured", e.what());
err = true;
}
promise.set_value_at_thread_exit(err);
}
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Memory Control");
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::logging::RedirectToStderr();
mg::Client::Init();
{
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
}
std::vector<std::promise<bool>> my_promises;
std::vector<std::future<bool>> my_futures;
for (int i = 0; i < 4; i++) {
my_promises.push_back(std::promise<bool>());
my_futures.emplace_back(my_promises.back().get_future());
}
std::vector<std::thread> my_threads;
for (int i = 0; i < kNumberClients; i++) {
my_threads.emplace_back(Func, std::move(my_promises[i]));
}
for (int i = 0; i < kNumberClients; i++) {
auto value = my_futures[i].get();
MG_ASSERT(value, "Error should have happend in thread");
}
for (int i = 0; i < kNumberClients; i++) {
my_threads[i].join();
}
return 0;
}

View File

@ -0,0 +1,66 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include <algorithm>
#include <exception>
#include <ios>
#include <iostream>
#include <mgclient.hpp>
#include "utils/logging.hpp"
#include "utils/timer.hpp"
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators");
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::logging::RedirectToStderr();
mg::Client::Init();
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
MG_ASSERT(
client->Execute("CALL libquery_memory_limit_proc.regular() YIELD allocated RETURN "
"allocated QUERY MEMORY LIMIT 250MB"));
bool error{false};
try {
auto result_rows = client->FetchAll();
if (result_rows) {
auto row = *result_rows->begin();
error = row[0].ValueBool() == false;
}
} catch (const std::exception &e) {
error = true;
}
MG_ASSERT(error, "Error should have happend");
return 0;
}

View File

@ -0,0 +1,66 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include <algorithm>
#include <exception>
#include <ios>
#include <iostream>
#include <mgclient.hpp>
#include "utils/logging.hpp"
#include "utils/timer.hpp"
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators");
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::logging::RedirectToStderr();
mg::Client::Init();
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
MG_ASSERT(
client->Execute("CALL libquery_memory_limit_proc_multi_thread.dual_thread() YIELD allocated_all RETURN "
"allocated_all QUERY MEMORY LIMIT 500MB"));
bool error{false};
try {
auto result_rows = client->FetchAll();
if (result_rows) {
auto row = *result_rows->begin();
error = row[0].ValueBool() == false;
}
} catch (const std::exception &e) {
error = true;
}
MG_ASSERT(error, "Error should have happend");
return 0;
}

View File

@ -23,6 +23,20 @@ disk_cluster: &disk_cluster
- "STORAGE MODE ON_DISK_TRANSACTIONAL" - "STORAGE MODE ON_DISK_TRANSACTIONAL"
validation_queries: [] validation_queries: []
args_query_limit: &args_query_limit
- "--bolt-port"
- *bolt_port
- "--storage-gc-cycle-sec=180"
- "--log-level=TRACE"
in_memory_query_limit_cluster: &in_memory_query_limit_cluster
cluster:
main:
args: *args_query_limit
log_file: "memory-e2e.log"
setup_queries: []
validation_queries: []
args_450_MiB_limit: &args_450_MiB_limit args_450_MiB_limit: &args_450_MiB_limit
- "--bolt-port" - "--bolt-port"
- *bolt_port - *bolt_port
@ -95,6 +109,27 @@ workloads:
proc: "tests/e2e/memory/procedures/" proc: "tests/e2e/memory/procedures/"
<<: *disk_cluster <<: *disk_cluster
- name: "Memory control query limit proc"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc"
proc: "tests/e2e/memory/procedures/"
args: ["--bolt-port", *bolt_port]
<<: *in_memory_query_limit_cluster
- name: "Memory control query limit proc multi thread"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc_multi_thread"
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
proc: "tests/e2e/memory/procedures/"
<<: *in_memory_query_limit_cluster
- name: "Memory control query limit create"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create"
args: ["--bolt-port", *bolt_port]
<<: *in_memory_query_limit_cluster
- name: "Memory control query limit create multi thread"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create_multi_thread"
args: ["--bolt-port", *bolt_port]
<<: *in_memory_query_limit_cluster
- name: "Memory control for detach delete" - name: "Memory control for detach delete"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete"
args: ["--bolt-port", *bolt_port] args: ["--bolt-port", *bolt_port]

View File

@ -24,6 +24,7 @@ import time
DEFAULT_DB = "memgraph" DEFAULT_DB = "memgraph"
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
QUERIES = [ QUERIES = [
("MATCH (n) DELETE n", {}), ("MATCH (n) DELETE n", {}),
@ -92,9 +93,13 @@ def execute_test(memgraph_binary, tester_binary):
# Register cleanup function # Register cleanup function
@atexit.register @atexit.register
def cleanup(): def cleanup():
if memgraph.poll() is None: pid = memgraph.pid
memgraph.terminate() try:
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
def execute_queries(queries): def execute_queries(queries):
for db, query, params in queries: for db, query, params in queries:
@ -122,10 +127,12 @@ def execute_test(memgraph_binary, tester_binary):
execute_queries(mt_queries3) execute_queries(mt_queries3)
print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n") print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n")
# Shutdown the memgraph binary pid = memgraph.pid
memgraph.terminate() try:
os.kill(pid, SIGNAL_SIGTERM)
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" except os.OSError:
assert False
time.sleep(1)
# Verify the written log # Verify the written log
print("\033[1;36m~~ Starting log verification ~~\033[0m") print("\033[1;36m~~ Starting log verification ~~\033[0m")

View File

@ -21,6 +21,7 @@ import time
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
# When you create a new permission just add a testcase to this list (a tuple # When you create a new permission just add a testcase to this list (a tuple
# of query, touple of required permissions) and the test will automatically # of query, touple of required permissions) and the test will automatically
@ -166,8 +167,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
@atexit.register @atexit.register
def cleanup(): def cleanup():
if memgraph.poll() is None: if memgraph.poll() is None:
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
# Prepare the multi database environment # Prepare the multi database environment
execute_admin_queries( execute_admin_queries(
@ -327,8 +332,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n") print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n")
# Shutdown the memgraph binary # Shutdown the memgraph binary
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -20,7 +20,6 @@ import sys
import tempfile import tempfile
import time import time
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
TESTS_DIR = os.path.join(SCRIPT_DIR, "tests") TESTS_DIR = os.path.join(SCRIPT_DIR, "tests")
@ -31,6 +30,8 @@ WAL_FILE_NAME = "wal.bin"
DUMP_SNAPSHOT_FILE_NAME = "expected_snapshot.cypher" DUMP_SNAPSHOT_FILE_NAME = "expected_snapshot.cypher"
DUMP_WAL_FILE_NAME = "expected_wal.cypher" DUMP_WAL_FILE_NAME = "expected_wal.cypher"
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1): def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)] cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
@ -40,7 +41,7 @@ def wait_for_server(port, delay=0.1):
def sorted_content(file_path): def sorted_content(file_path):
with open(file_path, 'r') as fin: with open(file_path, "r") as fin:
return sorted(list(map(lambda x: x.strip(), fin.readlines()))) return sorted(list(map(lambda x: x.strip(), fin.readlines())))
@ -52,32 +53,27 @@ def list_to_string(data):
return ret return ret
def execute_test( def execute_test(memgraph_binary, dump_binary, test_directory, test_type, write_expected):
memgraph_binary, assert test_type in ["SNAPSHOT", "WAL"], "Test type should be either 'SNAPSHOT' or 'WAL'."
dump_binary, print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m".format(os.path.relpath(test_directory, TESTS_DIR), test_type))
test_directory,
test_type,
write_expected):
assert test_type in ["SNAPSHOT", "WAL"], \
"Test type should be either 'SNAPSHOT' or 'WAL'."
print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m"
.format(os.path.relpath(test_directory, TESTS_DIR), test_type))
working_data_directory = tempfile.TemporaryDirectory() working_data_directory = tempfile.TemporaryDirectory()
if test_type == "SNAPSHOT": if test_type == "SNAPSHOT":
snapshots_dir = os.path.join(working_data_directory.name, "snapshots") snapshots_dir = os.path.join(working_data_directory.name, "snapshots")
os.makedirs(snapshots_dir) os.makedirs(snapshots_dir)
shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), snapshots_dir)
snapshots_dir)
else: else:
wal_dir = os.path.join(working_data_directory.name, "wal") wal_dir = os.path.join(working_data_directory.name, "wal")
os.makedirs(wal_dir) os.makedirs(wal_dir)
shutil.copy(os.path.join(test_directory, WAL_FILE_NAME), wal_dir) shutil.copy(os.path.join(test_directory, WAL_FILE_NAME), wal_dir)
memgraph_args = [memgraph_binary, memgraph_args = [
memgraph_binary,
"--storage-recover-on-startup", "--storage-recover-on-startup",
"--storage-properties-on-edges", "--storage-properties-on-edges",
"--data-directory", working_data_directory.name] "--data-directory",
working_data_directory.name,
]
# Start the memgraph binary # Start the memgraph binary
memgraph = subprocess.Popen(memgraph_args) memgraph = subprocess.Popen(memgraph_args)
@ -89,8 +85,12 @@ def execute_test(
@atexit.register @atexit.register
def cleanup(): def cleanup():
if memgraph.poll() is None: if memgraph.poll() is None:
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
# Execute `database dump` # Execute `database dump`
dump_output_file = tempfile.NamedTemporaryFile() dump_output_file = tempfile.NamedTemporaryFile()
@ -98,28 +98,31 @@ def execute_test(
subprocess.run(dump_args, stdout=dump_output_file, check=True) subprocess.run(dump_args, stdout=dump_output_file, check=True)
# Shutdown the memgraph binary # Shutdown the memgraph binary
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
dump_file_name = DUMP_SNAPSHOT_FILE_NAME if test_type == "SNAPSHOT" else DUMP_WAL_FILE_NAME dump_file_name = DUMP_SNAPSHOT_FILE_NAME if test_type == "SNAPSHOT" else DUMP_WAL_FILE_NAME
if write_expected: if write_expected:
with open(dump_output_file.name, 'r') as dump: with open(dump_output_file.name, "r") as dump:
queries_got = dump.readlines() queries_got = dump.readlines()
# Write dump files # Write dump files
expected_dump_file = os.path.join(test_directory, dump_file_name) expected_dump_file = os.path.join(test_directory, dump_file_name)
with open(expected_dump_file, 'w') as expected: with open(expected_dump_file, "w") as expected:
expected.writelines(queries_got) expected.writelines(queries_got)
else: else:
# Compare dump files # Compare dump files
expected_dump_file = os.path.join(test_directory, dump_file_name) expected_dump_file = os.path.join(test_directory, dump_file_name)
assert os.path.exists(expected_dump_file), \ assert os.path.exists(expected_dump_file), "Could not find expected dump path {}".format(expected_dump_file)
"Could not find expected dump path {}".format(expected_dump_file)
queries_got = sorted_content(dump_output_file.name) queries_got = sorted_content(dump_output_file.name)
queries_expected = sorted_content(expected_dump_file) queries_expected = sorted_content(expected_dump_file)
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \ assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format(
"{}".format(list_to_string(queries_got), list_to_string(queries_got), list_to_string(queries_expected)
list_to_string(queries_expected)) )
print("\033[1;32m~~ Test successful ~~\033[0m\n") print("\033[1;32m~~ Test successful ~~\033[0m\n")
@ -141,15 +144,17 @@ def find_test_directories(directory):
continue continue
snapshot_file = os.path.join(test_dir_path, SNAPSHOT_FILE_NAME) snapshot_file = os.path.join(test_dir_path, SNAPSHOT_FILE_NAME)
wal_file = os.path.join(test_dir_path, WAL_FILE_NAME) wal_file = os.path.join(test_dir_path, WAL_FILE_NAME)
dump_snapshot_file = os.path.join( dump_snapshot_file = os.path.join(test_dir_path, DUMP_SNAPSHOT_FILE_NAME)
test_dir_path, DUMP_SNAPSHOT_FILE_NAME)
dump_wal_file = os.path.join(test_dir_path, DUMP_WAL_FILE_NAME) dump_wal_file = os.path.join(test_dir_path, DUMP_WAL_FILE_NAME)
if (os.path.isfile(snapshot_file) and os.path.isfile(dump_snapshot_file) if (
and os.path.isfile(wal_file) and os.path.isfile(dump_wal_file)): os.path.isfile(snapshot_file)
and os.path.isfile(dump_snapshot_file)
and os.path.isfile(wal_file)
and os.path.isfile(dump_wal_file)
):
test_dirs.append(test_dir_path) test_dirs.append(test_dir_path)
else: else:
raise Exception("Missing data in test directory '{}'" raise Exception("Missing data in test directory '{}'".format(test_dir_path))
.format(test_dir_path))
return test_dirs return test_dirs
@ -161,26 +166,15 @@ if __name__ == "__main__":
parser.add_argument("--memgraph", default=memgraph_binary) parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--dump", default=dump_binary) parser.add_argument("--dump", default=dump_binary)
parser.add_argument( parser.add_argument(
'--write-expected', "--write-expected", action="store_true", help="Overwrite the expected cypher with results from current run"
action='store_true', )
help='Overwrite the expected cypher with results from current run')
args = parser.parse_args() args = parser.parse_args()
test_directories = find_test_directories(TESTS_DIR) test_directories = find_test_directories(TESTS_DIR)
assert len(test_directories) > 0, "No tests have been found!" assert len(test_directories) > 0, "No tests have been found!"
for test_directory in test_directories: for test_directory in test_directories:
execute_test( execute_test(args.memgraph, args.dump, test_directory, "SNAPSHOT", args.write_expected)
args.memgraph, execute_test(args.memgraph, args.dump, test_directory, "WAL", args.write_expected)
args.dump,
test_directory,
"SNAPSHOT",
args.write_expected)
execute_test(
args.memgraph,
args.dump,
test_directory,
"WAL",
args.write_expected)
sys.exit(0) sys.exit(0)

View File

@ -22,6 +22,7 @@ from typing import List
SCRIPT_DIR = Path(__file__).absolute() SCRIPT_DIR = Path(__file__).absolute()
PROJECT_DIR = SCRIPT_DIR.parents[3] PROJECT_DIR = SCRIPT_DIR.parents[3]
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1): def wait_for_server(port, delay=0.1):
@ -68,8 +69,12 @@ def execute_with_user(queries):
def cleanup(memgraph): def cleanup(memgraph):
if memgraph.poll() is None: if memgraph.poll() is None:
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True): def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True):

View File

@ -24,6 +24,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\." UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1): def wait_for_server(port, delay=0.1):
@ -80,8 +81,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
@atexit.register @atexit.register
def cleanup(): def cleanup():
if memgraph.poll() is None: if memgraph.poll() is None:
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
# Prepare all users # Prepare all users
def setup_user(): def setup_user():
@ -130,8 +135,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n") print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n")
# Shutdown the memgraph binary # Shutdown the memgraph binary
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -21,6 +21,7 @@ from typing import List
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
def wait_for_server(port: int, delay: float = 0.1) -> float: def wait_for_server(port: int, delay: float = 0.1) -> float:
@ -86,8 +87,12 @@ def execute_without_user(
def cleanup(memgraph: subprocess): def cleanup(memgraph: subprocess):
if memgraph.poll() is None: if memgraph.poll() is None:
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
def test_without_any_files(tester_binary: str, memgraph_args: List[str]): def test_without_any_files(tester_binary: str, memgraph_args: List[str]):

View File

@ -23,6 +23,8 @@ import time
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
CONFIG_TEMPLATE = """ CONFIG_TEMPLATE = """
server: server:
host: "127.0.0.1" host: "127.0.0.1"
@ -52,8 +54,7 @@ def wait_for_server(port, delay=0.1):
time.sleep(delay) time.sleep(delay)
def execute_tester(binary, queries, username="", password="", def execute_tester(binary, queries, username="", password="", auth_should_fail=False, query_should_fail=False):
auth_should_fail=False, query_should_fail=False):
if password == "": if password == "":
password = username password = username
args = [binary, "--username", username, "--password", password] args = [binary, "--username", username, "--password", password]
@ -76,18 +77,14 @@ class Memgraph:
def start(self, **kwargs): def start(self, **kwargs):
self.stop() self.stop()
self._storage_directory = tempfile.TemporaryDirectory() self._storage_directory = tempfile.TemporaryDirectory()
self._auth_module = os.path.join(self._storage_directory.name, self._auth_module = os.path.join(self._storage_directory.name, "ldap.py")
"ldap.py") self._auth_config = os.path.join(self._storage_directory.name, "ldap.yaml")
self._auth_config = os.path.join(self._storage_directory.name, script_file = os.path.join(PROJECT_DIR, "src", "auth", "reference_modules", "ldap.py")
"ldap.yaml")
script_file = os.path.join(PROJECT_DIR, "src", "auth",
"reference_modules", "ldap.py")
virtualenv_bin = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3") virtualenv_bin = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3")
with open(script_file) as fin: with open(script_file) as fin:
data = fin.read() data = fin.read()
data = data.replace("/usr/bin/python3", virtualenv_bin) data = data.replace("/usr/bin/python3", virtualenv_bin)
data = data.replace("/etc/memgraph/auth/ldap.yaml", data = data.replace("/etc/memgraph/auth/ldap.yaml", self._auth_config)
self._auth_config)
with open(self._auth_module, "w") as fout: with open(self._auth_module, "w") as fout:
fout.write(data) fout.write(data)
os.chmod(self._auth_module, stat.S_IRWXU | stat.S_IRWXG) os.chmod(self._auth_module, stat.S_IRWXU | stat.S_IRWXG)
@ -106,10 +103,13 @@ class Memgraph:
} }
with open(self._auth_config, "w") as f: with open(self._auth_config, "w") as f:
f.write(CONFIG_TEMPLATE.format(**config)) f.write(CONFIG_TEMPLATE.format(**config))
args = [self._binary, args = [
"--data-directory", self._storage_directory.name, self._binary,
"--data-directory",
self._storage_directory.name,
"--auth-module-executable", "--auth-module-executable",
kwargs.pop("module_executable", self._auth_module)] kwargs.pop("module_executable", self._auth_module),
]
for key, value in kwargs.items(): for key, value in kwargs.items():
ldap_key = "--auth-module-" + key.replace("_", "-") ldap_key = "--auth-module-" + key.replace("_", "-")
if isinstance(value, bool): if isinstance(value, bool):
@ -119,26 +119,27 @@ class Memgraph:
args.append(value) args.append(value)
self._process = subprocess.Popen(args) self._process = subprocess.Popen(args)
time.sleep(0.1) time.sleep(0.1)
assert self._process.poll() is None, "Memgraph process died " \ assert self._process.poll() is None, "Memgraph process died " "prematurely!"
"prematurely!"
wait_for_server(7687) wait_for_server(7687)
def stop(self, check=True): def stop(self, check=True):
if self._process is None: if self._process is None:
return 0 return 0
self._process.terminate() pid = self._process.pid
exitcode = self._process.wait() try:
self._process = None os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
if check: if check:
assert exitcode == 0, "Memgraph process didn't exit cleanly!" assert False
return exitcode return -1
time.sleep(1)
return 0
def initialize_test(memgraph, tester_binary, **kwargs): def initialize_test(memgraph, tester_binary, **kwargs):
memgraph.start(module_executable="") memgraph.start(module_executable="")
execute_tester(tester_binary, execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
check_login = kwargs.pop("check_login", True) check_login = kwargs.pop("check_login", True)
memgraph.restart(**kwargs) memgraph.restart(**kwargs)
if check_login: if check_login:
@ -170,18 +171,15 @@ def test_role_mapping(memgraph, tester_binary):
initialize_test(memgraph, tester_binary) initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "alice") execute_tester(tester_binary, [], "alice")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
execute_tester(tester_binary, [], "bob") execute_tester(tester_binary, [], "bob")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True)
query_should_fail=True)
execute_tester(tester_binary, [], "carol") execute_tester(tester_binary, [], "carol")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True)
query_should_fail=True)
execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root") execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol") execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave") execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave")
@ -192,15 +190,13 @@ def test_role_mapping(memgraph, tester_binary):
def test_role_removal(memgraph, tester_binary): def test_role_removal(memgraph, tester_binary):
initialize_test(memgraph, tester_binary) initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "alice") execute_tester(tester_binary, [], "alice")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.restart(manage_roles=False) memgraph.restart(manage_roles=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root") execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
query_should_fail=True)
memgraph.stop() memgraph.stop()
@ -229,28 +225,22 @@ def test_user_is_role(memgraph, tester_binary):
def test_user_permissions_persistancy(memgraph, tester_binary): def test_user_permissions_persistancy(memgraph, tester_binary):
initialize_test(memgraph, tester_binary) initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
["CREATE USER alice", "GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop() memgraph.stop()
def test_role_permissions_persistancy(memgraph, tester_binary): def test_role_permissions_persistancy(memgraph, tester_binary):
initialize_test(memgraph, tester_binary) initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop() memgraph.stop()
def test_only_authentication(memgraph, tester_binary): def test_only_authentication(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, manage_roles=False) initialize_test(memgraph, tester_binary, manage_roles=False)
execute_tester(tester_binary, execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
["CREATE ROLE moderator", "GRANT MATCH TO moderator"], execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
memgraph.stop() memgraph.stop()
@ -267,22 +257,16 @@ def test_wrong_suffix(memgraph, tester_binary):
def test_suffix_with_spaces(memgraph, tester_binary): def test_suffix_with_spaces(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com")
suffix=", ou= people, dc = memgraph, dc = com") execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
execute_tester(tester_binary,
["CREATE USER alice", "GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop() memgraph.stop()
def test_role_mapping_wrong_root_dn(memgraph, tester_binary): def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com")
root_dn="ou=invalid,dc=memgraph,dc=com") execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
memgraph.restart() memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop() memgraph.stop()
@ -290,11 +274,8 @@ def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, root_objectclass="person") initialize_test(memgraph, tester_binary, root_objectclass="person")
execute_tester(tester_binary, execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
["CREATE ROLE moderator", "GRANT MATCH TO moderator"], execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
memgraph.restart() memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop() memgraph.stop()
@ -302,11 +283,8 @@ def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, user_attribute="cn") initialize_test(memgraph, tester_binary, user_attribute="cn")
execute_tester(tester_binary, execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
["CREATE ROLE moderator", "GRANT MATCH TO moderator"], execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
memgraph.restart() memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop() memgraph.stop()
@ -314,8 +292,7 @@ def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
def test_wrong_password(memgraph, tester_binary): def test_wrong_password(memgraph, tester_binary):
initialize_test(memgraph, tester_binary) initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "root", password="sudo", execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
auth_should_fail=True)
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.stop() memgraph.stop()
@ -326,12 +303,10 @@ def test_password_persistancy(memgraph, tester_binary):
execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo") execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo")
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.restart() memgraph.restart()
execute_tester(tester_binary, [], "root", password="sudo", execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
auth_should_fail=True)
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.restart(module_executable="") memgraph.restart(module_executable="")
execute_tester(tester_binary, [], "root", password="sudo", execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
auth_should_fail=True)
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.stop() memgraph.stop()
@ -339,33 +314,25 @@ def test_password_persistancy(memgraph, tester_binary):
def test_user_multiple_roles(memgraph, tester_binary): def test_user_multiple_roles(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, check_login=False) initialize_test(memgraph, tester_binary, check_login=False)
memgraph.restart() memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
query_should_fail=True)
memgraph.restart(manage_roles=False) memgraph.restart(manage_roles=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
query_should_fail=True)
memgraph.restart(manage_roles=False, root_dn="") memgraph.restart(manage_roles=False, root_dn="")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
query_should_fail=True)
memgraph.stop() memgraph.stop()
def test_starttls_failure(memgraph, tester_binary): def test_starttls_failure(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, encryption="starttls", initialize_test(memgraph, tester_binary, encryption="starttls", check_login=False)
check_login=False)
execute_tester(tester_binary, [], "root", auth_should_fail=True) execute_tester(tester_binary, [], "root", auth_should_fail=True)
memgraph.stop() memgraph.stop()
def test_ssl_failure(memgraph, tester_binary): def test_ssl_failure(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, encryption="ssl", initialize_test(memgraph, tester_binary, encryption="ssl", check_login=False)
check_login=False)
execute_tester(tester_binary, [], "root", auth_should_fail=True) execute_tester(tester_binary, [], "root", auth_should_fail=True)
memgraph.stop() memgraph.stop()
@ -375,22 +342,19 @@ def test_ssl_failure(memgraph, tester_binary):
if __name__ == "__main__": if __name__ == "__main__":
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "ldap", "tester")
"integration", "ldap", "tester")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary) parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--tester", default=tester_binary) parser.add_argument("--tester", default=tester_binary)
parser.add_argument("--openldap-dir", parser.add_argument("--openldap-dir", default=os.path.join(SCRIPT_DIR, "openldap-2.4.47"))
default=os.path.join(SCRIPT_DIR, "openldap-2.4.47"))
args = parser.parse_args() args = parser.parse_args()
# Setup Memgraph handler # Setup Memgraph handler
memgraph = Memgraph(args.memgraph) memgraph = Memgraph(args.memgraph)
# Start the slapd binary # Start the slapd binary
slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), "-h", "ldap://127.0.0.1:1389/", "-d", "0"]
"-h", "ldap://127.0.0.1:1389/", "-d", "0"]
slapd = subprocess.Popen(slapd_args) slapd = subprocess.Popen(slapd_args)
time.sleep(0.1) time.sleep(0.1)
assert slapd.poll() is None, "slapd process died prematurely!" assert slapd.poll() is None, "slapd process died prematurely!"
@ -409,8 +373,7 @@ if __name__ == "__main__":
if slapd_stat != 0: if slapd_stat != 0:
print("slapd process didn't exit cleanly!") print("slapd process didn't exit cleanly!")
assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " \ assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " "(memgraph, slapd) crashed!"
"(memgraph, slapd) crashed!"
# Execute tests # Execute tests
names = sorted(globals().keys()) names = sorted(globals().keys())

View File

@ -18,12 +18,13 @@ import subprocess
import sys import sys
import tempfile import tempfile
import time import time
import yaml
import yaml
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
BUILD_DIR = os.path.join(BASE_DIR, "build") BUILD_DIR = os.path.join(BASE_DIR, "build")
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1): def wait_for_server(port, delay=0.1):
@ -46,17 +47,14 @@ def list_to_string(data):
def verify_lifetime(memgraph_binary, mg_import_csv_binary): def verify_lifetime(memgraph_binary, mg_import_csv_binary):
print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " "memgraph is running ~~\033[0m")
"memgraph is running ~~\033[0m")
storage_directory = tempfile.TemporaryDirectory() storage_directory = tempfile.TemporaryDirectory()
# Generate common args # Generate common args
common_args = ["--data-directory", storage_directory.name, common_args = ["--data-directory", storage_directory.name, "--storage-properties-on-edges=false"]
"--storage-properties-on-edges=false"]
# Start the memgraph binary # Start the memgraph binary
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \ memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args
common_args
memgraph = subprocess.Popen(list(map(str, memgraph_args))) memgraph = subprocess.Popen(list(map(str, memgraph_args)))
time.sleep(0.1) time.sleep(0.1)
assert memgraph.poll() is None, "Memgraph process died prematurely!" assert memgraph.poll() is None, "Memgraph process died prematurely!"
@ -66,47 +64,52 @@ def verify_lifetime(memgraph_binary, mg_import_csv_binary):
@atexit.register @atexit.register
def cleanup(): def cleanup():
if memgraph.poll() is None: if memgraph.poll() is None:
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
# Execute mg_import_csv. # Execute mg_import_csv.
mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + \ mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + common_args
common_args
ret = subprocess.run(mg_import_csv_args) ret = subprocess.run(mg_import_csv_args)
# Check the return code # Check the return code
if ret.returncode == 0: if ret.returncode == 0:
raise Exception( raise Exception("The importer was able to run while memgraph was running!")
"The importer was able to run while memgraph was running!")
# Shutdown the memgraph binary # Shutdown the memgraph binary
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
print("\033[1;32m~~ Test successful ~~\033[0m\n") print("\033[1;32m~~ Test successful ~~\033[0m\n")
def execute_test(name, test_path, test_config, memgraph_binary, def execute_test(name, test_path, test_config, memgraph_binary, mg_import_csv_binary, tester_binary, write_expected):
mg_import_csv_binary, tester_binary, write_expected):
print("\033[1;36m~~ Executing test", name, "~~\033[0m") print("\033[1;36m~~ Executing test", name, "~~\033[0m")
storage_directory = tempfile.TemporaryDirectory() storage_directory = tempfile.TemporaryDirectory()
# Verify test configuration # Verify test configuration
if ("import_should_fail" not in test_config and if ("import_should_fail" not in test_config and "expected" not in test_config) or (
"expected" not in test_config) or \ "import_should_fail" in test_config and "expected" in test_config
("import_should_fail" in test_config and ):
"expected" in test_config): raise Exception("The test should specify either 'import_should_fail' " "or 'expected'!")
raise Exception("The test should specify either 'import_should_fail' "
"or 'expected'!")
expected_path = test_config.pop("expected", "") expected_path = test_config.pop("expected", "")
import_should_fail = test_config.pop("import_should_fail", False) import_should_fail = test_config.pop("import_should_fail", False)
# Generate common args # Generate common args
properties_on_edges = bool(test_config.pop("properties_on_edges", False)) properties_on_edges = bool(test_config.pop("properties_on_edges", False))
common_args = ["--data-directory", storage_directory.name, common_args = [
"--storage-properties-on-edges=" + "--data-directory",
str(properties_on_edges).lower()] storage_directory.name,
"--storage-properties-on-edges=" + str(properties_on_edges).lower(),
]
# Generate mg_import_csv args using flags specified in the test # Generate mg_import_csv args using flags specified in the test
mg_import_csv_args = [mg_import_csv_binary] + common_args mg_import_csv_args = [mg_import_csv_binary] + common_args
@ -125,19 +128,16 @@ def execute_test(name, test_path, test_config, memgraph_binary,
if import_should_fail: if import_should_fail:
if ret.returncode == 0: if ret.returncode == 0:
raise Exception("The import should have failed, but it " raise Exception("The import should have failed, but it " "succeeded instead!")
"succeeded instead!")
else: else:
print("\033[1;32m~~ Test successful ~~\033[0m\n") print("\033[1;32m~~ Test successful ~~\033[0m\n")
return return
else: else:
if ret.returncode != 0: if ret.returncode != 0:
raise Exception("The import should have succeeded, but it " raise Exception("The import should have succeeded, but it " "failed instead!")
"failed instead!")
# Start the memgraph binary # Start the memgraph binary
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \ memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args
common_args
memgraph = subprocess.Popen(list(map(str, memgraph_args))) memgraph = subprocess.Popen(list(map(str, memgraph_args)))
time.sleep(0.1) time.sleep(0.1)
assert memgraph.poll() is None, "Memgraph process died prematurely!" assert memgraph.poll() is None, "Memgraph process died prematurely!"
@ -147,21 +147,29 @@ def execute_test(name, test_path, test_config, memgraph_binary,
@atexit.register @atexit.register
def cleanup(): def cleanup():
if memgraph.poll() is None: if memgraph.poll() is None:
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
# Get the contents of the database # Get the contents of the database
queries_got = extract_rows(subprocess.run( queries_got = extract_rows(
[tester_binary], stdout=subprocess.PIPE, subprocess.run([tester_binary], stdout=subprocess.PIPE, check=True).stdout.decode("utf-8")
check=True).stdout.decode("utf-8")) )
# Shutdown the memgraph binary # Shutdown the memgraph binary
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
if write_expected: if write_expected:
with open(os.path.join(test_path, expected_path), 'w') as expected: with open(os.path.join(test_path, expected_path), "w") as expected:
expected.write('\n'.join(queries_got)) expected.write("\n".join(queries_got))
else: else:
if expected_path: if expected_path:
@ -173,18 +181,16 @@ def execute_test(name, test_path, test_config, memgraph_binary,
# Verify the queries # Verify the queries
queries_expected.sort() queries_expected.sort()
queries_got.sort() queries_got.sort()
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \ assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format(
"{}".format(list_to_string(queries_got), list_to_string(queries_got), list_to_string(queries_expected)
list_to_string(queries_expected)) )
print("\033[1;32m~~ Test successful ~~\033[0m\n") print("\033[1;32m~~ Test successful ~~\033[0m\n")
if __name__ == "__main__": if __name__ == "__main__":
memgraph_binary = os.path.join(BUILD_DIR, "memgraph") memgraph_binary = os.path.join(BUILD_DIR, "memgraph")
mg_import_csv_binary = os.path.join( mg_import_csv_binary = os.path.join(BUILD_DIR, "src", "mg_import_csv")
BUILD_DIR, "src", "mg_import_csv") tester_binary = os.path.join(BUILD_DIR, "tests", "integration", "mg_import_csv", "tester")
tester_binary = os.path.join(
BUILD_DIR, "tests", "integration", "mg_import_csv", "tester")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary) parser.add_argument("--memgraph", default=memgraph_binary)
@ -193,7 +199,8 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--write-expected", "--write-expected",
action="store_true", action="store_true",
help="Overwrite the expected values with the results of the current run") help="Overwrite the expected values with the results of the current run",
)
args = parser.parse_args() args = parser.parse_args()
# First test whether the CSV importer can be started while the main # First test whether the CSV importer can be started while the main
@ -211,7 +218,8 @@ if __name__ == "__main__":
testcases = yaml.safe_load(f) testcases = yaml.safe_load(f)
for test_config in testcases: for test_config in testcases:
test_name = name + "/" + test_config.pop("name") test_name = name + "/" + test_config.pop("name")
execute_test(test_name, test_path, test_config, args.memgraph, execute_test(
args.mg_import_csv, args.tester, args.write_expected) test_name, test_path, test_config, args.memgraph, args.mg_import_csv, args.tester, args.write_expected
)
sys.exit(0) sys.exit(0)

View File

@ -23,6 +23,7 @@ from typing import List
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
def wait_for_server(port: int, delay: float = 0.1) -> float: def wait_for_server(port: int, delay: float = 0.1) -> float:
@ -91,8 +92,12 @@ def check_config(tester_binary: str, flag: str, value: str) -> None:
def cleanup(memgraph: subprocess): def cleanup(memgraph: subprocess):
if memgraph.poll() is None: if memgraph.poll() is None:
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
def run_test(tester_binary: str, memgraph_args: List[str], server_name: str, query_tx: str): def run_test(tester_binary: str, memgraph_args: List[str], server_name: str, query_tx: str):

View File

@ -22,6 +22,8 @@ assertion_queries = [
f"MATCH (n)-[e]->(m) WITH count(e) as cnt RETURN assert(cnt={len(edge_queries)});", f"MATCH (n)-[e]->(m) WITH count(e) as cnt RETURN assert(cnt={len(edge_queries)});",
] ]
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1): def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)] cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
@ -40,9 +42,12 @@ def prepare_memgraph(memgraph_args):
def terminate_memgraph(memgraph): def terminate_memgraph(memgraph):
memgraph.terminate() pid = memgraph.pid
time.sleep(0.1) try:
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
def execute_tester( def execute_tester(
@ -90,8 +95,12 @@ def execute_test_analytical_mode(memgraph_binary: str, tester_binary: str) -> No
execute_queries(assertion_queries) execute_queries(assertion_queries)
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_binary: str) -> None: def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_binary: str) -> None:
@ -135,8 +144,12 @@ def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_bi
execute_queries(assertion_queries) execute_queries(assertion_queries)
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n") print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_binary: str) -> None: def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_binary: str) -> None:
@ -177,8 +190,12 @@ def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_bi
execute_queries(assertion_queries) execute_queries(assertion_queries)
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n") print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
memgraph.terminate() pid = memgraph.pid
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -11,12 +11,13 @@
import logging import logging
import os import os
import shutil
import subprocess import subprocess
import tempfile
import time
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import defaultdict from collections import defaultdict
import tempfile
import shutil
import time
from common import get_absolute_path, set_cpus from common import get_absolute_path, set_cpus
try: try:
@ -36,13 +37,12 @@ class Memgraph:
""" """
Knows how to start and stop memgraph. Knows how to start and stop memgraph.
""" """
def __init__(self, args, num_workers): def __init__(self, args, num_workers):
self.log = logging.getLogger("MemgraphRunner") self.log = logging.getLogger("MemgraphRunner")
argp = ArgumentParser("MemgraphArgumentParser") argp = ArgumentParser("MemgraphArgumentParser")
argp.add_argument("--runner-bin", argp.add_argument("--runner-bin", default=get_absolute_path("memgraph", "build"))
default=get_absolute_path("memgraph", "build")) argp.add_argument("--port", default="7687", help="Database and client port")
argp.add_argument("--port", default="7687",
help="Database and client port")
argp.add_argument("--data-directory", default=None) argp.add_argument("--data-directory", default=None)
argp.add_argument("--storage-snapshot-on-exit", action="store_true") argp.add_argument("--storage-snapshot-on-exit", action="store_true")
argp.add_argument("--storage-recover-on-startup", action="store_true") argp.add_argument("--storage-recover-on-startup", action="store_true")
@ -55,8 +55,7 @@ class Memgraph:
def start(self): def start(self):
self.log.info("start") self.log.info("start")
database_args = ["--bolt-port", self.args.port, database_args = ["--bolt-port", self.args.port, "--query-execution-timeout-sec", "0"]
"--query-execution-timeout-sec", "0"]
if self.num_workers: if self.num_workers:
database_args += ["--bolt-num-workers", str(self.num_workers)] database_args += ["--bolt-num-workers", str(self.num_workers)]
if self.args.data_directory: if self.args.data_directory:
@ -82,15 +81,13 @@ class Neo:
""" """
Knows how to start and stop neo4j. Knows how to start and stop neo4j.
""" """
def __init__(self, args, config): def __init__(self, args, config):
self.log = logging.getLogger("NeoRunner") self.log = logging.getLogger("NeoRunner")
argp = ArgumentParser("NeoArgumentParser") argp = ArgumentParser("NeoArgumentParser")
argp.add_argument("--runner-bin", default=get_absolute_path( argp.add_argument("--runner-bin", default=get_absolute_path("neo4j/bin/neo4j", "libs"))
"neo4j/bin/neo4j", "libs")) argp.add_argument("--port", default="7687", help="Database and client port")
argp.add_argument("--port", default="7687", argp.add_argument("--http-port", default="7474", help="Database and client port")
help="Database and client port")
argp.add_argument("--http-port", default="7474",
help="Database and client port")
self.log.info("Initializing Runner with arguments %r", args) self.log.info("Initializing Runner with arguments %r", args)
self.args, _ = argp.parse_known_args(args) self.args, _ = argp.parse_known_args(args)
self.config = config self.config = config
@ -105,24 +102,22 @@ class Neo:
self.neo4j_home_path = tempfile.mkdtemp(dir="/dev/shm") self.neo4j_home_path = tempfile.mkdtemp(dir="/dev/shm")
try: try:
os.symlink(os.path.join(get_absolute_path("neo4j", "libs"), "lib"), os.symlink(
os.path.join(self.neo4j_home_path, "lib")) os.path.join(get_absolute_path("neo4j", "libs"), "lib"), os.path.join(self.neo4j_home_path, "lib")
)
neo4j_conf_dir = os.path.join(self.neo4j_home_path, "conf") neo4j_conf_dir = os.path.join(self.neo4j_home_path, "conf")
neo4j_conf_file = os.path.join(neo4j_conf_dir, "neo4j.conf") neo4j_conf_file = os.path.join(neo4j_conf_dir, "neo4j.conf")
os.mkdir(neo4j_conf_dir) os.mkdir(neo4j_conf_dir)
shutil.copyfile(self.config, neo4j_conf_file) shutil.copyfile(self.config, neo4j_conf_file)
with open(neo4j_conf_file, "a") as f: with open(neo4j_conf_file, "a") as f:
f.write("\ndbms.connector.bolt.listen_address=:" + f.write("\ndbms.connector.bolt.listen_address=:" + self.args.port + "\n")
self.args.port + "\n") f.write("\ndbms.connector.http.listen_address=:" + self.args.http_port + "\n")
f.write("\ndbms.connector.http.listen_address=:" +
self.args.http_port + "\n")
# environment # environment
cwd = os.path.dirname(self.args.runner_bin) cwd = os.path.dirname(self.args.runner_bin)
env = {"NEO4J_HOME": self.neo4j_home_path} env = {"NEO4J_HOME": self.neo4j_home_path}
self.database_bin.run(self.args.runner_bin, args=["console"], self.database_bin.run(self.args.runner_bin, args=["console"], env=env, timeout=600, cwd=cwd)
env=env, timeout=600, cwd=cwd)
except: except:
shutil.rmtree(self.neo4j_home_path) shutil.rmtree(self.neo4j_home_path)
raise Exception("Couldn't run Neo4j!") raise Exception("Couldn't run Neo4j!")