Use extent hooks for per query memory limit (#1340)

This commit is contained in:
Antonio Filipovic 2023-10-25 16:01:59 +02:00 committed by GitHub
parent 3d4d841753
commit a84f570c6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1275 additions and 267 deletions

View File

@ -111,6 +111,22 @@ enum mgp_error mgp_global_aligned_alloc(size_t size_in_bytes, size_t alignment,
/// The behavior is undefined if `ptr` is not a value returned from a prior
/// mgp_global_alloc() or mgp_global_aligned_alloc().
void mgp_global_free(void *p);
/// State of the graph database.
struct mgp_graph;
/// Allocations are tracked only for master thread. If new threads are spawned
/// inside procedure, by calling following function
/// you can start tracking allocations for current thread too. This
/// is important if you need query memory limit to work
/// for given procedure or per procedure memory limit.
enum mgp_error mgp_track_current_thread_allocations(struct mgp_graph *graph);
/// Once allocations are tracked for current thread, you need to stop tracking allocations
/// for given thread, before thread finishes with execution, or is detached.
/// Otherwise it might result in slowdown of system due to unnecessary tracking of
/// allocations.
enum mgp_error mgp_untrack_current_thread_allocations(struct mgp_graph *graph);
///@}
/// @name Operations on mgp_value
@ -854,9 +870,6 @@ enum mgp_error mgp_edge_set_properties(struct mgp_edge *e, struct mgp_map *prope
enum mgp_error mgp_edge_iter_properties(struct mgp_edge *e, struct mgp_memory *memory,
struct mgp_properties_iterator **result);
/// State of the graph database.
struct mgp_graph;
/// Get the vertex corresponding to given ID, or NULL if no such vertex exists.
/// Resulting vertex must be freed using mgp_vertex_destroy.
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate the vertex.

View File

@ -23,7 +23,7 @@
#include "glue/run_id.hpp"
#include "helpers.hpp"
#include "license/license_sender.hpp"
#include "memory/memory_control.hpp"
#include "memory/global_memory_control.hpp"
#include "query/config.hpp"
#include "query/discard_value_stream.hpp"
#include "query/interpreter.hpp"
@ -512,6 +512,7 @@ int main(int argc, char **argv) {
server.AwaitShutdown();
websocket_server.AwaitShutdown();
memgraph::memory::UnsetHooks();
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
metrics_server.AwaitShutdown();

View File

@ -1,6 +1,7 @@
set(memory_src_files
new_delete.cpp
memory_control.cpp)
global_memory_control.cpp
query_memory_control.cpp)

View File

@ -9,12 +9,16 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "memory_control.hpp"
#include <atomic>
#include <cstdint>
#include "global_memory_control.hpp"
#include "query_memory_control.hpp"
#include "utils/logging.hpp"
#include "utils/memory_tracker.hpp"
#if USE_JEMALLOC
#include <jemalloc/jemalloc.h>
#include "jemalloc/jemalloc.h"
#endif
namespace memgraph::memory {
@ -57,12 +61,24 @@ void *my_alloc(extent_hooks_t *extent_hooks, void *new_addr, size_t size, size_t
// This needs to be before, to throw exception in case of too big alloc
if (*commit) [[likely]] {
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(size));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Alloc(static_cast<int64_t>(size));
}
}
}
auto *ptr = old_hooks->alloc(extent_hooks, new_addr, size, alignment, zero, commit, arena_ind);
if (ptr == nullptr) [[unlikely]] {
if (*commit) {
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Free(static_cast<int64_t>(size));
}
}
}
return ptr;
}
@ -79,6 +95,13 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo
if (committed) [[likely]] {
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Free(static_cast<int64_t>(size));
}
}
}
return false;
@ -87,6 +110,12 @@ static bool my_dalloc(extent_hooks_t *extent_hooks, void *addr, size_t size, boo
static void my_destroy(extent_hooks_t *extent_hooks, void *addr, size_t size, bool committed, unsigned arena_ind) {
if (committed) [[likely]] {
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(size));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Free(static_cast<int64_t>(size));
}
}
}
old_hooks->destroy(extent_hooks, addr, size, committed, arena_ind);
@ -101,6 +130,12 @@ static bool my_commit(extent_hooks_t *extent_hooks, void *addr, size_t size, siz
}
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(length));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Alloc(static_cast<int64_t>(size));
}
}
return false;
}
@ -115,6 +150,12 @@ static bool my_decommit(extent_hooks_t *extent_hooks, void *addr, size_t size, s
}
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Free(static_cast<int64_t>(size));
}
}
return false;
}
@ -129,6 +170,13 @@ static bool my_purge_forced(extent_hooks_t *extent_hooks, void *addr, size_t siz
}
memgraph::utils::total_memory_tracker.Free(static_cast<int64_t>(length));
if (GetQueriesMemoryControl().IsArenaTracked(arena_ind)) [[unlikely]] {
auto *memory_tracker = GetQueriesMemoryControl().GetTrackerCurrentThread();
if (memory_tracker != nullptr) [[likely]] {
memory_tracker->Alloc(static_cast<int64_t>(size));
}
}
return false;
}
@ -153,6 +201,7 @@ void SetHooks() {
}
for (int i = 0; i < n_arenas; i++) {
GetQueriesMemoryControl().InitializeArenaCounter(i);
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
size_t hooks_len = sizeof(old_hooks);
@ -197,6 +246,45 @@ void SetHooks() {
#endif
}
void UnsetHooks() {
#if USE_JEMALLOC
uint64_t allocated{0};
uint64_t sz{sizeof(allocated)};
sz = sizeof(unsigned);
unsigned n_arenas{0};
int err = mallctl("opt.narenas", (void *)&n_arenas, &sz, nullptr, 0);
if (err) {
LOG_FATAL("Error setting default hooks for jemalloc arenas");
}
for (int i = 0; i < n_arenas; i++) {
GetQueriesMemoryControl().InitializeArenaCounter(i);
std::string func_name = "arena." + std::to_string(i) + ".extent_hooks";
MG_ASSERT(old_hooks);
MG_ASSERT(old_hooks->alloc);
MG_ASSERT(old_hooks->dalloc);
MG_ASSERT(old_hooks->destroy);
MG_ASSERT(old_hooks->commit);
MG_ASSERT(old_hooks->decommit);
MG_ASSERT(old_hooks->purge_forced);
MG_ASSERT(old_hooks->purge_lazy);
MG_ASSERT(old_hooks->split);
MG_ASSERT(old_hooks->merge);
err = mallctl(func_name.c_str(), nullptr, nullptr, &old_hooks, sizeof(old_hooks));
if (err) {
LOG_FATAL("Error setting default hooks for jemalloc arena {}", i);
}
}
#endif
}
void PurgeUnusedMemory() {
#if USE_JEMALLOC
mallctl("arena." STRINGIFY(MALLCTL_ARENAS_ALL) ".purge", nullptr, nullptr, nullptr, 0);

View File

@ -17,5 +17,6 @@ namespace memgraph::memory {
void PurgeUnusedMemory();
void SetHooks();
void UnsetHooks();
} // namespace memgraph::memory

View File

@ -0,0 +1,140 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <cstdint>
#include <iostream>
#include <optional>
#include <shared_mutex>
#include <thread>
#include <tuple>
#include <utility>
#include "query_memory_control.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/memory_tracker.hpp"
#include "utils/rw_spin_lock.hpp"
#if USE_JEMALLOC
#include "jemalloc/jemalloc.h"
#endif
namespace memgraph::memory {
#if USE_JEMALLOC
unsigned QueriesMemoryControl::GetArenaForThread() {
unsigned thread_arena{0};
size_t size_thread_arena = sizeof(thread_arena);
int err = mallctl("thread.arena", &thread_arena, &size_thread_arena, nullptr, 0);
if (err) {
LOG_FATAL("Can't get arena for thread.");
}
return thread_arena;
}
void QueriesMemoryControl::AddTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_add(1); }
void QueriesMemoryControl::RemoveTrackingOnArena(unsigned arena_id) { arena_tracking[arena_id].fetch_sub(1); }
void QueriesMemoryControl::UpdateThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) {
auto accessor = thread_id_to_transaction_id.access();
accessor.insert({thread_id, transaction_id});
}
void QueriesMemoryControl::EraseThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) {
auto accessor = thread_id_to_transaction_id.access();
auto elem = accessor.find(thread_id);
MG_ASSERT(elem != accessor.end() && elem->transaction_id == transaction_id);
accessor.remove(thread_id);
}
utils::MemoryTracker *QueriesMemoryControl::GetTrackerCurrentThread() {
auto thread_id_to_transaction_id_accessor = thread_id_to_transaction_id.access();
// we might be just constructing mapping between thread id and transaction id
// so we miss this allocation
auto thread_id_to_transaction_id_elem = thread_id_to_transaction_id_accessor.find(std::this_thread::get_id());
if (thread_id_to_transaction_id_elem == thread_id_to_transaction_id_accessor.end()) {
return nullptr;
}
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
auto transaction_id_to_tracker =
transaction_id_to_tracker_accessor.find(thread_id_to_transaction_id_elem->transaction_id);
return &transaction_id_to_tracker->tracker;
}
void QueriesMemoryControl::CreateTransactionIdTracker(uint64_t transaction_id, size_t inital_limit) {
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
auto [elem, result] = transaction_id_to_tracker_accessor.insert({transaction_id, utils::MemoryTracker{}});
elem->tracker.SetMaximumHardLimit(inital_limit);
elem->tracker.SetHardLimit(inital_limit);
}
bool QueriesMemoryControl::CheckTransactionIdTrackerExists(uint64_t transaction_id) {
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
return transaction_id_to_tracker_accessor.contains(transaction_id);
}
bool QueriesMemoryControl::EraseTransactionIdTracker(uint64_t transaction_id) {
auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access();
auto removed = transaction_id_to_tracker.access().remove(transaction_id);
return removed;
}
bool QueriesMemoryControl::IsArenaTracked(unsigned arena_ind) {
return arena_tracking[arena_ind].load(std::memory_order_acquire) != 0;
}
void QueriesMemoryControl::InitializeArenaCounter(unsigned arena_ind) {
arena_tracking[arena_ind].store(0, std::memory_order_relaxed);
}
#endif
void StartTrackingCurrentThreadTransaction(uint64_t transaction_id) {
#if USE_JEMALLOC
GetQueriesMemoryControl().UpdateThreadToTransactionId(std::this_thread::get_id(), transaction_id);
GetQueriesMemoryControl().AddTrackingOnArena(QueriesMemoryControl::GetArenaForThread());
#endif
}
void StopTrackingCurrentThreadTransaction(uint64_t transaction_id) {
#if USE_JEMALLOC
GetQueriesMemoryControl().EraseThreadToTransactionId(std::this_thread::get_id(), transaction_id);
GetQueriesMemoryControl().RemoveTrackingOnArena(QueriesMemoryControl::GetArenaForThread());
#endif
}
void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit) {
#if USE_JEMALLOC
if (GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) {
return;
}
GetQueriesMemoryControl().CreateTransactionIdTracker(transaction_id, limit);
#endif
}
void TryStopTrackingOnTransaction(uint64_t transaction_id) {
#if USE_JEMALLOC
if (!GetQueriesMemoryControl().CheckTransactionIdTrackerExists(transaction_id)) {
return;
}
GetQueriesMemoryControl().EraseTransactionIdTracker(transaction_id);
#endif
}
} // namespace memgraph::memory

