memgraph/tests/stress/common.py

266 lines
7.6 KiB
Python
Raw Normal View History

# Copyright 2021 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
# -*- coding: utf-8 -*-
2023-09-11 00:53:03 +08:00
"""
Common methods for writing graph database
integration tests in python.
Only Bolt communication protocol is supported.
2023-09-11 00:53:03 +08:00
"""
import contextlib
import os
2023-09-11 00:53:03 +08:00
import time
from argparse import ArgumentParser
from threading import Thread
2023-09-11 00:53:03 +08:00
from neo4j import TRUST_ALL_CERTIFICATES, GraphDatabase
class OutputData:
2023-09-11 00:53:03 +08:00
"""
Encapsulates results and info about the tests.
2023-09-11 00:53:03 +08:00
"""
def __init__(self):
# data in time format (name, time, unit)
self._measurements = []
# name, string data
self._statuses = []
def add_measurement(self, name, time, unit="s"):
2023-09-11 00:53:03 +08:00
"""
Stores measurement.
:param name: str, name of measurement
:param time: float, time value
:param unit: str, time unit
2023-09-11 00:53:03 +08:00
"""
self._measurements.append((name, time, unit))
def add_status(self, name, status):
2023-09-11 00:53:03 +08:00
"""
Stores status data point.
:param name: str, name of data point
:param status: printable value
2023-09-11 00:53:03 +08:00
"""
self._statuses.append((name, status))
def dump(self, print_f=print):
2023-09-11 00:53:03 +08:00
"""
Dumps output using the given ouput function.
Args:
print_f - the function that consumes ouptput. Defaults to
the 'print' function.
2023-09-11 00:53:03 +08:00
"""
print_f("Output data:")
for name, status in self._statuses:
print_f(" %s: %s" % (name, status))
for name, time, unit in self._measurements:
print_f(" %s: %s%s" % (name, time, unit))
def execute_till_success(session, query, max_retries=1000):
2023-09-11 00:53:03 +08:00
"""
Executes a query within Bolt session until the query is
successfully executed against the database.
Args:
session - the bolt session to execute the query with
query - str, the query to execute
max_retries - int, maximum allowed number of attempts
:param session: active Bolt session
:param query: query to execute
:return: tuple (results_data_list, number_of_failures, result_summary)
2023-09-11 00:53:03 +08:00
"""
no_failures = 0
while True:
try:
result = session.run(query)
data = result.data()
summary = result.consume()
return data, no_failures, summary
except Exception:
no_failures += 1
if no_failures >= max_retries:
2023-09-11 00:53:03 +08:00
raise Exception("Query '%s' failed %d times, aborting" % (query, max_retries))
def try_execute(session, query: str):
"""
Executes a query within Bolt session
Args:
session - the bolt session to execute the query with
query - str, the query to execute
:param session: active Bolt session
:param query: query to execute
:return: None
"""
result = session.run(query)
data = result.data()
summary = result.consume()
return data, summary
def batch(input, batch_size):
2023-09-11 00:53:03 +08:00
"""Batches the given input (must be iterable).
Supports input generators. Returns a generator.
All is lazy. The last batch can contain less elements
then `batch_size`, but is for sure more then zero.
Args:
input - iterable of elements
batch_size - number of elements in the batch
Return:
a generator that yields batches of elements.
"""
assert batch_size > 1, "Batch size must be greater then zero"
batch = []
for element in input:
batch.append(element)
if len(batch) >= batch_size:
yield batch
batch = []
if len(batch):
yield batch
def render(template, iterable_arguments):
"""
Calls template.format() for each given argument.
"""
for arguments in iterable_arguments:
yield template.format(arguments)
def assert_equal(expected, actual, message):
2023-09-11 00:53:03 +08:00
"""
Compares expected and actual values. If values are not the same terminate
the execution.
:param expected: expected value
:param actual: actual value
:param message: str, message in case that the values are not equal, must
contain two placeholders (%s) to print the values.
2023-09-11 00:53:03 +08:00
"""
assert expected == actual, message % (expected, actual)
def connection_argument_parser():
2023-09-11 00:53:03 +08:00
"""
Parses arguments related to establishing database connection like
host, port, username, etc.
:return: An instance of ArgumentParser
2023-09-11 00:53:03 +08:00
"""
parser = ArgumentParser(description=__doc__)
2023-09-11 00:53:03 +08:00
parser.add_argument(
"--endpoint",
type=str,
default="127.0.0.1:7687",
help="DBMS instance endpoint. " "Bolt protocol is the only option.",
)
parser.add_argument("--username", type=str, default="neo4j", help="DBMS instance username.")
parser.add_argument("--password", type=str, default="1234", help="DBMS instance password.")
parser.add_argument("--use-ssl", action="store_true", help="Is SSL enabled?")
return parser
@contextlib.contextmanager
def bolt_session(url, auth, ssl=False):
2023-09-11 00:53:03 +08:00
"""
with wrapper around Bolt session.
:param url: str, e.g. "bolt://127.0.0.1:7687"
:param auth: auth method, goes directly to the Bolt driver constructor
:param ssl: bool, is ssl enabled
2023-09-11 00:53:03 +08:00
"""
driver = GraphDatabase.driver(url, auth=auth, encrypted=ssl, trust=TRUST_ALL_CERTIFICATES)
session = driver.session()
try:
yield session
finally:
session.close()
driver.close()
# If you are using session with multiprocessing take a look at SesssionCache
# in bipartite for an idea how to reuse sessions.
def argument_session(args):
2023-09-11 00:53:03 +08:00
"""
:return: Bolt session context manager based on program arguments
2023-09-11 00:53:03 +08:00
"""
return bolt_session("bolt://" + args.endpoint, (args.username, args.password), args.use_ssl)
def argument_driver(args):
return GraphDatabase.driver(
2023-09-11 00:53:03 +08:00
"bolt://" + args.endpoint,
auth=(args.username, args.password),
encrypted=args.use_ssl,
trust=TRUST_ALL_CERTIFICATES,
)
# This class is used to create and cache sessions. Session is cached by args
# used to create it and process' pid in which it was created. This makes it
# easy to reuse session with python multiprocessing primitives like pmap.
class SessionCache:
cache = {}
@staticmethod
def argument_session(args):
key = tuple(vars(args).items()) + (os.getpid(),)
if key in SessionCache.cache:
return SessionCache.cache[key][1]
2023-09-11 00:53:03 +08:00
driver = argument_driver(args) # |
session = driver.session() # V
SessionCache.cache[key] = (driver, session)
return session
@staticmethod
def cleanup():
for _, (driver, session) in SessionCache.cache.items():
session.close()
driver.close()
def periodically_execute(callable, args, interval, daemon=True):
"""
Periodically calls the given callable.
Args:
callable - the callable to call
args - arguments to pass to callable
interval - time (in seconds) between two calls
deamon - if the execution thread should be a daemon
"""
2023-09-11 00:53:03 +08:00
def periodic_call():
while True:
2023-09-11 00:53:03 +08:00
time.sleep(interval)
callable()
Thread(target=periodic_call, args=args, daemon=daemon).start()