Use extent hooks for per query memory limit (#1340)
This commit is contained in:
parent
3d4d841753
commit
a84f570c6d
@ -111,6 +111,22 @@ enum mgp_error mgp_global_aligned_alloc(size_t size_in_bytes, size_t alignment,
|
||||
/// The behavior is undefined if `ptr` is not a value returned from a prior
|
||||
/// mgp_global_alloc() or mgp_global_aligned_alloc().
|
||||
void mgp_global_free(void *p);
|
||||
|
||||
/// State of the graph database.
|
||||
struct mgp_graph;
|
||||
|
||||
/// Allocations are tracked only for master thread. If new threads are spawned
|
||||
/// inside procedure, by calling following function
|
||||
/// you can start tracking allocations for current thread too. This
|
||||
/// is important if you need query memory limit to work
|
||||
/// for given procedure or per procedure memory limit.
|
||||
enum mgp_error mgp_track_current_thread_allocations(struct mgp_graph *graph);
|
||||
|
||||
/// Once allocations are tracked for current thread, you need to stop tracking allocations
|
||||
/// for given thread, before thread finishes with execution, or is detached.
|
||||
/// Otherwise it might result in slowdown of system due to unnecessary tracking of
|
||||
/// allocations.
|
||||
enum mgp_error mgp_untrack_current_thread_allocations(struct mgp_graph *graph);
|
||||
///@}
|
||||
|
||||
/// @name Operations on mgp_value
|
||||
@ -854,9 +870,6 @@ enum mgp_error mgp_edge_set_properties(struct mgp_edge *e, struct mgp_map *prope
|
||||
enum mgp_error mgp_edge_iter_properties(struct mgp_edge *e, struct mgp_memory *memory,
|
||||
struct mgp_properties_iterator **result);
|
||||
|
||||
/// State of the graph database.
|
||||
struct mgp_graph;
|
||||
|
||||
/// Get the vertex corresponding to given ID, or NULL if no such vertex exists.
|
||||
/// Resulting vertex must be freed using mgp_vertex_destroy.
|
||||
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate the vertex.
|
||||
|
@ -23,7 +23,7 @@
|
||||
#include "glue/run_id.hpp"
|
||||
#include "helpers.hpp"
|
||||
#include "license/license_sender.hpp"
|
||||
#include "memory/memory_control.hpp"
|
||||
#include "memory/global_memory_control.hpp"
|
||||
#include "query/config.hpp"
|
||||
#include "query/discard_value_stream.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
@ -512,6 +512,7 @@ int main(int argc, char **argv) {
|
||||
|
||||
server.AwaitShutdown();
|
||||
websocket_server.AwaitShutdown();
|
||||
memgraph::memory::UnsetHooks();
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
metrics_server.AwaitShutdown();
|
||||
|
@ -1,6 +1,7 @@
|
||||
set(memory_src_files
|
||||
new_delete.cpp
|
||||
memory_control.cpp)
|
||||
global_memory_control.cpp
|
||||
query_memory_control.cpp)
|
||||
|
||||
|
||||
|
||||
|
@ -9,12 +9,16 @@
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include "memory_control.hpp"
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
|
||||
#include "global_memory_control.hpp"
|
||||
#include "query_memory_control.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/memory_tracker.hpp"
|
||||
|
||||
#if USE_JEMALLOC
|
||||
#include <jemalloc/jemalloc.h>
|
||||
#include "jemalloc/jemalloc.h"
|
||||
#endif
|
||||
|
||||
namespace memgraph::memory {
|
||||
@ -57,12 +61,24 @@ void *my_alloc(extent_hooks_t *extent_hooks, void *new_addr, size_t size, size_t
|
||||
// This needs to be before, to throw exception in case of too big alloc
|
||||
if (*commit) [[likely]] {
|
||||
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(size));
|
||||
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||
if (memory_tracker != nullptr) [[likely]] {
|
||||
memory_tracker->Alloc(static_cast<int64_t>(size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto *ptr = old_hooks->alloc(extent_hooks, new_addr, size, alignment, zero, commit, arena_ind);
|
||||
if (ptr == nullptr) [[unlikely]] {
|
||||
if (*commit) {
|
||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
||||
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||
if (memory_tracker != nullptr) [[likely]] {
|
||||
memory_tracker->Free(static_cast<int64_t>(size));
|
||||
}
|
||||
}
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
@ -79,6 +95,13 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo
|
||||
|
||||
if (committed) [[likely]] {
|
||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
||||
|
||||
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||
if (memory_tracker != nullptr) [[likely]] {
|
||||
memory_tracker->Free(static_cast<int64_t>(size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
@ -87,6 +110,12 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo
|
||||
static void my_destroy(extent_hooks_t *extent_hooks, void *addr, size_t size, bool committed, unsigned arena_ind) {
|
||||
if (committed) [[likely]] {
|
||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
||||
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||
if (memory_tracker != nullptr) [[likely]] {
|
||||
memory_tracker->Free(static_cast<int64_t>(size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
old_hooks->destroy(extent_hooks, addr, size, committed, arena_ind);
|
||||
@ -101,6 +130,12 @@ static bool my_commit(extent_hooks_t *extent_hooks, void *addr, size_t size, siz
|
||||
}
|
||||
|
||||
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(length));
|
||||
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||
if (memory_tracker != nullptr) [[likely]] {
|
||||
memory_tracker->Alloc(static_cast<int64_t>(size));
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@ -115,6 +150,12 @@ static bool my_decommit(extent_hooks_t *extent_hooks, void *addr, size_t size, s
|
||||
}
|
||||
|
||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
|
||||
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||
if (memory_tracker != nullptr) [[likely]] {
|
||||
memory_tracker->Free(static_cast<int64_t>(size));
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@ -129,6 +170,13 @@ static bool my_purge_forced(extent_hooks_t *extent_hooks, void *addr, size_t siz
|
||||
}
|
||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
|
||||
|
||||
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||
if (memory_tracker != nullptr) [[likely]] {
|
||||
memory_tracker->Alloc(static_cast<int64_t>(size));
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -153,6 +201,7 @@ void SetHooks() {
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_arenas; i++) {
|
||||
GetQueriesMemoryControl().InitializeArenaCounter(i);
|
||||
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
|
||||
|
||||
size_t hooks_len = sizeof(old_hooks);
|
||||
@ -197,6 +246,45 @@ void SetHooks() {
|
||||
#endif
|
||||
}
|
||||
|
||||
void UnsetHooks() {
|
||||
#if USE_JEMALLOC
|
||||
|
||||
uint64_t allocated{0};
|
||||
uint64_t sz{sizeof(allocated)};
|
||||
|
||||
sz = sizeof(unsigned);
|
||||
unsigned n_arenas{0};
|
||||
int err = mallctl("opt.narenas", (void *)&n_arenas, &sz, nullptr, 0);
|
||||
|
||||
if (err) {
|
||||
LOG_FATAL("Error setting default hooks for jemalloc arenas");
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_arenas; i++) {
|
||||
GetQueriesMemoryControl().InitializeArenaCounter(i);
|
||||
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
|
||||
|
||||
MG_ASSERT(old_hooks);
|
||||
MG_ASSERT(old_hooks->alloc);
|
||||
MG_ASSERT(old_hooks->dalloc);
|
||||
MG_ASSERT(old_hooks->destroy);
|
||||
MG_ASSERT(old_hooks->commit);
|
||||
MG_ASSERT(old_hooks->decommit);
|
||||
MG_ASSERT(old_hooks->purge_forced);
|
||||
MG_ASSERT(old_hooks->purge_lazy);
|
||||
MG_ASSERT(old_hooks->split);
|
||||
MG_ASSERT(old_hooks->merge);
|
||||
|
||||
err = mallctl(func_name.c_str(), nullptr, nullptr, &old_hooks, sizeof(old_hooks));
|
||||
|
||||
if (err) {
|
||||
LOG_FATAL("Error setting default hooks for jemalloc arena {}", i);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
void PurgeUnusedMemory() {
|
||||
#if USE_JEMALLOC
|
||||
mallctl("arena." STRINGIFY(MALLCTL_ARENAS_ALL) ".purge", nullptr, nullptr, nullptr, 0);
|
@ -17,5 +17,6 @@ namespace memgraph::memory {
|
||||
|
||||
void PurgeUnusedMemory();
|
||||
void SetHooks();
|
||||
void UnsetHooks();
|
||||
|
||||
} // namespace memgraph::memory
|
140
src/memory/query_memory_control.cpp
Normal file
140
src/memory/query_memory_control.cpp
Normal file
@ -0,0 +1,140 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <shared_mutex>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "query_memory_control.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/memory_tracker.hpp"
|
||||
#include "utils/rw_spin_lock.hpp"
|
||||
|
||||
#if USE_JEMALLOC
|
||||
#include "jemalloc/jemalloc.h"
|
||||
#endif
|
||||
|
||||
namespace memgraph::memory {
|
||||
|
||||
#if USE_JEMALLOC
|
||||
|
||||
unsigned QueriesMemoryControl::GetArenaForThread() {
|
||||
unsigned thread_arena{0};
|
||||
size_t size_thread_arena = sizeof(thread_arena);
|
||||
int err = mallctl("thread.arena", &thread_arena, &size_thread_arena, nullptr, 0);
|
||||
if (err) {
|
||||
LOG_FATAL("Can't get arena for thread.");
|
||||
}
|
||||
return thread_arena;
|
||||
}
|
||||
|
||||
void QueriesMemoryControl::AddTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_add(1); }
|
||||
|
||||
void QueriesMemoryControl::RemoveTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_sub(1); }
|
||||
|
||||
void QueriesMemoryControl::UpdateThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) {
|
||||
auto accessor = thread_id_to_transaction_id.access();
|
||||
accessor.insert({thread_id, transaction_id});
|
||||
}
|
||||
|
||||
void QueriesMemoryControl::EraseThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) {
|
||||
auto accessor = thread_id_to_transaction_id.access();
|
||||
auto elem = accessor.find(thread_id);
|
||||
MG_ASSERT(elem != accessor.end() && elem->transaction_id == transaction_id);
|
||||
accessor.remove(thread_id);
|
||||
}
|
||||
|
||||
utils::MemoryTracker *QueriesMemoryControl::GetTrackerCurrentThread() {
|
||||
auto thread_id_to_transaction_id_accessor = thread_id_to_transaction_id.access();
|
||||
|
||||
// we might be just constructing mapping between thread id and transaction id
|
||||
// so we miss this allocation
|
||||
auto thread_id_to_transaction_id_elem = thread_id_to_transaction_id_accessor.find(std::this_thread::get_id());
|
||||
if (thread_id_to_transaction_id_elem == thread_id_to_transaction_id_accessor.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
|
||||
auto transaction_id_to_tracker =
|
||||
transaction_id_to_tracker_accessor.find(thread_id_to_transaction_id_elem->transaction_id);
|
||||
return &transaction_id_to_tracker->tracker;
|
||||
}
|
||||
|
||||
void QueriesMemoryControl::CreateTransactionIdTracker(uint64_t transaction_id, size_t inital_limit) {
|
||||
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
|
||||
|
||||
auto [elem, result] = transaction_id_to_tracker_accessor.insert({transaction_id, utils::MemoryTracker{}});
|
||||
|
||||
elem->tracker.SetMaximumHardLimit(inital_limit);
|
||||
elem->tracker.SetHardLimit(inital_limit);
|
||||
}
|
||||
|
||||
bool QueriesMemoryControl::CheckTransactionIdTrackerExists(uint64_t transaction_id) {
|
||||
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
|
||||
return transaction_id_to_tracker_accessor.contains(transaction_id);
|
||||
}
|
||||
|
||||
bool QueriesMemoryControl::EraseTransactionIdTracker(uint64_t transaction_id) {
|
||||
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
|
||||
auto removed = transaction_id_to_tracker.access().remove(transaction_id);
|
||||
return removed;
|
||||
}
|
||||
|
||||
bool QueriesMemoryControl::IsArenaTracked(unsigned arena_ind) {
|
||||
return arena_tracking[arena_ind].load(std::memory_order_acquire) != 0;
|
||||
}
|
||||
|
||||
void QueriesMemoryControl::InitializeArenaCounter(unsigned arena_ind) {
|
||||
arena_tracking[arena_ind].store(0, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void StartTrackingCurrentThreadTransaction(uint64_t transaction_id) {
|
||||
#if USE_JEMALLOC
|
||||
GetQueriesMemoryControl().UpdateThreadToTransactionId(std::this_thread::get_id(), transaction_id);
|
||||
GetQueriesMemoryControl().AddTrackingOnArena(QueriesMemoryControl::GetArenaForThread());
|
||||
#endif
|
||||
}
|
||||
|
||||
void StopTrackingCurrentThreadTransaction(uint64_t transaction_id) {
|
||||
#if USE_JEMALLOC
|
||||
GetQueriesMemoryControl().EraseThreadToTransactionId(std::this_thread::get_id(), transaction_id);
|
||||
GetQueriesMemoryControl().RemoveTrackingOnArena(QueriesMemoryControl::GetArenaForThread());
|
||||
#endif
|
||||
}
|
||||
|
||||
void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit) {
|
||||
#if USE_JEMALLOC
|
||||
if (GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) {
|
||||
return;
|
||||
}
|
||||
GetQueriesMemoryControl().CreateTransactionIdTracker(transaction_id, limit);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
void TryStopTrackingOnTransaction(uint64_t transaction_id) {
|
||||
#if USE_JEMALLOC
|
||||
if (!GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) {
|
||||
return;
|
||||
}
|
||||
GetQueriesMemoryControl().EraseTransactionIdTracker(transaction_id);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace memgraph::memory
|
141
src/memory/query_memory_control.hpp
Normal file
141
src/memory/query_memory_control.hpp
Normal file
@ -0,0 +1,141 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "utils/memory_tracker.hpp"
|
||||
#include "utils/skip_list.hpp"
|
||||
|
||||
namespace memgraph::memory {
|
||||
|
||||
#if USE_JEMALLOC
|
||||
|
||||
// Track memory allocations per query.
|
||||
// Multiple threads can allocate inside one transaction.
|
||||
// If user forgets to unregister tracking for that thread before it dies, it will continue to
|
||||
// track allocations for that arena indefinitely.
|
||||
// As multiple queries can be executed inside one transaction, one by one (multi-transaction)
|
||||
// it is necessary to restart tracking at the beginning of new query for that transaction.
|
||||
class QueriesMemoryControl {
|
||||
public:
|
||||
/*
|
||||
Arena stats
|
||||
*/
|
||||
|
||||
static unsigned GetArenaForThread();
|
||||
|
||||
// Add counter on threads allocating inside arena
|
||||
void AddTrackingOnArena(unsigned);
|
||||
|
||||
// Remove counter on threads allocating in arena
|
||||
void RemoveTrackingOnArena(unsigned);
|
||||
|
||||
// Are any threads using current arena for allocations
|
||||
// Multiple threads can allocate inside one arena
|
||||
bool IsArenaTracked(unsigned);
|
||||
|
||||
// Initialize arena counter
|
||||
void InitializeArenaCounter(unsigned);
|
||||
|
||||
/*
|
||||
Transaction id <-> tracker
|
||||
*/
|
||||
|
||||
// Create new tracker for transaction_id with initial limit
|
||||
void CreateTransactionIdTracker(uint64_t, size_t);
|
||||
|
||||
// Check if tracker for given transaction id exists
|
||||
bool CheckTransactionIdTrackerExists(uint64_t);
|
||||
|
||||
// Remove current tracker for transaction_id
|
||||
bool EraseTransactionIdTracker(uint64_t);
|
||||
|
||||
/*
|
||||
Thread handlings
|
||||
*/
|
||||
|
||||
// Map thread to transaction with given id
|
||||
// This way we can know which thread belongs to which transaction
|
||||
// and get correct tracker for given transaction
|
||||
void UpdateThreadToTransactionId(const std::thread::id &, uint64_t);
|
||||
|
||||
// Remove tracking of thread from transaction.
|
||||
// Important to reset if one thread gets reused for different transaction
|
||||
void EraseThreadToTransactionId(const std::thread::id &, uint64_t);
|
||||
|
||||
// C-API functionality for thread to transaction mapping
|
||||
void UpdateThreadToTransactionId(const char *, uint64_t);
|
||||
|
||||
// C-API functionality for thread to transaction unmapping
|
||||
void EraseThreadToTransactionId(const char *, uint64_t);
|
||||
|
||||
// Get tracker to current thread if exists, otherwise return
|
||||
// nullptr. This can happen only if tracker is still
|
||||
// being constructed.
|
||||
utils::MemoryTracker *GetTrackerCurrentThread();
|
||||
|
||||
private:
|
||||
std::unordered_map<unsigned, std::atomic<int>> arena_tracking;
|
||||
|
||||
struct ThreadIdToTransactionId {
|
||||
std::thread::id thread_id;
|
||||
uint64_t transaction_id;
|
||||
|
||||
bool operator<(const ThreadIdToTransactionId &other) const { return thread_id < other.thread_id; }
|
||||
bool operator==(const ThreadIdToTransactionId &other) const { return thread_id == other.thread_id; }
|
||||
|
||||
bool operator<(const std::thread::id other) const { return thread_id < other; }
|
||||
bool operator==(const std::thread::id other) const { return thread_id == other; }
|
||||
};
|
||||
|
||||
struct TransactionIdToTracker {
|
||||
uint64_t transaction_id;
|
||||
utils::MemoryTracker tracker;
|
||||
|
||||
bool operator<(const TransactionIdToTracker &other) const { return transaction_id < other.transaction_id; }
|
||||
bool operator==(const TransactionIdToTracker &other) const { return transaction_id == other.transaction_id; }
|
||||
|
||||
bool operator<(uint64_t other) const { return transaction_id < other; }
|
||||
bool operator==(uint64_t other) const { return transaction_id == other; }
|
||||
};
|
||||
|
||||
utils::SkipList<ThreadIdToTransactionId> thread_id_to_transaction_id;
|
||||
utils::SkipList<TransactionIdToTracker> transaction_id_to_tracker;
|
||||
};
|
||||
|
||||
inline QueriesMemoryControl &GetQueriesMemoryControl() {
|
||||
static QueriesMemoryControl queries_memory_control_;
|
||||
return queries_memory_control_;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// API function call for to start tracking current thread for given transaction.
|
||||
// Does nothing if jemalloc is not enabled
|
||||
void StartTrackingCurrentThreadTransaction(uint64_t transaction_id);
|
||||
|
||||
// API function call for to stop tracking current thread for given transaction.
|
||||
// Does nothing if jemalloc is not enabled
|
||||
void StopTrackingCurrentThreadTransaction(uint64_t transaction_id);
|
||||
|
||||
// API function call for try to create tracker for transaction and set it to given limit.
|
||||
// Does nothing if jemalloc is not enabled. Does nothing if tracker already exists
|
||||
void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit);
|
||||
|
||||
// API function call to stop tracking for given transaction.
|
||||
// Does nothing if jemalloc is not enabled. Does nothing if tracker doesn't exist
|
||||
void TryStopTrackingOnTransaction(uint64_t transaction_id);
|
||||
|
||||
} // namespace memgraph::memory
|
@ -21,6 +21,10 @@ namespace memgraph::query {
|
||||
SubgraphDbAccessor::SubgraphDbAccessor(query::DbAccessor db_accessor, Graph *graph)
|
||||
: db_accessor_(db_accessor), graph_(graph) {}
|
||||
|
||||
void SubgraphDbAccessor::TrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); }
|
||||
|
||||
void SubgraphDbAccessor::UntrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); }
|
||||
|
||||
storage::PropertyId SubgraphDbAccessor::NameToProperty(const std::string_view name) {
|
||||
return db_accessor_.NameToProperty(name);
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include <cppitertools/filter.hpp>
|
||||
#include <cppitertools/imap.hpp>
|
||||
|
||||
#include "memory/query_memory_control.hpp"
|
||||
#include "query/exceptions.hpp"
|
||||
#include "storage/v2/edge_accessor.hpp"
|
||||
#include "storage/v2/id_types.hpp"
|
||||
@ -372,6 +373,16 @@ class DbAccessor final {
|
||||
|
||||
void FinalizeTransaction() { accessor_->FinalizeTransaction(); }
|
||||
|
||||
void TrackCurrentThreadAllocations() {
|
||||
memgraph::memory::StartTrackingCurrentThreadTransaction(*accessor_->GetTransactionId());
|
||||
}
|
||||
|
||||
void UntrackCurrentThreadAllocations() {
|
||||
memgraph::memory::StopTrackingCurrentThreadTransaction(*accessor_->GetTransactionId());
|
||||
}
|
||||
|
||||
std::optional<uint64_t> GetTransactionId() { return accessor_->GetTransactionId(); }
|
||||
|
||||
VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); }
|
||||
|
||||
VerticesIterable Vertices(storage::View view, storage::LabelId label) {
|
||||
@ -640,6 +651,14 @@ class SubgraphDbAccessor final {
|
||||
|
||||
static SubgraphDbAccessor *MakeSubgraphDbAccessor(DbAccessor *db_accessor, Graph *graph);
|
||||
|
||||
void TrackThreadAllocations(const char *thread_id);
|
||||
|
||||
void TrackCurrentThreadAllocations();
|
||||
|
||||
void UntrackThreadAllocations(const char *thread_id);
|
||||
|
||||
void UntrackCurrentThreadAllocations();
|
||||
|
||||
storage::PropertyId NameToProperty(std::string_view name);
|
||||
|
||||
storage::LabelId NameToLabel(std::string_view name);
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
@ -39,7 +40,8 @@
|
||||
#include "flags/run_time_configurable.hpp"
|
||||
#include "glue/communication.hpp"
|
||||
#include "license/license.hpp"
|
||||
#include "memory/memory_control.hpp"
|
||||
#include "memory/global_memory_control.hpp"
|
||||
#include "memory/query_memory_control.hpp"
|
||||
#include "query/config.hpp"
|
||||
#include "query/constants.hpp"
|
||||
#include "query/context.hpp"
|
||||
@ -1283,6 +1285,24 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
|
||||
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n,
|
||||
const std::vector<Symbol> &output_symbols,
|
||||
std::map<std::string, TypedValue> *summary) {
|
||||
std::optional<uint64_t> transaction_id = ctx_.db_accessor->GetTransactionId();
|
||||
MG_ASSERT(transaction_id.has_value());
|
||||
|
||||
if (memory_limit_) {
|
||||
memgraph::memory::TryStartTrackingOnTransaction(*transaction_id, *memory_limit_);
|
||||
memgraph::memory::StartTrackingCurrentThreadTransaction(*transaction_id);
|
||||
}
|
||||
utils::OnScopeExit<std::function<void()>> reset_query_limit{
|
||||
[memory_limit = memory_limit_, transaction_id = *transaction_id]() {
|
||||
if (memory_limit) {
|
||||
// Stopping tracking of transaction occurs in interpreter::pull
|
||||
// Exception can occur so we need to handle that case there.
|
||||
// We can't stop tracking here as there can be multiple pulls
|
||||
// so we need to take care of that after everything was pulled
|
||||
memgraph::memory::StopTrackingCurrentThreadTransaction(transaction_id);
|
||||
}
|
||||
}};
|
||||
|
||||
// Set up temporary memory for a single Pull. Initial memory comes from the
|
||||
// stack. 256 KiB should fit on the stack and should be more than enough for a
|
||||
// single `Pull`.
|
||||
@ -1306,13 +1326,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *strea
|
||||
pool_memory.emplace(kMaxBlockPerChunks, 1024, &monotonic_memory, &resource_with_exception);
|
||||
}
|
||||
|
||||
std::optional<utils::LimitedMemoryResource> maybe_limited_resource;
|
||||
if (memory_limit_) {
|
||||
maybe_limited_resource.emplace(&*pool_memory, *memory_limit_);
|
||||
ctx_.evaluation_context.memory = &*maybe_limited_resource;
|
||||
} else {
|
||||
ctx_.evaluation_context.memory = &*pool_memory;
|
||||
}
|
||||
|
||||
// Returns true if a result was pulled.
|
||||
const auto pull_result = [&]() -> bool { return cursor_->Pull(frame_, ctx_); };
|
||||
@ -1379,6 +1393,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *strea
|
||||
}
|
||||
cursor_->Shutdown();
|
||||
ctx_.profile_execution_time = execution_time_;
|
||||
|
||||
return GetStatsWithTotalTime(ctx_);
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include <gflags/gflags.h>
|
||||
|
||||
#include "dbms/database.hpp"
|
||||
#include "memory/query_memory_control.hpp"
|
||||
#include "query/auth_checker.hpp"
|
||||
#include "query/auth_query_handler.hpp"
|
||||
#include "query/config.hpp"
|
||||
@ -402,6 +403,9 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
|
||||
// If the query finished executing, we have received a value which tells
|
||||
// us what to do after.
|
||||
if (maybe_res) {
|
||||
if (current_transaction_) {
|
||||
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
|
||||
}
|
||||
// Save its summary
|
||||
maybe_summary.emplace(std::move(query_execution->summary));
|
||||
if (!query_execution->notifications.empty()) {
|
||||
@ -440,9 +444,15 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
|
||||
}
|
||||
}
|
||||
} catch (const ExplicitTransactionUsageException &) {
|
||||
if (current_transaction_) {
|
||||
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
|
||||
}
|
||||
query_execution.reset(nullptr);
|
||||
throw;
|
||||
} catch (const utils::BasicException &) {
|
||||
if (current_transaction_) {
|
||||
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
|
||||
}
|
||||
// Trigger first failed query
|
||||
metrics::FirstFailedQuery();
|
||||
memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery);
|
||||
|
@ -3596,3 +3596,15 @@ mgp_error mgp_log(const mgp_log_level log_level, const char *output) {
|
||||
throw std::invalid_argument{fmt::format("Invalid log level: {}", log_level)};
|
||||
});
|
||||
}
|
||||
|
||||
mgp_error mgp_track_current_thread_allocations(mgp_graph *graph) {
|
||||
return WrapExceptions([&]() {
|
||||
std::visit([](auto *db_accessor) -> void { db_accessor->TrackCurrentThreadAllocations(); }, graph->impl);
|
||||
});
|
||||
}
|
||||
|
||||
mgp_error mgp_untrack_current_thread_allocations(mgp_graph *graph) {
|
||||
return WrapExceptions([&]() {
|
||||
std::visit([](auto *db_accessor) -> void { db_accessor->UntrackCurrentThreadAllocations(); }, graph->impl);
|
||||
});
|
||||
}
|
||||
|
@ -89,6 +89,13 @@ void MemoryTracker::TryRaiseHardLimit(const int64_t limit) {
|
||||
;
|
||||
}
|
||||
|
||||
void MemoryTracker::ResetTrackings() {
|
||||
hard_limit_.store(0, std::memory_order_relaxed);
|
||||
peak_.store(0, std::memory_order_relaxed);
|
||||
amount_.store(0, std::memory_order_relaxed);
|
||||
maximum_hard_limit_ = 0;
|
||||
}
|
||||
|
||||
void MemoryTracker::SetMaximumHardLimit(const int64_t limit) {
|
||||
if (maximum_hard_limit_ < 0) {
|
||||
spdlog::warn("Invalid maximum hard limit.");
|
||||
|
@ -12,6 +12,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <type_traits>
|
||||
|
||||
#include "utils/exceptions.hpp"
|
||||
|
||||
@ -41,9 +42,20 @@ class MemoryTracker final {
|
||||
MemoryTracker() = default;
|
||||
~MemoryTracker() = default;
|
||||
|
||||
MemoryTracker(MemoryTracker &&other) noexcept
|
||||
: amount_(other.amount_.load(std::memory_order_acquire)),
|
||||
peak_(other.peak_.load(std::memory_order_acquire)),
|
||||
hard_limit_(other.hard_limit_.load(std::memory_order_acquire)),
|
||||
maximum_hard_limit_(other.maximum_hard_limit_) {
|
||||
other.maximum_hard_limit_ = 0;
|
||||
other.amount_.store(0, std::memory_order_acquire);
|
||||
other.peak_.store(0, std::memory_order_acquire);
|
||||
other.hard_limit_.store(0, std::memory_order_acquire);
|
||||
}
|
||||
|
||||
MemoryTracker(const MemoryTracker &) = delete;
|
||||
MemoryTracker &operator=(const MemoryTracker &) = delete;
|
||||
MemoryTracker(MemoryTracker &&) = delete;
|
||||
|
||||
MemoryTracker &operator=(MemoryTracker &&) = delete;
|
||||
|
||||
void Alloc(int64_t size);
|
||||
@ -59,6 +71,8 @@ class MemoryTracker final {
|
||||
void TryRaiseHardLimit(int64_t limit);
|
||||
void SetMaximumHardLimit(int64_t limit);
|
||||
|
||||
void ResetTrackings();
|
||||
|
||||
// By creating an object of this class, every allocation in its scope that goes over
|
||||
// the set hard limit produces an OutOfMemoryException.
|
||||
class OutOfMemoryExceptionEnabler final {
|
||||
|
@ -21,6 +21,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
|
||||
BUILD_DIR = os.path.join(PROJECT_DIR, "build")
|
||||
MEMGRAPH_BINARY = os.path.join(BUILD_DIR, "memgraph")
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.01):
|
||||
@ -133,7 +134,7 @@ class MemgraphInstanceRunner:
|
||||
|
||||
pid = self.proc_mg.pid
|
||||
try:
|
||||
os.kill(pid, 15) # 15 is the signal number for SIGTERM
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
|
||||
|
@ -11,10 +11,21 @@ target_link_libraries(memgraph__e2e__memory__limit_global_alloc gflags mgclient
|
||||
add_executable(memgraph__e2e__memory__limit_global_alloc_proc memory_limit_global_alloc_proc.cpp)
|
||||
target_link_libraries(memgraph__e2e__memory__limit_global_alloc_proc gflags mgclient mg-utils mg-io Threads::Threads)
|
||||
|
||||
add_executable(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread query_memory_limit_proc_multi_thread.cpp)
|
||||
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread gflags mgclient mg-utils mg-io Threads::Threads)
|
||||
|
||||
add_executable(memgraph__e2e__memory__limit_query_alloc_create query_memory_limit_create.cpp)
|
||||
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create gflags mgclient mg-utils mg-io)
|
||||
|
||||
add_executable(memgraph__e2e__memory__limit_query_alloc_proc query_memory_limit_proc.cpp)
|
||||
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc gflags mgclient mg-utils mg-io)
|
||||
|
||||
add_executable(memgraph__e2e__memory__limit_query_alloc_create_multi_thread query_memory_limit_multi_thread.cpp)
|
||||
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create_multi_thread gflags mgclient mg-utils mg-io Threads::Threads)
|
||||
|
||||
add_executable(memgraph__e2e__memory__limit_delete memory_limit_delete.cpp)
|
||||
target_link_libraries(memgraph__e2e__memory__limit_delete gflags mgclient mg-utils mg-io)
|
||||
|
||||
|
||||
add_executable(memgraph__e2e__memory__limit_accumulation memory_limit_accumulation.cpp)
|
||||
target_link_libraries(memgraph__e2e__memory__limit_accumulation gflags mgclient mg-utils mg-io)
|
||||
|
||||
|
@ -3,3 +3,13 @@ target_include_directories(global_memory_limit PRIVATE ${CMAKE_SOURCE_DIR}/inclu
|
||||
|
||||
add_library(global_memory_limit_proc SHARED global_memory_limit_proc.c)
|
||||
target_include_directories(global_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
||||
|
||||
|
||||
add_library(query_memory_limit_proc_multi_thread SHARED query_memory_limit_proc_multi_thread.cpp)
|
||||
target_include_directories(query_memory_limit_proc_multi_thread PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
||||
target_link_libraries(query_memory_limit_proc_multi_thread mg-utils)
|
||||
|
||||
|
||||
add_library(query_memory_limit_proc SHARED query_memory_limit_proc.cpp)
|
||||
target_include_directories(query_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
||||
target_link_libraries(query_memory_limit_proc mg-utils)
|
||||
|
70
tests/e2e/memory/procedures/query_memory_limit_proc.cpp
Normal file
70
tests/e2e/memory/procedures/query_memory_limit_proc.cpp
Normal file
@ -0,0 +1,70 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <mgp.hpp>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mg_procedure.h"
|
||||
#include "utils/on_scope_exit.hpp"
|
||||
|
||||
enum mgp_error Alloc(void *ptr) {
|
||||
const size_t mb_size_268 = 1 << 28;
|
||||
|
||||
return mgp_global_alloc(mb_size_268, (void **)(&ptr));
|
||||
}
|
||||
|
||||
void Regular(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
|
||||
mgp::MemoryDispatcherGuard guard{memory};
|
||||
const auto arguments = mgp::List(args);
|
||||
const auto record_factory = mgp::RecordFactory(result);
|
||||
|
||||
try {
|
||||
void *ptr{nullptr};
|
||||
|
||||
memgraph::utils::OnScopeExit<std::function<void(void)>> cleanup{[&ptr]() {
|
||||
if (nullptr == ptr) {
|
||||
return;
|
||||
}
|
||||
mgp_global_free(ptr);
|
||||
}};
|
||||
|
||||
const enum mgp_error alloc_err = Alloc(ptr);
|
||||
auto new_record = record_factory.NewRecord();
|
||||
new_record.Insert("allocated", alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE);
|
||||
} catch (std::exception &e) {
|
||||
record_factory.SetErrorMessage(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
|
||||
try {
|
||||
mgp::MemoryDispatcherGuard mdg{memory};
|
||||
|
||||
AddProcedure(Regular, std::string("regular").c_str(), mgp::ProcedureType::Read, {},
|
||||
{mgp::Return(std::string("allocated").c_str(), mgp::Type::Bool)}, module, memory);
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C" int mgp_shutdown_module() { return 0; }
|
@ -0,0 +1,101 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <mgp.hpp>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mg_procedure.h"
|
||||
#include "utils/on_scope_exit.hpp"
|
||||
|
||||
enum mgp_error Alloc(void *ptr) {
|
||||
const size_t mb_size_268 = 1 << 28;
|
||||
|
||||
return mgp_global_alloc(mb_size_268, (void **)(&ptr));
|
||||
}
|
||||
|
||||
// change communication between threads with feature and promise
|
||||
std::atomic<int> num_allocations{0};
|
||||
std::vector<void *> ptrs_;
|
||||
|
||||
void AllocFunc(mgp_graph *graph) {
|
||||
[[maybe_unused]] const enum mgp_error tracking_error = mgp_track_current_thread_allocations(graph);
|
||||
void *ptr = nullptr;
|
||||
|
||||
ptrs_.emplace_back(ptr);
|
||||
try {
|
||||
enum mgp_error alloc_err { mgp_error::MGP_ERROR_NO_ERROR };
|
||||
alloc_err = Alloc(ptr);
|
||||
if (alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) {
|
||||
num_allocations.fetch_add(1, std::memory_order_relaxed);
|
||||
}
|
||||
if (alloc_err != mgp_error::MGP_ERROR_NO_ERROR) {
|
||||
assert(false);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
[[maybe_unused]] const enum mgp_error untracking_error = mgp_untrack_current_thread_allocations(graph);
|
||||
}
|
||||
|
||||
void DualThread(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
|
||||
mgp::MemoryDispatcherGuard guard{memory};
|
||||
const auto arguments = mgp::List(args);
|
||||
const auto record_factory = mgp::RecordFactory(result);
|
||||
num_allocations.store(0, std::memory_order_relaxed);
|
||||
try {
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
threads.emplace_back(AllocFunc, memgraph_graph);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
threads[i].join();
|
||||
}
|
||||
for (void *ptr : ptrs_) {
|
||||
if (ptr != nullptr) {
|
||||
mgp_global_free(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto new_record = record_factory.NewRecord();
|
||||
|
||||
new_record.Insert("allocated_all", num_allocations.load(std::memory_order_relaxed) == 2);
|
||||
} catch (std::exception &e) {
|
||||
record_factory.SetErrorMessage(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
|
||||
try {
|
||||
mgp::memory = memory;
|
||||
|
||||
AddProcedure(DualThread, std::string("dual_thread").c_str(), mgp::ProcedureType::Read, {},
|
||||
{mgp::Return(std::string("allocated_all").c_str(), mgp::Type::Bool)}, module, memory);
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C" int mgp_shutdown_module() { return 0; }
|
65
tests/e2e/memory/query_memory_limit_create.cpp
Normal file
65
tests/e2e/memory/query_memory_limit_create.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <iostream>
|
||||
#include <mgclient.hpp>
|
||||
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::SetUsageMessage("Memgraph E2E Memory Control");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
memgraph::logging::RedirectToStderr();
|
||||
|
||||
mg::Client::Init();
|
||||
|
||||
auto client =
|
||||
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||
if (!client) {
|
||||
LOG_FATAL("Failed to connect!");
|
||||
}
|
||||
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
|
||||
if (FLAGS_multi_db) {
|
||||
client->Execute("CREATE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("USE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
const auto *create_query =
|
||||
"UNWIND range(1, 50000) as u CREATE (n {string: 'Some longer string'}) RETURN n QUERY MEMORY LIMIT 30MB;";
|
||||
|
||||
try {
|
||||
client->Execute(create_query);
|
||||
[[maybe_unused]] auto results = client->FetchAll();
|
||||
if (results->empty()) {
|
||||
assert(true);
|
||||
return 0;
|
||||
}
|
||||
} catch (const mg::TransientException & /*unused*/) {
|
||||
spdlog::info("Memgraph is out of memory");
|
||||
assert(true);
|
||||
return 0;
|
||||
}
|
||||
MG_ASSERT(false, "Query should have failed!");
|
||||
|
||||
return 0;
|
||||
}
|
100
tests/e2e/memory/query_memory_limit_multi_thread.cpp
Normal file
100
tests/e2e/memory/query_memory_limit_multi_thread.cpp
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <exception>
|
||||
#include <future>
|
||||
#include <thread>
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <iostream>
|
||||
#include <mgclient.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||
|
||||
static constexpr int kNumberClients{2};
|
||||
|
||||
void Func(std::promise<bool> promise) {
|
||||
auto client =
|
||||
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||
if (!client) {
|
||||
LOG_FATAL("Failed to connect!");
|
||||
}
|
||||
const auto *create_query =
|
||||
"FOREACH(i in range(1, 300000) | CREATE (n: Node {string: 'Some longer string'})) QUERY MEMORY LIMIT 100MB;";
|
||||
bool err{false};
|
||||
try {
|
||||
client->Execute(create_query);
|
||||
[[maybe_unused]] auto results = client->FetchAll();
|
||||
if (results->empty()) {
|
||||
err = true;
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::info("Good: Exception occured", e.what());
|
||||
err = true;
|
||||
}
|
||||
promise.set_value_at_thread_exit(err);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::SetUsageMessage("Memgraph E2E Memory Control");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
memgraph::logging::RedirectToStderr();
|
||||
|
||||
mg::Client::Init();
|
||||
|
||||
{
|
||||
auto client =
|
||||
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||
if (!client) {
|
||||
LOG_FATAL("Failed to connect!");
|
||||
}
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
|
||||
if (FLAGS_multi_db) {
|
||||
client->Execute("CREATE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("USE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::promise<bool>> my_promises;
|
||||
std::vector<std::future<bool>> my_futures;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
my_promises.push_back(std::promise<bool>());
|
||||
my_futures.emplace_back(my_promises.back().get_future());
|
||||
}
|
||||
|
||||
std::vector<std::thread> my_threads;
|
||||
|
||||
for (int i = 0; i < kNumberClients; i++) {
|
||||
my_threads.emplace_back(Func, std::move(my_promises[i]));
|
||||
}
|
||||
|
||||
for (int i = 0; i < kNumberClients; i++) {
|
||||
auto value = my_futures[i].get();
|
||||
MG_ASSERT(value, "Error should have happend in thread");
|
||||
}
|
||||
|
||||
for (int i = 0; i < kNumberClients; i++) {
|
||||
my_threads[i].join();
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
66
tests/e2e/memory/query_memory_limit_proc.cpp
Normal file
66
tests/e2e/memory/query_memory_limit_proc.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <algorithm>
|
||||
#include <exception>
|
||||
#include <ios>
|
||||
#include <iostream>
|
||||
#include <mgclient.hpp>
|
||||
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
memgraph::logging::RedirectToStderr();
|
||||
|
||||
mg::Client::Init();
|
||||
|
||||
auto client =
|
||||
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||
if (!client) {
|
||||
LOG_FATAL("Failed to connect!");
|
||||
}
|
||||
|
||||
if (FLAGS_multi_db) {
|
||||
client->Execute("CREATE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("USE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
MG_ASSERT(
|
||||
client->Execute("CALL libquery_memory_limit_proc.regular() YIELD allocated RETURN "
|
||||
"allocated QUERY MEMORY LIMIT 250MB"));
|
||||
bool error{false};
|
||||
try {
|
||||
auto result_rows = client->FetchAll();
|
||||
if (result_rows) {
|
||||
auto row = *result_rows->begin();
|
||||
error = row[0].ValueBool() == false;
|
||||
}
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
error = true;
|
||||
}
|
||||
|
||||
MG_ASSERT(error, "Error should have happend");
|
||||
|
||||
return 0;
|
||||
}
|
66
tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp
Normal file
66
tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <algorithm>
|
||||
#include <exception>
|
||||
#include <ios>
|
||||
#include <iostream>
|
||||
#include <mgclient.hpp>
|
||||
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
memgraph::logging::RedirectToStderr();
|
||||
|
||||
mg::Client::Init();
|
||||
|
||||
auto client =
|
||||
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||
if (!client) {
|
||||
LOG_FATAL("Failed to connect!");
|
||||
}
|
||||
|
||||
if (FLAGS_multi_db) {
|
||||
client->Execute("CREATE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("USE DATABASE clean;");
|
||||
client->DiscardAll();
|
||||
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
client->DiscardAll();
|
||||
}
|
||||
|
||||
MG_ASSERT(
|
||||
client->Execute("CALL libquery_memory_limit_proc_multi_thread.dual_thread() YIELD allocated_all RETURN "
|
||||
"allocated_all QUERY MEMORY LIMIT 500MB"));
|
||||
bool error{false};
|
||||
try {
|
||||
auto result_rows = client->FetchAll();
|
||||
if (result_rows) {
|
||||
auto row = *result_rows->begin();
|
||||
error = row[0].ValueBool() == false;
|
||||
}
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
error = true;
|
||||
}
|
||||
|
||||
MG_ASSERT(error, "Error should have happend");
|
||||
|
||||
return 0;
|
||||
}
|
@ -23,6 +23,20 @@ disk_cluster: &disk_cluster
|
||||
- "STORAGE MODE ON_DISK_TRANSACTIONAL"
|
||||
validation_queries: []
|
||||
|
||||
args_query_limit: &args_query_limit
|
||||
- "--bolt-port"
|
||||
- *bolt_port
|
||||
- "--storage-gc-cycle-sec=180"
|
||||
- "--log-level=TRACE"
|
||||
|
||||
in_memory_query_limit_cluster: &in_memory_query_limit_cluster
|
||||
cluster:
|
||||
main:
|
||||
args: *args_query_limit
|
||||
log_file: "memory-e2e.log"
|
||||
setup_queries: []
|
||||
validation_queries: []
|
||||
|
||||
args_450_MiB_limit: &args_450_MiB_limit
|
||||
- "--bolt-port"
|
||||
- *bolt_port
|
||||
@ -95,6 +109,27 @@ workloads:
|
||||
proc: "tests/e2e/memory/procedures/"
|
||||
<<: *disk_cluster
|
||||
|
||||
- name: "Memory control query limit proc"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc"
|
||||
proc: "tests/e2e/memory/procedures/"
|
||||
args: ["--bolt-port", *bolt_port]
|
||||
<<: *in_memory_query_limit_cluster
|
||||
|
||||
- name: "Memory control query limit proc multi thread"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc_multi_thread"
|
||||
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
|
||||
proc: "tests/e2e/memory/procedures/"
|
||||
<<: *in_memory_query_limit_cluster
|
||||
|
||||
- name: "Memory control query limit create"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create"
|
||||
args: ["--bolt-port", *bolt_port]
|
||||
<<: *in_memory_query_limit_cluster
|
||||
|
||||
- name: "Memory control query limit create multi thread"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create_multi_thread"
|
||||
args: ["--bolt-port", *bolt_port]
|
||||
<<: *in_memory_query_limit_cluster
|
||||
- name: "Memory control for detach delete"
|
||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete"
|
||||
args: ["--bolt-port", *bolt_port]
|
||||
|
@ -24,6 +24,7 @@ import time
|
||||
DEFAULT_DB = "memgraph"
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
QUERIES = [
|
||||
("MATCH (n) DELETE n", {}),
|
||||
@ -92,9 +93,13 @@ def execute_test(memgraph_binary, tester_binary):
|
||||
# Register cleanup function
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
def execute_queries(queries):
|
||||
for db, query, params in queries:
|
||||
@ -122,10 +127,12 @@ def execute_test(memgraph_binary, tester_binary):
|
||||
execute_queries(mt_queries3)
|
||||
print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n")
|
||||
|
||||
# Shutdown the memgraph binary
|
||||
memgraph.terminate()
|
||||
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
# Verify the written log
|
||||
print("\033[1;36m~~ Starting log verification ~~\033[0m")
|
||||
|
@ -21,6 +21,7 @@ import time
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
# When you create a new permission just add a testcase to this list (a tuple
|
||||
# of query, touple of required permissions) and the test will automatically
|
||||
@ -166,8 +167,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
# Prepare the multi database environment
|
||||
execute_admin_queries(
|
||||
@ -327,8 +332,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
|
||||
print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n")
|
||||
|
||||
# Shutdown the memgraph binary
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -20,7 +20,6 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
TESTS_DIR = os.path.join(SCRIPT_DIR, "tests")
|
||||
@ -31,6 +30,8 @@ WAL_FILE_NAME = "wal.bin"
|
||||
DUMP_SNAPSHOT_FILE_NAME = "expected_snapshot.cypher"
|
||||
DUMP_WAL_FILE_NAME = "expected_wal.cypher"
|
||||
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
||||
@ -40,7 +41,7 @@ def wait_for_server(port, delay=0.1):
|
||||
|
||||
|
||||
def sorted_content(file_path):
|
||||
with open(file_path, 'r') as fin:
|
||||
with open(file_path, "r") as fin:
|
||||
return sorted(list(map(lambda x: x.strip(), fin.readlines())))
|
||||
|
||||
|
||||
@ -52,32 +53,27 @@ def list_to_string(data):
|
||||
return ret
|
||||
|
||||
|
||||
def execute_test(
|
||||
memgraph_binary,
|
||||
dump_binary,
|
||||
test_directory,
|
||||
test_type,
|
||||
write_expected):
|
||||
assert test_type in ["SNAPSHOT", "WAL"], \
|
||||
"Test type should be either 'SNAPSHOT' or 'WAL'."
|
||||
print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m"
|
||||
.format(os.path.relpath(test_directory, TESTS_DIR), test_type))
|
||||
def execute_test(memgraph_binary, dump_binary, test_directory, test_type, write_expected):
|
||||
assert test_type in ["SNAPSHOT", "WAL"], "Test type should be either 'SNAPSHOT' or 'WAL'."
|
||||
print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m".format(os.path.relpath(test_directory, TESTS_DIR), test_type))
|
||||
|
||||
working_data_directory = tempfile.TemporaryDirectory()
|
||||
if test_type == "SNAPSHOT":
|
||||
snapshots_dir = os.path.join(working_data_directory.name, "snapshots")
|
||||
os.makedirs(snapshots_dir)
|
||||
shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME),
|
||||
snapshots_dir)
|
||||
shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), snapshots_dir)
|
||||
else:
|
||||
wal_dir = os.path.join(working_data_directory.name, "wal")
|
||||
os.makedirs(wal_dir)
|
||||
shutil.copy(os.path.join(test_directory, WAL_FILE_NAME), wal_dir)
|
||||
|
||||
memgraph_args = [memgraph_binary,
|
||||
memgraph_args = [
|
||||
memgraph_binary,
|
||||
"--storage-recover-on-startup",
|
||||
"--storage-properties-on-edges",
|
||||
"--data-directory", working_data_directory.name]
|
||||
"--data-directory",
|
||||
working_data_directory.name,
|
||||
]
|
||||
|
||||
# Start the memgraph binary
|
||||
memgraph = subprocess.Popen(memgraph_args)
|
||||
@ -89,8 +85,12 @@ def execute_test(
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
# Execute `database dump`
|
||||
dump_output_file = tempfile.NamedTemporaryFile()
|
||||
@ -98,28 +98,31 @@ def execute_test(
|
||||
subprocess.run(dump_args, stdout=dump_output_file, check=True)
|
||||
|
||||
# Shutdown the memgraph binary
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
dump_file_name = DUMP_SNAPSHOT_FILE_NAME if test_type == "SNAPSHOT" else DUMP_WAL_FILE_NAME
|
||||
|
||||
if write_expected:
|
||||
with open(dump_output_file.name, 'r') as dump:
|
||||
with open(dump_output_file.name, "r") as dump:
|
||||
queries_got = dump.readlines()
|
||||
# Write dump files
|
||||
expected_dump_file = os.path.join(test_directory, dump_file_name)
|
||||
with open(expected_dump_file, 'w') as expected:
|
||||
with open(expected_dump_file, "w") as expected:
|
||||
expected.writelines(queries_got)
|
||||
else:
|
||||
# Compare dump files
|
||||
expected_dump_file = os.path.join(test_directory, dump_file_name)
|
||||
assert os.path.exists(expected_dump_file), \
|
||||
"Could not find expected dump path {}".format(expected_dump_file)
|
||||
assert os.path.exists(expected_dump_file), "Could not find expected dump path {}".format(expected_dump_file)
|
||||
queries_got = sorted_content(dump_output_file.name)
|
||||
queries_expected = sorted_content(expected_dump_file)
|
||||
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \
|
||||
"{}".format(list_to_string(queries_got),
|
||||
list_to_string(queries_expected))
|
||||
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format(
|
||||
list_to_string(queries_got), list_to_string(queries_expected)
|
||||
)
|
||||
|
||||
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
||||
|
||||
@ -141,15 +144,17 @@ def find_test_directories(directory):
|
||||
continue
|
||||
snapshot_file = os.path.join(test_dir_path, SNAPSHOT_FILE_NAME)
|
||||
wal_file = os.path.join(test_dir_path, WAL_FILE_NAME)
|
||||
dump_snapshot_file = os.path.join(
|
||||
test_dir_path, DUMP_SNAPSHOT_FILE_NAME)
|
||||
dump_snapshot_file = os.path.join(test_dir_path, DUMP_SNAPSHOT_FILE_NAME)
|
||||
dump_wal_file = os.path.join(test_dir_path, DUMP_WAL_FILE_NAME)
|
||||
if (os.path.isfile(snapshot_file) and os.path.isfile(dump_snapshot_file)
|
||||
and os.path.isfile(wal_file) and os.path.isfile(dump_wal_file)):
|
||||
if (
|
||||
os.path.isfile(snapshot_file)
|
||||
and os.path.isfile(dump_snapshot_file)
|
||||
and os.path.isfile(wal_file)
|
||||
and os.path.isfile(dump_wal_file)
|
||||
):
|
||||
test_dirs.append(test_dir_path)
|
||||
else:
|
||||
raise Exception("Missing data in test directory '{}'"
|
||||
.format(test_dir_path))
|
||||
raise Exception("Missing data in test directory '{}'".format(test_dir_path))
|
||||
return test_dirs
|
||||
|
||||
|
||||
@ -161,26 +166,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||
parser.add_argument("--dump", default=dump_binary)
|
||||
parser.add_argument(
|
||||
'--write-expected',
|
||||
action='store_true',
|
||||
help='Overwrite the expected cypher with results from current run')
|
||||
"--write-expected", action="store_true", help="Overwrite the expected cypher with results from current run"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test_directories = find_test_directories(TESTS_DIR)
|
||||
assert len(test_directories) > 0, "No tests have been found!"
|
||||
|
||||
for test_directory in test_directories:
|
||||
execute_test(
|
||||
args.memgraph,
|
||||
args.dump,
|
||||
test_directory,
|
||||
"SNAPSHOT",
|
||||
args.write_expected)
|
||||
execute_test(
|
||||
args.memgraph,
|
||||
args.dump,
|
||||
test_directory,
|
||||
"WAL",
|
||||
args.write_expected)
|
||||
execute_test(args.memgraph, args.dump, test_directory, "SNAPSHOT", args.write_expected)
|
||||
execute_test(args.memgraph, args.dump, test_directory, "WAL", args.write_expected)
|
||||
|
||||
sys.exit(0)
|
||||
|
@ -22,6 +22,7 @@ from typing import List
|
||||
|
||||
SCRIPT_DIR = Path(__file__).absolute()
|
||||
PROJECT_DIR = SCRIPT_DIR.parents[3]
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
@ -68,8 +69,12 @@ def execute_with_user(queries):
|
||||
|
||||
def cleanup(memgraph):
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True):
|
||||
|
@ -24,6 +24,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
|
||||
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
@ -80,8 +81,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
# Prepare all users
|
||||
def setup_user():
|
||||
@ -130,8 +135,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
|
||||
print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n")
|
||||
|
||||
# Shutdown the memgraph binary
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -21,6 +21,7 @@ from typing import List
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
|
||||
def wait_for_server(port: int, delay: float = 0.1) -> float:
|
||||
@ -86,8 +87,12 @@ def execute_without_user(
|
||||
|
||||
def cleanup(memgraph: subprocess):
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def test_without_any_files(tester_binary: str, memgraph_args: List[str]):
|
||||
|
@ -23,6 +23,8 @@ import time
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
CONFIG_TEMPLATE = """
|
||||
server:
|
||||
host: "127.0.0.1"
|
||||
@ -52,8 +54,7 @@ def wait_for_server(port, delay=0.1):
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def execute_tester(binary, queries, username="", password="",
|
||||
auth_should_fail=False, query_should_fail=False):
|
||||
def execute_tester(binary, queries, username="", password="", auth_should_fail=False, query_should_fail=False):
|
||||
if password == "":
|
||||
password = username
|
||||
args = [binary, "--username", username, "--password", password]
|
||||
@ -76,18 +77,14 @@ class Memgraph:
|
||||
def start(self, **kwargs):
|
||||
self.stop()
|
||||
self._storage_directory = tempfile.TemporaryDirectory()
|
||||
self._auth_module = os.path.join(self._storage_directory.name,
|
||||
"ldap.py")
|
||||
self._auth_config = os.path.join(self._storage_directory.name,
|
||||
"ldap.yaml")
|
||||
script_file = os.path.join(PROJECT_DIR, "src", "auth",
|
||||
"reference_modules", "ldap.py")
|
||||
self._auth_module = os.path.join(self._storage_directory.name, "ldap.py")
|
||||
self._auth_config = os.path.join(self._storage_directory.name, "ldap.yaml")
|
||||
script_file = os.path.join(PROJECT_DIR, "src", "auth", "reference_modules", "ldap.py")
|
||||
virtualenv_bin = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3")
|
||||
with open(script_file) as fin:
|
||||
data = fin.read()
|
||||
data = data.replace("/usr/bin/python3", virtualenv_bin)
|
||||
data = data.replace("/etc/memgraph/auth/ldap.yaml",
|
||||
self._auth_config)
|
||||
data = data.replace("/etc/memgraph/auth/ldap.yaml", self._auth_config)
|
||||
with open(self._auth_module, "w") as fout:
|
||||
fout.write(data)
|
||||
os.chmod(self._auth_module, stat.S_IRWXU | stat.S_IRWXG)
|
||||
@ -106,10 +103,13 @@ class Memgraph:
|
||||
}
|
||||
with open(self._auth_config, "w") as f:
|
||||
f.write(CONFIG_TEMPLATE.format(**config))
|
||||
args = [self._binary,
|
||||
"--data-directory", self._storage_directory.name,
|
||||
args = [
|
||||
self._binary,
|
||||
"--data-directory",
|
||||
self._storage_directory.name,
|
||||
"--auth-module-executable",
|
||||
kwargs.pop("module_executable", self._auth_module)]
|
||||
kwargs.pop("module_executable", self._auth_module),
|
||||
]
|
||||
for key, value in kwargs.items():
|
||||
ldap_key = "--auth-module-" + key.replace("_", "-")
|
||||
if isinstance(value, bool):
|
||||
@ -119,26 +119,27 @@ class Memgraph:
|
||||
args.append(value)
|
||||
self._process = subprocess.Popen(args)
|
||||
time.sleep(0.1)
|
||||
assert self._process.poll() is None, "Memgraph process died " \
|
||||
"prematurely!"
|
||||
assert self._process.poll() is None, "Memgraph process died " "prematurely!"
|
||||
wait_for_server(7687)
|
||||
|
||||
def stop(self, check=True):
|
||||
if self._process is None:
|
||||
return 0
|
||||
self._process.terminate()
|
||||
exitcode = self._process.wait()
|
||||
self._process = None
|
||||
pid = self._process.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
if check:
|
||||
assert exitcode == 0, "Memgraph process didn't exit cleanly!"
|
||||
return exitcode
|
||||
assert False
|
||||
return -1
|
||||
time.sleep(1)
|
||||
return 0
|
||||
|
||||
|
||||
def initialize_test(memgraph, tester_binary, **kwargs):
|
||||
memgraph.start(module_executable="")
|
||||
|
||||
execute_tester(tester_binary,
|
||||
["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
|
||||
execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
|
||||
check_login = kwargs.pop("check_login", True)
|
||||
memgraph.restart(**kwargs)
|
||||
if check_login:
|
||||
@ -170,18 +171,15 @@ def test_role_mapping(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary)
|
||||
|
||||
execute_tester(tester_binary, [], "alice")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
|
||||
execute_tester(tester_binary, [], "bob")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True)
|
||||
|
||||
execute_tester(tester_binary, [], "carol")
|
||||
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root")
|
||||
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol")
|
||||
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave")
|
||||
@ -192,15 +190,13 @@ def test_role_mapping(memgraph, tester_binary):
|
||||
def test_role_removal(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary)
|
||||
execute_tester(tester_binary, [], "alice")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
memgraph.restart(manage_roles=False)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||
memgraph.stop()
|
||||
|
||||
|
||||
@ -229,28 +225,22 @@ def test_user_is_role(memgraph, tester_binary):
|
||||
|
||||
def test_user_permissions_persistancy(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary)
|
||||
execute_tester(tester_binary,
|
||||
["CREATE USER alice", "GRANT MATCH TO alice"], "root")
|
||||
execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
memgraph.stop()
|
||||
|
||||
|
||||
def test_role_permissions_persistancy(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary)
|
||||
execute_tester(tester_binary,
|
||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
||||
"root")
|
||||
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
memgraph.stop()
|
||||
|
||||
|
||||
def test_only_authentication(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary, manage_roles=False)
|
||||
execute_tester(tester_binary,
|
||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
||||
"root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||
memgraph.stop()
|
||||
|
||||
|
||||
@ -267,22 +257,16 @@ def test_wrong_suffix(memgraph, tester_binary):
|
||||
|
||||
|
||||
def test_suffix_with_spaces(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary,
|
||||
suffix=", ou= people, dc = memgraph, dc = com")
|
||||
execute_tester(tester_binary,
|
||||
["CREATE USER alice", "GRANT MATCH TO alice"], "root")
|
||||
initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com")
|
||||
execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
memgraph.stop()
|
||||
|
||||
|
||||
def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary,
|
||||
root_dn="ou=invalid,dc=memgraph,dc=com")
|
||||
execute_tester(tester_binary,
|
||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
||||
"root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
||||
query_should_fail=True)
|
||||
initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com")
|
||||
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||
memgraph.restart()
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
memgraph.stop()
|
||||
@ -290,11 +274,8 @@ def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
|
||||
|
||||
def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary, root_objectclass="person")
|
||||
execute_tester(tester_binary,
|
||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
||||
"root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||
memgraph.restart()
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
memgraph.stop()
|
||||
@ -302,11 +283,8 @@ def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
|
||||
|
||||
def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary, user_attribute="cn")
|
||||
execute_tester(tester_binary,
|
||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
||||
"root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||
memgraph.restart()
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||
memgraph.stop()
|
||||
@ -314,8 +292,7 @@ def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
|
||||
|
||||
def test_wrong_password(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary)
|
||||
execute_tester(tester_binary, [], "root", password="sudo",
|
||||
auth_should_fail=True)
|
||||
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
|
||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
||||
memgraph.stop()
|
||||
|
||||
@ -326,12 +303,10 @@ def test_password_persistancy(memgraph, tester_binary):
|
||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo")
|
||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
||||
memgraph.restart()
|
||||
execute_tester(tester_binary, [], "root", password="sudo",
|
||||
auth_should_fail=True)
|
||||
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
|
||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
||||
memgraph.restart(module_executable="")
|
||||
execute_tester(tester_binary, [], "root", password="sudo",
|
||||
auth_should_fail=True)
|
||||
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
|
||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
||||
memgraph.stop()
|
||||
|
||||
@ -339,33 +314,25 @@ def test_password_persistancy(memgraph, tester_binary):
|
||||
def test_user_multiple_roles(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary, check_login=False)
|
||||
memgraph.restart()
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
|
||||
memgraph.restart(manage_roles=False)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
|
||||
memgraph.restart(manage_roles=False, root_dn="")
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
|
||||
query_should_fail=True)
|
||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
|
||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
|
||||
memgraph.stop()
|
||||
|
||||
|
||||
def test_starttls_failure(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary, encryption="starttls",
|
||||
check_login=False)
|
||||
initialize_test(memgraph, tester_binary, encryption="starttls", check_login=False)
|
||||
execute_tester(tester_binary, [], "root", auth_should_fail=True)
|
||||
memgraph.stop()
|
||||
|
||||
|
||||
def test_ssl_failure(memgraph, tester_binary):
|
||||
initialize_test(memgraph, tester_binary, encryption="ssl",
|
||||
check_login=False)
|
||||
initialize_test(memgraph, tester_binary, encryption="ssl", check_login=False)
|
||||
execute_tester(tester_binary, [], "root", auth_should_fail=True)
|
||||
memgraph.stop()
|
||||
|
||||
@ -375,22 +342,19 @@ def test_ssl_failure(memgraph, tester_binary):
|
||||
|
||||
if __name__ == "__main__":
|
||||
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
|
||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
|
||||
"integration", "ldap", "tester")
|
||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "ldap", "tester")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||
parser.add_argument("--tester", default=tester_binary)
|
||||
parser.add_argument("--openldap-dir",
|
||||
default=os.path.join(SCRIPT_DIR, "openldap-2.4.47"))
|
||||
parser.add_argument("--openldap-dir", default=os.path.join(SCRIPT_DIR, "openldap-2.4.47"))
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup Memgraph handler
|
||||
memgraph = Memgraph(args.memgraph)
|
||||
|
||||
# Start the slapd binary
|
||||
slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"),
|
||||
"-h", "ldap://127.0.0.1:1389/", "-d", "0"]
|
||||
slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), "-h", "ldap://127.0.0.1:1389/", "-d", "0"]
|
||||
slapd = subprocess.Popen(slapd_args)
|
||||
time.sleep(0.1)
|
||||
assert slapd.poll() is None, "slapd process died prematurely!"
|
||||
@ -409,8 +373,7 @@ if __name__ == "__main__":
|
||||
if slapd_stat != 0:
|
||||
print("slapd process didn't exit cleanly!")
|
||||
|
||||
assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " \
|
||||
"(memgraph, slapd) crashed!"
|
||||
assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " "(memgraph, slapd) crashed!"
|
||||
|
||||
# Execute tests
|
||||
names = sorted(globals().keys())
|
||||
|
@ -18,12 +18,13 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import yaml
|
||||
|
||||
import yaml
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
BUILD_DIR = os.path.join(BASE_DIR, "build")
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
@ -46,17 +47,14 @@ def list_to_string(data):
|
||||
|
||||
|
||||
def verify_lifetime(memgraph_binary, mg_import_csv_binary):
|
||||
print("\033[1;36m~~ Verifying that mg_import_csv can't be started while "
|
||||
"memgraph is running ~~\033[0m")
|
||||
print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " "memgraph is running ~~\033[0m")
|
||||
storage_directory = tempfile.TemporaryDirectory()
|
||||
|
||||
# Generate common args
|
||||
common_args = ["--data-directory", storage_directory.name,
|
||||
"--storage-properties-on-edges=false"]
|
||||
common_args = ["--data-directory", storage_directory.name, "--storage-properties-on-edges=false"]
|
||||
|
||||
# Start the memgraph binary
|
||||
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \
|
||||
common_args
|
||||
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args
|
||||
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
||||
time.sleep(0.1)
|
||||
assert memgraph.poll() is None, "Memgraph process died prematurely!"
|
||||
@ -66,47 +64,52 @@ def verify_lifetime(memgraph_binary, mg_import_csv_binary):
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
# Execute mg_import_csv.
|
||||
mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + \
|
||||
common_args
|
||||
mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + common_args
|
||||
ret = subprocess.run(mg_import_csv_args)
|
||||
|
||||
# Check the return code
|
||||
if ret.returncode == 0:
|
||||
raise Exception(
|
||||
"The importer was able to run while memgraph was running!")
|
||||
raise Exception("The importer was able to run while memgraph was running!")
|
||||
|
||||
# Shutdown the memgraph binary
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
||||
|
||||
|
||||
def execute_test(name, test_path, test_config, memgraph_binary,
|
||||
mg_import_csv_binary, tester_binary, write_expected):
|
||||
def execute_test(name, test_path, test_config, memgraph_binary, mg_import_csv_binary, tester_binary, write_expected):
|
||||
print("\033[1;36m~~ Executing test", name, "~~\033[0m")
|
||||
storage_directory = tempfile.TemporaryDirectory()
|
||||
|
||||
# Verify test configuration
|
||||
if ("import_should_fail" not in test_config and
|
||||
"expected" not in test_config) or \
|
||||
("import_should_fail" in test_config and
|
||||
"expected" in test_config):
|
||||
raise Exception("The test should specify either 'import_should_fail' "
|
||||
"or 'expected'!")
|
||||
if ("import_should_fail" not in test_config and "expected" not in test_config) or (
|
||||
"import_should_fail" in test_config and "expected" in test_config
|
||||
):
|
||||
raise Exception("The test should specify either 'import_should_fail' " "or 'expected'!")
|
||||
|
||||
expected_path = test_config.pop("expected", "")
|
||||
import_should_fail = test_config.pop("import_should_fail", False)
|
||||
|
||||
# Generate common args
|
||||
properties_on_edges = bool(test_config.pop("properties_on_edges", False))
|
||||
common_args = ["--data-directory", storage_directory.name,
|
||||
"--storage-properties-on-edges=" +
|
||||
str(properties_on_edges).lower()]
|
||||
common_args = [
|
||||
"--data-directory",
|
||||
storage_directory.name,
|
||||
"--storage-properties-on-edges=" + str(properties_on_edges).lower(),
|
||||
]
|
||||
|
||||
# Generate mg_import_csv args using flags specified in the test
|
||||
mg_import_csv_args = [mg_import_csv_binary] + common_args
|
||||
@ -125,19 +128,16 @@ def execute_test(name, test_path, test_config, memgraph_binary,
|
||||
|
||||
if import_should_fail:
|
||||
if ret.returncode == 0:
|
||||
raise Exception("The import should have failed, but it "
|
||||
"succeeded instead!")
|
||||
raise Exception("The import should have failed, but it " "succeeded instead!")
|
||||
else:
|
||||
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
||||
return
|
||||
else:
|
||||
if ret.returncode != 0:
|
||||
raise Exception("The import should have succeeded, but it "
|
||||
"failed instead!")
|
||||
raise Exception("The import should have succeeded, but it " "failed instead!")
|
||||
|
||||
# Start the memgraph binary
|
||||
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \
|
||||
common_args
|
||||
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args
|
||||
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
||||
time.sleep(0.1)
|
||||
assert memgraph.poll() is None, "Memgraph process died prematurely!"
|
||||
@ -147,21 +147,29 @@ def execute_test(name, test_path, test_config, memgraph_binary,
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
# Get the contents of the database
|
||||
queries_got = extract_rows(subprocess.run(
|
||||
[tester_binary], stdout=subprocess.PIPE,
|
||||
check=True).stdout.decode("utf-8"))
|
||||
queries_got = extract_rows(
|
||||
subprocess.run([tester_binary], stdout=subprocess.PIPE, check=True).stdout.decode("utf-8")
|
||||
)
|
||||
|
||||
# Shutdown the memgraph binary
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
if write_expected:
|
||||
with open(os.path.join(test_path, expected_path), 'w') as expected:
|
||||
expected.write('\n'.join(queries_got))
|
||||
with open(os.path.join(test_path, expected_path), "w") as expected:
|
||||
expected.write("\n".join(queries_got))
|
||||
|
||||
else:
|
||||
if expected_path:
|
||||
@ -173,18 +181,16 @@ def execute_test(name, test_path, test_config, memgraph_binary,
|
||||
# Verify the queries
|
||||
queries_expected.sort()
|
||||
queries_got.sort()
|
||||
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \
|
||||
"{}".format(list_to_string(queries_got),
|
||||
list_to_string(queries_expected))
|
||||
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format(
|
||||
list_to_string(queries_got), list_to_string(queries_expected)
|
||||
)
|
||||
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
memgraph_binary = os.path.join(BUILD_DIR, "memgraph")
|
||||
mg_import_csv_binary = os.path.join(
|
||||
BUILD_DIR, "src", "mg_import_csv")
|
||||
tester_binary = os.path.join(
|
||||
BUILD_DIR, "tests", "integration", "mg_import_csv", "tester")
|
||||
mg_import_csv_binary = os.path.join(BUILD_DIR, "src", "mg_import_csv")
|
||||
tester_binary = os.path.join(BUILD_DIR, "tests", "integration", "mg_import_csv", "tester")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||
@ -193,7 +199,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--write-expected",
|
||||
action="store_true",
|
||||
help="Overwrite the expected values with the results of the current run")
|
||||
help="Overwrite the expected values with the results of the current run",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# First test whether the CSV importer can be started while the main
|
||||
@ -211,7 +218,8 @@ if __name__ == "__main__":
|
||||
testcases = yaml.safe_load(f)
|
||||
for test_config in testcases:
|
||||
test_name = name + "/" + test_config.pop("name")
|
||||
execute_test(test_name, test_path, test_config, args.memgraph,
|
||||
args.mg_import_csv, args.tester, args.write_expected)
|
||||
execute_test(
|
||||
test_name, test_path, test_config, args.memgraph, args.mg_import_csv, args.tester, args.write_expected
|
||||
)
|
||||
|
||||
sys.exit(0)
|
||||
|
@ -23,6 +23,7 @@ from typing import List
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
|
||||
def wait_for_server(port: int, delay: float = 0.1) -> float:
|
||||
@ -91,8 +92,12 @@ def check_config(tester_binary: str, flag: str, value: str) -> None:
|
||||
|
||||
def cleanup(memgraph: subprocess):
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def run_test(tester_binary: str, memgraph_args: List[str], server_name: str, query_tx: str):
|
||||
|
@ -22,6 +22,8 @@ assertion_queries = [
|
||||
f"MATCH (n)-[e]->(m) WITH count(e) as cnt RETURN assert(cnt={len(edge_queries)});",
|
||||
]
|
||||
|
||||
SIGNAL_SIGTERM = 15
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
||||
@ -40,9 +42,12 @@ def prepare_memgraph(memgraph_args):
|
||||
|
||||
|
||||
def terminate_memgraph(memgraph):
|
||||
memgraph.terminate()
|
||||
time.sleep(0.1)
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def execute_tester(
|
||||
@ -90,8 +95,12 @@ def execute_test_analytical_mode(memgraph_binary: str, tester_binary: str) -> No
|
||||
|
||||
execute_queries(assertion_queries)
|
||||
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_binary: str) -> None:
|
||||
@ -135,8 +144,12 @@ def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_bi
|
||||
execute_queries(assertion_queries)
|
||||
|
||||
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_binary: str) -> None:
|
||||
@ -177,8 +190,12 @@ def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_bi
|
||||
execute_queries(assertion_queries)
|
||||
|
||||
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
pid = memgraph.pid
|
||||
try:
|
||||
os.kill(pid, SIGNAL_SIGTERM)
|
||||
except os.OSError:
|
||||
assert False, "Memgraph process didn't exit cleanly!"
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -11,12 +11,13 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
import tempfile
|
||||
import shutil
|
||||
import time
|
||||
|
||||
from common import get_absolute_path, set_cpus
|
||||
|
||||
try:
|
||||
@ -36,13 +37,12 @@ class Memgraph:
|
||||
"""
|
||||
Knows how to start and stop memgraph.
|
||||
"""
|
||||
|
||||
def __init__(self, args, num_workers):
|
||||
self.log = logging.getLogger("MemgraphRunner")
|
||||
argp = ArgumentParser("MemgraphArgumentParser")
|
||||
argp.add_argument("--runner-bin",
|
||||
default=get_absolute_path("memgraph", "build"))
|
||||
argp.add_argument("--port", default="7687",
|
||||
help="Database and client port")
|
||||
argp.add_argument("--runner-bin", default=get_absolute_path("memgraph", "build"))
|
||||
argp.add_argument("--port", default="7687", help="Database and client port")
|
||||
argp.add_argument("--data-directory", default=None)
|
||||
argp.add_argument("--storage-snapshot-on-exit", action="store_true")
|
||||
argp.add_argument("--storage-recover-on-startup", action="store_true")
|
||||
@ -55,8 +55,7 @@ class Memgraph:
|
||||
|
||||
def start(self):
|
||||
self.log.info("start")
|
||||
database_args = ["--bolt-port", self.args.port,
|
||||
"--query-execution-timeout-sec", "0"]
|
||||
database_args = ["--bolt-port", self.args.port, "--query-execution-timeout-sec", "0"]
|
||||
if self.num_workers:
|
||||
database_args += ["--bolt-num-workers", str(self.num_workers)]
|
||||
if self.args.data_directory:
|
||||
@ -82,15 +81,13 @@ class Neo:
|
||||
"""
|
||||
Knows how to start and stop neo4j.
|
||||
"""
|
||||
|
||||
def __init__(self, args, config):
|
||||
self.log = logging.getLogger("NeoRunner")
|
||||
argp = ArgumentParser("NeoArgumentParser")
|
||||
argp.add_argument("--runner-bin", default=get_absolute_path(
|
||||
"neo4j/bin/neo4j", "libs"))
|
||||
argp.add_argument("--port", default="7687",
|
||||
help="Database and client port")
|
||||
argp.add_argument("--http-port", default="7474",
|
||||
help="Database and client port")
|
||||
argp.add_argument("--runner-bin", default=get_absolute_path("neo4j/bin/neo4j", "libs"))
|
||||
argp.add_argument("--port", default="7687", help="Database and client port")
|
||||
argp.add_argument("--http-port", default="7474", help="Database and client port")
|
||||
self.log.info("Initializing Runner with arguments %r", args)
|
||||
self.args, _ = argp.parse_known_args(args)
|
||||
self.config = config
|
||||
@ -105,24 +102,22 @@ class Neo:
|
||||
self.neo4j_home_path = tempfile.mkdtemp(dir="/dev/shm")
|
||||
|
||||
try:
|
||||
os.symlink(os.path.join(get_absolute_path("neo4j", "libs"), "lib"),
|
||||
os.path.join(self.neo4j_home_path, "lib"))
|
||||
os.symlink(
|
||||
os.path.join(get_absolute_path("neo4j", "libs"), "lib"), os.path.join(self.neo4j_home_path, "lib")
|
||||
)
|
||||
neo4j_conf_dir = os.path.join(self.neo4j_home_path, "conf")
|
||||
neo4j_conf_file = os.path.join(neo4j_conf_dir, "neo4j.conf")
|
||||
os.mkdir(neo4j_conf_dir)
|
||||
shutil.copyfile(self.config, neo4j_conf_file)
|
||||
with open(neo4j_conf_file, "a") as f:
|
||||
f.write("\ndbms.connector.bolt.listen_address=:" +
|
||||
self.args.port + "\n")
|
||||
f.write("\ndbms.connector.http.listen_address=:" +
|
||||
self.args.http_port + "\n")
|
||||
f.write("\ndbms.connector.bolt.listen_address=:" + self.args.port + "\n")
|
||||
f.write("\ndbms.connector.http.listen_address=:" + self.args.http_port + "\n")
|
||||
|
||||
# environment
|
||||
cwd = os.path.dirname(self.args.runner_bin)
|
||||
env = {"NEO4J_HOME": self.neo4j_home_path}
|
||||
|
||||
self.database_bin.run(self.args.runner_bin, args=["console"],
|
||||
env=env, timeout=600, cwd=cwd)
|
||||
self.database_bin.run(self.args.runner_bin, args=["console"], env=env, timeout=600, cwd=cwd)
|
||||
except:
|
||||
shutil.rmtree(self.neo4j_home_path)
|
||||
raise Exception("Couldn't run Neo4j!")
|
||||
|
Loading…
Reference in New Issue
Block a user