View File

@ -0,0 +1,141 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstddef>
#include <cstdint>
#include <thread>
#include <unordered_map>
#include "utils/memory_tracker.hpp"
#include "utils/skip_list.hpp"
namespace memgraph::memory {
#if USE_JEMALLOC
// Track memory allocations per query.
// Multiple threads can allocate inside one transaction.
// If user forgets to unregister tracking for that thread before it dies, it will continue to
// track allocations for that arena indefinitely.
// As multiple queries can be executed inside one transaction, one by one (multi-transaction)
// it is necessary to restart tracking at the beginning of new query for that transaction.
class QueriesMemoryControl {
public:
/*
Arena stats
*/
static unsigned GetArenaForThread();
// Add counter on threads allocating inside arena
void AddTrackingOnArena(unsigned);
// Remove counter on threads allocating in arena
void RemoveTrackingOnArena(unsigned);
// Are any threads using current arena for allocations
// Multiple threads can allocate inside one arena
bool IsArenaTracked(unsigned);
// Initialize arena counter
void InitializeArenaCounter(unsigned);
/*
Transaction id <-> tracker
*/
// Create new tracker for transaction_id with initial limit
void CreateTransactionIdTracker(uint64_t, size_t);
// Check if tracker for given transaction id exists
bool CheckTransactionIdTrackerExists(uint64_t);
// Remove current tracker for transaction_id
bool EraseTransactionIdTracker(uint64_t);
/*
Thread handlings
*/
// Map thread to transaction with given id
// This way we can know which thread belongs to which transaction
// and get correct tracker for given transaction
void UpdateThreadToTransactionId(const std::thread::id &, uint64_t);
// Remove tracking of thread from transaction.
// Important to reset if one thread gets reused for different transaction
void EraseThreadToTransactionId(const std::thread::id &, uint64_t);
// C-API functionality for thread to transaction mapping
void UpdateThreadToTransactionId(const char *, uint64_t);
// C-API functionality for thread to transaction unmapping
void EraseThreadToTransactionId(const char *, uint64_t);
// Get tracker to current thread if exists, otherwise return
// nullptr. This can happen only if tracker is still
// being constructed.
utils::MemoryTracker *GetTrackerCurrentThread();
private:
std::unordered_map<unsigned, std::atomic<int>> arena_tracking;
struct ThreadIdToTransactionId {
std::thread::id thread_id;
uint64_t transaction_id;
bool operator<(const ThreadIdToTransactionId &other) const { return thread_id < other.thread_id; }
bool operator==(const ThreadIdToTransactionId &other) const { return thread_id == other.thread_id; }
bool operator<(const std::thread::id other) const { return thread_id < other; }
bool operator==(const std::thread::id other) const { return thread_id == other; }
};
struct TransactionIdToTracker {
uint64_t transaction_id;
utils::MemoryTracker tracker;
bool operator<(const TransactionIdToTracker &other) const { return transaction_id < other.transaction_id; }
bool operator==(const TransactionIdToTracker &other) const { return transaction_id == other.transaction_id; }
bool operator<(uint64_t other) const { return transaction_id < other; }
bool operator==(uint64_t other) const { return transaction_id == other; }
};
utils::SkipList<ThreadIdToTransactionId> thread_id_to_transaction_id;
utils::SkipList<TransactionIdToTracker> transaction_id_to_tracker;
};
inline QueriesMemoryControl &GetQueriesMemoryControl() {
static QueriesMemoryControl queries_memory_control_;
return queries_memory_control_;
}
#endif
// API function call for to start tracking current thread for given transaction.
// Does nothing if jemalloc is not enabled
void StartTrackingCurrentThreadTransaction(uint64_t transaction_id);
// API function call for to stop tracking current thread for given transaction.
// Does nothing if jemalloc is not enabled
void StopTrackingCurrentThreadTransaction(uint64_t transaction_id);
// API function call for try to create tracker for transaction and set it to given limit.
// Does nothing if jemalloc is not enabled. Does nothing if tracker already exists
void TryStartTrackingOnTransaction(uint64_t transaction_id, size_t limit);
// API function call to stop tracking for given transaction.
// Does nothing if jemalloc is not enabled. Does nothing if tracker doesn't exist
void TryStopTrackingOnTransaction(uint64_t transaction_id);
} // namespace memgraph::memory

