diff --git a/tests/mgbench/benchmark.py b/tests/mgbench/benchmark.py index 5ce715571..15fb67d5f 100755 --- a/tests/mgbench/benchmark.py +++ b/tests/mgbench/benchmark.py @@ -26,6 +26,9 @@ import log import helpers import runners +WITH_FINE_GRAINED_AUTHORIZATION = "with_fine_grained_authorization" +WITHOUT_FINE_GRAINED_AUTHORIZATION = "without_fine_grained_authorization" + def get_queries(gen, count): # Make the generator deterministic. @@ -37,8 +40,7 @@ def get_queries(gen, count): return ret -def match_patterns(dataset, variant, group, test, is_default_variant, - patterns): +def match_patterns(dataset, variant, group, test, is_default_variant, patterns): for pattern in patterns: verdict = [fnmatch.fnmatchcase(dataset, pattern[0])] if pattern[1] != "": @@ -58,7 +60,7 @@ def filter_benchmarks(generators, patterns): pattern = patterns[i].split("/") if len(pattern) > 4 or len(pattern) == 0: raise Exception("Invalid benchmark description '" + pattern + "'!") - pattern.extend(["", "*", "*"][len(pattern) - 1:]) + pattern.extend(["", "*", "*"][len(pattern) - 1 :]) patterns[i] = pattern filtered = [] for dataset in sorted(generators.keys()): @@ -68,8 +70,7 @@ def filter_benchmarks(generators, patterns): current = collections.defaultdict(list) for group in tests: for test_name, test_func in tests[group]: - if match_patterns(dataset, variant, group, test_name, - is_default_variant, patterns): + if match_patterns(dataset, variant, group, test_name, is_default_variant, patterns): current[group].append((test_name, test_func)) if len(current) > 0: filtered.append((generator(variant), dict(current))) @@ -79,43 +80,72 @@ def filter_benchmarks(generators, patterns): # Parse options. parser = argparse.ArgumentParser( description="Memgraph benchmark executor.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("benchmarks", nargs="*", default="", - help="descriptions of benchmarks that should be run; " - "multiple descriptions can be specified to run multiple " - "benchmarks; the description is specified as " - "dataset/variant/group/test; Unix shell-style wildcards " - "can be used in the descriptions; variant, group and test " - "are optional and they can be left out; the default " - "variant is '' which selects the default dataset variant; " - "the default group is '*' which selects all groups; the " - "default test is '*' which selects all tests") -parser.add_argument("--memgraph-binary", - default=helpers.get_binary_path("memgraph"), - help="Memgraph binary used for benchmarking") -parser.add_argument("--client-binary", - default=helpers.get_binary_path("tests/mgbench/client"), - help="client binary used for benchmarking") -parser.add_argument("--num-workers-for-import", type=int, - default=multiprocessing.cpu_count() // 2, - help="number of workers used to import the dataset") -parser.add_argument("--num-workers-for-benchmark", type=int, - default=1, - help="number of workers used to execute the benchmark") -parser.add_argument("--single-threaded-runtime-sec", type=int, - default=10, - help="single threaded duration of each test") -parser.add_argument("--no-load-query-counts", action="store_true", - help="disable loading of cached query counts") -parser.add_argument("--no-save-query-counts", action="store_true", - help="disable storing of cached query counts") -parser.add_argument("--export-results", default="", - help="file path into which results should be exported") -parser.add_argument("--temporary-directory", default="/tmp", - help="directory path where temporary data should " - "be stored") -parser.add_argument("--no-properties-on-edges", action="store_true", - help="disable properties on edges") + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + "benchmarks", + nargs="*", + default="", + help="descriptions of benchmarks that should be run; " + "multiple descriptions can be specified to run multiple " + "benchmarks; the description is specified as " + "dataset/variant/group/test; Unix shell-style wildcards " + "can be used in the descriptions; variant, group and test " + "are optional and they can be left out; the default " + "variant is '' which selects the default dataset variant; " + "the default group is '*' which selects all groups; the " + "default test is '*' which selects all tests", +) +parser.add_argument( + "--memgraph-binary", + default=helpers.get_binary_path("memgraph"), + help="Memgraph binary used for benchmarking", +) +parser.add_argument( + "--client-binary", + default=helpers.get_binary_path("tests/mgbench/client"), + help="client binary used for benchmarking", +) +parser.add_argument( + "--num-workers-for-import", + type=int, + default=multiprocessing.cpu_count() // 2, + help="number of workers used to import the dataset", +) +parser.add_argument( + "--num-workers-for-benchmark", + type=int, + default=1, + help="number of workers used to execute the benchmark", +) +parser.add_argument( + "--single-threaded-runtime-sec", + type=int, + default=10, + help="single threaded duration of each test", +) +parser.add_argument( + "--no-load-query-counts", + action="store_true", + help="disable loading of cached query counts", +) +parser.add_argument( + "--no-save-query-counts", + action="store_true", + help="disable storing of cached query counts", +) +parser.add_argument( + "--export-results", + default="", + help="file path into which results should be exported", +) +parser.add_argument( + "--temporary-directory", + default="/tmp", + help="directory path where temporary data should " "be stored", +) +parser.add_argument("--no-properties-on-edges", action="store_true", help="disable properties on edges") +parser.add_argument("--bolt-port", default=7687, help="memgraph bolt port") args = parser.parse_args() # Detect available datasets. @@ -124,8 +154,7 @@ for key in dir(datasets): if key.startswith("_"): continue dataset = getattr(datasets, key) - if not inspect.isclass(dataset) or dataset == datasets.Dataset or \ - not issubclass(dataset, datasets.Dataset): + if not inspect.isclass(dataset) or dataset == datasets.Dataset or not issubclass(dataset, datasets.Dataset): continue tests = collections.defaultdict(list) for funcname in dir(dataset): @@ -135,8 +164,9 @@ for key in dir(datasets): tests[group].append((test, funcname)) generators[dataset.NAME] = (dataset, dict(tests)) if dataset.PROPERTIES_ON_EDGES and args.no_properties_on_edges: - raise Exception("The \"{}\" dataset requires properties on edges, " - "but you have disabled them!".format(dataset.NAME)) + raise Exception( + 'The "{}" dataset requires properties on edges, ' "but you have disabled them!".format(dataset.NAME) + ) # List datasets if there is no specified dataset. if len(args.benchmarks) == 0: @@ -144,8 +174,11 @@ if len(args.benchmarks) == 0: for name in sorted(generators.keys()): print("Dataset:", name) dataset, tests = generators[name] - print(" Variants:", ", ".join(dataset.VARIANTS), - "(default: " + dataset.DEFAULT_VARIANT + ")") + print( + " Variants:", + ", ".join(dataset.VARIANTS), + "(default: " + dataset.DEFAULT_VARIANT + ")", + ) for group in sorted(tests.keys()): print(" Group:", group) for test_name, test_func in tests[group]: @@ -162,34 +195,45 @@ results = helpers.RecursiveDict() # Filter out the generators. benchmarks = filter_benchmarks(generators, args.benchmarks) - # Run all specified benchmarks. for dataset, tests in benchmarks: - log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(), - "dataset") - dataset.prepare(cache.cache_directory("datasets", dataset.NAME, - dataset.get_variant())) + log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(), "dataset") + dataset.prepare(cache.cache_directory("datasets", dataset.NAME, dataset.get_variant())) # Prepare runners and import the dataset. - memgraph = runners.Memgraph(args.memgraph_binary, args.temporary_directory, - not args.no_properties_on_edges) - client = runners.Client(args.client_binary, args.temporary_directory) + memgraph = runners.Memgraph( + args.memgraph_binary, + args.temporary_directory, + not args.no_properties_on_edges, + args.bolt_port, + ) + client = runners.Client(args.client_binary, args.temporary_directory, args.bolt_port) memgraph.start_preparation() - ret = client.execute(file_path=dataset.get_file(), - num_workers=args.num_workers_for_import) + ret = client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import) usage = memgraph.stop() # Display import statistics. print() for row in ret: - print("Executed", row["count"], "queries in", row["duration"], - "seconds using", row["num_workers"], - "workers with a total throughput of", row["throughput"], - "queries/second.") + print( + "Executed", + row["count"], + "queries in", + row["duration"], + "seconds using", + row["num_workers"], + "workers with a total throughput of", + row["throughput"], + "queries/second.", + ) print() - print("The database used", usage["cpu"], - "seconds of CPU time and peaked at", - usage["memory"] / 1024 / 1024, "MiB of RAM.") + print( + "The database used", + usage["cpu"], + "seconds of CPU time and peaked at", + usage["memory"] / 1024 / 1024, + "MiB of RAM.", + ) # Save import results. import_key = [dataset.NAME, dataset.get_variant(), "__import__"] @@ -198,87 +242,128 @@ for dataset, tests in benchmarks: # TODO: cache import data # Run all benchmarks in all available groups. - for group in sorted(tests.keys()): - for test, funcname in tests[group]: - log.info("Running test:", "{}/{}".format(group, test)) - func = getattr(dataset, funcname) - # Get number of queries to execute. - # TODO: implement minimum number of queries, `max(10, num_workers)` - config_key = [dataset.NAME, dataset.get_variant(), group, test] - cached_count = config.get_value(*config_key) - if cached_count is None: - print("Determining the number of queries necessary for", - args.single_threaded_runtime_sec, - "seconds of single-threaded runtime...") - # First run to prime the query caches. + for with_fine_grained_authorization in [False, True]: + if with_fine_grained_authorization: + memgraph.start_preparation() + client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import) + client.execute( + queries=[ + ("CREATE USER user IDENTIFIED BY 'test';", {}), + ("GRANT ALL PRIVILEGES TO user;", {}), + ("GRANT CREATE_DELETE ON EDGE_TYPES * TO user;", {}), + ("GRANT CREATE_DELETE ON LABELS * TO user;", {}), + ] + ) + client = runners.Client( + args.client_binary, + args.temporary_directory, + args.bolt_port, + username="user", + password="test", + ) + memgraph.stop() + + test_type = ( + WITH_FINE_GRAINED_AUTHORIZATION if with_fine_grained_authorization else WITHOUT_FINE_GRAINED_AUTHORIZATION + ) + + for group in sorted(tests.keys()): + for test, funcname in tests[group]: + log.info("Running test:", "{}/{}/{}".format(group, test, test_type)) + func = getattr(dataset, funcname) + + # Get number of queries to execute. + # TODO: implement minimum number of queries, `max(10, num_workers)` + config_key = [dataset.NAME, dataset.get_variant(), group, test, test_type] + cached_count = config.get_value(*config_key) + if cached_count is None: + print( + "Determining the number of queries necessary for", + args.single_threaded_runtime_sec, + "seconds of single-threaded runtime...", + ) + # First run to prime the query caches. + memgraph.start_benchmark() + client.execute(queries=get_queries(func, 1), num_workers=1) + # Get a sense of the runtime. + count = 1 + while True: + ret = client.execute(queries=get_queries(func, count), num_workers=1) + duration = ret[0]["duration"] + should_execute = int(args.single_threaded_runtime_sec / (duration / count)) + print( + "executed_queries={}, total_duration={}, " + "query_duration={}, estimated_count={}".format( + count, duration, duration / count, should_execute + ) + ) + # We don't have to execute the next iteration when + # `should_execute` becomes the same order of magnitude as + # `count * 10`. + if should_execute / (count * 10) < 10: + count = should_execute + break + else: + count = count * 10 + memgraph.stop() + config.set_value( + *config_key, + value={ + "count": count, + "duration": args.single_threaded_runtime_sec, + }, + ) + else: + print( + "Using cached query count of", + cached_count["count"], + "queries for", + cached_count["duration"], + "seconds of single-threaded runtime.", + ) + count = int(cached_count["count"] * args.single_threaded_runtime_sec / cached_count["duration"]) + + # Benchmark run. + print("Sample query:", get_queries(func, 1)[0][0]) + print( + "Executing benchmark with", + count, + "queries that should " "yield a single-threaded runtime of", + args.single_threaded_runtime_sec, + "seconds.", + ) + print( + "Queries are executed using", + args.num_workers_for_benchmark, + "concurrent clients.", + ) memgraph.start_benchmark() - client.execute(queries=get_queries(func, 1), num_workers=1) - # Get a sense of the runtime. - count = 1 - while True: - ret = client.execute(queries=get_queries(func, count), - num_workers=1) - duration = ret[0]["duration"] - should_execute = int(args.single_threaded_runtime_sec / - (duration / count)) - print("executed_queries={}, total_duration={}, " - "query_duration={}, estimated_count={}".format( - count, duration, duration / count, - should_execute)) - # We don't have to execute the next iteration when - # `should_execute` becomes the same order of magnitude as - # `count * 10`. - if should_execute / (count * 10) < 10: - count = should_execute - break - else: - count = count * 10 - memgraph.stop() - config.set_value(*config_key, value={ - "count": count, - "duration": args.single_threaded_runtime_sec}) - else: - print("Using cached query count of", cached_count["count"], - "queries for", cached_count["duration"], - "seconds of single-threaded runtime.") - count = int(cached_count["count"] * - args.single_threaded_runtime_sec / - cached_count["duration"]) + ret = client.execute( + queries=get_queries(func, count), + num_workers=args.num_workers_for_benchmark, + )[0] + usage = memgraph.stop() + ret["database"] = usage - # Benchmark run. - print("Sample query:", get_queries(func, 1)[0][0]) - print("Executing benchmark with", count, "queries that should " - "yield a single-threaded runtime of", - args.single_threaded_runtime_sec, "seconds.") - print("Queries are executed using", args.num_workers_for_benchmark, - "concurrent clients.") - memgraph.start_benchmark() - ret = client.execute(queries=get_queries(func, count), - num_workers=args.num_workers_for_benchmark)[0] - usage = memgraph.stop() - ret["database"] = usage + # Output summary. + print() + print("Executed", ret["count"], "queries in", ret["duration"], "seconds.") + print("Queries have been retried", ret["retries"], "times.") + print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"])) + print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0)) + print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", "avg", "max")) + metadata = ret["metadata"] + for key in sorted(metadata.keys()): + print( + "{name:>30}: {minimum:>20.06f} {average:>20.06f} " + "{maximum:>20.06f}".format(name=key, **metadata[key]) + ) + log.success("Throughput: {:02f} QPS".format(ret["throughput"])) - # Output summary. - print() - print("Executed", ret["count"], "queries in", - ret["duration"], "seconds.") - print("Queries have been retried", ret["retries"], "times.") - print("Database used {:.3f} seconds of CPU time.".format( - usage["cpu"])) - print("Database peaked at {:.3f} MiB of memory.".format( - usage["memory"] / 1024.0 / 1024.0)) - print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", - "avg", "max")) - metadata = ret["metadata"] - for key in sorted(metadata.keys()): - print("{name:>30}: {minimum:>20.06f} {average:>20.06f} " - "{maximum:>20.06f}".format(name=key, **metadata[key])) - log.success("Throughput: {:02f} QPS".format(ret["throughput"])) - - # Save results. - results_key = [dataset.NAME, dataset.get_variant(), group, test] - results.set_value(*results_key, value=ret) + # Save results. + results_key = [dataset.NAME, dataset.get_variant(), group, test, test_type] + results.set_value(*results_key, value=ret) # Save configuration. if not args.no_save_query_counts: diff --git a/tests/mgbench/compare_results.py b/tests/mgbench/compare_results.py index 2179bb408..17703b9dd 100755 --- a/tests/mgbench/compare_results.py +++ b/tests/mgbench/compare_results.py @@ -77,7 +77,7 @@ def recursive_get(data, *args, value=None): return data -def compare_results(results_from, results_to, fields): +def compare_results(results_from, results_to, fields, ignored): ret = {} for dataset, variants in results_to.items(): for variant, groups in variants.items(): @@ -85,39 +85,44 @@ def compare_results(results_from, results_to, fields): if group == "__import__": continue for scenario, summary_to in scenarios.items(): - summary_from = recursive_get( - results_from, dataset, variant, group, scenario, - value={}) - if len(summary_from) > 0 and \ - summary_to["count"] != summary_from["count"] or \ - summary_to["num_workers"] != \ - summary_from["num_workers"]: + if scenario in ignored: + continue + + summary_from = recursive_get(results_from, dataset, variant, group, scenario, value={}) + if ( + len(summary_from) > 0 + and summary_to["count"] != summary_from["count"] + or summary_to["num_workers"] != summary_from["num_workers"] + ): raise Exception("Incompatible results!") - testcode = "/".join([dataset, variant, group, scenario, - "{:02d}".format( - summary_to["num_workers"])]) + testcode = "/".join( + [ + dataset, + variant, + group, + scenario, + "{:02d}".format(summary_to["num_workers"]), + ] + ) row = {} performance_changed = False for field in fields: key = field["name"] if key in summary_to: - row[key] = compute_diff( - summary_from.get(key, None), - summary_to[key]) + row[key] = compute_diff(summary_from.get(key, None), summary_to[key]) elif key in summary_to["database"]: row[key] = compute_diff( - recursive_get(summary_from, "database", key, - value=None), - summary_to["database"][key]) + recursive_get(summary_from, "database", key, value=None), + summary_to["database"][key], + ) else: row[key] = compute_diff( - recursive_get(summary_from, "metadata", key, - "average", value=None), - summary_to["metadata"][key]["average"]) - if "diff" not in row[key] or \ - ("diff_treshold" in field and - abs(row[key]["diff"]) >= - field["diff_treshold"]): + recursive_get(summary_from, "metadata", key, "average", value=None), + summary_to["metadata"][key]["average"], + ) + if "diff" not in row[key] or ( + "diff_treshold" in field and abs(row[key]["diff"]) >= field["diff_treshold"] + ): performance_changed = True if performance_changed: ret[testcode] = row @@ -130,8 +135,15 @@ def generate_remarkup(fields, data): ret += "\n" ret += " \n" ret += " \n" - ret += "\n".join(map(lambda x: " ".format( - x["name"].replace("_", " ").capitalize()), fields)) + "\n" + ret += ( + "\n".join( + map( + lambda x: " ".format(x["name"].replace("_", " ").capitalize()), + fields, + ) + ) + + "\n" + ) ret += " \n" for testcode in sorted(data.keys()): ret += " \n" @@ -147,12 +159,9 @@ def generate_remarkup(fields, data): else: color = "red" sign = "{{icon {} color={}}}".format(arrow, color) - ret += " \n".format( - value, field["unit"], diff, sign) + ret += ' \n'.format(color, value, field["unit"], diff) else: - ret += " \n".format( - value, field["unit"]) + ret += '\n'.format(value, field["unit"]) ret += " \n" ret += "
Testcode{}{}
{:.3f}{} //({:+.2%})// {}{:.3f}{} ({:+.2%}){:.3f}{} //(new)// " \ - "{{icon plus color=blue}}{:.3f}{} //(new)//
\n" else: @@ -161,22 +170,33 @@ def generate_remarkup(fields, data): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Compare results of multiple benchmark runs.") - parser.add_argument("--compare", action="append", nargs=2, - metavar=("from", "to"), - help="compare results between `from` and `to` files") + parser = argparse.ArgumentParser(description="Compare results of multiple benchmark runs.") + parser.add_argument( + "--compare", + action="append", + nargs=2, + metavar=("from", "to"), + help="compare results between `from` and `to` files", + ) parser.add_argument("--output", default="", help="output file name") + # file is read line by line, each representing one test name + parser.add_argument("--exclude_tests_file", help="file listing test names to be excluded") args = parser.parse_args() if args.compare is None or len(args.compare) == 0: raise Exception("You must specify at least one pair of files!") + if args.exclude_tests_file: + with open(args.exclude_tests_file, "r") as f: + ignored = [line.rstrip("\n") for line in f] + else: + ignored = [] + data = {} for file_from, file_to in args.compare: results_from = load_results(file_from) results_to = load_results(file_to) - data.update(compare_results(results_from, results_to, FIELDS)) + data.update(compare_results(results_from, results_to, FIELDS, ignored)) remarkup = generate_remarkup(FIELDS, data) if args.output: diff --git a/tests/mgbench/datasets.py b/tests/mgbench/datasets.py index c68fcac34..c0508e2d2 100644 --- a/tests/mgbench/datasets.py +++ b/tests/mgbench/datasets.py @@ -135,10 +135,7 @@ class Pokec(Dataset): return ("MATCH (n:User {id : $id}) RETURN n", {"id": self._get_random_vertex()}) def benchmark__arango__single_vertex_write(self): - return ( - "CREATE (n:UserTemp {id : $id}) RETURN n", - {"id": random.randint(1, self._num_vertices * 10)}, - ) + return ("CREATE (n:UserTemp {id : $id}) RETURN n", {"id": random.randint(1, self._num_vertices * 10)}) def benchmark__arango__single_edge_write(self): vertex_from, vertex_to = self._get_random_from_to() @@ -154,10 +151,7 @@ class Pokec(Dataset): return ("MATCH (n:User) WHERE n.age >= 18 RETURN n.age, COUNT(*)", {}) def benchmark__arango__expansion_1(self): - return ( - "MATCH (s:User {id: $id})-->(n:User) " "RETURN n.id", - {"id": self._get_random_vertex()}, - ) + return ("MATCH (s:User {id: $id})-->(n:User) " "RETURN n.id", {"id": self._get_random_vertex()}) def benchmark__arango__expansion_1_with_filter(self): return ( @@ -166,10 +160,7 @@ class Pokec(Dataset): ) def benchmark__arango__expansion_2(self): - return ( - "MATCH (s:User {id: $id})-->()-->(n:User) " "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}, - ) + return ("MATCH (s:User {id: $id})-->()-->(n:User) " "RETURN DISTINCT n.id", {"id": self._get_random_vertex()}) def benchmark__arango__expansion_2_with_filter(self): return ( @@ -202,10 +193,7 @@ class Pokec(Dataset): ) def benchmark__arango__neighbours_2(self): - return ( - "MATCH (s:User {id: $id})-[*1..2]->(n:User) " "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}, - ) + return ("MATCH (s:User {id: $id})-[*1..2]->(n:User) " "RETURN DISTINCT n.id", {"id": self._get_random_vertex()}) def benchmark__arango__neighbours_2_with_filter(self): return ( @@ -282,10 +270,7 @@ class Pokec(Dataset): return ("MATCH (n) RETURN min(n.age), max(n.age), avg(n.age)", {}) def benchmark__match__pattern_cycle(self): - return ( - "MATCH (n:User {id: $id})-[e1]->(m)-[e2]->(n) " "RETURN e1, m, e2", - {"id": self._get_random_vertex()}, - ) + return ("MATCH (n:User {id: $id})-[e1]->(m)-[e2]->(n) " "RETURN e1, m, e2", {"id": self._get_random_vertex()}) def benchmark__match__pattern_long(self): return ( @@ -294,19 +279,16 @@ class Pokec(Dataset): ) def benchmark__match__pattern_short(self): - return ( - "MATCH (n:User {id: $id})-[e]->(m) " "RETURN m LIMIT 1", - {"id": self._get_random_vertex()}, - ) + return ("MATCH (n:User {id: $id})-[e]->(m) " "RETURN m LIMIT 1", {"id": self._get_random_vertex()}) def benchmark__match__vertex_on_label_property(self): - return ( - "MATCH (n:User) WITH n WHERE n.id = $id RETURN n", - {"id": self._get_random_vertex()}, - ) + return ("MATCH (n:User) WITH n WHERE n.id = $id RETURN n", {"id": self._get_random_vertex()}) def benchmark__match__vertex_on_label_property_index(self): return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()}) def benchmark__match__vertex_on_property(self): return ("MATCH (n {id: $id}) RETURN n", {"id": self._get_random_vertex()}) + + def benchmark__update__vertex_on_property(self): + return ("MATCH (n {id: $id}) SET n.property = -1", {"id": self._get_random_vertex()}) diff --git a/tests/mgbench/runners.py b/tests/mgbench/runners.py index b6b727ed8..bc2142cba 100644 --- a/tests/mgbench/runners.py +++ b/tests/mgbench/runners.py @@ -51,11 +51,12 @@ def _get_usage(pid): class Memgraph: - def __init__(self, memgraph_binary, temporary_dir, properties_on_edges): + def __init__(self, memgraph_binary, temporary_dir, properties_on_edges, bolt_port): self._memgraph_binary = memgraph_binary self._directory = tempfile.TemporaryDirectory(dir=temporary_dir) self._properties_on_edges = properties_on_edges self._proc_mg = None + self._bolt_port = bolt_port atexit.register(self._cleanup) # Determine Memgraph version @@ -69,6 +70,7 @@ class Memgraph: def _get_args(self, **kwargs): data_directory = os.path.join(self._directory.name, "memgraph") + kwargs["bolt_port"] = self._bolt_port if self._memgraph_version >= (0, 50, 0): kwargs["data_directory"] = data_directory else: @@ -88,7 +90,7 @@ class Memgraph: if self._proc_mg.poll() is not None: self._proc_mg = None raise Exception("The database process died prematurely!") - wait_for_server(7687) + wait_for_server(self._bolt_port) ret = self._proc_mg.poll() assert ret is None, "The database process died prematurely " "({})!".format(ret) @@ -121,9 +123,14 @@ class Memgraph: class Client: - def __init__(self, client_binary, temporary_directory): + def __init__( + self, client_binary: str, temporary_directory: str, bolt_port: int, username: str = "", password: str = "" + ): self._client_binary = client_binary self._directory = tempfile.TemporaryDirectory(dir=temporary_directory) + self._username = username + self._password = password + self._bolt_port = bolt_port def _get_args(self, **kwargs): return _convert_args_to_flags(self._client_binary, **kwargs) @@ -144,8 +151,15 @@ class Client: json.dump(query, f) f.write("\n") - args = self._get_args(input=file_path, num_workers=num_workers, queries_json=queries_json) + args = self._get_args( + input=file_path, + num_workers=num_workers, + queries_json=queries_json, + username=self._username, + password=self._password, + port=self._bolt_port, + ) ret = subprocess.run(args, stdout=subprocess.PIPE, check=True) data = ret.stdout.decode("utf-8").strip().split("\n") - data = [x for x in data if not x.startswith("[")] + # data = [x for x in data if not x.startswith("[")] return list(map(json.loads, data))