Use extent hooks for per query memory limit (#1340)
This commit is contained in:
parent
3d4d841753
commit
a84f570c6d
@ -111,6 +111,22 @@ enum mgp_error mgp_global_aligned_alloc(size_t size_in_bytes, size_t alignment,
|
|||||||
/// The behavior is undefined if `ptr` is not a value returned from a prior
|
/// The behavior is undefined if `ptr` is not a value returned from a prior
|
||||||
/// mgp_global_alloc() or mgp_global_aligned_alloc().
|
/// mgp_global_alloc() or mgp_global_aligned_alloc().
|
||||||
void mgp_global_free(void *p);
|
void mgp_global_free(void *p);
|
||||||
|
|
||||||
|
/// State of the graph database.
|
||||||
|
struct mgp_graph;
|
||||||
|
|
||||||
|
/// Allocations are tracked only for master thread. If new threads are spawned
|
||||||
|
/// inside procedure, by calling following function
|
||||||
|
/// you can start tracking allocations for current thread too. This
|
||||||
|
/// is important if you need query memory limit to work
|
||||||
|
/// for given procedure or per procedure memory limit.
|
||||||
|
enum mgp_error mgp_track_current_thread_allocations(struct mgp_graph *graph);
|
||||||
|
|
||||||
|
/// Once allocations are tracked for current thread, you need to stop tracking allocations
|
||||||
|
/// for given thread, before thread finishes with execution, or is detached.
|
||||||
|
/// Otherwise it might result in slowdown of system due to unnecessary tracking of
|
||||||
|
/// allocations.
|
||||||
|
enum mgp_error mgp_untrack_current_thread_allocations(struct mgp_graph *graph);
|
||||||
///@}
|
///@}
|
||||||
|
|
||||||
/// @name Operations on mgp_value
|
/// @name Operations on mgp_value
|
||||||
@ -854,9 +870,6 @@ enum mgp_error mgp_edge_set_properties(struct mgp_edge *e, struct mgp_map *prope
|
|||||||
enum mgp_error mgp_edge_iter_properties(struct mgp_edge *e, struct mgp_memory *memory,
|
enum mgp_error mgp_edge_iter_properties(struct mgp_edge *e, struct mgp_memory *memory,
|
||||||
struct mgp_properties_iterator **result);
|
struct mgp_properties_iterator **result);
|
||||||
|
|
||||||
/// State of the graph database.
|
|
||||||
struct mgp_graph;
|
|
||||||
|
|
||||||
/// Get the vertex corresponding to given ID, or NULL if no such vertex exists.
|
/// Get the vertex corresponding to given ID, or NULL if no such vertex exists.
|
||||||
/// Resulting vertex must be freed using mgp_vertex_destroy.
|
/// Resulting vertex must be freed using mgp_vertex_destroy.
|
||||||
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate the vertex.
|
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate the vertex.
|
||||||
|
@ -23,7 +23,7 @@
|
|||||||
#include "glue/run_id.hpp"
|
#include "glue/run_id.hpp"
|
||||||
#include "helpers.hpp"
|
#include "helpers.hpp"
|
||||||
#include "license/license_sender.hpp"
|
#include "license/license_sender.hpp"
|
||||||
#include "memory/memory_control.hpp"
|
#include "memory/global_memory_control.hpp"
|
||||||
#include "query/config.hpp"
|
#include "query/config.hpp"
|
||||||
#include "query/discard_value_stream.hpp"
|
#include "query/discard_value_stream.hpp"
|
||||||
#include "query/interpreter.hpp"
|
#include "query/interpreter.hpp"
|
||||||
@ -512,6 +512,7 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
server.AwaitShutdown();
|
server.AwaitShutdown();
|
||||||
websocket_server.AwaitShutdown();
|
websocket_server.AwaitShutdown();
|
||||||
|
memgraph::memory::UnsetHooks();
|
||||||
#ifdef MG_ENTERPRISE
|
#ifdef MG_ENTERPRISE
|
||||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||||
metrics_server.AwaitShutdown();
|
metrics_server.AwaitShutdown();
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
set(memory_src_files
|
set(memory_src_files
|
||||||
new_delete.cpp
|
new_delete.cpp
|
||||||
memory_control.cpp)
|
global_memory_control.cpp
|
||||||
|
query_memory_control.cpp)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,12 +9,16 @@
|
|||||||
// by the Apache License, Version 2.0, included in the file
|
// by the Apache License, Version 2.0, included in the file
|
||||||
// licenses/APL.txt.
|
// licenses/APL.txt.
|
||||||
|
|
||||||
#include "memory_control.hpp"
|
#include <atomic>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "global_memory_control.hpp"
|
||||||
|
#include "query_memory_control.hpp"
|
||||||
#include "utils/logging.hpp"
|
#include "utils/logging.hpp"
|
||||||
#include "utils/memory_tracker.hpp"
|
#include "utils/memory_tracker.hpp"
|
||||||
|
|
||||||
#if USE_JEMALLOC
|
#if USE_JEMALLOC
|
||||||
#include <jemalloc/jemalloc.h>
|
#include "jemalloc/jemalloc.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace memgraph::memory {
|
namespace memgraph::memory {
|
||||||
@ -57,12 +61,24 @@ void *my_alloc(extent_hooks_t *extent_hooks, void *new_addr, size_t size, size_t
|
|||||||
// This needs to be before, to throw exception in case of too big alloc
|
// This needs to be before, to throw exception in case of too big alloc
|
||||||
if (*commit) [[likely]] {
|
if (*commit) [[likely]] {
|
||||||
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(size));
|
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(size));
|
||||||
|
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||||
|
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||||
|
if (memory_tracker != nullptr) [[likely]] {
|
||||||
|
memory_tracker->Alloc(static_cast<int64_t>(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto *ptr = old_hooks->alloc(extent_hooks, new_addr, size, alignment, zero, commit, arena_ind);
|
auto *ptr = old_hooks->alloc(extent_hooks, new_addr, size, alignment, zero, commit, arena_ind);
|
||||||
if (ptr == nullptr) [[unlikely]] {
|
if (ptr == nullptr) [[unlikely]] {
|
||||||
if (*commit) {
|
if (*commit) {
|
||||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
||||||
|
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||||
|
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||||
|
if (memory_tracker != nullptr) [[likely]] {
|
||||||
|
memory_tracker->Free(static_cast<int64_t>(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
@ -79,6 +95,13 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo
|
|||||||
|
|
||||||
if (committed) [[likely]] {
|
if (committed) [[likely]] {
|
||||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
||||||
|
|
||||||
|
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||||
|
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||||
|
if (memory_tracker != nullptr) [[likely]] {
|
||||||
|
memory_tracker->Free(static_cast<int64_t>(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
@ -87,6 +110,12 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo
|
|||||||
static void my_destroy(extent_hooks_t *extent_hooks, void *addr, size_t size, bool committed, unsigned arena_ind) {
|
static void my_destroy(extent_hooks_t *extent_hooks, void *addr, size_t size, bool committed, unsigned arena_ind) {
|
||||||
if (committed) [[likely]] {
|
if (committed) [[likely]] {
|
||||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
|
||||||
|
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||||
|
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||||
|
if (memory_tracker != nullptr) [[likely]] {
|
||||||
|
memory_tracker->Free(static_cast<int64_t>(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
old_hooks->destroy(extent_hooks, addr, size, committed, arena_ind);
|
old_hooks->destroy(extent_hooks, addr, size, committed, arena_ind);
|
||||||
@ -101,6 +130,12 @@ static bool my_commit(extent_hooks_t *extent_hooks, void *addr, size_t size, siz
|
|||||||
}
|
}
|
||||||
|
|
||||||
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(length));
|
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(length));
|
||||||
|
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||||
|
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||||
|
if (memory_tracker != nullptr) [[likely]] {
|
||||||
|
memory_tracker->Alloc(static_cast<int64_t>(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -115,6 +150,12 @@ static bool my_decommit(extent_hooks_t *extent_hooks, void *addr, size_t size, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
|
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
|
||||||
|
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||||
|
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||||
|
if (memory_tracker != nullptr) [[likely]] {
|
||||||
|
memory_tracker->Free(static_cast<int64_t>(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -129,6 +170,13 @@ static bool my_purge_forced(extent_hooks_t *extent_hooks, void *addr, size_t siz
|
|||||||
}
|
}
|
||||||
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
|
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
|
||||||
|
|
||||||
|
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
|
||||||
|
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
|
||||||
|
if (memory_tracker != nullptr) [[likely]] {
|
||||||
|
memory_tracker->Alloc(static_cast<int64_t>(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,6 +201,7 @@ void SetHooks() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_arenas; i++) {
|
for (int i = 0; i < n_arenas; i++) {
|
||||||
|
GetQueriesMemoryControl().InitializeArenaCounter(i);
|
||||||
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
|
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
|
||||||
|
|
||||||
size_t hooks_len = sizeof(old_hooks);
|
size_t hooks_len = sizeof(old_hooks);
|
||||||
@ -197,6 +246,45 @@ void SetHooks() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void UnsetHooks() {
|
||||||
|
#if USE_JEMALLOC
|
||||||
|
|
||||||
|
uint64_t allocated{0};
|
||||||
|
uint64_t sz{sizeof(allocated)};
|
||||||
|
|
||||||
|
sz = sizeof(unsigned);
|
||||||
|
unsigned n_arenas{0};
|
||||||
|
int err = mallctl("opt.narenas", (void *)&n_arenas, &sz, nullptr, 0);
|
||||||
|
|
||||||
|
if (err) {
|
||||||
|
LOG_FATAL("Error setting default hooks for jemalloc arenas");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_arenas; i++) {
|
||||||
|
GetQueriesMemoryControl().InitializeArenaCounter(i);
|
||||||
|
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
|
||||||
|
|
||||||
|
MG_ASSERT(old_hooks);
|
||||||
|
MG_ASSERT(old_hooks->alloc);
|
||||||
|
MG_ASSERT(old_hooks->dalloc);
|
||||||
|
MG_ASSERT(old_hooks->destroy);
|
||||||
|
MG_ASSERT(old_hooks->commit);
|
||||||
|
MG_ASSERT(old_hooks->decommit);
|
||||||
|
MG_ASSERT(old_hooks->purge_forced);
|
||||||
|
MG_ASSERT(old_hooks->purge_lazy);
|
||||||
|
MG_ASSERT(old_hooks->split);
|
||||||
|
MG_ASSERT(old_hooks->merge);
|
||||||
|
|
||||||
|
err = mallctl(func_name.c_str(), nullptr, nullptr, &old_hooks, sizeof(old_hooks));
|
||||||
|
|
||||||
|
if (err) {
|
||||||
|
LOG_FATAL("Error setting default hooks for jemalloc arena {}", i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void PurgeUnusedMemory() {
|
void PurgeUnusedMemory() {
|
||||||
#if USE_JEMALLOC
|
#if USE_JEMALLOC
|
||||||
mallctl("arena." STRINGIFY(MALLCTL_ARENAS_ALL) ".purge", nullptr, nullptr, nullptr, 0);
|
mallctl("arena." STRINGIFY(MALLCTL_ARENAS_ALL) ".purge", nullptr, nullptr, nullptr, 0);
|
@ -17,5 +17,6 @@ namespace memgraph::memory {
|
|||||||
|
|
||||||
void PurgeUnusedMemory();
|
void PurgeUnusedMemory();
|
||||||
void SetHooks();
|
void SetHooks();
|
||||||
|
void UnsetHooks();
|
||||||
|
|
||||||
} // namespace memgraph::memory
|
} // namespace memgraph::memory
|
140
src/memory/query_memory_control.cpp
Normal file
140
src/memory/query_memory_control.cpp
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
// Copyright 2023 Memgraph Ltd.
|
||||||
|
//
|
||||||
|
// Use of this software is governed by the Business Source License
|
||||||
|
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||||
|
// License, and you may not use this file except in compliance with the Business Source License.
|
||||||
|
//
|
||||||
|
// As of the Change Date specified in that file, in accordance with
|
||||||
|
// the Business Source License, use of this software will be governed
|
||||||
|
// by the Apache License, Version 2.0, included in the file
|
||||||
|
// licenses/APL.txt.
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <iostream>
|
||||||
|
#include <optional>
|
||||||
|
#include <shared_mutex>
|
||||||
|
#include <thread>
|
||||||
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "query_memory_control.hpp"
|
||||||
|
#include "utils/exceptions.hpp"
|
||||||
|
#include "utils/logging.hpp"
|
||||||
|
#include "utils/memory_tracker.hpp"
|
||||||
|
#include "utils/rw_spin_lock.hpp"
|
||||||
|
|
||||||
|
#if USE_JEMALLOC
|
||||||
|
#include "jemalloc/jemalloc.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace memgraph::memory {
|
||||||
|
|
||||||
|
#if USE_JEMALLOC
|
||||||
|
|
||||||
|
unsigned QueriesMemoryControl::GetArenaForThread() {
|
||||||
|
unsigned thread_arena{0};
|
||||||
|
size_t size_thread_arena = sizeof(thread_arena);
|
||||||
|
int err = mallctl("thread.arena", &thread_arena, &size_thread_arena, nullptr, 0);
|
||||||
|
if (err) {
|
||||||
|
LOG_FATAL("Can't get arena for thread.");
|
||||||
|
}
|
||||||
|
return thread_arena;
|
||||||
|
}
|
||||||
|
|
||||||
|
void QueriesMemoryControl::AddTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_add(1); }
|
||||||
|
|
||||||
|
void QueriesMemoryControl::RemoveTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_sub(1); }
|
||||||
|
|
||||||
|
void QueriesMemoryControl::UpdateThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) {
|
||||||
|
auto accessor = thread_id_to_transaction_id.access();
|
||||||
|
accessor.insert({thread_id, transaction_id});
|
||||||
|
}
|
||||||
|
|
||||||
|
void QueriesMemoryControl::EraseThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) {
|
||||||
|
auto accessor = thread_id_to_transaction_id.access();
|
||||||
|
auto elem = accessor.find(thread_id);
|
||||||
|
MG_ASSERT(elem != accessor.end() && elem->transaction_id == transaction_id);
|
||||||
|
accessor.remove(thread_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
utils::MemoryTracker *QueriesMemoryControl::GetTrackerCurrentThread() {
|
||||||
|
auto thread_id_to_transaction_id_accessor = thread_id_to_transaction_id.access();
|
||||||
|
|
||||||
|
// we might be just constructing mapping between thread id and transaction id
|
||||||
|
// so we miss this allocation
|
||||||
|
auto thread_id_to_transaction_id_elem = thread_id_to_transaction_id_accessor.find(std::this_thread::get_id());
|
||||||
|
if (thread_id_to_transaction_id_elem == thread_id_to_transaction_id_accessor.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
|
||||||
|
auto transaction_id_to_tracker =
|
||||||
|
transaction_id_to_tracker_accessor.find(thread_id_to_transaction_id_elem->transaction_id);
|
||||||
|
return &transaction_id_to_tracker->tracker;
|
||||||
|
}
|
||||||
|
|
||||||
|
void QueriesMemoryControl::CreateTransactionIdTracker(uint64_t transaction_id, size_t inital_limit) {
|
||||||
|
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
|
||||||
|
|
||||||
|
auto [elem, result] = transaction_id_to_tracker_accessor.insert({transaction_id, utils::MemoryTracker{}});
|
||||||
|
|
||||||
|
elem->tracker.SetMaximumHardLimit(inital_limit);
|
||||||
|
elem->tracker.SetHardLimit(inital_limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool QueriesMemoryControl::CheckTransactionIdTrackerExists(uint64_t transaction_id) {
|
||||||
|
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
|
||||||
|
return transaction_id_to_tracker_accessor.contains(transaction_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool QueriesMemoryControl::EraseTransactionIdTracker(uint64_t transaction_id) {
|
||||||
|
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
|
||||||
|
auto removed = transaction_id_to_tracker.access().remove(transaction_id);
|
||||||
|
return removed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool QueriesMemoryControl::IsArenaTracked(unsigned arena_ind) {
|
||||||
|
return arena_tracking[arena_ind].load(std::memory_order_acquire) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void QueriesMemoryControl::InitializeArenaCounter(unsigned arena_ind) {
|
||||||
|
arena_tracking[arena_ind].store(0, std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void StartTrackingCurrentThreadTransaction(uint64_t transaction_id) {
|
||||||
|
#if USE_JEMALLOC
|
||||||
|
GetQueriesMemoryControl().UpdateThreadToTransactionId(std::this_thread::get_id(), transaction_id);
|
||||||
|
GetQueriesMemoryControl().AddTrackingOnArena(QueriesMemoryControl::GetArenaForThread());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void StopTrackingCurrentThreadTransaction(uint64_t transaction_id) {
|
||||||
|
#if USE_JEMALLOC
|
||||||
|
GetQueriesMemoryControl().EraseThreadToTransactionId(std::this_thread::get_id(), transaction_id);
|
||||||
|
GetQueriesMemoryControl().RemoveTrackingOnArena(QueriesMemoryControl::GetArenaForThread());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit) {
|
||||||
|
#if USE_JEMALLOC
|
||||||
|
if (GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
GetQueriesMemoryControl().CreateTransactionIdTracker(transaction_id, limit);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void TryStopTrackingOnTransaction(uint64_t transaction_id) {
|
||||||
|
#if USE_JEMALLOC
|
||||||
|
if (!GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
GetQueriesMemoryControl().EraseTransactionIdTracker(transaction_id);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace memgraph::memory
|
141
src/memory/query_memory_control.hpp
Normal file
141
src/memory/query_memory_control.hpp
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
// Copyright 2023 Memgraph Ltd.
|
||||||
|
//
|
||||||
|
// Use of this software is governed by the Business Source License
|
||||||
|
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||||
|
// License, and you may not use this file except in compliance with the Business Source License.
|
||||||
|
//
|
||||||
|
// As of the Change Date specified in that file, in accordance with
|
||||||
|
// the Business Source License, use of this software will be governed
|
||||||
|
// by the Apache License, Version 2.0, included in the file
|
||||||
|
// licenses/APL.txt.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <thread>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "utils/memory_tracker.hpp"
|
||||||
|
#include "utils/skip_list.hpp"
|
||||||
|
|
||||||
|
namespace memgraph::memory {
|
||||||
|
|
||||||
|
#if USE_JEMALLOC
|
||||||
|
|
||||||
|
// Track memory allocations per query.
|
||||||
|
// Multiple threads can allocate inside one transaction.
|
||||||
|
// If user forgets to unregister tracking for that thread before it dies, it will continue to
|
||||||
|
// track allocations for that arena indefinitely.
|
||||||
|
// As multiple queries can be executed inside one transaction, one by one (multi-transaction)
|
||||||
|
// it is necessary to restart tracking at the beginning of new query for that transaction.
|
||||||
|
class QueriesMemoryControl {
|
||||||
|
public:
|
||||||
|
/*
|
||||||
|
Arena stats
|
||||||
|
*/
|
||||||
|
|
||||||
|
static unsigned GetArenaForThread();
|
||||||
|
|
||||||
|
// Add counter on threads allocating inside arena
|
||||||
|
void AddTrackingOnArena(unsigned);
|
||||||
|
|
||||||
|
// Remove counter on threads allocating in arena
|
||||||
|
void RemoveTrackingOnArena(unsigned);
|
||||||
|
|
||||||
|
// Are any threads using current arena for allocations
|
||||||
|
// Multiple threads can allocate inside one arena
|
||||||
|
bool IsArenaTracked(unsigned);
|
||||||
|
|
||||||
|
// Initialize arena counter
|
||||||
|
void InitializeArenaCounter(unsigned);
|
||||||
|
|
||||||
|
/*
|
||||||
|
Transaction id <-> tracker
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Create new tracker for transaction_id with initial limit
|
||||||
|
void CreateTransactionIdTracker(uint64_t, size_t);
|
||||||
|
|
||||||
|
// Check if tracker for given transaction id exists
|
||||||
|
bool CheckTransactionIdTrackerExists(uint64_t);
|
||||||
|
|
||||||
|
// Remove current tracker for transaction_id
|
||||||
|
bool EraseTransactionIdTracker(uint64_t);
|
||||||
|
|
||||||
|
/*
|
||||||
|
Thread handlings
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Map thread to transaction with given id
|
||||||
|
// This way we can know which thread belongs to which transaction
|
||||||
|
// and get correct tracker for given transaction
|
||||||
|
void UpdateThreadToTransactionId(const std::thread::id &, uint64_t);
|
||||||
|
|
||||||
|
// Remove tracking of thread from transaction.
|
||||||
|
// Important to reset if one thread gets reused for different transaction
|
||||||
|
void EraseThreadToTransactionId(const std::thread::id &, uint64_t);
|
||||||
|
|
||||||
|
// C-API functionality for thread to transaction mapping
|
||||||
|
void UpdateThreadToTransactionId(const char *, uint64_t);
|
||||||
|
|
||||||
|
// C-API functionality for thread to transaction unmapping
|
||||||
|
void EraseThreadToTransactionId(const char *, uint64_t);
|
||||||
|
|
||||||
|
// Get tracker to current thread if exists, otherwise return
|
||||||
|
// nullptr. This can happen only if tracker is still
|
||||||
|
// being constructed.
|
||||||
|
utils::MemoryTracker *GetTrackerCurrentThread();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<unsigned, std::atomic<int>> arena_tracking;
|
||||||
|
|
||||||
|
struct ThreadIdToTransactionId {
|
||||||
|
std::thread::id thread_id;
|
||||||
|
uint64_t transaction_id;
|
||||||
|
|
||||||
|
bool operator<(const ThreadIdToTransactionId &other) const { return thread_id < other.thread_id; }
|
||||||
|
bool operator==(const ThreadIdToTransactionId &other) const { return thread_id == other.thread_id; }
|
||||||
|
|
||||||
|
bool operator<(const std::thread::id other) const { return thread_id < other; }
|
||||||
|
bool operator==(const std::thread::id other) const { return thread_id == other; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TransactionIdToTracker {
|
||||||
|
uint64_t transaction_id;
|
||||||
|
utils::MemoryTracker tracker;
|
||||||
|
|
||||||
|
bool operator<(const TransactionIdToTracker &other) const { return transaction_id < other.transaction_id; }
|
||||||
|
bool operator==(const TransactionIdToTracker &other) const { return transaction_id == other.transaction_id; }
|
||||||
|
|
||||||
|
bool operator<(uint64_t other) const { return transaction_id < other; }
|
||||||
|
bool operator==(uint64_t other) const { return transaction_id == other; }
|
||||||
|
};
|
||||||
|
|
||||||
|
utils::SkipList<ThreadIdToTransactionId> thread_id_to_transaction_id;
|
||||||
|
utils::SkipList<TransactionIdToTracker> transaction_id_to_tracker;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline QueriesMemoryControl &GetQueriesMemoryControl() {
|
||||||
|
static QueriesMemoryControl queries_memory_control_;
|
||||||
|
return queries_memory_control_;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// API function call for to start tracking current thread for given transaction.
|
||||||
|
// Does nothing if jemalloc is not enabled
|
||||||
|
void StartTrackingCurrentThreadTransaction(uint64_t transaction_id);
|
||||||
|
|
||||||
|
// API function call for to stop tracking current thread for given transaction.
|
||||||
|
// Does nothing if jemalloc is not enabled
|
||||||
|
void StopTrackingCurrentThreadTransaction(uint64_t transaction_id);
|
||||||
|
|
||||||
|
// API function call for try to create tracker for transaction and set it to given limit.
|
||||||
|
// Does nothing if jemalloc is not enabled. Does nothing if tracker already exists
|
||||||
|
void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit);
|
||||||
|
|
||||||
|
// API function call to stop tracking for given transaction.
|
||||||
|
// Does nothing if jemalloc is not enabled. Does nothing if tracker doesn't exist
|
||||||
|
void TryStopTrackingOnTransaction(uint64_t transaction_id);
|
||||||
|
|
||||||
|
} // namespace memgraph::memory
|
@ -21,6 +21,10 @@ namespace memgraph::query {
|
|||||||
SubgraphDbAccessor::SubgraphDbAccessor(query::DbAccessor db_accessor, Graph *graph)
|
SubgraphDbAccessor::SubgraphDbAccessor(query::DbAccessor db_accessor, Graph *graph)
|
||||||
: db_accessor_(db_accessor), graph_(graph) {}
|
: db_accessor_(db_accessor), graph_(graph) {}
|
||||||
|
|
||||||
|
void SubgraphDbAccessor::TrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); }
|
||||||
|
|
||||||
|
void SubgraphDbAccessor::UntrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); }
|
||||||
|
|
||||||
storage::PropertyId SubgraphDbAccessor::NameToProperty(const std::string_view name) {
|
storage::PropertyId SubgraphDbAccessor::NameToProperty(const std::string_view name) {
|
||||||
return db_accessor_.NameToProperty(name);
|
return db_accessor_.NameToProperty(name);
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include <cppitertools/filter.hpp>
|
#include <cppitertools/filter.hpp>
|
||||||
#include <cppitertools/imap.hpp>
|
#include <cppitertools/imap.hpp>
|
||||||
|
|
||||||
|
#include "memory/query_memory_control.hpp"
|
||||||
#include "query/exceptions.hpp"
|
#include "query/exceptions.hpp"
|
||||||
#include "storage/v2/edge_accessor.hpp"
|
#include "storage/v2/edge_accessor.hpp"
|
||||||
#include "storage/v2/id_types.hpp"
|
#include "storage/v2/id_types.hpp"
|
||||||
@ -372,6 +373,16 @@ class DbAccessor final {
|
|||||||
|
|
||||||
void FinalizeTransaction() { accessor_->FinalizeTransaction(); }
|
void FinalizeTransaction() { accessor_->FinalizeTransaction(); }
|
||||||
|
|
||||||
|
void TrackCurrentThreadAllocations() {
|
||||||
|
memgraph::memory::StartTrackingCurrentThreadTransaction(*accessor_->GetTransactionId());
|
||||||
|
}
|
||||||
|
|
||||||
|
void UntrackCurrentThreadAllocations() {
|
||||||
|
memgraph::memory::StopTrackingCurrentThreadTransaction(*accessor_->GetTransactionId());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<uint64_t> GetTransactionId() { return accessor_->GetTransactionId(); }
|
||||||
|
|
||||||
VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); }
|
VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); }
|
||||||
|
|
||||||
VerticesIterable Vertices(storage::View view, storage::LabelId label) {
|
VerticesIterable Vertices(storage::View view, storage::LabelId label) {
|
||||||
@ -640,6 +651,14 @@ class SubgraphDbAccessor final {
|
|||||||
|
|
||||||
static SubgraphDbAccessor *MakeSubgraphDbAccessor(DbAccessor *db_accessor, Graph *graph);
|
static SubgraphDbAccessor *MakeSubgraphDbAccessor(DbAccessor *db_accessor, Graph *graph);
|
||||||
|
|
||||||
|
void TrackThreadAllocations(const char *thread_id);
|
||||||
|
|
||||||
|
void TrackCurrentThreadAllocations();
|
||||||
|
|
||||||
|
void UntrackThreadAllocations(const char *thread_id);
|
||||||
|
|
||||||
|
void UntrackCurrentThreadAllocations();
|
||||||
|
|
||||||
storage::PropertyId NameToProperty(std::string_view name);
|
storage::PropertyId NameToProperty(std::string_view name);
|
||||||
|
|
||||||
storage::LabelId NameToLabel(std::string_view name);
|
storage::LabelId NameToLabel(std::string_view name);
|
||||||
|
@ -26,6 +26,7 @@
|
|||||||
#include <optional>
|
#include <optional>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <tuple>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
@ -39,7 +40,8 @@
|
|||||||
#include "flags/run_time_configurable.hpp"
|
#include "flags/run_time_configurable.hpp"
|
||||||
#include "glue/communication.hpp"
|
#include "glue/communication.hpp"
|
||||||
#include "license/license.hpp"
|
#include "license/license.hpp"
|
||||||
#include "memory/memory_control.hpp"
|
#include "memory/global_memory_control.hpp"
|
||||||
|
#include "memory/query_memory_control.hpp"
|
||||||
#include "query/config.hpp"
|
#include "query/config.hpp"
|
||||||
#include "query/constants.hpp"
|
#include "query/constants.hpp"
|
||||||
#include "query/context.hpp"
|
#include "query/context.hpp"
|
||||||
@ -1283,6 +1285,24 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
|
|||||||
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n,
|
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n,
|
||||||
const std::vector<Symbol> &output_symbols,
|
const std::vector<Symbol> &output_symbols,
|
||||||
std::map<std::string, TypedValue> *summary) {
|
std::map<std::string, TypedValue> *summary) {
|
||||||
|
std::optional<uint64_t> transaction_id = ctx_.db_accessor->GetTransactionId();
|
||||||
|
MG_ASSERT(transaction_id.has_value());
|
||||||
|
|
||||||
|
if (memory_limit_) {
|
||||||
|
memgraph::memory::TryStartTrackingOnTransaction(*transaction_id, *memory_limit_);
|
||||||
|
memgraph::memory::StartTrackingCurrentThreadTransaction(*transaction_id);
|
||||||
|
}
|
||||||
|
utils::OnScopeExit<std::function<void()>> reset_query_limit{
|
||||||
|
[memory_limit = memory_limit_, transaction_id = *transaction_id]() {
|
||||||
|
if (memory_limit) {
|
||||||
|
// Stopping tracking of transaction occurs in interpreter::pull
|
||||||
|
// Exception can occur so we need to handle that case there.
|
||||||
|
// We can't stop tracking here as there can be multiple pulls
|
||||||
|
// so we need to take care of that after everything was pulled
|
||||||
|
memgraph::memory::StopTrackingCurrentThreadTransaction(transaction_id);
|
||||||
|
}
|
||||||
|
}};
|
||||||
|
|
||||||
// Set up temporary memory for a single Pull. Initial memory comes from the
|
// Set up temporary memory for a single Pull. Initial memory comes from the
|
||||||
// stack. 256 KiB should fit on the stack and should be more than enough for a
|
// stack. 256 KiB should fit on the stack and should be more than enough for a
|
||||||
// single `Pull`.
|
// single `Pull`.
|
||||||
@ -1306,13 +1326,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *strea
|
|||||||
pool_memory.emplace(kMaxBlockPerChunks, 1024, &monotonic_memory, &resource_with_exception);
|
pool_memory.emplace(kMaxBlockPerChunks, 1024, &monotonic_memory, &resource_with_exception);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<utils::LimitedMemoryResource> maybe_limited_resource;
|
|
||||||
if (memory_limit_) {
|
|
||||||
maybe_limited_resource.emplace(&*pool_memory, *memory_limit_);
|
|
||||||
ctx_.evaluation_context.memory = &*maybe_limited_resource;
|
|
||||||
} else {
|
|
||||||
ctx_.evaluation_context.memory = &*pool_memory;
|
ctx_.evaluation_context.memory = &*pool_memory;
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if a result was pulled.
|
// Returns true if a result was pulled.
|
||||||
const auto pull_result = [&]() -> bool { return cursor_->Pull(frame_, ctx_); };
|
const auto pull_result = [&]() -> bool { return cursor_->Pull(frame_, ctx_); };
|
||||||
@ -1379,6 +1393,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *strea
|
|||||||
}
|
}
|
||||||
cursor_->Shutdown();
|
cursor_->Shutdown();
|
||||||
ctx_.profile_execution_time = execution_time_;
|
ctx_.profile_execution_time = execution_time_;
|
||||||
|
|
||||||
return GetStatsWithTotalTime(ctx_);
|
return GetStatsWithTotalTime(ctx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#include <gflags/gflags.h>
|
#include <gflags/gflags.h>
|
||||||
|
|
||||||
#include "dbms/database.hpp"
|
#include "dbms/database.hpp"
|
||||||
|
#include "memory/query_memory_control.hpp"
|
||||||
#include "query/auth_checker.hpp"
|
#include "query/auth_checker.hpp"
|
||||||
#include "query/auth_query_handler.hpp"
|
#include "query/auth_query_handler.hpp"
|
||||||
#include "query/config.hpp"
|
#include "query/config.hpp"
|
||||||
@ -402,6 +403,9 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
|
|||||||
// If the query finished executing, we have received a value which tells
|
// If the query finished executing, we have received a value which tells
|
||||||
// us what to do after.
|
// us what to do after.
|
||||||
if (maybe_res) {
|
if (maybe_res) {
|
||||||
|
if (current_transaction_) {
|
||||||
|
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
|
||||||
|
}
|
||||||
// Save its summary
|
// Save its summary
|
||||||
maybe_summary.emplace(std::move(query_execution->summary));
|
maybe_summary.emplace(std::move(query_execution->summary));
|
||||||
if (!query_execution->notifications.empty()) {
|
if (!query_execution->notifications.empty()) {
|
||||||
@ -440,9 +444,15 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (const ExplicitTransactionUsageException &) {
|
} catch (const ExplicitTransactionUsageException &) {
|
||||||
|
if (current_transaction_) {
|
||||||
|
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
|
||||||
|
}
|
||||||
query_execution.reset(nullptr);
|
query_execution.reset(nullptr);
|
||||||
throw;
|
throw;
|
||||||
} catch (const utils::BasicException &) {
|
} catch (const utils::BasicException &) {
|
||||||
|
if (current_transaction_) {
|
||||||
|
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
|
||||||
|
}
|
||||||
// Trigger first failed query
|
// Trigger first failed query
|
||||||
metrics::FirstFailedQuery();
|
metrics::FirstFailedQuery();
|
||||||
memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery);
|
memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery);
|
||||||
|
@ -3596,3 +3596,15 @@ mgp_error mgp_log(const mgp_log_level log_level, const char *output) {
|
|||||||
throw std::invalid_argument{fmt::format("Invalid log level: {}", log_level)};
|
throw std::invalid_argument{fmt::format("Invalid log level: {}", log_level)};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mgp_error mgp_track_current_thread_allocations(mgp_graph *graph) {
|
||||||
|
return WrapExceptions([&]() {
|
||||||
|
std::visit([](auto *db_accessor) -> void { db_accessor->TrackCurrentThreadAllocations(); }, graph->impl);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
mgp_error mgp_untrack_current_thread_allocations(mgp_graph *graph) {
|
||||||
|
return WrapExceptions([&]() {
|
||||||
|
std::visit([](auto *db_accessor) -> void { db_accessor->UntrackCurrentThreadAllocations(); }, graph->impl);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
@ -89,6 +89,13 @@ void MemoryTracker::TryRaiseHardLimit(const int64_t limit) {
|
|||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MemoryTracker::ResetTrackings() {
|
||||||
|
hard_limit_.store(0, std::memory_order_relaxed);
|
||||||
|
peak_.store(0, std::memory_order_relaxed);
|
||||||
|
amount_.store(0, std::memory_order_relaxed);
|
||||||
|
maximum_hard_limit_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
void MemoryTracker::SetMaximumHardLimit(const int64_t limit) {
|
void MemoryTracker::SetMaximumHardLimit(const int64_t limit) {
|
||||||
if (maximum_hard_limit_ < 0) {
|
if (maximum_hard_limit_ < 0) {
|
||||||
spdlog::warn("Invalid maximum hard limit.");
|
spdlog::warn("Invalid maximum hard limit.");
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "utils/exceptions.hpp"
|
#include "utils/exceptions.hpp"
|
||||||
|
|
||||||
@ -41,9 +42,20 @@ class MemoryTracker final {
|
|||||||
MemoryTracker() = default;
|
MemoryTracker() = default;
|
||||||
~MemoryTracker() = default;
|
~MemoryTracker() = default;
|
||||||
|
|
||||||
|
MemoryTracker(MemoryTracker &&other) noexcept
|
||||||
|
: amount_(other.amount_.load(std::memory_order_acquire)),
|
||||||
|
peak_(other.peak_.load(std::memory_order_acquire)),
|
||||||
|
hard_limit_(other.hard_limit_.load(std::memory_order_acquire)),
|
||||||
|
maximum_hard_limit_(other.maximum_hard_limit_) {
|
||||||
|
other.maximum_hard_limit_ = 0;
|
||||||
|
other.amount_.store(0, std::memory_order_acquire);
|
||||||
|
other.peak_.store(0, std::memory_order_acquire);
|
||||||
|
other.hard_limit_.store(0, std::memory_order_acquire);
|
||||||
|
}
|
||||||
|
|
||||||
MemoryTracker(const MemoryTracker &) = delete;
|
MemoryTracker(const MemoryTracker &) = delete;
|
||||||
MemoryTracker &operator=(const MemoryTracker &) = delete;
|
MemoryTracker &operator=(const MemoryTracker &) = delete;
|
||||||
MemoryTracker(MemoryTracker &&) = delete;
|
|
||||||
MemoryTracker &operator=(MemoryTracker &&) = delete;
|
MemoryTracker &operator=(MemoryTracker &&) = delete;
|
||||||
|
|
||||||
void Alloc(int64_t size);
|
void Alloc(int64_t size);
|
||||||
@ -59,6 +71,8 @@ class MemoryTracker final {
|
|||||||
void TryRaiseHardLimit(int64_t limit);
|
void TryRaiseHardLimit(int64_t limit);
|
||||||
void SetMaximumHardLimit(int64_t limit);
|
void SetMaximumHardLimit(int64_t limit);
|
||||||
|
|
||||||
|
void ResetTrackings();
|
||||||
|
|
||||||
// By creating an object of this class, every allocation in its scope that goes over
|
// By creating an object of this class, every allocation in its scope that goes over
|
||||||
// the set hard limit produces an OutOfMemoryException.
|
// the set hard limit produces an OutOfMemoryException.
|
||||||
class OutOfMemoryExceptionEnabler final {
|
class OutOfMemoryExceptionEnabler final {
|
||||||
|
@ -21,6 +21,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
|||||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
|
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
|
||||||
BUILD_DIR = os.path.join(PROJECT_DIR, "build")
|
BUILD_DIR = os.path.join(PROJECT_DIR, "build")
|
||||||
MEMGRAPH_BINARY = os.path.join(BUILD_DIR, "memgraph")
|
MEMGRAPH_BINARY = os.path.join(BUILD_DIR, "memgraph")
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(port, delay=0.01):
|
def wait_for_server(port, delay=0.01):
|
||||||
@ -133,7 +134,7 @@ class MemgraphInstanceRunner:
|
|||||||
|
|
||||||
pid = self.proc_mg.pid
|
pid = self.proc_mg.pid
|
||||||
try:
|
try:
|
||||||
os.kill(pid, 15) # 15 is the signal number for SIGTERM
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
except os.OSError:
|
except os.OSError:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
@ -11,10 +11,21 @@ target_link_libraries(memgraph__e2e__memory__limit_global_alloc gflags mgclient
|
|||||||
add_executable(memgraph__e2e__memory__limit_global_alloc_proc memory_limit_global_alloc_proc.cpp)
|
add_executable(memgraph__e2e__memory__limit_global_alloc_proc memory_limit_global_alloc_proc.cpp)
|
||||||
target_link_libraries(memgraph__e2e__memory__limit_global_alloc_proc gflags mgclient mg-utils mg-io Threads::Threads)
|
target_link_libraries(memgraph__e2e__memory__limit_global_alloc_proc gflags mgclient mg-utils mg-io Threads::Threads)
|
||||||
|
|
||||||
|
add_executable(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread query_memory_limit_proc_multi_thread.cpp)
|
||||||
|
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread gflags mgclient mg-utils mg-io Threads::Threads)
|
||||||
|
|
||||||
|
add_executable(memgraph__e2e__memory__limit_query_alloc_create query_memory_limit_create.cpp)
|
||||||
|
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create gflags mgclient mg-utils mg-io)
|
||||||
|
|
||||||
|
add_executable(memgraph__e2e__memory__limit_query_alloc_proc query_memory_limit_proc.cpp)
|
||||||
|
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc gflags mgclient mg-utils mg-io)
|
||||||
|
|
||||||
|
add_executable(memgraph__e2e__memory__limit_query_alloc_create_multi_thread query_memory_limit_multi_thread.cpp)
|
||||||
|
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create_multi_thread gflags mgclient mg-utils mg-io Threads::Threads)
|
||||||
|
|
||||||
add_executable(memgraph__e2e__memory__limit_delete memory_limit_delete.cpp)
|
add_executable(memgraph__e2e__memory__limit_delete memory_limit_delete.cpp)
|
||||||
target_link_libraries(memgraph__e2e__memory__limit_delete gflags mgclient mg-utils mg-io)
|
target_link_libraries(memgraph__e2e__memory__limit_delete gflags mgclient mg-utils mg-io)
|
||||||
|
|
||||||
|
|
||||||
add_executable(memgraph__e2e__memory__limit_accumulation memory_limit_accumulation.cpp)
|
add_executable(memgraph__e2e__memory__limit_accumulation memory_limit_accumulation.cpp)
|
||||||
target_link_libraries(memgraph__e2e__memory__limit_accumulation gflags mgclient mg-utils mg-io)
|
target_link_libraries(memgraph__e2e__memory__limit_accumulation gflags mgclient mg-utils mg-io)
|
||||||
|
|
||||||
|
@ -3,3 +3,13 @@ target_include_directories(global_memory_limit PRIVATE ${CMAKE_SOURCE_DIR}/inclu
|
|||||||
|
|
||||||
add_library(global_memory_limit_proc SHARED global_memory_limit_proc.c)
|
add_library(global_memory_limit_proc SHARED global_memory_limit_proc.c)
|
||||||
target_include_directories(global_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
target_include_directories(global_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
||||||
|
|
||||||
|
|
||||||
|
add_library(query_memory_limit_proc_multi_thread SHARED query_memory_limit_proc_multi_thread.cpp)
|
||||||
|
target_include_directories(query_memory_limit_proc_multi_thread PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
||||||
|
target_link_libraries(query_memory_limit_proc_multi_thread mg-utils)
|
||||||
|
|
||||||
|
|
||||||
|
add_library(query_memory_limit_proc SHARED query_memory_limit_proc.cpp)
|
||||||
|
target_include_directories(query_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
||||||
|
target_link_libraries(query_memory_limit_proc mg-utils)
|
||||||
|
70
tests/e2e/memory/procedures/query_memory_limit_proc.cpp
Normal file
70
tests/e2e/memory/procedures/query_memory_limit_proc.cpp
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
// Copyright 2023 Memgraph Ltd.
|
||||||
|
//
|
||||||
|
// Use of this software is governed by the Business Source License
|
||||||
|
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||||
|
// License, and you may not use this file except in compliance with the Business Source License.
|
||||||
|
//
|
||||||
|
// As of the Change Date specified in that file, in accordance with
|
||||||
|
// the Business Source License, use of this software will be governed
|
||||||
|
// by the Apache License, Version 2.0, included in the file
|
||||||
|
// licenses/APL.txt.
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <cassert>
|
||||||
|
#include <exception>
|
||||||
|
#include <functional>
|
||||||
|
#include <mgp.hpp>
|
||||||
|
#include <mutex>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mg_procedure.h"
|
||||||
|
#include "utils/on_scope_exit.hpp"
|
||||||
|
|
||||||
|
enum mgp_error Alloc(void *ptr) {
|
||||||
|
const size_t mb_size_268 = 1 << 28;
|
||||||
|
|
||||||
|
return mgp_global_alloc(mb_size_268, (void **)(&ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Regular(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
|
||||||
|
mgp::MemoryDispatcherGuard guard{memory};
|
||||||
|
const auto arguments = mgp::List(args);
|
||||||
|
const auto record_factory = mgp::RecordFactory(result);
|
||||||
|
|
||||||
|
try {
|
||||||
|
void *ptr{nullptr};
|
||||||
|
|
||||||
|
memgraph::utils::OnScopeExit<std::function<void(void)>> cleanup{[&ptr]() {
|
||||||
|
if (nullptr == ptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mgp_global_free(ptr);
|
||||||
|
}};
|
||||||
|
|
||||||
|
const enum mgp_error alloc_err = Alloc(ptr);
|
||||||
|
auto new_record = record_factory.NewRecord();
|
||||||
|
new_record.Insert("allocated", alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE);
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
record_factory.SetErrorMessage(e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
|
||||||
|
try {
|
||||||
|
mgp::MemoryDispatcherGuard mdg{memory};
|
||||||
|
|
||||||
|
AddProcedure(Regular, std::string("regular").c_str(), mgp::ProcedureType::Read, {},
|
||||||
|
{mgp::Return(std::string("allocated").c_str(), mgp::Type::Bool)}, module, memory);
|
||||||
|
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int mgp_shutdown_module() { return 0; }
|
@ -0,0 +1,101 @@
|
|||||||
|
// Copyright 2023 Memgraph Ltd.
|
||||||
|
//
|
||||||
|
// Use of this software is governed by the Business Source License
|
||||||
|
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||||
|
// License, and you may not use this file except in compliance with the Business Source License.
|
||||||
|
//
|
||||||
|
// As of the Change Date specified in that file, in accordance with
|
||||||
|
// the Business Source License, use of this software will be governed
|
||||||
|
// by the Apache License, Version 2.0, included in the file
|
||||||
|
// licenses/APL.txt.
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <cassert>
|
||||||
|
#include <exception>
|
||||||
|
#include <functional>
|
||||||
|
#include <mgp.hpp>
|
||||||
|
#include <mutex>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mg_procedure.h"
|
||||||
|
#include "utils/on_scope_exit.hpp"
|
||||||
|
|
||||||
|
enum mgp_error Alloc(void *ptr) {
|
||||||
|
const size_t mb_size_268 = 1 << 28;
|
||||||
|
|
||||||
|
return mgp_global_alloc(mb_size_268, (void **)(&ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
// change communication between threads with feature and promise
|
||||||
|
std::atomic<int> num_allocations{0};
|
||||||
|
std::vector<void *> ptrs_;
|
||||||
|
|
||||||
|
void AllocFunc(mgp_graph *graph) {
|
||||||
|
[[maybe_unused]] const enum mgp_error tracking_error = mgp_track_current_thread_allocations(graph);
|
||||||
|
void *ptr = nullptr;
|
||||||
|
|
||||||
|
ptrs_.emplace_back(ptr);
|
||||||
|
try {
|
||||||
|
enum mgp_error alloc_err { mgp_error::MGP_ERROR_NO_ERROR };
|
||||||
|
alloc_err = Alloc(ptr);
|
||||||
|
if (alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) {
|
||||||
|
num_allocations.fetch_add(1, std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
if (alloc_err != mgp_error::MGP_ERROR_NO_ERROR) {
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[maybe_unused]] const enum mgp_error untracking_error = mgp_untrack_current_thread_allocations(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DualThread(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
|
||||||
|
mgp::MemoryDispatcherGuard guard{memory};
|
||||||
|
const auto arguments = mgp::List(args);
|
||||||
|
const auto record_factory = mgp::RecordFactory(result);
|
||||||
|
num_allocations.store(0, std::memory_order_relaxed);
|
||||||
|
try {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
threads.emplace_back(AllocFunc, memgraph_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
threads[i].join();
|
||||||
|
}
|
||||||
|
for (void *ptr : ptrs_) {
|
||||||
|
if (ptr != nullptr) {
|
||||||
|
mgp_global_free(ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_record = record_factory.NewRecord();
|
||||||
|
|
||||||
|
new_record.Insert("allocated_all", num_allocations.load(std::memory_order_relaxed) == 2);
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
record_factory.SetErrorMessage(e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
|
||||||
|
try {
|
||||||
|
mgp::memory = memory;
|
||||||
|
|
||||||
|
AddProcedure(DualThread, std::string("dual_thread").c_str(), mgp::ProcedureType::Read, {},
|
||||||
|
{mgp::Return(std::string("allocated_all").c_str(), mgp::Type::Bool)}, module, memory);
|
||||||
|
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int mgp_shutdown_module() { return 0; }
|
65
tests/e2e/memory/query_memory_limit_create.cpp
Normal file
65
tests/e2e/memory/query_memory_limit_create.cpp
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
// Copyright 2023 Memgraph Ltd.
|
||||||
|
//
|
||||||
|
// Use of this software is governed by the Business Source License
|
||||||
|
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||||
|
// License, and you may not use this file except in compliance with the Business Source License.
|
||||||
|
//
|
||||||
|
// As of the Change Date specified in that file, in accordance with
|
||||||
|
// the Business Source License, use of this software will be governed
|
||||||
|
// by the Apache License, Version 2.0, included in the file
|
||||||
|
// licenses/APL.txt.
|
||||||
|
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mgclient.hpp>
|
||||||
|
|
||||||
|
#include "utils/logging.hpp"
|
||||||
|
#include "utils/timer.hpp"
|
||||||
|
|
||||||
|
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||||
|
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
google::SetUsageMessage("Memgraph E2E Memory Control");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
memgraph::logging::RedirectToStderr();
|
||||||
|
|
||||||
|
mg::Client::Init();
|
||||||
|
|
||||||
|
auto client =
|
||||||
|
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||||
|
if (!client) {
|
||||||
|
LOG_FATAL("Failed to connect!");
|
||||||
|
}
|
||||||
|
|
||||||
|
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||||
|
client->DiscardAll();
|
||||||
|
|
||||||
|
if (FLAGS_multi_db) {
|
||||||
|
client->Execute("CREATE DATABASE clean;");
|
||||||
|
client->DiscardAll();
|
||||||
|
client->Execute("USE DATABASE clean;");
|
||||||
|
client->DiscardAll();
|
||||||
|
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||||
|
client->DiscardAll();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto *create_query =
|
||||||
|
"UNWIND range(1, 50000) as u CREATE (n {string: 'Some longer string'}) RETURN n QUERY MEMORY LIMIT 30MB;";
|
||||||
|
|
||||||
|
try {
|
||||||
|
client->Execute(create_query);
|
||||||
|
[[maybe_unused]] auto results = client->FetchAll();
|
||||||
|
if (results->empty()) {
|
||||||
|
assert(true);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} catch (const mg::TransientException & /*unused*/) {
|
||||||
|
spdlog::info("Memgraph is out of memory");
|
||||||
|
assert(true);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
MG_ASSERT(false, "Query should have failed!");
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
100
tests/e2e/memory/query_memory_limit_multi_thread.cpp
Normal file
100
tests/e2e/memory/query_memory_limit_multi_thread.cpp
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
// Copyright 2023 Memgraph Ltd.
|
||||||
|
//
|
||||||
|
// Use of this software is governed by the Business Source License
|
||||||
|
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||||
|
// License, and you may not use this file except in compliance with the Business Source License.
|
||||||
|
//
|
||||||
|
// As of the Change Date specified in that file, in accordance with
|
||||||
|
// the Business Source License, use of this software will be governed
|
||||||
|
// by the Apache License, Version 2.0, included in the file
|
||||||
|
// licenses/APL.txt.
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include <future>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mgclient.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "utils/logging.hpp"
|
||||||
|
#include "utils/timer.hpp"
|
||||||
|
|
||||||
|
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||||
|
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||||
|
|
||||||
|
static constexpr int kNumberClients{2};
|
||||||
|
|
||||||
|
void Func(std::promise<bool> promise) {
|
||||||
|
auto client =
|
||||||
|
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||||
|
if (!client) {
|
||||||
|
LOG_FATAL("Failed to connect!");
|
||||||
|
}
|
||||||
|
const auto *create_query =
|
||||||
|
"FOREACH(i in range(1, 300000) | CREATE (n: Node {string: 'Some longer string'})) QUERY MEMORY LIMIT 100MB;";
|
||||||
|
bool err{false};
|
||||||
|
try {
|
||||||
|
client->Execute(create_query);
|
||||||
|
[[maybe_unused]] auto results = client->FetchAll();
|
||||||
|
if (results->empty()) {
|
||||||
|
err = true;
|
||||||
|
}
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
spdlog::info("Good: Exception occured", e.what());
|
||||||
|
err = true;
|
||||||
|
}
|
||||||
|
promise.set_value_at_thread_exit(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
google::SetUsageMessage("Memgraph E2E Memory Control");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
memgraph::logging::RedirectToStderr();
|
||||||
|
|
||||||
|
mg::Client::Init();
|
||||||
|
|
||||||
|
{
|
||||||
|
auto client =
|
||||||
|
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||||
|
if (!client) {
|
||||||
|
LOG_FATAL("Failed to connect!");
|
||||||
|
}
|
||||||
|
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||||
|
client->DiscardAll();
|
||||||
|
|
||||||
|
if (FLAGS_multi_db) {
|
||||||
|
client->Execute("CREATE DATABASE clean;");
|
||||||
|
client->DiscardAll();
|
||||||
|
client->Execute("USE DATABASE clean;");
|
||||||
|
client->DiscardAll();
|
||||||
|
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||||
|
client->DiscardAll();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::promise<bool>> my_promises;
|
||||||
|
std::vector<std::future<bool>> my_futures;
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
my_promises.push_back(std::promise<bool>());
|
||||||
|
my_futures.emplace_back(my_promises.back().get_future());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::thread> my_threads;
|
||||||
|
|
||||||
|
for (int i = 0; i < kNumberClients; i++) {
|
||||||
|
my_threads.emplace_back(Func, std::move(my_promises[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < kNumberClients; i++) {
|
||||||
|
auto value = my_futures[i].get();
|
||||||
|
MG_ASSERT(value, "Error should have happend in thread");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < kNumberClients; i++) {
|
||||||
|
my_threads[i].join();
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
66
tests/e2e/memory/query_memory_limit_proc.cpp
Normal file
66
tests/e2e/memory/query_memory_limit_proc.cpp
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Copyright 2023 Memgraph Ltd.
|
||||||
|
//
|
||||||
|
// Use of this software is governed by the Business Source License
|
||||||
|
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||||
|
// License, and you may not use this file except in compliance with the Business Source License.
|
||||||
|
//
|
||||||
|
// As of the Change Date specified in that file, in accordance with
|
||||||
|
// the Business Source License, use of this software will be governed
|
||||||
|
// by the Apache License, Version 2.0, included in the file
|
||||||
|
// licenses/APL.txt.
|
||||||
|
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <exception>
|
||||||
|
#include <ios>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mgclient.hpp>
|
||||||
|
|
||||||
|
#include "utils/logging.hpp"
|
||||||
|
#include "utils/timer.hpp"
|
||||||
|
|
||||||
|
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||||
|
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||||
|
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
memgraph::logging::RedirectToStderr();
|
||||||
|
|
||||||
|
mg::Client::Init();
|
||||||
|
|
||||||
|
auto client =
|
||||||
|
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||||
|
if (!client) {
|
||||||
|
LOG_FATAL("Failed to connect!");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FLAGS_multi_db) {
|
||||||
|
client->Execute("CREATE DATABASE clean;");
|
||||||
|
client->DiscardAll();
|
||||||
|
client->Execute("USE DATABASE clean;");
|
||||||
|
client->DiscardAll();
|
||||||
|
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||||
|
client->DiscardAll();
|
||||||
|
}
|
||||||
|
|
||||||
|
MG_ASSERT(
|
||||||
|
client->Execute("CALL libquery_memory_limit_proc.regular() YIELD allocated RETURN "
|
||||||
|
"allocated QUERY MEMORY LIMIT 250MB"));
|
||||||
|
bool error{false};
|
||||||
|
try {
|
||||||
|
auto result_rows = client->FetchAll();
|
||||||
|
if (result_rows) {
|
||||||
|
auto row = *result_rows->begin();
|
||||||
|
error = row[0].ValueBool() == false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
error = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
MG_ASSERT(error, "Error should have happend");
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
66
tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp
Normal file
66
tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Copyright 2023 Memgraph Ltd.
|
||||||
|
//
|
||||||
|
// Use of this software is governed by the Business Source License
|
||||||
|
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||||
|
// License, and you may not use this file except in compliance with the Business Source License.
|
||||||
|
//
|
||||||
|
// As of the Change Date specified in that file, in accordance with
|
||||||
|
// the Business Source License, use of this software will be governed
|
||||||
|
// by the Apache License, Version 2.0, included in the file
|
||||||
|
// licenses/APL.txt.
|
||||||
|
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <exception>
|
||||||
|
#include <ios>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mgclient.hpp>
|
||||||
|
|
||||||
|
#include "utils/logging.hpp"
|
||||||
|
#include "utils/timer.hpp"
|
||||||
|
|
||||||
|
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||||
|
DEFINE_uint64(timeout, 120, "Timeout seconds");
|
||||||
|
DEFINE_bool(multi_db, false, "Run test in multi db environment");
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
memgraph::logging::RedirectToStderr();
|
||||||
|
|
||||||
|
mg::Client::Init();
|
||||||
|
|
||||||
|
auto client =
|
||||||
|
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||||
|
if (!client) {
|
||||||
|
LOG_FATAL("Failed to connect!");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FLAGS_multi_db) {
|
||||||
|
client->Execute("CREATE DATABASE clean;");
|
||||||
|
client->DiscardAll();
|
||||||
|
client->Execute("USE DATABASE clean;");
|
||||||
|
client->DiscardAll();
|
||||||
|
client->Execute("MATCH (n) DETACH DELETE n;");
|
||||||
|
client->DiscardAll();
|
||||||
|
}
|
||||||
|
|
||||||
|
MG_ASSERT(
|
||||||
|
client->Execute("CALL libquery_memory_limit_proc_multi_thread.dual_thread() YIELD allocated_all RETURN "
|
||||||
|
"allocated_all QUERY MEMORY LIMIT 500MB"));
|
||||||
|
bool error{false};
|
||||||
|
try {
|
||||||
|
auto result_rows = client->FetchAll();
|
||||||
|
if (result_rows) {
|
||||||
|
auto row = *result_rows->begin();
|
||||||
|
error = row[0].ValueBool() == false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
error = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
MG_ASSERT(error, "Error should have happend");
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
@ -23,6 +23,20 @@ disk_cluster: &disk_cluster
|
|||||||
- "STORAGE MODE ON_DISK_TRANSACTIONAL"
|
- "STORAGE MODE ON_DISK_TRANSACTIONAL"
|
||||||
validation_queries: []
|
validation_queries: []
|
||||||
|
|
||||||
|
args_query_limit: &args_query_limit
|
||||||
|
- "--bolt-port"
|
||||||
|
- *bolt_port
|
||||||
|
- "--storage-gc-cycle-sec=180"
|
||||||
|
- "--log-level=TRACE"
|
||||||
|
|
||||||
|
in_memory_query_limit_cluster: &in_memory_query_limit_cluster
|
||||||
|
cluster:
|
||||||
|
main:
|
||||||
|
args: *args_query_limit
|
||||||
|
log_file: "memory-e2e.log"
|
||||||
|
setup_queries: []
|
||||||
|
validation_queries: []
|
||||||
|
|
||||||
args_450_MiB_limit: &args_450_MiB_limit
|
args_450_MiB_limit: &args_450_MiB_limit
|
||||||
- "--bolt-port"
|
- "--bolt-port"
|
||||||
- *bolt_port
|
- *bolt_port
|
||||||
@ -95,6 +109,27 @@ workloads:
|
|||||||
proc: "tests/e2e/memory/procedures/"
|
proc: "tests/e2e/memory/procedures/"
|
||||||
<<: *disk_cluster
|
<<: *disk_cluster
|
||||||
|
|
||||||
|
- name: "Memory control query limit proc"
|
||||||
|
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc"
|
||||||
|
proc: "tests/e2e/memory/procedures/"
|
||||||
|
args: ["--bolt-port", *bolt_port]
|
||||||
|
<<: *in_memory_query_limit_cluster
|
||||||
|
|
||||||
|
- name: "Memory control query limit proc multi thread"
|
||||||
|
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc_multi_thread"
|
||||||
|
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
|
||||||
|
proc: "tests/e2e/memory/procedures/"
|
||||||
|
<<: *in_memory_query_limit_cluster
|
||||||
|
|
||||||
|
- name: "Memory control query limit create"
|
||||||
|
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create"
|
||||||
|
args: ["--bolt-port", *bolt_port]
|
||||||
|
<<: *in_memory_query_limit_cluster
|
||||||
|
|
||||||
|
- name: "Memory control query limit create multi thread"
|
||||||
|
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create_multi_thread"
|
||||||
|
args: ["--bolt-port", *bolt_port]
|
||||||
|
<<: *in_memory_query_limit_cluster
|
||||||
- name: "Memory control for detach delete"
|
- name: "Memory control for detach delete"
|
||||||
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete"
|
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete"
|
||||||
args: ["--bolt-port", *bolt_port]
|
args: ["--bolt-port", *bolt_port]
|
||||||
|
@ -24,6 +24,7 @@ import time
|
|||||||
DEFAULT_DB = "memgraph"
|
DEFAULT_DB = "memgraph"
|
||||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
QUERIES = [
|
QUERIES = [
|
||||||
("MATCH (n) DELETE n", {}),
|
("MATCH (n) DELETE n", {}),
|
||||||
@ -92,9 +93,13 @@ def execute_test(memgraph_binary, tester_binary):
|
|||||||
# Register cleanup function
|
# Register cleanup function
|
||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup():
|
def cleanup():
|
||||||
if memgraph.poll() is None:
|
pid = memgraph.pid
|
||||||
memgraph.terminate()
|
try:
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
def execute_queries(queries):
|
def execute_queries(queries):
|
||||||
for db, query, params in queries:
|
for db, query, params in queries:
|
||||||
@ -122,10 +127,12 @@ def execute_test(memgraph_binary, tester_binary):
|
|||||||
execute_queries(mt_queries3)
|
execute_queries(mt_queries3)
|
||||||
print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n")
|
print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n")
|
||||||
|
|
||||||
# Shutdown the memgraph binary
|
pid = memgraph.pid
|
||||||
memgraph.terminate()
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
# Verify the written log
|
# Verify the written log
|
||||||
print("\033[1;36m~~ Starting log verification ~~\033[0m")
|
print("\033[1;36m~~ Starting log verification ~~\033[0m")
|
||||||
|
@ -21,6 +21,7 @@ import time
|
|||||||
|
|
||||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
# When you create a new permission just add a testcase to this list (a tuple
|
# When you create a new permission just add a testcase to this list (a tuple
|
||||||
# of query, touple of required permissions) and the test will automatically
|
# of query, touple of required permissions) and the test will automatically
|
||||||
@ -166,8 +167,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
|
|||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup():
|
def cleanup():
|
||||||
if memgraph.poll() is None:
|
if memgraph.poll() is None:
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
# Prepare the multi database environment
|
# Prepare the multi database environment
|
||||||
execute_admin_queries(
|
execute_admin_queries(
|
||||||
@ -327,8 +332,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
|
|||||||
print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n")
|
print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n")
|
||||||
|
|
||||||
# Shutdown the memgraph binary
|
# Shutdown the memgraph binary
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -20,7 +20,6 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||||
TESTS_DIR = os.path.join(SCRIPT_DIR, "tests")
|
TESTS_DIR = os.path.join(SCRIPT_DIR, "tests")
|
||||||
@ -31,6 +30,8 @@ WAL_FILE_NAME = "wal.bin"
|
|||||||
DUMP_SNAPSHOT_FILE_NAME = "expected_snapshot.cypher"
|
DUMP_SNAPSHOT_FILE_NAME = "expected_snapshot.cypher"
|
||||||
DUMP_WAL_FILE_NAME = "expected_wal.cypher"
|
DUMP_WAL_FILE_NAME = "expected_wal.cypher"
|
||||||
|
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(port, delay=0.1):
|
def wait_for_server(port, delay=0.1):
|
||||||
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
||||||
@ -40,7 +41,7 @@ def wait_for_server(port, delay=0.1):
|
|||||||
|
|
||||||
|
|
||||||
def sorted_content(file_path):
|
def sorted_content(file_path):
|
||||||
with open(file_path, 'r') as fin:
|
with open(file_path, "r") as fin:
|
||||||
return sorted(list(map(lambda x: x.strip(), fin.readlines())))
|
return sorted(list(map(lambda x: x.strip(), fin.readlines())))
|
||||||
|
|
||||||
|
|
||||||
@ -52,32 +53,27 @@ def list_to_string(data):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def execute_test(
|
def execute_test(memgraph_binary, dump_binary, test_directory, test_type, write_expected):
|
||||||
memgraph_binary,
|
assert test_type in ["SNAPSHOT", "WAL"], "Test type should be either 'SNAPSHOT' or 'WAL'."
|
||||||
dump_binary,
|
print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m".format(os.path.relpath(test_directory, TESTS_DIR), test_type))
|
||||||
test_directory,
|
|
||||||
test_type,
|
|
||||||
write_expected):
|
|
||||||
assert test_type in ["SNAPSHOT", "WAL"], \
|
|
||||||
"Test type should be either 'SNAPSHOT' or 'WAL'."
|
|
||||||
print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m"
|
|
||||||
.format(os.path.relpath(test_directory, TESTS_DIR), test_type))
|
|
||||||
|
|
||||||
working_data_directory = tempfile.TemporaryDirectory()
|
working_data_directory = tempfile.TemporaryDirectory()
|
||||||
if test_type == "SNAPSHOT":
|
if test_type == "SNAPSHOT":
|
||||||
snapshots_dir = os.path.join(working_data_directory.name, "snapshots")
|
snapshots_dir = os.path.join(working_data_directory.name, "snapshots")
|
||||||
os.makedirs(snapshots_dir)
|
os.makedirs(snapshots_dir)
|
||||||
shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME),
|
shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), snapshots_dir)
|
||||||
snapshots_dir)
|
|
||||||
else:
|
else:
|
||||||
wal_dir = os.path.join(working_data_directory.name, "wal")
|
wal_dir = os.path.join(working_data_directory.name, "wal")
|
||||||
os.makedirs(wal_dir)
|
os.makedirs(wal_dir)
|
||||||
shutil.copy(os.path.join(test_directory, WAL_FILE_NAME), wal_dir)
|
shutil.copy(os.path.join(test_directory, WAL_FILE_NAME), wal_dir)
|
||||||
|
|
||||||
memgraph_args = [memgraph_binary,
|
memgraph_args = [
|
||||||
|
memgraph_binary,
|
||||||
"--storage-recover-on-startup",
|
"--storage-recover-on-startup",
|
||||||
"--storage-properties-on-edges",
|
"--storage-properties-on-edges",
|
||||||
"--data-directory", working_data_directory.name]
|
"--data-directory",
|
||||||
|
working_data_directory.name,
|
||||||
|
]
|
||||||
|
|
||||||
# Start the memgraph binary
|
# Start the memgraph binary
|
||||||
memgraph = subprocess.Popen(memgraph_args)
|
memgraph = subprocess.Popen(memgraph_args)
|
||||||
@ -89,8 +85,12 @@ def execute_test(
|
|||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup():
|
def cleanup():
|
||||||
if memgraph.poll() is None:
|
if memgraph.poll() is None:
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
# Execute `database dump`
|
# Execute `database dump`
|
||||||
dump_output_file = tempfile.NamedTemporaryFile()
|
dump_output_file = tempfile.NamedTemporaryFile()
|
||||||
@ -98,28 +98,31 @@ def execute_test(
|
|||||||
subprocess.run(dump_args, stdout=dump_output_file, check=True)
|
subprocess.run(dump_args, stdout=dump_output_file, check=True)
|
||||||
|
|
||||||
# Shutdown the memgraph binary
|
# Shutdown the memgraph binary
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
dump_file_name = DUMP_SNAPSHOT_FILE_NAME if test_type == "SNAPSHOT" else DUMP_WAL_FILE_NAME
|
dump_file_name = DUMP_SNAPSHOT_FILE_NAME if test_type == "SNAPSHOT" else DUMP_WAL_FILE_NAME
|
||||||
|
|
||||||
if write_expected:
|
if write_expected:
|
||||||
with open(dump_output_file.name, 'r') as dump:
|
with open(dump_output_file.name, "r") as dump:
|
||||||
queries_got = dump.readlines()
|
queries_got = dump.readlines()
|
||||||
# Write dump files
|
# Write dump files
|
||||||
expected_dump_file = os.path.join(test_directory, dump_file_name)
|
expected_dump_file = os.path.join(test_directory, dump_file_name)
|
||||||
with open(expected_dump_file, 'w') as expected:
|
with open(expected_dump_file, "w") as expected:
|
||||||
expected.writelines(queries_got)
|
expected.writelines(queries_got)
|
||||||
else:
|
else:
|
||||||
# Compare dump files
|
# Compare dump files
|
||||||
expected_dump_file = os.path.join(test_directory, dump_file_name)
|
expected_dump_file = os.path.join(test_directory, dump_file_name)
|
||||||
assert os.path.exists(expected_dump_file), \
|
assert os.path.exists(expected_dump_file), "Could not find expected dump path {}".format(expected_dump_file)
|
||||||
"Could not find expected dump path {}".format(expected_dump_file)
|
|
||||||
queries_got = sorted_content(dump_output_file.name)
|
queries_got = sorted_content(dump_output_file.name)
|
||||||
queries_expected = sorted_content(expected_dump_file)
|
queries_expected = sorted_content(expected_dump_file)
|
||||||
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \
|
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format(
|
||||||
"{}".format(list_to_string(queries_got),
|
list_to_string(queries_got), list_to_string(queries_expected)
|
||||||
list_to_string(queries_expected))
|
)
|
||||||
|
|
||||||
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
||||||
|
|
||||||
@ -141,15 +144,17 @@ def find_test_directories(directory):
|
|||||||
continue
|
continue
|
||||||
snapshot_file = os.path.join(test_dir_path, SNAPSHOT_FILE_NAME)
|
snapshot_file = os.path.join(test_dir_path, SNAPSHOT_FILE_NAME)
|
||||||
wal_file = os.path.join(test_dir_path, WAL_FILE_NAME)
|
wal_file = os.path.join(test_dir_path, WAL_FILE_NAME)
|
||||||
dump_snapshot_file = os.path.join(
|
dump_snapshot_file = os.path.join(test_dir_path, DUMP_SNAPSHOT_FILE_NAME)
|
||||||
test_dir_path, DUMP_SNAPSHOT_FILE_NAME)
|
|
||||||
dump_wal_file = os.path.join(test_dir_path, DUMP_WAL_FILE_NAME)
|
dump_wal_file = os.path.join(test_dir_path, DUMP_WAL_FILE_NAME)
|
||||||
if (os.path.isfile(snapshot_file) and os.path.isfile(dump_snapshot_file)
|
if (
|
||||||
and os.path.isfile(wal_file) and os.path.isfile(dump_wal_file)):
|
os.path.isfile(snapshot_file)
|
||||||
|
and os.path.isfile(dump_snapshot_file)
|
||||||
|
and os.path.isfile(wal_file)
|
||||||
|
and os.path.isfile(dump_wal_file)
|
||||||
|
):
|
||||||
test_dirs.append(test_dir_path)
|
test_dirs.append(test_dir_path)
|
||||||
else:
|
else:
|
||||||
raise Exception("Missing data in test directory '{}'"
|
raise Exception("Missing data in test directory '{}'".format(test_dir_path))
|
||||||
.format(test_dir_path))
|
|
||||||
return test_dirs
|
return test_dirs
|
||||||
|
|
||||||
|
|
||||||
@ -161,26 +166,15 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||||
parser.add_argument("--dump", default=dump_binary)
|
parser.add_argument("--dump", default=dump_binary)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--write-expected',
|
"--write-expected", action="store_true", help="Overwrite the expected cypher with results from current run"
|
||||||
action='store_true',
|
)
|
||||||
help='Overwrite the expected cypher with results from current run')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
test_directories = find_test_directories(TESTS_DIR)
|
test_directories = find_test_directories(TESTS_DIR)
|
||||||
assert len(test_directories) > 0, "No tests have been found!"
|
assert len(test_directories) > 0, "No tests have been found!"
|
||||||
|
|
||||||
for test_directory in test_directories:
|
for test_directory in test_directories:
|
||||||
execute_test(
|
execute_test(args.memgraph, args.dump, test_directory, "SNAPSHOT", args.write_expected)
|
||||||
args.memgraph,
|
execute_test(args.memgraph, args.dump, test_directory, "WAL", args.write_expected)
|
||||||
args.dump,
|
|
||||||
test_directory,
|
|
||||||
"SNAPSHOT",
|
|
||||||
args.write_expected)
|
|
||||||
execute_test(
|
|
||||||
args.memgraph,
|
|
||||||
args.dump,
|
|
||||||
test_directory,
|
|
||||||
"WAL",
|
|
||||||
args.write_expected)
|
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -22,6 +22,7 @@ from typing import List
|
|||||||
|
|
||||||
SCRIPT_DIR = Path(__file__).absolute()
|
SCRIPT_DIR = Path(__file__).absolute()
|
||||||
PROJECT_DIR = SCRIPT_DIR.parents[3]
|
PROJECT_DIR = SCRIPT_DIR.parents[3]
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(port, delay=0.1):
|
def wait_for_server(port, delay=0.1):
|
||||||
@ -68,8 +69,12 @@ def execute_with_user(queries):
|
|||||||
|
|
||||||
def cleanup(memgraph):
|
def cleanup(memgraph):
|
||||||
if memgraph.poll() is None:
|
if memgraph.poll() is None:
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True):
|
def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True):
|
||||||
|
@ -24,6 +24,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
|||||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||||
|
|
||||||
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
|
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(port, delay=0.1):
|
def wait_for_server(port, delay=0.1):
|
||||||
@ -80,8 +81,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
|
|||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup():
|
def cleanup():
|
||||||
if memgraph.poll() is None:
|
if memgraph.poll() is None:
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
# Prepare all users
|
# Prepare all users
|
||||||
def setup_user():
|
def setup_user():
|
||||||
@ -130,8 +135,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
|
|||||||
print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n")
|
print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n")
|
||||||
|
|
||||||
# Shutdown the memgraph binary
|
# Shutdown the memgraph binary
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -21,6 +21,7 @@ from typing import List
|
|||||||
|
|
||||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(port: int, delay: float = 0.1) -> float:
|
def wait_for_server(port: int, delay: float = 0.1) -> float:
|
||||||
@ -86,8 +87,12 @@ def execute_without_user(
|
|||||||
|
|
||||||
def cleanup(memgraph: subprocess):
|
def cleanup(memgraph: subprocess):
|
||||||
if memgraph.poll() is None:
|
if memgraph.poll() is None:
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
def test_without_any_files(tester_binary: str, memgraph_args: List[str]):
|
def test_without_any_files(tester_binary: str, memgraph_args: List[str]):
|
||||||
|
@ -23,6 +23,8 @@ import time
|
|||||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||||
|
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
CONFIG_TEMPLATE = """
|
CONFIG_TEMPLATE = """
|
||||||
server:
|
server:
|
||||||
host: "127.0.0.1"
|
host: "127.0.0.1"
|
||||||
@ -52,8 +54,7 @@ def wait_for_server(port, delay=0.1):
|
|||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
def execute_tester(binary, queries, username="", password="",
|
def execute_tester(binary, queries, username="", password="", auth_should_fail=False, query_should_fail=False):
|
||||||
auth_should_fail=False, query_should_fail=False):
|
|
||||||
if password == "":
|
if password == "":
|
||||||
password = username
|
password = username
|
||||||
args = [binary, "--username", username, "--password", password]
|
args = [binary, "--username", username, "--password", password]
|
||||||
@ -76,18 +77,14 @@ class Memgraph:
|
|||||||
def start(self, **kwargs):
|
def start(self, **kwargs):
|
||||||
self.stop()
|
self.stop()
|
||||||
self._storage_directory = tempfile.TemporaryDirectory()
|
self._storage_directory = tempfile.TemporaryDirectory()
|
||||||
self._auth_module = os.path.join(self._storage_directory.name,
|
self._auth_module = os.path.join(self._storage_directory.name, "ldap.py")
|
||||||
"ldap.py")
|
self._auth_config = os.path.join(self._storage_directory.name, "ldap.yaml")
|
||||||
self._auth_config = os.path.join(self._storage_directory.name,
|
script_file = os.path.join(PROJECT_DIR, "src", "auth", "reference_modules", "ldap.py")
|
||||||
"ldap.yaml")
|
|
||||||
script_file = os.path.join(PROJECT_DIR, "src", "auth",
|
|
||||||
"reference_modules", "ldap.py")
|
|
||||||
virtualenv_bin = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3")
|
virtualenv_bin = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3")
|
||||||
with open(script_file) as fin:
|
with open(script_file) as fin:
|
||||||
data = fin.read()
|
data = fin.read()
|
||||||
data = data.replace("/usr/bin/python3", virtualenv_bin)
|
data = data.replace("/usr/bin/python3", virtualenv_bin)
|
||||||
data = data.replace("/etc/memgraph/auth/ldap.yaml",
|
data = data.replace("/etc/memgraph/auth/ldap.yaml", self._auth_config)
|
||||||
self._auth_config)
|
|
||||||
with open(self._auth_module, "w") as fout:
|
with open(self._auth_module, "w") as fout:
|
||||||
fout.write(data)
|
fout.write(data)
|
||||||
os.chmod(self._auth_module, stat.S_IRWXU | stat.S_IRWXG)
|
os.chmod(self._auth_module, stat.S_IRWXU | stat.S_IRWXG)
|
||||||
@ -106,10 +103,13 @@ class Memgraph:
|
|||||||
}
|
}
|
||||||
with open(self._auth_config, "w") as f:
|
with open(self._auth_config, "w") as f:
|
||||||
f.write(CONFIG_TEMPLATE.format(**config))
|
f.write(CONFIG_TEMPLATE.format(**config))
|
||||||
args = [self._binary,
|
args = [
|
||||||
"--data-directory", self._storage_directory.name,
|
self._binary,
|
||||||
|
"--data-directory",
|
||||||
|
self._storage_directory.name,
|
||||||
"--auth-module-executable",
|
"--auth-module-executable",
|
||||||
kwargs.pop("module_executable", self._auth_module)]
|
kwargs.pop("module_executable", self._auth_module),
|
||||||
|
]
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
ldap_key = "--auth-module-" + key.replace("_", "-")
|
ldap_key = "--auth-module-" + key.replace("_", "-")
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
@ -119,26 +119,27 @@ class Memgraph:
|
|||||||
args.append(value)
|
args.append(value)
|
||||||
self._process = subprocess.Popen(args)
|
self._process = subprocess.Popen(args)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
assert self._process.poll() is None, "Memgraph process died " \
|
assert self._process.poll() is None, "Memgraph process died " "prematurely!"
|
||||||
"prematurely!"
|
|
||||||
wait_for_server(7687)
|
wait_for_server(7687)
|
||||||
|
|
||||||
def stop(self, check=True):
|
def stop(self, check=True):
|
||||||
if self._process is None:
|
if self._process is None:
|
||||||
return 0
|
return 0
|
||||||
self._process.terminate()
|
pid = self._process.pid
|
||||||
exitcode = self._process.wait()
|
try:
|
||||||
self._process = None
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
if check:
|
if check:
|
||||||
assert exitcode == 0, "Memgraph process didn't exit cleanly!"
|
assert False
|
||||||
return exitcode
|
return -1
|
||||||
|
time.sleep(1)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def initialize_test(memgraph, tester_binary, **kwargs):
|
def initialize_test(memgraph, tester_binary, **kwargs):
|
||||||
memgraph.start(module_executable="")
|
memgraph.start(module_executable="")
|
||||||
|
|
||||||
execute_tester(tester_binary,
|
execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
|
||||||
["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
|
|
||||||
check_login = kwargs.pop("check_login", True)
|
check_login = kwargs.pop("check_login", True)
|
||||||
memgraph.restart(**kwargs)
|
memgraph.restart(**kwargs)
|
||||||
if check_login:
|
if check_login:
|
||||||
@ -170,18 +171,15 @@ def test_role_mapping(memgraph, tester_binary):
|
|||||||
initialize_test(memgraph, tester_binary)
|
initialize_test(memgraph, tester_binary)
|
||||||
|
|
||||||
execute_tester(tester_binary, [], "alice")
|
execute_tester(tester_binary, [], "alice")
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||||
query_should_fail=True)
|
|
||||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
|
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
|
|
||||||
execute_tester(tester_binary, [], "bob")
|
execute_tester(tester_binary, [], "bob")
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob",
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True)
|
||||||
query_should_fail=True)
|
|
||||||
|
|
||||||
execute_tester(tester_binary, [], "carol")
|
execute_tester(tester_binary, [], "carol")
|
||||||
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol",
|
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True)
|
||||||
query_should_fail=True)
|
|
||||||
execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root")
|
execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root")
|
||||||
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol")
|
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol")
|
||||||
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave")
|
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave")
|
||||||
@ -192,15 +190,13 @@ def test_role_mapping(memgraph, tester_binary):
|
|||||||
def test_role_removal(memgraph, tester_binary):
|
def test_role_removal(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary)
|
initialize_test(memgraph, tester_binary)
|
||||||
execute_tester(tester_binary, [], "alice")
|
execute_tester(tester_binary, [], "alice")
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||||
query_should_fail=True)
|
|
||||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
|
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
memgraph.restart(manage_roles=False)
|
memgraph.restart(manage_roles=False)
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root")
|
execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root")
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||||
query_should_fail=True)
|
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
|
|
||||||
@ -229,28 +225,22 @@ def test_user_is_role(memgraph, tester_binary):
|
|||||||
|
|
||||||
def test_user_permissions_persistancy(memgraph, tester_binary):
|
def test_user_permissions_persistancy(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary)
|
initialize_test(memgraph, tester_binary)
|
||||||
execute_tester(tester_binary,
|
execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
|
||||||
["CREATE USER alice", "GRANT MATCH TO alice"], "root")
|
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
|
|
||||||
def test_role_permissions_persistancy(memgraph, tester_binary):
|
def test_role_permissions_persistancy(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary)
|
initialize_test(memgraph, tester_binary)
|
||||||
execute_tester(tester_binary,
|
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
|
||||||
"root")
|
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
|
|
||||||
def test_only_authentication(memgraph, tester_binary):
|
def test_only_authentication(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary, manage_roles=False)
|
initialize_test(memgraph, tester_binary, manage_roles=False)
|
||||||
execute_tester(tester_binary,
|
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||||
"root")
|
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
|
||||||
query_should_fail=True)
|
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
|
|
||||||
@ -267,22 +257,16 @@ def test_wrong_suffix(memgraph, tester_binary):
|
|||||||
|
|
||||||
|
|
||||||
def test_suffix_with_spaces(memgraph, tester_binary):
|
def test_suffix_with_spaces(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary,
|
initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com")
|
||||||
suffix=", ou= people, dc = memgraph, dc = com")
|
execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
|
||||||
execute_tester(tester_binary,
|
|
||||||
["CREATE USER alice", "GRANT MATCH TO alice"], "root")
|
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
|
|
||||||
def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
|
def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary,
|
initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com")
|
||||||
root_dn="ou=invalid,dc=memgraph,dc=com")
|
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||||
execute_tester(tester_binary,
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
|
||||||
"root")
|
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
|
||||||
query_should_fail=True)
|
|
||||||
memgraph.restart()
|
memgraph.restart()
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
@ -290,11 +274,8 @@ def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
|
|||||||
|
|
||||||
def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
|
def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary, root_objectclass="person")
|
initialize_test(memgraph, tester_binary, root_objectclass="person")
|
||||||
execute_tester(tester_binary,
|
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||||
"root")
|
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
|
||||||
query_should_fail=True)
|
|
||||||
memgraph.restart()
|
memgraph.restart()
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
@ -302,11 +283,8 @@ def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
|
|||||||
|
|
||||||
def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
|
def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary, user_attribute="cn")
|
initialize_test(memgraph, tester_binary, user_attribute="cn")
|
||||||
execute_tester(tester_binary,
|
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
|
||||||
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
|
||||||
"root")
|
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
|
|
||||||
query_should_fail=True)
|
|
||||||
memgraph.restart()
|
memgraph.restart()
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
@ -314,8 +292,7 @@ def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
|
|||||||
|
|
||||||
def test_wrong_password(memgraph, tester_binary):
|
def test_wrong_password(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary)
|
initialize_test(memgraph, tester_binary)
|
||||||
execute_tester(tester_binary, [], "root", password="sudo",
|
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
|
||||||
auth_should_fail=True)
|
|
||||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
@ -326,12 +303,10 @@ def test_password_persistancy(memgraph, tester_binary):
|
|||||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo")
|
execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo")
|
||||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
||||||
memgraph.restart()
|
memgraph.restart()
|
||||||
execute_tester(tester_binary, [], "root", password="sudo",
|
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
|
||||||
auth_should_fail=True)
|
|
||||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
||||||
memgraph.restart(module_executable="")
|
memgraph.restart(module_executable="")
|
||||||
execute_tester(tester_binary, [], "root", password="sudo",
|
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
|
||||||
auth_should_fail=True)
|
|
||||||
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
@ -339,33 +314,25 @@ def test_password_persistancy(memgraph, tester_binary):
|
|||||||
def test_user_multiple_roles(memgraph, tester_binary):
|
def test_user_multiple_roles(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary, check_login=False)
|
initialize_test(memgraph, tester_binary, check_login=False)
|
||||||
memgraph.restart()
|
memgraph.restart()
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
|
||||||
query_should_fail=True)
|
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
|
||||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
|
|
||||||
query_should_fail=True)
|
|
||||||
memgraph.restart(manage_roles=False)
|
memgraph.restart(manage_roles=False)
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
|
||||||
query_should_fail=True)
|
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
|
||||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
|
|
||||||
query_should_fail=True)
|
|
||||||
memgraph.restart(manage_roles=False, root_dn="")
|
memgraph.restart(manage_roles=False, root_dn="")
|
||||||
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
|
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
|
||||||
query_should_fail=True)
|
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
|
||||||
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
|
|
||||||
query_should_fail=True)
|
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
|
|
||||||
def test_starttls_failure(memgraph, tester_binary):
|
def test_starttls_failure(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary, encryption="starttls",
|
initialize_test(memgraph, tester_binary, encryption="starttls", check_login=False)
|
||||||
check_login=False)
|
|
||||||
execute_tester(tester_binary, [], "root", auth_should_fail=True)
|
execute_tester(tester_binary, [], "root", auth_should_fail=True)
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
|
|
||||||
def test_ssl_failure(memgraph, tester_binary):
|
def test_ssl_failure(memgraph, tester_binary):
|
||||||
initialize_test(memgraph, tester_binary, encryption="ssl",
|
initialize_test(memgraph, tester_binary, encryption="ssl", check_login=False)
|
||||||
check_login=False)
|
|
||||||
execute_tester(tester_binary, [], "root", auth_should_fail=True)
|
execute_tester(tester_binary, [], "root", auth_should_fail=True)
|
||||||
memgraph.stop()
|
memgraph.stop()
|
||||||
|
|
||||||
@ -375,22 +342,19 @@ def test_ssl_failure(memgraph, tester_binary):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
|
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
|
||||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
|
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "ldap", "tester")
|
||||||
"integration", "ldap", "tester")
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||||
parser.add_argument("--tester", default=tester_binary)
|
parser.add_argument("--tester", default=tester_binary)
|
||||||
parser.add_argument("--openldap-dir",
|
parser.add_argument("--openldap-dir", default=os.path.join(SCRIPT_DIR, "openldap-2.4.47"))
|
||||||
default=os.path.join(SCRIPT_DIR, "openldap-2.4.47"))
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Setup Memgraph handler
|
# Setup Memgraph handler
|
||||||
memgraph = Memgraph(args.memgraph)
|
memgraph = Memgraph(args.memgraph)
|
||||||
|
|
||||||
# Start the slapd binary
|
# Start the slapd binary
|
||||||
slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"),
|
slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), "-h", "ldap://127.0.0.1:1389/", "-d", "0"]
|
||||||
"-h", "ldap://127.0.0.1:1389/", "-d", "0"]
|
|
||||||
slapd = subprocess.Popen(slapd_args)
|
slapd = subprocess.Popen(slapd_args)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
assert slapd.poll() is None, "slapd process died prematurely!"
|
assert slapd.poll() is None, "slapd process died prematurely!"
|
||||||
@ -409,8 +373,7 @@ if __name__ == "__main__":
|
|||||||
if slapd_stat != 0:
|
if slapd_stat != 0:
|
||||||
print("slapd process didn't exit cleanly!")
|
print("slapd process didn't exit cleanly!")
|
||||||
|
|
||||||
assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " \
|
assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " "(memgraph, slapd) crashed!"
|
||||||
"(memgraph, slapd) crashed!"
|
|
||||||
|
|
||||||
# Execute tests
|
# Execute tests
|
||||||
names = sorted(globals().keys())
|
names = sorted(globals().keys())
|
||||||
|
@ -18,12 +18,13 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||||
BUILD_DIR = os.path.join(BASE_DIR, "build")
|
BUILD_DIR = os.path.join(BASE_DIR, "build")
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(port, delay=0.1):
|
def wait_for_server(port, delay=0.1):
|
||||||
@ -46,17 +47,14 @@ def list_to_string(data):
|
|||||||
|
|
||||||
|
|
||||||
def verify_lifetime(memgraph_binary, mg_import_csv_binary):
|
def verify_lifetime(memgraph_binary, mg_import_csv_binary):
|
||||||
print("\033[1;36m~~ Verifying that mg_import_csv can't be started while "
|
print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " "memgraph is running ~~\033[0m")
|
||||||
"memgraph is running ~~\033[0m")
|
|
||||||
storage_directory = tempfile.TemporaryDirectory()
|
storage_directory = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
# Generate common args
|
# Generate common args
|
||||||
common_args = ["--data-directory", storage_directory.name,
|
common_args = ["--data-directory", storage_directory.name, "--storage-properties-on-edges=false"]
|
||||||
"--storage-properties-on-edges=false"]
|
|
||||||
|
|
||||||
# Start the memgraph binary
|
# Start the memgraph binary
|
||||||
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \
|
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args
|
||||||
common_args
|
|
||||||
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
assert memgraph.poll() is None, "Memgraph process died prematurely!"
|
assert memgraph.poll() is None, "Memgraph process died prematurely!"
|
||||||
@ -66,47 +64,52 @@ def verify_lifetime(memgraph_binary, mg_import_csv_binary):
|
|||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup():
|
def cleanup():
|
||||||
if memgraph.poll() is None:
|
if memgraph.poll() is None:
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
# Execute mg_import_csv.
|
# Execute mg_import_csv.
|
||||||
mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + \
|
mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + common_args
|
||||||
common_args
|
|
||||||
ret = subprocess.run(mg_import_csv_args)
|
ret = subprocess.run(mg_import_csv_args)
|
||||||
|
|
||||||
# Check the return code
|
# Check the return code
|
||||||
if ret.returncode == 0:
|
if ret.returncode == 0:
|
||||||
raise Exception(
|
raise Exception("The importer was able to run while memgraph was running!")
|
||||||
"The importer was able to run while memgraph was running!")
|
|
||||||
|
|
||||||
# Shutdown the memgraph binary
|
# Shutdown the memgraph binary
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
||||||
|
|
||||||
|
|
||||||
def execute_test(name, test_path, test_config, memgraph_binary,
|
def execute_test(name, test_path, test_config, memgraph_binary, mg_import_csv_binary, tester_binary, write_expected):
|
||||||
mg_import_csv_binary, tester_binary, write_expected):
|
|
||||||
print("\033[1;36m~~ Executing test", name, "~~\033[0m")
|
print("\033[1;36m~~ Executing test", name, "~~\033[0m")
|
||||||
storage_directory = tempfile.TemporaryDirectory()
|
storage_directory = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
# Verify test configuration
|
# Verify test configuration
|
||||||
if ("import_should_fail" not in test_config and
|
if ("import_should_fail" not in test_config and "expected" not in test_config) or (
|
||||||
"expected" not in test_config) or \
|
"import_should_fail" in test_config and "expected" in test_config
|
||||||
("import_should_fail" in test_config and
|
):
|
||||||
"expected" in test_config):
|
raise Exception("The test should specify either 'import_should_fail' " "or 'expected'!")
|
||||||
raise Exception("The test should specify either 'import_should_fail' "
|
|
||||||
"or 'expected'!")
|
|
||||||
|
|
||||||
expected_path = test_config.pop("expected", "")
|
expected_path = test_config.pop("expected", "")
|
||||||
import_should_fail = test_config.pop("import_should_fail", False)
|
import_should_fail = test_config.pop("import_should_fail", False)
|
||||||
|
|
||||||
# Generate common args
|
# Generate common args
|
||||||
properties_on_edges = bool(test_config.pop("properties_on_edges", False))
|
properties_on_edges = bool(test_config.pop("properties_on_edges", False))
|
||||||
common_args = ["--data-directory", storage_directory.name,
|
common_args = [
|
||||||
"--storage-properties-on-edges=" +
|
"--data-directory",
|
||||||
str(properties_on_edges).lower()]
|
storage_directory.name,
|
||||||
|
"--storage-properties-on-edges=" + str(properties_on_edges).lower(),
|
||||||
|
]
|
||||||
|
|
||||||
# Generate mg_import_csv args using flags specified in the test
|
# Generate mg_import_csv args using flags specified in the test
|
||||||
mg_import_csv_args = [mg_import_csv_binary] + common_args
|
mg_import_csv_args = [mg_import_csv_binary] + common_args
|
||||||
@ -125,19 +128,16 @@ def execute_test(name, test_path, test_config, memgraph_binary,
|
|||||||
|
|
||||||
if import_should_fail:
|
if import_should_fail:
|
||||||
if ret.returncode == 0:
|
if ret.returncode == 0:
|
||||||
raise Exception("The import should have failed, but it "
|
raise Exception("The import should have failed, but it " "succeeded instead!")
|
||||||
"succeeded instead!")
|
|
||||||
else:
|
else:
|
||||||
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
if ret.returncode != 0:
|
if ret.returncode != 0:
|
||||||
raise Exception("The import should have succeeded, but it "
|
raise Exception("The import should have succeeded, but it " "failed instead!")
|
||||||
"failed instead!")
|
|
||||||
|
|
||||||
# Start the memgraph binary
|
# Start the memgraph binary
|
||||||
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \
|
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args
|
||||||
common_args
|
|
||||||
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
assert memgraph.poll() is None, "Memgraph process died prematurely!"
|
assert memgraph.poll() is None, "Memgraph process died prematurely!"
|
||||||
@ -147,21 +147,29 @@ def execute_test(name, test_path, test_config, memgraph_binary,
|
|||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup():
|
def cleanup():
|
||||||
if memgraph.poll() is None:
|
if memgraph.poll() is None:
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
# Get the contents of the database
|
# Get the contents of the database
|
||||||
queries_got = extract_rows(subprocess.run(
|
queries_got = extract_rows(
|
||||||
[tester_binary], stdout=subprocess.PIPE,
|
subprocess.run([tester_binary], stdout=subprocess.PIPE, check=True).stdout.decode("utf-8")
|
||||||
check=True).stdout.decode("utf-8"))
|
)
|
||||||
|
|
||||||
# Shutdown the memgraph binary
|
# Shutdown the memgraph binary
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
if write_expected:
|
if write_expected:
|
||||||
with open(os.path.join(test_path, expected_path), 'w') as expected:
|
with open(os.path.join(test_path, expected_path), "w") as expected:
|
||||||
expected.write('\n'.join(queries_got))
|
expected.write("\n".join(queries_got))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if expected_path:
|
if expected_path:
|
||||||
@ -173,18 +181,16 @@ def execute_test(name, test_path, test_config, memgraph_binary,
|
|||||||
# Verify the queries
|
# Verify the queries
|
||||||
queries_expected.sort()
|
queries_expected.sort()
|
||||||
queries_got.sort()
|
queries_got.sort()
|
||||||
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \
|
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format(
|
||||||
"{}".format(list_to_string(queries_got),
|
list_to_string(queries_got), list_to_string(queries_expected)
|
||||||
list_to_string(queries_expected))
|
)
|
||||||
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
print("\033[1;32m~~ Test successful ~~\033[0m\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
memgraph_binary = os.path.join(BUILD_DIR, "memgraph")
|
memgraph_binary = os.path.join(BUILD_DIR, "memgraph")
|
||||||
mg_import_csv_binary = os.path.join(
|
mg_import_csv_binary = os.path.join(BUILD_DIR, "src", "mg_import_csv")
|
||||||
BUILD_DIR, "src", "mg_import_csv")
|
tester_binary = os.path.join(BUILD_DIR, "tests", "integration", "mg_import_csv", "tester")
|
||||||
tester_binary = os.path.join(
|
|
||||||
BUILD_DIR, "tests", "integration", "mg_import_csv", "tester")
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||||
@ -193,7 +199,8 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--write-expected",
|
"--write-expected",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Overwrite the expected values with the results of the current run")
|
help="Overwrite the expected values with the results of the current run",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# First test whether the CSV importer can be started while the main
|
# First test whether the CSV importer can be started while the main
|
||||||
@ -211,7 +218,8 @@ if __name__ == "__main__":
|
|||||||
testcases = yaml.safe_load(f)
|
testcases = yaml.safe_load(f)
|
||||||
for test_config in testcases:
|
for test_config in testcases:
|
||||||
test_name = name + "/" + test_config.pop("name")
|
test_name = name + "/" + test_config.pop("name")
|
||||||
execute_test(test_name, test_path, test_config, args.memgraph,
|
execute_test(
|
||||||
args.mg_import_csv, args.tester, args.write_expected)
|
test_name, test_path, test_config, args.memgraph, args.mg_import_csv, args.tester, args.write_expected
|
||||||
|
)
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -23,6 +23,7 @@ from typing import List
|
|||||||
|
|
||||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(port: int, delay: float = 0.1) -> float:
|
def wait_for_server(port: int, delay: float = 0.1) -> float:
|
||||||
@ -91,8 +92,12 @@ def check_config(tester_binary: str, flag: str, value: str) -> None:
|
|||||||
|
|
||||||
def cleanup(memgraph: subprocess):
|
def cleanup(memgraph: subprocess):
|
||||||
if memgraph.poll() is None:
|
if memgraph.poll() is None:
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
def run_test(tester_binary: str, memgraph_args: List[str], server_name: str, query_tx: str):
|
def run_test(tester_binary: str, memgraph_args: List[str], server_name: str, query_tx: str):
|
||||||
|
@ -22,6 +22,8 @@ assertion_queries = [
|
|||||||
f"MATCH (n)-[e]->(m) WITH count(e) as cnt RETURN assert(cnt={len(edge_queries)});",
|
f"MATCH (n)-[e]->(m) WITH count(e) as cnt RETURN assert(cnt={len(edge_queries)});",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SIGNAL_SIGTERM = 15
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(port, delay=0.1):
|
def wait_for_server(port, delay=0.1):
|
||||||
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
||||||
@ -40,9 +42,12 @@ def prepare_memgraph(memgraph_args):
|
|||||||
|
|
||||||
|
|
||||||
def terminate_memgraph(memgraph):
|
def terminate_memgraph(memgraph):
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
time.sleep(0.1)
|
try:
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
def execute_tester(
|
def execute_tester(
|
||||||
@ -90,8 +95,12 @@ def execute_test_analytical_mode(memgraph_binary: str, tester_binary: str) -> No
|
|||||||
|
|
||||||
execute_queries(assertion_queries)
|
execute_queries(assertion_queries)
|
||||||
|
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_binary: str) -> None:
|
def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_binary: str) -> None:
|
||||||
@ -135,8 +144,12 @@ def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_bi
|
|||||||
execute_queries(assertion_queries)
|
execute_queries(assertion_queries)
|
||||||
|
|
||||||
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
|
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_binary: str) -> None:
|
def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_binary: str) -> None:
|
||||||
@ -177,8 +190,12 @@ def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_bi
|
|||||||
execute_queries(assertion_queries)
|
execute_queries(assertion_queries)
|
||||||
|
|
||||||
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
|
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
|
||||||
memgraph.terminate()
|
pid = memgraph.pid
|
||||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
try:
|
||||||
|
os.kill(pid, SIGNAL_SIGTERM)
|
||||||
|
except os.OSError:
|
||||||
|
assert False, "Memgraph process didn't exit cleanly!"
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -11,12 +11,13 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
from common import get_absolute_path, set_cpus
|
from common import get_absolute_path, set_cpus
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -36,13 +37,12 @@ class Memgraph:
|
|||||||
"""
|
"""
|
||||||
Knows how to start and stop memgraph.
|
Knows how to start and stop memgraph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args, num_workers):
|
def __init__(self, args, num_workers):
|
||||||
self.log = logging.getLogger("MemgraphRunner")
|
self.log = logging.getLogger("MemgraphRunner")
|
||||||
argp = ArgumentParser("MemgraphArgumentParser")
|
argp = ArgumentParser("MemgraphArgumentParser")
|
||||||
argp.add_argument("--runner-bin",
|
argp.add_argument("--runner-bin", default=get_absolute_path("memgraph", "build"))
|
||||||
default=get_absolute_path("memgraph", "build"))
|
argp.add_argument("--port", default="7687", help="Database and client port")
|
||||||
argp.add_argument("--port", default="7687",
|
|
||||||
help="Database and client port")
|
|
||||||
argp.add_argument("--data-directory", default=None)
|
argp.add_argument("--data-directory", default=None)
|
||||||
argp.add_argument("--storage-snapshot-on-exit", action="store_true")
|
argp.add_argument("--storage-snapshot-on-exit", action="store_true")
|
||||||
argp.add_argument("--storage-recover-on-startup", action="store_true")
|
argp.add_argument("--storage-recover-on-startup", action="store_true")
|
||||||
@ -55,8 +55,7 @@ class Memgraph:
|
|||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self.log.info("start")
|
self.log.info("start")
|
||||||
database_args = ["--bolt-port", self.args.port,
|
database_args = ["--bolt-port", self.args.port, "--query-execution-timeout-sec", "0"]
|
||||||
"--query-execution-timeout-sec", "0"]
|
|
||||||
if self.num_workers:
|
if self.num_workers:
|
||||||
database_args += ["--bolt-num-workers", str(self.num_workers)]
|
database_args += ["--bolt-num-workers", str(self.num_workers)]
|
||||||
if self.args.data_directory:
|
if self.args.data_directory:
|
||||||
@ -82,15 +81,13 @@ class Neo:
|
|||||||
"""
|
"""
|
||||||
Knows how to start and stop neo4j.
|
Knows how to start and stop neo4j.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args, config):
|
def __init__(self, args, config):
|
||||||
self.log = logging.getLogger("NeoRunner")
|
self.log = logging.getLogger("NeoRunner")
|
||||||
argp = ArgumentParser("NeoArgumentParser")
|
argp = ArgumentParser("NeoArgumentParser")
|
||||||
argp.add_argument("--runner-bin", default=get_absolute_path(
|
argp.add_argument("--runner-bin", default=get_absolute_path("neo4j/bin/neo4j", "libs"))
|
||||||
"neo4j/bin/neo4j", "libs"))
|
argp.add_argument("--port", default="7687", help="Database and client port")
|
||||||
argp.add_argument("--port", default="7687",
|
argp.add_argument("--http-port", default="7474", help="Database and client port")
|
||||||
help="Database and client port")
|
|
||||||
argp.add_argument("--http-port", default="7474",
|
|
||||||
help="Database and client port")
|
|
||||||
self.log.info("Initializing Runner with arguments %r", args)
|
self.log.info("Initializing Runner with arguments %r", args)
|
||||||
self.args, _ = argp.parse_known_args(args)
|
self.args, _ = argp.parse_known_args(args)
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -105,24 +102,22 @@ class Neo:
|
|||||||
self.neo4j_home_path = tempfile.mkdtemp(dir="/dev/shm")
|
self.neo4j_home_path = tempfile.mkdtemp(dir="/dev/shm")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.symlink(os.path.join(get_absolute_path("neo4j", "libs"), "lib"),
|
os.symlink(
|
||||||
os.path.join(self.neo4j_home_path, "lib"))
|
os.path.join(get_absolute_path("neo4j", "libs"), "lib"), os.path.join(self.neo4j_home_path, "lib")
|
||||||
|
)
|
||||||
neo4j_conf_dir = os.path.join(self.neo4j_home_path, "conf")
|
neo4j_conf_dir = os.path.join(self.neo4j_home_path, "conf")
|
||||||
neo4j_conf_file = os.path.join(neo4j_conf_dir, "neo4j.conf")
|
neo4j_conf_file = os.path.join(neo4j_conf_dir, "neo4j.conf")
|
||||||
os.mkdir(neo4j_conf_dir)
|
os.mkdir(neo4j_conf_dir)
|
||||||
shutil.copyfile(self.config, neo4j_conf_file)
|
shutil.copyfile(self.config, neo4j_conf_file)
|
||||||
with open(neo4j_conf_file, "a") as f:
|
with open(neo4j_conf_file, "a") as f:
|
||||||
f.write("\ndbms.connector.bolt.listen_address=:" +
|
f.write("\ndbms.connector.bolt.listen_address=:" + self.args.port + "\n")
|
||||||
self.args.port + "\n")
|
f.write("\ndbms.connector.http.listen_address=:" + self.args.http_port + "\n")
|
||||||
f.write("\ndbms.connector.http.listen_address=:" +
|
|
||||||
self.args.http_port + "\n")
|
|
||||||
|
|
||||||
# environment
|
# environment
|
||||||
cwd = os.path.dirname(self.args.runner_bin)
|
cwd = os.path.dirname(self.args.runner_bin)
|
||||||
env = {"NEO4J_HOME": self.neo4j_home_path}
|
env = {"NEO4J_HOME": self.neo4j_home_path}
|
||||||
|
|
||||||
self.database_bin.run(self.args.runner_bin, args=["console"],
|
self.database_bin.run(self.args.runner_bin, args=["console"], env=env, timeout=600, cwd=cwd)
|
||||||
env=env, timeout=600, cwd=cwd)
|
|
||||||
except:
|
except:
|
||||||
shutil.rmtree(self.neo4j_home_path)
|
shutil.rmtree(self.neo4j_home_path)
|
||||||
raise Exception("Couldn't run Neo4j!")
|
raise Exception("Couldn't run Neo4j!")
|
||||||
|
Loading…
Reference in New Issue
Block a user