View File

@ -21,6 +21,10 @@ namespace memgraph::query {
SubgraphDbAccessor::SubgraphDbAccessor(query::DbAccessor db_accessor, Graph *graph)
: db_accessor_(db_accessor), graph_(graph) {}
void SubgraphDbAccessor::TrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); }
void SubgraphDbAccessor::UntrackCurrentThreadAllocations() { return db_accessor_.TrackCurrentThreadAllocations(); }
storage::PropertyId SubgraphDbAccessor::NameToProperty(const std::string_view name) {
return db_accessor_.NameToProperty(name);
}

View File

@ -17,6 +17,7 @@
#include <cppitertools/filter.hpp>
#include <cppitertools/imap.hpp>
#include "memory/query_memory_control.hpp"
#include "query/exceptions.hpp"
#include "storage/v2/edge_accessor.hpp"
#include "storage/v2/id_types.hpp"
@ -372,6 +373,16 @@ class DbAccessor final {
void FinalizeTransaction() { accessor_->FinalizeTransaction(); }
void TrackCurrentThreadAllocations() {
memgraph::memory::StartTrackingCurrentThreadTransaction(*accessor_->GetTransactionId());
}
void UntrackCurrentThreadAllocations() {
memgraph::memory::StopTrackingCurrentThreadTransaction(*accessor_->GetTransactionId());
}
std::optional<uint64_t> GetTransactionId() { return accessor_->GetTransactionId(); }
VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); }
VerticesIterable Vertices(storage::View view, storage::LabelId label) {
@ -640,6 +651,14 @@ class SubgraphDbAccessor final {
static SubgraphDbAccessor *MakeSubgraphDbAccessor(DbAccessor *db_accessor, Graph *graph);
void TrackThreadAllocations(const char *thread_id);
void TrackCurrentThreadAllocations();
void UntrackThreadAllocations(const char *thread_id);
void UntrackCurrentThreadAllocations();
storage::PropertyId NameToProperty(std::string_view name);
storage::LabelId NameToLabel(std::string_view name);

View File

@ -26,6 +26,7 @@
#include <optional>
#include <stdexcept>
#include <thread>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <variant>
@ -39,7 +40,8 @@
#include "flags/run_time_configurable.hpp"
#include "glue/communication.hpp"
#include "license/license.hpp"
#include "memory/memory_control.hpp"
#include "memory/global_memory_control.hpp"
#include "memory/query_memory_control.hpp"
#include "query/config.hpp"
#include "query/constants.hpp"
#include "query/context.hpp"
@ -1283,6 +1285,24 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
std::map<std::string, TypedValue> *summary) {
std::optional<uint64_t> transaction_id = ctx_.db_accessor->GetTransactionId();
MG_ASSERT(transaction_id.has_value());
if (memory_limit_) {
memgraph::memory::TryStartTrackingOnTransaction(*transaction_id, *memory_limit_);
memgraph::memory::StartTrackingCurrentThreadTransaction(*transaction_id);
}
utils::OnScopeExit<std::function<void()>> reset_query_limit{
[memory_limit = memory_limit_, transaction_id = *transaction_id]() {
if (memory_limit) {
// Stopping tracking of transaction occurs in interpreter::pull
// Exception can occur so we need to handle that case there.
// We can't stop tracking here as there can be multiple pulls
// so we need to take care of that after everything was pulled
memgraph::memory::StopTrackingCurrentThreadTransaction(transaction_id);
}
}};
// Set up temporary memory for a single Pull. Initial memory comes from the
// stack. 256 KiB should fit on the stack and should be more than enough for a
// single `Pull`.
@ -1306,13 +1326,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *strea
pool_memory.emplace(kMaxBlockPerChunks, 1024, &monotonic_memory, &resource_with_exception);
}
std::optional<utils::LimitedMemoryResource> maybe_limited_resource;
if (memory_limit_) {
maybe_limited_resource.emplace(&*pool_memory, *memory_limit_);
ctx_.evaluation_context.memory = &*maybe_limited_resource;
} else {
ctx_.evaluation_context.memory = &*pool_memory;
}
// Returns true if a result was pulled.
const auto pull_result = [&]() -> bool { return cursor_->Pull(frame_, ctx_); };
@ -1379,6 +1393,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *strea
}
cursor_->Shutdown();
ctx_.profile_execution_time = execution_time_;
return GetStatsWithTotalTime(ctx_);
}

View File

@ -16,6 +16,7 @@
#include <gflags/gflags.h>
#include "dbms/database.hpp"
#include "memory/query_memory_control.hpp"
#include "query/auth_checker.hpp"
#include "query/auth_query_handler.hpp"
#include "query/config.hpp"
@ -402,6 +403,9 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
// If the query finished executing, we have received a value which tells
// us what to do after.
if (maybe_res) {
if (current_transaction_) {
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
}
// Save its summary
maybe_summary.emplace(std::move(query_execution->summary));
if (!query_execution->notifications.empty()) {
@ -440,9 +444,15 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
}
}
} catch (const ExplicitTransactionUsageException &) {
if (current_transaction_) {
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
}
query_execution.reset(nullptr);
throw;
} catch (const utils::BasicException &) {
if (current_transaction_) {
memgraph::memory::TryStopTrackingOnTransaction(*current_transaction_);
}
// Trigger first failed query
metrics::FirstFailedQuery();
memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery);

View File

@ -3596,3 +3596,15 @@ mgp_error mgp_log(const mgp_log_level log_level, const char *output) {
throw std::invalid_argument{fmt::format("Invalid log level: {}", log_level)};
});
}
mgp_error mgp_track_current_thread_allocations(mgp_graph *graph) {
return WrapExceptions([&]() {
std::visit([](auto *db_accessor) -> void { db_accessor->TrackCurrentThreadAllocations(); }, graph->impl);
});
}
mgp_error mgp_untrack_current_thread_allocations(mgp_graph *graph) {
return WrapExceptions([&]() {
std::visit([](auto *db_accessor) -> void { db_accessor->UntrackCurrentThreadAllocations(); }, graph->impl);
});
}

View File

