Improve connection handling in tests/e2e (#1012)

This commit is contained in:
Andi 2023-06-26 22:43:34 +02:00 committed by GitHub
parent 3b781bf525
commit 0f1ca745e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 24 deletions

View File

@ -34,16 +34,16 @@ import atexit
import logging import logging
import os import os
import subprocess import subprocess
from argparse import ArgumentParser import sys
from pathlib import Path
import tempfile import tempfile
import time import time
import sys from argparse import ArgumentParser
from inspect import signature from inspect import signature
from pathlib import Path
import yaml import yaml
from memgraph import MemgraphInstanceRunner
from memgraph import extract_bolt_port from memgraph import MemgraphInstanceRunner, extract_bolt_port
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
@ -104,7 +104,7 @@ def is_port_in_use(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0 return s.connect_ex(("localhost", port)) == 0
def _start_instance(name, args, log_file, queries, use_ssl, procdir, data_directory): def _start_instance(name, args, log_file, setup_queries, use_ssl, procdir, data_directory):
assert ( assert (
name not in MEMGRAPH_INSTANCES.keys() name not in MEMGRAPH_INSTANCES.keys()
), "If this raises, you are trying to start an instance with the same name than one already running." ), "If this raises, you are trying to start an instance with the same name than one already running."
@ -120,10 +120,7 @@ def _start_instance(name, args, log_file, queries, use_ssl, procdir, data_direct
if len(procdir) != 0: if len(procdir) != 0:
binary_args.append("--query-modules-directory=" + procdir) binary_args.append("--query-modules-directory=" + procdir)
mg_instance.start(args=binary_args) mg_instance.start(args=binary_args, setup_queries=setup_queries)
for query in queries:
mg_instance.query(query)
assert mg_instance.is_running(), "An error occured after starting Memgraph instance: application stopped running." assert mg_instance.is_running(), "An error occured after starting Memgraph instance: application stopped running."

View File

@ -62,19 +62,47 @@ class MemgraphInstanceRunner:
self.binary_path = binary_path self.binary_path = binary_path
self.args = None self.args = None
self.proc_mg = None self.proc_mg = None
self.conn = None
self.ssl = use_ssl self.ssl = use_ssl
def query(self, query): def execute_setup_queries(self, setup_queries):
cursor = self.conn.cursor() if setup_queries is None:
cursor.execute(query) return
return cursor.fetchall() # An assumption being database instance is fresh, no need for the auth.
conn = mgclient.connect(host=self.host, port=self.bolt_port, sslmode=self.ssl)
conn.autocommit = True
cursor = conn.cursor()
for query in setup_queries:
cursor.execute(query)
cursor.close()
conn.close()
def start(self, restart=False, args=[]): # NOTE: Both query and get_connection may esablish new connection -> auth
# details required -> username/password should be optional arguments.
def query(self, query, conn=None, username="", password=""):
new_conn = conn is None
if new_conn:
conn = self.get_connection(username, password)
cursor = conn.cursor()
cursor.execute(query)
data = cursor.fetchall()
cursor.close()
if new_conn:
conn.close()
return data
def get_connection(self, username="", password=""):
conn = mgclient.connect(
host=self.host, port=self.bolt_port, sslmode=self.ssl, username=username, password=password
)
conn.autocommit = True
return conn
def start(self, restart=False, args=None, setup_queries=None):
if not restart and self.is_running(): if not restart and self.is_running():
return return
self.stop() self.stop()
self.args = copy.deepcopy(args) if args is not None:
self.args = copy.deepcopy(args)
self.args = [replace_paths(arg) for arg in self.args] self.args = [replace_paths(arg) for arg in self.args]
args_mg = [ args_mg = [
self.binary_path, self.binary_path,
@ -86,8 +114,7 @@ class MemgraphInstanceRunner:
self.bolt_port = extract_bolt_port(args_mg) self.bolt_port = extract_bolt_port(args_mg)
self.proc_mg = subprocess.Popen(args_mg) self.proc_mg = subprocess.Popen(args_mg)
wait_for_server(self.bolt_port) wait_for_server(self.bolt_port)
self.conn = mgclient.connect(host=self.host, port=self.bolt_port, sslmode=self.ssl) self.execute_setup_queries(setup_queries)
self.conn.autocommit = True
assert self.is_running(), "The Memgraph process died!" assert self.is_running(), "The Memgraph process died!"
def is_running(self): def is_running(self):

View File

@ -16,9 +16,8 @@ import subprocess
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
import yaml
import interactive_mg_runner import interactive_mg_runner
import yaml
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
@ -49,6 +48,7 @@ def run(args):
if args.workload_name is not None and args.workload_name != workload_name: if args.workload_name is not None and args.workload_name != workload_name:
continue continue
log.info("%s STARTED.", workload_name) log.info("%s STARTED.", workload_name)
# Setup. # Setup.
@atexit.register @atexit.register
def cleanup(): def cleanup():
@ -66,10 +66,21 @@ def run(args):
# Validation. # Validation.
if "cluster" in workload: if "cluster" in workload:
for name, config in workload["cluster"].items(): for name, config in workload["cluster"].items():
for validation in config.get("validation_queries", []): mg_instance = interactive_mg_runner.MEMGRAPH_INSTANCES[name]
mg_instance = interactive_mg_runner.MEMGRAPH_INSTANCES[name] # Explicitely check if there are validation queries and skip if
data = mg_instance.query(validation["query"])[0][0] # nothing is to validate. If setup queries are dealing with
# users, any new connection requires auth details.
validation_queries = config.get("validation_queries", [])
if len(validation_queries) == 0:
continue
# NOTE: If the setup quries create users AND there are some
# validation queries, the connection here has to get the right
# username/password.
conn = mg_instance.get_connection()
for validation in validation_queries:
data = mg_instance.query(validation["query"], conn)[0][0]
assert data == validation["expected"] assert data == validation["expected"]
conn.close()
cleanup() cleanup()
log.info("%s PASSED.", workload_name) log.info("%s PASSED.", workload_name)