@ -89,6 +89,13 @@ void MemoryTracker::TryRaiseHardLimit(const int64_t limit) {
;
}
void MemoryTracker::ResetTrackings() {
hard_limit_.store(0, std::memory_order_relaxed);
peak_.store(0, std::memory_order_relaxed);
amount_.store(0, std::memory_order_relaxed);
maximum_hard_limit_ = 0;
}
void MemoryTracker::SetMaximumHardLimit(const int64_t limit) {
if (maximum_hard_limit_ < 0) {
spdlog::warn("Invalid maximum hard limit.");

View File

@ -12,6 +12,7 @@
#pragma once
#include <atomic>
#include <type_traits>
#include "utils/exceptions.hpp"
@ -41,9 +42,20 @@ class MemoryTracker final {
MemoryTracker() = default;
~MemoryTracker() = default;
MemoryTracker(MemoryTracker &&other) noexcept
: amount_(other.amount_.load(std::memory_order_acquire)),
peak_(other.peak_.load(std::memory_order_acquire)),
hard_limit_(other.hard_limit_.load(std::memory_order_acquire)),
maximum_hard_limit_(other.maximum_hard_limit_) {
other.maximum_hard_limit_ = 0;
other.amount_.store(0, std::memory_order_acquire);
other.peak_.store(0, std::memory_order_acquire);
other.hard_limit_.store(0, std::memory_order_acquire);
}
MemoryTracker(const MemoryTracker &) = delete;
MemoryTracker &operator=(const MemoryTracker &) = delete;
MemoryTracker(MemoryTracker &&) = delete;
MemoryTracker &operator=(MemoryTracker &&) = delete;
void Alloc(int64_t size);
@ -59,6 +71,8 @@ class MemoryTracker final {
void TryRaiseHardLimit(int64_t limit);
void SetMaximumHardLimit(int64_t limit);
void ResetTrackings();
// By creating an object of this class, every allocation in its scope that goes over
// the set hard limit produces an OutOfMemoryException.
class OutOfMemoryExceptionEnabler final {

View File

@ -21,6 +21,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
BUILD_DIR = os.path.join(PROJECT_DIR, "build")
MEMGRAPH_BINARY = os.path.join(BUILD_DIR, "memgraph")
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.01):
@ -133,7 +134,7 @@ class MemgraphInstanceRunner:
pid = self.proc_mg.pid
try:
os.kill(pid, 15) # 15 is the signal number for SIGTERM
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False

View File

@ -11,10 +11,21 @@ target_link_libraries(memgraph__e2e__memory__limit_global_alloc gflags mgclient
add_executable(memgraph__e2e__memory__limit_global_alloc_proc memory_limit_global_alloc_proc.cpp)
target_link_libraries(memgraph__e2e__memory__limit_global_alloc_proc gflags mgclient mg-utils mg-io Threads::Threads)
add_executable(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread query_memory_limit_proc_multi_thread.cpp)
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc_multi_thread gflags mgclient mg-utils mg-io Threads::Threads)
add_executable(memgraph__e2e__memory__limit_query_alloc_create query_memory_limit_create.cpp)
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create gflags mgclient mg-utils mg-io)
add_executable(memgraph__e2e__memory__limit_query_alloc_proc query_memory_limit_proc.cpp)
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_proc gflags mgclient mg-utils mg-io)
add_executable(memgraph__e2e__memory__limit_query_alloc_create_multi_thread query_memory_limit_multi_thread.cpp)
target_link_libraries(memgraph__e2e__memory__limit_query_alloc_create_multi_thread gflags mgclient mg-utils mg-io Threads::Threads)
add_executable(memgraph__e2e__memory__limit_delete memory_limit_delete.cpp)
target_link_libraries(memgraph__e2e__memory__limit_delete gflags mgclient mg-utils mg-io)
add_executable(memgraph__e2e__memory__limit_accumulation memory_limit_accumulation.cpp)
target_link_libraries(memgraph__e2e__memory__limit_accumulation gflags mgclient mg-utils mg-io)

View File

@ -3,3 +3,13 @@ target_include_directories(global_memory_limit PRIVATE ${CMAKE_SOURCE_DIR}/inclu
add_library(global_memory_limit_proc SHARED global_memory_limit_proc.c)
target_include_directories(global_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
add_library(query_memory_limit_proc_multi_thread SHARED query_memory_limit_proc_multi_thread.cpp)
target_include_directories(query_memory_limit_proc_multi_thread PRIVATE ${CMAKE_SOURCE_DIR}/include)
target_link_libraries(query_memory_limit_proc_multi_thread mg-utils)
add_library(query_memory_limit_proc SHARED query_memory_limit_proc.cpp)
target_include_directories(query_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include)
target_link_libraries(query_memory_limit_proc mg-utils)

View File

@ -0,0 +1,70 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <cassert>
#include <exception>
#include <functional>
#include <mgp.hpp>
#include <mutex>
#include <sstream>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "mg_procedure.h"
#include "utils/on_scope_exit.hpp"
enum mgp_error Alloc(void *ptr) {
const size_t mb_size_268 = 1 << 28;
return mgp_global_alloc(mb_size_268, (void **)(&ptr));
}
void Regular(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
mgp::MemoryDispatcherGuard guard{memory};
const auto arguments = mgp::List(args);
const auto record_factory = mgp::RecordFactory(result);
try {
void *ptr{nullptr};
memgraph::utils::OnScopeExit<std::function<void(void)>> cleanup{[&ptr]() {
if (nullptr == ptr) {
return;
}
mgp_global_free(ptr);
}};
const enum mgp_error alloc_err = Alloc(ptr);
auto new_record = record_factory.NewRecord();
new_record.Insert("allocated", alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE);
} catch (std::exception &e) {
record_factory.SetErrorMessage(e.what());
}
}
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
try {
mgp::MemoryDispatcherGuard mdg{memory};
AddProcedure(Regular, std::string("regular").c_str(), mgp::ProcedureType::Read, {},
{mgp::Return(std::string("allocated").c_str(), mgp::Type::Bool)}, module, memory);
} catch (const std::exception &e) {
return 1;
}
return 0;
}
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,101 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <cassert>
#include <exception>
#include <functional>
#include <mgp.hpp>
#include <mutex>
#include <sstream>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "mg_procedure.h"
#include "utils/on_scope_exit.hpp"
enum mgp_error Alloc(void *ptr) {
const size_t mb_size_268 = 1 << 28;
return mgp_global_alloc(mb_size_268, (void **)(&ptr));
}
// change communication between threads with feature and promise
std::atomic<int> num_allocations{0};
std::vector<void *> ptrs_;
void AllocFunc(mgp_graph *graph) {
[[maybe_unused]] const enum mgp_error tracking_error = mgp_track_current_thread_allocations(graph);
void *ptr = nullptr;
ptrs_.emplace_back(ptr);
try {
enum mgp_error alloc_err { mgp_error::MGP_ERROR_NO_ERROR };
alloc_err = Alloc(ptr);
if (alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) {
num_allocations.fetch_add(1, std::memory_order_relaxed);
}
if (alloc_err != mgp_error::MGP_ERROR_NO_ERROR) {
assert(false);
}
} catch (const std::exception &e) {
assert(false);
}
[[maybe_unused]] const enum mgp_error untracking_error = mgp_untrack_current_thread_allocations(graph);
}
void DualThread(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
mgp::MemoryDispatcherGuard guard{memory};
const auto arguments = mgp::List(args);
const auto record_factory = mgp::RecordFactory(result);
num_allocations.store(0, std::memory_order_relaxed);
try {
std::vector<std::thread> threads;
for (int i = 0; i < 2; i++) {
threads.emplace_back(AllocFunc, memgraph_graph);
}
for (int i = 0; i < 2; i++) {
threads[i].join();
}
for (void *ptr : ptrs_) {
if (ptr != nullptr) {
mgp_global_free(ptr);
}
}
auto new_record = record_factory.NewRecord();
new_record.Insert("allocated_all", num_allocations.load(std::memory_order_relaxed) == 2);
} catch (std::exception &e) {
record_factory.SetErrorMessage(e.what());
}
}
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
try {
mgp::memory = memory;
AddProcedure(DualThread, std::string("dual_thread").c_str(), mgp::ProcedureType::Read, {},
{mgp::Return(std::string("allocated_all").c_str(), mgp::Type::Bool)}, module, memory);
} catch (const std::exception &e) {
return 1;
}
return 0;
}
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,65 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include <iostream>
#include <mgclient.hpp>
#include "utils/logging.hpp"
#include "utils/timer.hpp"
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Memory Control");
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::logging::RedirectToStderr();
mg::Client::Init();
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
const auto *create_query =
"UNWIND range(1, 50000) as u CREATE (n {string: 'Some longer string'}) RETURN n QUERY MEMORY LIMIT 30MB;";
try {
client->Execute(create_query);
[[maybe_unused]] auto results = client->FetchAll();
if (results->empty()) {
assert(true);
return 0;
}
} catch (const mg::TransientException & /*unused*/) {
spdlog::info("Memgraph is out of memory");
assert(true);
return 0;
}
MG_ASSERT(false, "Query should have failed!");
return 0;
}

View File

@ -0,0 +1,100 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <exception>
#include <future>
#include <thread>
#include <gflags/gflags.h>
#include <iostream>
#include <mgclient.hpp>
#include <utility>
#include "utils/logging.hpp"
#include "utils/timer.hpp"
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
static constexpr int kNumberClients{2};
void Func(std::promise<bool> promise) {
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
const auto *create_query =
"FOREACH(i in range(1, 300000) | CREATE (n: Node {string: 'Some longer string'})) QUERY MEMORY LIMIT 100MB;";
bool err{false};
try {
client->Execute(create_query);
[[maybe_unused]] auto results = client->FetchAll();
if (results->empty()) {
err = true;
}
} catch (const std::exception &e) {
spdlog::info("Good: Exception occured", e.what());
err = true;
}
promise.set_value_at_thread_exit(err);
}
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Memory Control");
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::logging::RedirectToStderr();
mg::Client::Init();
{
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
}
std::vector<std::promise<bool>> my_promises;
std::vector<std::future<bool>> my_futures;
for (int i = 0; i < 4; i++) {
my_promises.push_back(std::promise<bool>());
my_futures.emplace_back(my_promises.back().get_future());
}
std::vector<std::thread> my_threads;
for (int i = 0; i < kNumberClients; i++) {
my_threads.emplace_back(Func, std::move(my_promises[i]));
}
for (int i = 0; i < kNumberClients; i++) {
auto value = my_futures[i].get();
MG_ASSERT(value, "Error should have happend in thread");
}
for (int i = 0; i < kNumberClients; i++) {
my_threads[i].join();
}
return 0;
}

View File

@ -0,0 +1,66 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include <algorithm>
#include <exception>
#include <ios>
#include <iostream>
#include <mgclient.hpp>
#include "utils/logging.hpp"
#include "utils/timer.hpp"
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators");
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::logging::RedirectToStderr();
mg::Client::Init();
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
MG_ASSERT(
client->Execute("CALL libquery_memory_limit_proc.regular() YIELD allocated RETURN "
"allocated QUERY MEMORY LIMIT 250MB"));
bool error{false};
try {
auto result_rows = client->FetchAll();
if (result_rows) {
auto row = *result_rows->begin();
error = row[0].ValueBool() == false;
}
} catch (const std::exception &e) {
error = true;
}
MG_ASSERT(error, "Error should have happend");
return 0;
}

View File

@ -0,0 +1,66 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include <algorithm>
#include <exception>
#include <ios>
#include <iostream>
#include <mgclient.hpp>
#include "utils/logging.hpp"
#include "utils/timer.hpp"
DEFINE_uint64(bolt_port, 7687, "Bolt port");
DEFINE_uint64(timeout, 120, "Timeout seconds");
DEFINE_bool(multi_db, false, "Run test in multi db environment");
int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph E2E Query Memory Limit In Multi-Thread For Global Allocators");
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::logging::RedirectToStderr();
mg::Client::Init();
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
if (!client) {
LOG_FATAL("Failed to connect!");
}
if (FLAGS_multi_db) {
client->Execute("CREATE DATABASE clean;");
client->DiscardAll();
client->Execute("USE DATABASE clean;");
client->DiscardAll();
client->Execute("MATCH (n) DETACH DELETE n;");
client->DiscardAll();
}
MG_ASSERT(
client->Execute("CALL libquery_memory_limit_proc_multi_thread.dual_thread() YIELD allocated_all RETURN "
"allocated_all QUERY MEMORY LIMIT 500MB"));
bool error{false};
try {
auto result_rows = client->FetchAll();
if (result_rows) {
auto row = *result_rows->begin();
error = row[0].ValueBool() == false;
}
} catch (const std::exception &e) {
error = true;
}
MG_ASSERT(error, "Error should have happend");
return 0;
}

View File

@ -23,6 +23,20 @@ disk_cluster: &disk_cluster
- "STORAGE MODE ON_DISK_TRANSACTIONAL"
validation_queries: []
args_query_limit: &args_query_limit
- "--bolt-port"
- *bolt_port
- "--storage-gc-cycle-sec=180"
- "--log-level=TRACE"
in_memory_query_limit_cluster: &in_memory_query_limit_cluster
cluster:
main:
args: *args_query_limit
log_file: "memory-e2e.log"
setup_queries: []
validation_queries: []
args_450_MiB_limit: &args_450_MiB_limit
- "--bolt-port"
- *bolt_port
@ -95,6 +109,27 @@ workloads:
proc: "tests/e2e/memory/procedures/"
<<: *disk_cluster
- name: "Memory control query limit proc"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc"
proc: "tests/e2e/memory/procedures/"
args: ["--bolt-port", *bolt_port]
<<: *in_memory_query_limit_cluster
- name: "Memory control query limit proc multi thread"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_proc_multi_thread"
args: ["--bolt-port", *bolt_port, "--timeout", "180"]
proc: "tests/e2e/memory/procedures/"
<<: *in_memory_query_limit_cluster
- name: "Memory control query limit create"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create"
args: ["--bolt-port", *bolt_port]
<<: *in_memory_query_limit_cluster
- name: "Memory control query limit create multi thread"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_query_alloc_create_multi_thread"
args: ["--bolt-port", *bolt_port]
<<: *in_memory_query_limit_cluster
- name: "Memory control for detach delete"
binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete"
args: ["--bolt-port", *bolt_port]

View File

@ -24,6 +24,7 @@ import time
DEFAULT_DB = "memgraph"
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
QUERIES = [
("MATCH (n) DELETE n", {}),
@ -92,9 +93,13 @@ def execute_test(memgraph_binary, tester_binary):
# Register cleanup function
@atexit.register
def cleanup():
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
def execute_queries(queries):
for db, query, params in queries:
@ -122,10 +127,12 @@ def execute_test(memgraph_binary, tester_binary):
execute_queries(mt_queries3)
print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n")
# Shutdown the memgraph binary
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
# Verify the written log
print("\033[1;36m~~ Starting log verification ~~\033[0m")

View File

@ -21,6 +21,7 @@ import time
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
# When you create a new permission just add a testcase to this list (a tuple
# of query, touple of required permissions) and the test will automatically
@ -166,8 +167,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
@atexit.register
def cleanup():
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
# Prepare the multi database environment
execute_admin_queries(
@ -327,8 +332,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n")
# Shutdown the memgraph binary
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
if __name__ == "__main__":

View File

@ -20,7 +20,6 @@ import sys
import tempfile
import time
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
TESTS_DIR = os.path.join(SCRIPT_DIR, "tests")
@ -31,6 +30,8 @@ WAL_FILE_NAME = "wal.bin"
DUMP_SNAPSHOT_FILE_NAME = "expected_snapshot.cypher"
DUMP_WAL_FILE_NAME = "expected_wal.cypher"
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
@ -40,7 +41,7 @@ def wait_for_server(port, delay=0.1):
def sorted_content(file_path):
with open(file_path, 'r') as fin:
with open(file_path, "r") as fin:
return sorted(list(map(lambda x: x.strip(), fin.readlines())))
@ -52,32 +53,27 @@ def list_to_string(data):
return ret
def execute_test(
memgraph_binary,
dump_binary,
test_directory,
test_type,
write_expected):
assert test_type in ["SNAPSHOT", "WAL"], \
"Test type should be either 'SNAPSHOT' or 'WAL'."
print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m"
.format(os.path.relpath(test_directory, TESTS_DIR), test_type))
def execute_test(memgraph_binary, dump_binary, test_directory, test_type, write_expected):
assert test_type in ["SNAPSHOT", "WAL"], "Test type should be either 'SNAPSHOT' or 'WAL'."
print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m".format(os.path.relpath(test_directory, TESTS_DIR), test_type))
working_data_directory = tempfile.TemporaryDirectory()
if test_type == "SNAPSHOT":
snapshots_dir = os.path.join(working_data_directory.name, "snapshots")
os.makedirs(snapshots_dir)
shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME),
snapshots_dir)
shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), snapshots_dir)
else:
wal_dir = os.path.join(working_data_directory.name, "wal")
os.makedirs(wal_dir)
shutil.copy(os.path.join(test_directory, WAL_FILE_NAME), wal_dir)
memgraph_args = [memgraph_binary,
memgraph_args = [
memgraph_binary,
"--storage-recover-on-startup",
"--storage-properties-on-edges",
"--data-directory", working_data_directory.name]
"--data-directory",
working_data_directory.name,
]
# Start the memgraph binary
memgraph = subprocess.Popen(memgraph_args)
@ -89,8 +85,12 @@ def execute_test(
@atexit.register
def cleanup():
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
# Execute `database dump`
dump_output_file = tempfile.NamedTemporaryFile()
@ -98,28 +98,31 @@ def execute_test(
subprocess.run(dump_args, stdout=dump_output_file, check=True)
# Shutdown the memgraph binary
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
dump_file_name = DUMP_SNAPSHOT_FILE_NAME if test_type == "SNAPSHOT" else DUMP_WAL_FILE_NAME
if write_expected:
with open(dump_output_file.name, 'r') as dump:
with open(dump_output_file.name, "r") as dump:
queries_got = dump.readlines()
# Write dump files
expected_dump_file = os.path.join(test_directory, dump_file_name)
with open(expected_dump_file, 'w') as expected:
with open(expected_dump_file, "w") as expected:
expected.writelines(queries_got)
else:
# Compare dump files
expected_dump_file = os.path.join(test_directory, dump_file_name)
assert os.path.exists(expected_dump_file), \
"Could not find expected dump path {}".format(expected_dump_file)
assert os.path.exists(expected_dump_file), "Could not find expected dump path {}".format(expected_dump_file)
queries_got = sorted_content(dump_output_file.name)
queries_expected = sorted_content(expected_dump_file)
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \
"{}".format(list_to_string(queries_got),
list_to_string(queries_expected))
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format(
list_to_string(queries_got), list_to_string(queries_expected)
)
print("\033[1;32m~~ Test successful ~~\033[0m\n")
@ -141,15 +144,17 @@ def find_test_directories(directory):
continue
snapshot_file = os.path.join(test_dir_path, SNAPSHOT_FILE_NAME)
wal_file = os.path.join(test_dir_path, WAL_FILE_NAME)
dump_snapshot_file = os.path.join(
test_dir_path, DUMP_SNAPSHOT_FILE_NAME)
dump_snapshot_file = os.path.join(test_dir_path, DUMP_SNAPSHOT_FILE_NAME)
dump_wal_file = os.path.join(test_dir_path, DUMP_WAL_FILE_NAME)
if (os.path.isfile(snapshot_file) and os.path.isfile(dump_snapshot_file)
and os.path.isfile(wal_file) and os.path.isfile(dump_wal_file)):
if (
os.path.isfile(snapshot_file)
and os.path.isfile(dump_snapshot_file)
and os.path.isfile(wal_file)
and os.path.isfile(dump_wal_file)
):
test_dirs.append(test_dir_path)
else:
raise Exception("Missing data in test directory '{}'"
.format(test_dir_path))
raise Exception("Missing data in test directory '{}'".format(test_dir_path))
return test_dirs
@ -161,26 +166,15 @@ if __name__ == "__main__":
parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--dump", default=dump_binary)
parser.add_argument(
'--write-expected',
action='store_true',
help='Overwrite the expected cypher with results from current run')
"--write-expected", action="store_true", help="Overwrite the expected cypher with results from current run"
)
args = parser.parse_args()
test_directories = find_test_directories(TESTS_DIR)
assert len(test_directories) > 0, "No tests have been found!"
for test_directory in test_directories:
execute_test(
args.memgraph,
args.dump,
test_directory,
"SNAPSHOT",
args.write_expected)
execute_test(
args.memgraph,
args.dump,
test_directory,
"WAL",
args.write_expected)
execute_test(args.memgraph, args.dump, test_directory, "SNAPSHOT", args.write_expected)
execute_test(args.memgraph, args.dump, test_directory, "WAL", args.write_expected)
sys.exit(0)

View File

@ -22,6 +22,7 @@ from typing import List
SCRIPT_DIR = Path(__file__).absolute()
PROJECT_DIR = SCRIPT_DIR.parents[3]
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1):
@ -68,8 +69,12 @@ def execute_with_user(queries):
def cleanup(memgraph):
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True):

View File

@ -24,6 +24,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\."
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1):
@ -80,8 +81,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
@atexit.register
def cleanup():
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
# Prepare all users
def setup_user():
@ -130,8 +135,12 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str
print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n")
# Shutdown the memgraph binary
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
if __name__ == "__main__":

View File

@ -21,6 +21,7 @@ from typing import List
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
def wait_for_server(port: int, delay: float = 0.1) -> float:
@ -86,8 +87,12 @@ def execute_without_user(
def cleanup(memgraph: subprocess):
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False
time.sleep(1)
def test_without_any_files(tester_binary: str, memgraph_args: List[str]):

View File

@ -23,6 +23,8 @@ import time
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
CONFIG_TEMPLATE = """
server:
host: "127.0.0.1"
@ -52,8 +54,7 @@ def wait_for_server(port, delay=0.1):
time.sleep(delay)
def execute_tester(binary, queries, username="", password="",
auth_should_fail=False, query_should_fail=False):
def execute_tester(binary, queries, username="", password="", auth_should_fail=False, query_should_fail=False):
if password == "":
password = username
args = [binary, "--username", username, "--password", password]
@ -76,18 +77,14 @@ class Memgraph:
def start(self, **kwargs):
self.stop()
self._storage_directory = tempfile.TemporaryDirectory()
self._auth_module = os.path.join(self._storage_directory.name,
"ldap.py")
self._auth_config = os.path.join(self._storage_directory.name,
"ldap.yaml")
script_file = os.path.join(PROJECT_DIR, "src", "auth",
"reference_modules", "ldap.py")
self._auth_module = os.path.join(self._storage_directory.name, "ldap.py")
self._auth_config = os.path.join(self._storage_directory.name, "ldap.yaml")
script_file = os.path.join(PROJECT_DIR, "src", "auth", "reference_modules", "ldap.py")
virtualenv_bin = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3")
with open(script_file) as fin:
data = fin.read()
data = data.replace("/usr/bin/python3", virtualenv_bin)
data = data.replace("/etc/memgraph/auth/ldap.yaml",
self._auth_config)
data = data.replace("/etc/memgraph/auth/ldap.yaml", self._auth_config)
with open(self._auth_module, "w") as fout:
fout.write(data)
os.chmod(self._auth_module, stat.S_IRWXU | stat.S_IRWXG)
@ -106,10 +103,13 @@ class Memgraph:
}
with open(self._auth_config, "w") as f:
f.write(CONFIG_TEMPLATE.format(**config))
args = [self._binary,
"--data-directory", self._storage_directory.name,
args = [
self._binary,
"--data-directory",
self._storage_directory.name,
"--auth-module-executable",
kwargs.pop("module_executable", self._auth_module)]
kwargs.pop("module_executable", self._auth_module),
]
for key, value in kwargs.items():
ldap_key = "--auth-module-" + key.replace("_", "-")
if isinstance(value, bool):
@ -119,26 +119,27 @@ class Memgraph:
args.append(value)
self._process = subprocess.Popen(args)
time.sleep(0.1)
assert self._process.poll() is None, "Memgraph process died " \
"prematurely!"
assert self._process.poll() is None, "Memgraph process died " "prematurely!"
wait_for_server(7687)
def stop(self, check=True):
if self._process is None:
return 0
self._process.terminate()
exitcode = self._process.wait()
self._process = None
pid = self._process.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
if check:
assert exitcode == 0, "Memgraph process didn't exit cleanly!"
return exitcode
assert False
return -1
time.sleep(1)
return 0
def initialize_test(memgraph, tester_binary, **kwargs):
memgraph.start(module_executable="")
execute_tester(tester_binary,
["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
check_login = kwargs.pop("check_login", True)
memgraph.restart(**kwargs)
if check_login:
@ -170,18 +171,15 @@ def test_role_mapping(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "alice")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
execute_tester(tester_binary, [], "bob")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob",
query_should_fail=True)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True)
execute_tester(tester_binary, [], "carol")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol",
query_should_fail=True)
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True)
execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave")
@ -192,15 +190,13 @@ def test_role_mapping(memgraph, tester_binary):
def test_role_removal(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "alice")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.restart(manage_roles=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.stop()
@ -229,28 +225,22 @@ def test_user_is_role(memgraph, tester_binary):
def test_user_permissions_persistancy(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary,
["CREATE USER alice", "GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
def test_role_permissions_persistancy(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary,
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
"root")
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
def test_only_authentication(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, manage_roles=False)
execute_tester(tester_binary,
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.stop()
@ -267,22 +257,16 @@ def test_wrong_suffix(memgraph, tester_binary):
def test_suffix_with_spaces(memgraph, tester_binary):
initialize_test(memgraph, tester_binary,
suffix=", ou= people, dc = memgraph, dc = com")
execute_tester(tester_binary,
["CREATE USER alice", "GRANT MATCH TO alice"], "root")
initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com")
execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
initialize_test(memgraph, tester_binary,
root_dn="ou=invalid,dc=memgraph,dc=com")
execute_tester(tester_binary,
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com")
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
@ -290,11 +274,8 @@ def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, root_objectclass="person")
execute_tester(tester_binary,
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
@ -302,11 +283,8 @@ def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, user_attribute="cn")
execute_tester(tester_binary,
["CREATE ROLE moderator", "GRANT MATCH TO moderator"],
"root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice",
query_should_fail=True)
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
@ -314,8 +292,7 @@ def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
def test_wrong_password(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "root", password="sudo",
auth_should_fail=True)
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.stop()
@ -326,12 +303,10 @@ def test_password_persistancy(memgraph, tester_binary):
execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo")
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.restart()
execute_tester(tester_binary, [], "root", password="sudo",
auth_should_fail=True)
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.restart(module_executable="")
execute_tester(tester_binary, [], "root", password="sudo",
auth_should_fail=True)
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.stop()
@ -339,33 +314,25 @@ def test_password_persistancy(memgraph, tester_binary):
def test_user_multiple_roles(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, check_login=False)
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
query_should_fail=True)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
memgraph.restart(manage_roles=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
query_should_fail=True)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
memgraph.restart(manage_roles=False, root_dn="")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve",
query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root",
query_should_fail=True)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
memgraph.stop()
def test_starttls_failure(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, encryption="starttls",
check_login=False)
initialize_test(memgraph, tester_binary, encryption="starttls", check_login=False)
execute_tester(tester_binary, [], "root", auth_should_fail=True)
memgraph.stop()
def test_ssl_failure(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, encryption="ssl",
check_login=False)
initialize_test(memgraph, tester_binary, encryption="ssl", check_login=False)
execute_tester(tester_binary, [], "root", auth_should_fail=True)
memgraph.stop()
@ -375,22 +342,19 @@ def test_ssl_failure(memgraph, tester_binary):
if __name__ == "__main__":
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
"integration", "ldap", "tester")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "ldap", "tester")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--tester", default=tester_binary)
parser.add_argument("--openldap-dir",
default=os.path.join(SCRIPT_DIR, "openldap-2.4.47"))
parser.add_argument("--openldap-dir", default=os.path.join(SCRIPT_DIR, "openldap-2.4.47"))
args = parser.parse_args()
# Setup Memgraph handler
memgraph = Memgraph(args.memgraph)
# Start the slapd binary
slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"),
"-h", "ldap://127.0.0.1:1389/", "-d", "0"]
slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), "-h", "ldap://127.0.0.1:1389/", "-d", "0"]
slapd = subprocess.Popen(slapd_args)
time.sleep(0.1)
assert slapd.poll() is None, "slapd process died prematurely!"
@ -409,8 +373,7 @@ if __name__ == "__main__":
if slapd_stat != 0:
print("slapd process didn't exit cleanly!")
assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " \
"(memgraph, slapd) crashed!"
assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " "(memgraph, slapd) crashed!"
# Execute tests
names = sorted(globals().keys())

View File

@ -18,12 +18,13 @@ import subprocess
import sys
import tempfile
import time
import yaml
import yaml
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
BUILD_DIR = os.path.join(BASE_DIR, "build")
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1):
@ -46,17 +47,14 @@ def list_to_string(data):
def verify_lifetime(memgraph_binary, mg_import_csv_binary):
print("\033[1;36m~~ Verifying that mg_import_csv can't be started while "
"memgraph is running ~~\033[0m")
print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " "memgraph is running ~~\033[0m")
storage_directory = tempfile.TemporaryDirectory()
# Generate common args
common_args = ["--data-directory", storage_directory.name,
"--storage-properties-on-edges=false"]
common_args = ["--data-directory", storage_directory.name, "--storage-properties-on-edges=false"]
# Start the memgraph binary
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \
common_args
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
time.sleep(0.1)
assert memgraph.poll() is None, "Memgraph process died prematurely!"
@ -66,47 +64,52 @@ def verify_lifetime(memgraph_binary, mg_import_csv_binary):
@atexit.register
def cleanup():
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
# Execute mg_import_csv.
mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + \
common_args
mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + common_args
ret = subprocess.run(mg_import_csv_args)
# Check the return code
if ret.returncode == 0:
raise Exception(
"The importer was able to run while memgraph was running!")
raise Exception("The importer was able to run while memgraph was running!")
# Shutdown the memgraph binary
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
print("\033[1;32m~~ Test successful ~~\033[0m\n")
def execute_test(name, test_path, test_config, memgraph_binary,
mg_import_csv_binary, tester_binary, write_expected):
def execute_test(name, test_path, test_config, memgraph_binary, mg_import_csv_binary, tester_binary, write_expected):
print("\033[1;36m~~ Executing test", name, "~~\033[0m")
storage_directory = tempfile.TemporaryDirectory()
# Verify test configuration
if ("import_should_fail" not in test_config and
"expected" not in test_config) or \
("import_should_fail" in test_config and
"expected" in test_config):
raise Exception("The test should specify either 'import_should_fail' "
"or 'expected'!")
if ("import_should_fail" not in test_config and "expected" not in test_config) or (
"import_should_fail" in test_config and "expected" in test_config
):
raise Exception("The test should specify either 'import_should_fail' " "or 'expected'!")
expected_path = test_config.pop("expected", "")
import_should_fail = test_config.pop("import_should_fail", False)
# Generate common args
properties_on_edges = bool(test_config.pop("properties_on_edges", False))
common_args = ["--data-directory", storage_directory.name,
"--storage-properties-on-edges=" +
str(properties_on_edges).lower()]
common_args = [
"--data-directory",
storage_directory.name,
"--storage-properties-on-edges=" + str(properties_on_edges).lower(),
]
# Generate mg_import_csv args using flags specified in the test
mg_import_csv_args = [mg_import_csv_binary] + common_args
@ -125,19 +128,16 @@ def execute_test(name, test_path, test_config, memgraph_binary,
if import_should_fail:
if ret.returncode == 0:
raise Exception("The import should have failed, but it "
"succeeded instead!")
raise Exception("The import should have failed, but it " "succeeded instead!")
else:
print("\033[1;32m~~ Test successful ~~\033[0m\n")
return
else:
if ret.returncode != 0:
raise Exception("The import should have succeeded, but it "
"failed instead!")
raise Exception("The import should have succeeded, but it " "failed instead!")
# Start the memgraph binary
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \
common_args
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
time.sleep(0.1)
assert memgraph.poll() is None, "Memgraph process died prematurely!"
@ -147,21 +147,29 @@ def execute_test(name, test_path, test_config, memgraph_binary,
@atexit.register
def cleanup():
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
# Get the contents of the database
queries_got = extract_rows(subprocess.run(
[tester_binary], stdout=subprocess.PIPE,
check=True).stdout.decode("utf-8"))
queries_got = extract_rows(
subprocess.run([tester_binary], stdout=subprocess.PIPE, check=True).stdout.decode("utf-8")
)
# Shutdown the memgraph binary
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
if write_expected:
with open(os.path.join(test_path, expected_path), 'w') as expected:
expected.write('\n'.join(queries_got))
with open(os.path.join(test_path, expected_path), "w") as expected:
expected.write("\n".join(queries_got))
else:
if expected_path:
@ -173,18 +181,16 @@ def execute_test(name, test_path, test_config, memgraph_binary,
# Verify the queries
queries_expected.sort()
queries_got.sort()
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \
"{}".format(list_to_string(queries_got),
list_to_string(queries_expected))
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format(
list_to_string(queries_got), list_to_string(queries_expected)
)
print("\033[1;32m~~ Test successful ~~\033[0m\n")
if __name__ == "__main__":
memgraph_binary = os.path.join(BUILD_DIR, "memgraph")
mg_import_csv_binary = os.path.join(
BUILD_DIR, "src", "mg_import_csv")
tester_binary = os.path.join(
BUILD_DIR, "tests", "integration", "mg_import_csv", "tester")
mg_import_csv_binary = os.path.join(BUILD_DIR, "src", "mg_import_csv")
tester_binary = os.path.join(BUILD_DIR, "tests", "integration", "mg_import_csv", "tester")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)
@ -193,7 +199,8 @@ if __name__ == "__main__":
parser.add_argument(
"--write-expected",
action="store_true",
help="Overwrite the expected values with the results of the current run")
help="Overwrite the expected values with the results of the current run",
)
args = parser.parse_args()
# First test whether the CSV importer can be started while the main
@ -211,7 +218,8 @@ if __name__ == "__main__":
testcases = yaml.safe_load(f)
for test_config in testcases:
test_name = name + "/" + test_config.pop("name")
execute_test(test_name, test_path, test_config, args.memgraph,
args.mg_import_csv, args.tester, args.write_expected)
execute_test(
test_name, test_path, test_config, args.memgraph, args.mg_import_csv, args.tester, args.write_expected
)
sys.exit(0)

View File

@ -23,6 +23,7 @@ from typing import List
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
SIGNAL_SIGTERM = 15
def wait_for_server(port: int, delay: float = 0.1) -> float:
@ -91,8 +92,12 @@ def check_config(tester_binary: str, flag: str, value: str) -> None:
def cleanup(memgraph: subprocess):
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
def run_test(tester_binary: str, memgraph_args: List[str], server_name: str, query_tx: str):

View File

@ -22,6 +22,8 @@ assertion_queries = [
f"MATCH (n)-[e]->(m) WITH count(e) as cnt RETURN assert(cnt={len(edge_queries)});",
]
SIGNAL_SIGTERM = 15
def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
@ -40,9 +42,12 @@ def prepare_memgraph(memgraph_args):
def terminate_memgraph(memgraph):
memgraph.terminate()
time.sleep(0.1)
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
def execute_tester(
@ -90,8 +95,12 @@ def execute_test_analytical_mode(memgraph_binary: str, tester_binary: str) -> No
execute_queries(assertion_queries)
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_binary: str) -> None:
@ -135,8 +144,12 @@ def execute_test_switch_analytical_transactional(memgraph_binary: str, tester_bi
execute_queries(assertion_queries)
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_binary: str) -> None:
@ -177,8 +190,12 @@ def execute_test_switch_transactional_analytical(memgraph_binary: str, tester_bi
execute_queries(assertion_queries)
print("\033[1;36m~~ Terminating memgraph ~~\033[0m\n")
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
pid = memgraph.pid
try:
os.kill(pid, SIGNAL_SIGTERM)
except os.OSError:
assert False, "Memgraph process didn't exit cleanly!"
time.sleep(1)
if __name__ == "__main__":

View File

@ -11,12 +11,13 @@
import logging
import os
import shutil
import subprocess
import tempfile
import time
from argparse import ArgumentParser
from collections import defaultdict
import tempfile
import shutil
import time
from common import get_absolute_path, set_cpus
try:
@ -36,13 +37,12 @@ class Memgraph:
"""
Knows how to start and stop memgraph.
"""
def __init__(self, args, num_workers):
self.log = logging.getLogger("MemgraphRunner")
argp = ArgumentParser("MemgraphArgumentParser")
argp.add_argument("--runner-bin",
default=get_absolute_path("memgraph", "build"))
argp.add_argument("--port", default="7687",
help="Database and client port")
argp.add_argument("--runner-bin", default=get_absolute_path("memgraph", "build"))
argp.add_argument("--port", default="7687", help="Database and client port")
argp.add_argument("--data-directory", default=None)
argp.add_argument("--storage-snapshot-on-exit", action="store_true")
argp.add_argument("--storage-recover-on-startup", action="store_true")
@ -55,8 +55,7 @@ class Memgraph:
def start(self):
self.log.info("start")
database_args = ["--bolt-port", self.args.port,
"--query-execution-timeout-sec", "0"]
database_args = ["--bolt-port", self.args.port, "--query-execution-timeout-sec", "0"]
if self.num_workers:
database_args += ["--bolt-num-workers", str(self.num_workers)]
if self.args.data_directory:
@ -82,15 +81,13 @@ class Neo:
"""
Knows how to start and stop neo4j.
"""
def __init__(self, args, config):
self.log = logging.getLogger("NeoRunner")
argp = ArgumentParser("NeoArgumentParser")
argp.add_argument("--runner-bin", default=get_absolute_path(
"neo4j/bin/neo4j", "libs"))
argp.add_argument("--port", default="7687",
help="Database and client port")
argp.add_argument("--http-port", default="7474",
help="Database and client port")
argp.add_argument("--runner-bin", default=get_absolute_path("neo4j/bin/neo4j", "libs"))
argp.add_argument("--port", default="7687", help="Database and client port")
argp.add_argument("--http-port", default="7474", help="Database and client port")
self.log.info("Initializing Runner with arguments %r", args)
self.args, _ = argp.parse_known_args(args)
self.config = config
@ -105,24 +102,22 @@ class Neo:
self.neo4j_home_path = tempfile.mkdtemp(dir="/dev/shm")
try:
os.symlink(os.path.join(get_absolute_path("neo4j", "libs"), "lib"),
os.path.join(self.neo4j_home_path, "lib"))
os.symlink(
os.path.join(get_absolute_path("neo4j", "libs"), "lib"), os.path.join(self.neo4j_home_path, "lib")
)
neo4j_conf_dir = os.path.join(self.neo4j_home_path, "conf")
neo4j_conf_file = os.path.join(neo4j_conf_dir, "neo4j.conf")
os.mkdir(neo4j_conf_dir)
shutil.copyfile(self.config, neo4j_conf_file)
with open(neo4j_conf_file, "a") as f:
f.write("\ndbms.connector.bolt.listen_address=:" +
self.args.port + "\n")
f.write("\ndbms.connector.http.listen_address=:" +
self.args.http_port + "\n")
f.write("\ndbms.connector.bolt.listen_address=:" + self.args.port + "\n")
f.write("\ndbms.connector.http.listen_address=:" + self.args.http_port + "\n")
# environment
cwd = os.path.dirname(self.args.runner_bin)
env = {"NEO4J_HOME": self.neo4j_home_path}
self.database_bin.run(self.args.runner_bin, args=["console"],
env=env, timeout=600, cwd=cwd)
self.database_bin.run(self.args.runner_bin, args=["console"], env=env, timeout=600, cwd=cwd)
except:
shutil.rmtree(self.neo4j_home_path)
raise Exception("Couldn't run Neo4j!")