Add fine grained access control to mgbench (#522)

This commit is contained in:
niko4299 2022-09-15 21:33:15 +02:00 committed by GitHub
parent a0b8871b36
commit a3c2492672
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 313 additions and 212 deletions

View File

@ -26,6 +26,9 @@ import log
import helpers import helpers
import runners import runners
WITH_FINE_GRAINED_AUTHORIZATION = "with_fine_grained_authorization"
WITHOUT_FINE_GRAINED_AUTHORIZATION = "without_fine_grained_authorization"
def get_queries(gen, count): def get_queries(gen, count):
# Make the generator deterministic. # Make the generator deterministic.
@ -37,8 +40,7 @@ def get_queries(gen, count):
return ret return ret
def match_patterns(dataset, variant, group, test, is_default_variant, def match_patterns(dataset, variant, group, test, is_default_variant, patterns):
patterns):
for pattern in patterns: for pattern in patterns:
verdict = [fnmatch.fnmatchcase(dataset, pattern[0])] verdict = [fnmatch.fnmatchcase(dataset, pattern[0])]
if pattern[1] != "": if pattern[1] != "":
@ -58,7 +60,7 @@ def filter_benchmarks(generators, patterns):
pattern = patterns[i].split("/") pattern = patterns[i].split("/")
if len(pattern) > 4 or len(pattern) == 0: if len(pattern) > 4 or len(pattern) == 0:
raise Exception("Invalid benchmark description '" + pattern + "'!") raise Exception("Invalid benchmark description '" + pattern + "'!")
pattern.extend(["", "*", "*"][len(pattern) - 1:]) pattern.extend(["", "*", "*"][len(pattern) - 1 :])
patterns[i] = pattern patterns[i] = pattern
filtered = [] filtered = []
for dataset in sorted(generators.keys()): for dataset in sorted(generators.keys()):
@ -68,8 +70,7 @@ def filter_benchmarks(generators, patterns):
current = collections.defaultdict(list) current = collections.defaultdict(list)
for group in tests: for group in tests:
for test_name, test_func in tests[group]: for test_name, test_func in tests[group]:
if match_patterns(dataset, variant, group, test_name, if match_patterns(dataset, variant, group, test_name, is_default_variant, patterns):
is_default_variant, patterns):
current[group].append((test_name, test_func)) current[group].append((test_name, test_func))
if len(current) > 0: if len(current) > 0:
filtered.append((generator(variant), dict(current))) filtered.append((generator(variant), dict(current)))
@ -79,43 +80,72 @@ def filter_benchmarks(generators, patterns):
# Parse options. # Parse options.
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Memgraph benchmark executor.", description="Memgraph benchmark executor.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument("benchmarks", nargs="*", default="", )
help="descriptions of benchmarks that should be run; " parser.add_argument(
"multiple descriptions can be specified to run multiple " "benchmarks",
"benchmarks; the description is specified as " nargs="*",
"dataset/variant/group/test; Unix shell-style wildcards " default="",
"can be used in the descriptions; variant, group and test " help="descriptions of benchmarks that should be run; "
"are optional and they can be left out; the default " "multiple descriptions can be specified to run multiple "
"variant is '' which selects the default dataset variant; " "benchmarks; the description is specified as "
"the default group is '*' which selects all groups; the " "dataset/variant/group/test; Unix shell-style wildcards "
"default test is '*' which selects all tests") "can be used in the descriptions; variant, group and test "
parser.add_argument("--memgraph-binary", "are optional and they can be left out; the default "
default=helpers.get_binary_path("memgraph"), "variant is '' which selects the default dataset variant; "
help="Memgraph binary used for benchmarking") "the default group is '*' which selects all groups; the "
parser.add_argument("--client-binary", "default test is '*' which selects all tests",
default=helpers.get_binary_path("tests/mgbench/client"), )
help="client binary used for benchmarking") parser.add_argument(
parser.add_argument("--num-workers-for-import", type=int, "--memgraph-binary",
default=multiprocessing.cpu_count() // 2, default=helpers.get_binary_path("memgraph"),
help="number of workers used to import the dataset") help="Memgraph binary used for benchmarking",
parser.add_argument("--num-workers-for-benchmark", type=int, )
default=1, parser.add_argument(
help="number of workers used to execute the benchmark") "--client-binary",
parser.add_argument("--single-threaded-runtime-sec", type=int, default=helpers.get_binary_path("tests/mgbench/client"),
default=10, help="client binary used for benchmarking",
help="single threaded duration of each test") )
parser.add_argument("--no-load-query-counts", action="store_true", parser.add_argument(
help="disable loading of cached query counts") "--num-workers-for-import",
parser.add_argument("--no-save-query-counts", action="store_true", type=int,
help="disable storing of cached query counts") default=multiprocessing.cpu_count() // 2,
parser.add_argument("--export-results", default="", help="number of workers used to import the dataset",
help="file path into which results should be exported") )
parser.add_argument("--temporary-directory", default="/tmp", parser.add_argument(
help="directory path where temporary data should " "--num-workers-for-benchmark",
"be stored") type=int,
parser.add_argument("--no-properties-on-edges", action="store_true", default=1,
help="disable properties on edges") help="number of workers used to execute the benchmark",
)
parser.add_argument(
"--single-threaded-runtime-sec",
type=int,
default=10,
help="single threaded duration of each test",
)
parser.add_argument(
"--no-load-query-counts",
action="store_true",
help="disable loading of cached query counts",
)
parser.add_argument(
"--no-save-query-counts",
action="store_true",
help="disable storing of cached query counts",
)
parser.add_argument(
"--export-results",
default="",
help="file path into which results should be exported",
)
parser.add_argument(
"--temporary-directory",
default="/tmp",
help="directory path where temporary data should " "be stored",
)
parser.add_argument("--no-properties-on-edges", action="store_true", help="disable properties on edges")
parser.add_argument("--bolt-port", default=7687, help="memgraph bolt port")
args = parser.parse_args() args = parser.parse_args()
# Detect available datasets. # Detect available datasets.
@ -124,8 +154,7 @@ for key in dir(datasets):
if key.startswith("_"): if key.startswith("_"):
continue continue
dataset = getattr(datasets, key) dataset = getattr(datasets, key)
if not inspect.isclass(dataset) or dataset == datasets.Dataset or \ if not inspect.isclass(dataset) or dataset == datasets.Dataset or not issubclass(dataset, datasets.Dataset):
not issubclass(dataset, datasets.Dataset):
continue continue
tests = collections.defaultdict(list) tests = collections.defaultdict(list)
for funcname in dir(dataset): for funcname in dir(dataset):
@ -135,8 +164,9 @@ for key in dir(datasets):
tests[group].append((test, funcname)) tests[group].append((test, funcname))
generators[dataset.NAME] = (dataset, dict(tests)) generators[dataset.NAME] = (dataset, dict(tests))
if dataset.PROPERTIES_ON_EDGES and args.no_properties_on_edges: if dataset.PROPERTIES_ON_EDGES and args.no_properties_on_edges:
raise Exception("The \"{}\" dataset requires properties on edges, " raise Exception(
"but you have disabled them!".format(dataset.NAME)) 'The "{}" dataset requires properties on edges, ' "but you have disabled them!".format(dataset.NAME)
)
# List datasets if there is no specified dataset. # List datasets if there is no specified dataset.
if len(args.benchmarks) == 0: if len(args.benchmarks) == 0:
@ -144,8 +174,11 @@ if len(args.benchmarks) == 0:
for name in sorted(generators.keys()): for name in sorted(generators.keys()):
print("Dataset:", name) print("Dataset:", name)
dataset, tests = generators[name] dataset, tests = generators[name]
print(" Variants:", ", ".join(dataset.VARIANTS), print(
"(default: " + dataset.DEFAULT_VARIANT + ")") " Variants:",
", ".join(dataset.VARIANTS),
"(default: " + dataset.DEFAULT_VARIANT + ")",
)
for group in sorted(tests.keys()): for group in sorted(tests.keys()):
print(" Group:", group) print(" Group:", group)
for test_name, test_func in tests[group]: for test_name, test_func in tests[group]:
@ -162,34 +195,45 @@ results = helpers.RecursiveDict()
# Filter out the generators. # Filter out the generators.
benchmarks = filter_benchmarks(generators, args.benchmarks) benchmarks = filter_benchmarks(generators, args.benchmarks)
# Run all specified benchmarks. # Run all specified benchmarks.
for dataset, tests in benchmarks: for dataset, tests in benchmarks:
log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(), log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(), "dataset")
"dataset") dataset.prepare(cache.cache_directory("datasets", dataset.NAME, dataset.get_variant()))
dataset.prepare(cache.cache_directory("datasets", dataset.NAME,
dataset.get_variant()))
# Prepare runners and import the dataset. # Prepare runners and import the dataset.
memgraph = runners.Memgraph(args.memgraph_binary, args.temporary_directory, memgraph = runners.Memgraph(
not args.no_properties_on_edges) args.memgraph_binary,
client = runners.Client(args.client_binary, args.temporary_directory) args.temporary_directory,
not args.no_properties_on_edges,
args.bolt_port,
)
client = runners.Client(args.client_binary, args.temporary_directory, args.bolt_port)
memgraph.start_preparation() memgraph.start_preparation()
ret = client.execute(file_path=dataset.get_file(), ret = client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import)
num_workers=args.num_workers_for_import)
usage = memgraph.stop() usage = memgraph.stop()
# Display import statistics. # Display import statistics.
print() print()
for row in ret: for row in ret:
print("Executed", row["count"], "queries in", row["duration"], print(
"seconds using", row["num_workers"], "Executed",
"workers with a total throughput of", row["throughput"], row["count"],
"queries/second.") "queries in",
row["duration"],
"seconds using",
row["num_workers"],
"workers with a total throughput of",
row["throughput"],
"queries/second.",
)
print() print()
print("The database used", usage["cpu"], print(
"seconds of CPU time and peaked at", "The database used",
usage["memory"] / 1024 / 1024, "MiB of RAM.") usage["cpu"],
"seconds of CPU time and peaked at",
usage["memory"] / 1024 / 1024,
"MiB of RAM.",
)
# Save import results. # Save import results.
import_key = [dataset.NAME, dataset.get_variant(), "__import__"] import_key = [dataset.NAME, dataset.get_variant(), "__import__"]
@ -198,87 +242,128 @@ for dataset, tests in benchmarks:
# TODO: cache import data # TODO: cache import data
# Run all benchmarks in all available groups. # Run all benchmarks in all available groups.
for group in sorted(tests.keys()):
for test, funcname in tests[group]:
log.info("Running test:", "{}/{}".format(group, test))
func = getattr(dataset, funcname)
# Get number of queries to execute. for with_fine_grained_authorization in [False, True]:
# TODO: implement minimum number of queries, `max(10, num_workers)` if with_fine_grained_authorization:
config_key = [dataset.NAME, dataset.get_variant(), group, test] memgraph.start_preparation()
cached_count = config.get_value(*config_key) client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import)
if cached_count is None: client.execute(
print("Determining the number of queries necessary for", queries=[
args.single_threaded_runtime_sec, ("CREATE USER user IDENTIFIED BY 'test';", {}),
"seconds of single-threaded runtime...") ("GRANT ALL PRIVILEGES TO user;", {}),
# First run to prime the query caches. ("GRANT CREATE_DELETE ON EDGE_TYPES * TO user;", {}),
("GRANT CREATE_DELETE ON LABELS * TO user;", {}),
]
)
client = runners.Client(
args.client_binary,
args.temporary_directory,
args.bolt_port,
username="user",
password="test",
)
memgraph.stop()
test_type = (
WITH_FINE_GRAINED_AUTHORIZATION if with_fine_grained_authorization else WITHOUT_FINE_GRAINED_AUTHORIZATION
)
for group in sorted(tests.keys()):
for test, funcname in tests[group]:
log.info("Running test:", "{}/{}/{}".format(group, test, test_type))
func = getattr(dataset, funcname)
# Get number of queries to execute.
# TODO: implement minimum number of queries, `max(10, num_workers)`
config_key = [dataset.NAME, dataset.get_variant(), group, test, test_type]
cached_count = config.get_value(*config_key)
if cached_count is None:
print(
"Determining the number of queries necessary for",
args.single_threaded_runtime_sec,
"seconds of single-threaded runtime...",
)
# First run to prime the query caches.
memgraph.start_benchmark()
client.execute(queries=get_queries(func, 1), num_workers=1)
# Get a sense of the runtime.
count = 1
while True:
ret = client.execute(queries=get_queries(func, count), num_workers=1)
duration = ret[0]["duration"]
should_execute = int(args.single_threaded_runtime_sec / (duration / count))
print(
"executed_queries={}, total_duration={}, "
"query_duration={}, estimated_count={}".format(
count, duration, duration / count, should_execute
)
)
# We don't have to execute the next iteration when
# `should_execute` becomes the same order of magnitude as
# `count * 10`.
if should_execute / (count * 10) < 10:
count = should_execute
break
else:
count = count * 10
memgraph.stop()
config.set_value(
*config_key,
value={
"count": count,
"duration": args.single_threaded_runtime_sec,
},
)
else:
print(
"Using cached query count of",
cached_count["count"],
"queries for",
cached_count["duration"],
"seconds of single-threaded runtime.",
)
count = int(cached_count["count"] * args.single_threaded_runtime_sec / cached_count["duration"])
# Benchmark run.
print("Sample query:", get_queries(func, 1)[0][0])
print(
"Executing benchmark with",
count,
"queries that should " "yield a single-threaded runtime of",
args.single_threaded_runtime_sec,
"seconds.",
)
print(
"Queries are executed using",
args.num_workers_for_benchmark,
"concurrent clients.",
)
memgraph.start_benchmark() memgraph.start_benchmark()
client.execute(queries=get_queries(func, 1), num_workers=1) ret = client.execute(
# Get a sense of the runtime. queries=get_queries(func, count),
count = 1 num_workers=args.num_workers_for_benchmark,
while True: )[0]
ret = client.execute(queries=get_queries(func, count), usage = memgraph.stop()
num_workers=1) ret["database"] = usage
duration = ret[0]["duration"]
should_execute = int(args.single_threaded_runtime_sec /
(duration / count))
print("executed_queries={}, total_duration={}, "
"query_duration={}, estimated_count={}".format(
count, duration, duration / count,
should_execute))
# We don't have to execute the next iteration when
# `should_execute` becomes the same order of magnitude as
# `count * 10`.
if should_execute / (count * 10) < 10:
count = should_execute
break
else:
count = count * 10
memgraph.stop()
config.set_value(*config_key, value={
"count": count,
"duration": args.single_threaded_runtime_sec})
else:
print("Using cached query count of", cached_count["count"],
"queries for", cached_count["duration"],
"seconds of single-threaded runtime.")
count = int(cached_count["count"] *
args.single_threaded_runtime_sec /
cached_count["duration"])
# Benchmark run. # Output summary.
print("Sample query:", get_queries(func, 1)[0][0]) print()
print("Executing benchmark with", count, "queries that should " print("Executed", ret["count"], "queries in", ret["duration"], "seconds.")
"yield a single-threaded runtime of", print("Queries have been retried", ret["retries"], "times.")
args.single_threaded_runtime_sec, "seconds.") print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"]))
print("Queries are executed using", args.num_workers_for_benchmark, print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0))
"concurrent clients.") print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", "avg", "max"))
memgraph.start_benchmark() metadata = ret["metadata"]
ret = client.execute(queries=get_queries(func, count), for key in sorted(metadata.keys()):
num_workers=args.num_workers_for_benchmark)[0] print(
usage = memgraph.stop() "{name:>30}: {minimum:>20.06f} {average:>20.06f} "
ret["database"] = usage "{maximum:>20.06f}".format(name=key, **metadata[key])
)
log.success("Throughput: {:02f} QPS".format(ret["throughput"]))
# Output summary. # Save results.
print() results_key = [dataset.NAME, dataset.get_variant(), group, test, test_type]
print("Executed", ret["count"], "queries in", results.set_value(*results_key, value=ret)
ret["duration"], "seconds.")
print("Queries have been retried", ret["retries"], "times.")
print("Database used {:.3f} seconds of CPU time.".format(
usage["cpu"]))
print("Database peaked at {:.3f} MiB of memory.".format(
usage["memory"] / 1024.0 / 1024.0))
print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min",
"avg", "max"))
metadata = ret["metadata"]
for key in sorted(metadata.keys()):
print("{name:>30}: {minimum:>20.06f} {average:>20.06f} "
"{maximum:>20.06f}".format(name=key, **metadata[key]))
log.success("Throughput: {:02f} QPS".format(ret["throughput"]))
# Save results.
results_key = [dataset.NAME, dataset.get_variant(), group, test]
results.set_value(*results_key, value=ret)
# Save configuration. # Save configuration.
if not args.no_save_query_counts: if not args.no_save_query_counts:

View File

@ -77,7 +77,7 @@ def recursive_get(data, *args, value=None):
return data return data
def compare_results(results_from, results_to, fields): def compare_results(results_from, results_to, fields, ignored):
ret = {} ret = {}
for dataset, variants in results_to.items(): for dataset, variants in results_to.items():
for variant, groups in variants.items(): for variant, groups in variants.items():
@ -85,39 +85,44 @@ def compare_results(results_from, results_to, fields):
if group == "__import__": if group == "__import__":
continue continue
for scenario, summary_to in scenarios.items(): for scenario, summary_to in scenarios.items():
summary_from = recursive_get( if scenario in ignored:
results_from, dataset, variant, group, scenario, continue
value={})
if len(summary_from) > 0 and \ summary_from = recursive_get(results_from, dataset, variant, group, scenario, value={})
summary_to["count"] != summary_from["count"] or \ if (
summary_to["num_workers"] != \ len(summary_from) > 0
summary_from["num_workers"]: and summary_to["count"] != summary_from["count"]
or summary_to["num_workers"] != summary_from["num_workers"]
):
raise Exception("Incompatible results!") raise Exception("Incompatible results!")
testcode = "/".join([dataset, variant, group, scenario, testcode = "/".join(
"{:02d}".format( [
summary_to["num_workers"])]) dataset,
variant,
group,
scenario,
"{:02d}".format(summary_to["num_workers"]),
]
)
row = {} row = {}
performance_changed = False performance_changed = False
for field in fields: for field in fields:
key = field["name"] key = field["name"]
if key in summary_to: if key in summary_to:
row[key] = compute_diff( row[key] = compute_diff(summary_from.get(key, None), summary_to[key])
summary_from.get(key, None),
summary_to[key])
elif key in summary_to["database"]: elif key in summary_to["database"]:
row[key] = compute_diff( row[key] = compute_diff(
recursive_get(summary_from, "database", key, recursive_get(summary_from, "database", key, value=None),
value=None), summary_to["database"][key],
summary_to["database"][key]) )
else: else:
row[key] = compute_diff( row[key] = compute_diff(
recursive_get(summary_from, "metadata", key, recursive_get(summary_from, "metadata", key, "average", value=None),
"average", value=None), summary_to["metadata"][key]["average"],
summary_to["metadata"][key]["average"]) )
if "diff" not in row[key] or \ if "diff" not in row[key] or (
("diff_treshold" in field and "diff_treshold" in field and abs(row[key]["diff"]) >= field["diff_treshold"]
abs(row[key]["diff"]) >= ):
field["diff_treshold"]):
performance_changed = True performance_changed = True
if performance_changed: if performance_changed:
ret[testcode] = row ret[testcode] = row
@ -130,8 +135,15 @@ def generate_remarkup(fields, data):
ret += "<table>\n" ret += "<table>\n"
ret += " <tr>\n" ret += " <tr>\n"
ret += " <th>Testcode</th>\n" ret += " <th>Testcode</th>\n"
ret += "\n".join(map(lambda x: " <th>{}</th>".format( ret += (
x["name"].replace("_", " ").capitalize()), fields)) + "\n" "\n".join(
map(
lambda x: " <th>{}</th>".format(x["name"].replace("_", " ").capitalize()),
fields,
)
)
+ "\n"
)
ret += " </tr>\n" ret += " </tr>\n"
for testcode in sorted(data.keys()): for testcode in sorted(data.keys()):
ret += " <tr>\n" ret += " <tr>\n"
@ -147,12 +159,9 @@ def generate_remarkup(fields, data):
else: else:
color = "red" color = "red"
sign = "{{icon {} color={}}}".format(arrow, color) sign = "{{icon {} color={}}}".format(arrow, color)
ret += " <td>{:.3f}{} //({:+.2%})// {}</td>\n".format( ret += ' <td bgcolor="{}">{:.3f}{} ({:+.2%})</td>\n'.format(color, value, field["unit"], diff)
value, field["unit"], diff, sign)
else: else:
ret += " <td>{:.3f}{} //(new)// " \ ret += '<td bgcolor="blue">{:.3f}{} //(new)// </td>\n'.format(value, field["unit"])
"{{icon plus color=blue}}</td>\n".format(
value, field["unit"])
ret += " </tr>\n" ret += " </tr>\n"
ret += "</table>\n" ret += "</table>\n"
else: else:
@ -161,22 +170,33 @@ def generate_remarkup(fields, data):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Compare results of multiple benchmark runs.")
description="Compare results of multiple benchmark runs.") parser.add_argument(
parser.add_argument("--compare", action="append", nargs=2, "--compare",
metavar=("from", "to"), action="append",
help="compare results between `from` and `to` files") nargs=2,
metavar=("from", "to"),
help="compare results between `from` and `to` files",
)
parser.add_argument("--output", default="", help="output file name") parser.add_argument("--output", default="", help="output file name")
# file is read line by line, each representing one test name
parser.add_argument("--exclude_tests_file", help="file listing test names to be excluded")
args = parser.parse_args() args = parser.parse_args()
if args.compare is None or len(args.compare) == 0: if args.compare is None or len(args.compare) == 0:
raise Exception("You must specify at least one pair of files!") raise Exception("You must specify at least one pair of files!")
if args.exclude_tests_file:
with open(args.exclude_tests_file, "r") as f:
ignored = [line.rstrip("\n") for line in f]
else:
ignored = []
data = {} data = {}
for file_from, file_to in args.compare: for file_from, file_to in args.compare:
results_from = load_results(file_from) results_from = load_results(file_from)
results_to = load_results(file_to) results_to = load_results(file_to)
data.update(compare_results(results_from, results_to, FIELDS)) data.update(compare_results(results_from, results_to, FIELDS, ignored))
remarkup = generate_remarkup(FIELDS, data) remarkup = generate_remarkup(FIELDS, data)
if args.output: if args.output:

View File

@ -135,10 +135,7 @@ class Pokec(Dataset):
return ("MATCH (n:User {id : $id}) RETURN n", {"id": self._get_random_vertex()}) return ("MATCH (n:User {id : $id}) RETURN n", {"id": self._get_random_vertex()})
def benchmark__arango__single_vertex_write(self): def benchmark__arango__single_vertex_write(self):
return ( return ("CREATE (n:UserTemp {id : $id}) RETURN n", {"id": random.randint(1, self._num_vertices * 10)})
"CREATE (n:UserTemp {id : $id}) RETURN n",
{"id": random.randint(1, self._num_vertices * 10)},
)
def benchmark__arango__single_edge_write(self): def benchmark__arango__single_edge_write(self):
vertex_from, vertex_to = self._get_random_from_to() vertex_from, vertex_to = self._get_random_from_to()
@ -154,10 +151,7 @@ class Pokec(Dataset):
return ("MATCH (n:User) WHERE n.age >= 18 RETURN n.age, COUNT(*)", {}) return ("MATCH (n:User) WHERE n.age >= 18 RETURN n.age, COUNT(*)", {})
def benchmark__arango__expansion_1(self): def benchmark__arango__expansion_1(self):
return ( return ("MATCH (s:User {id: $id})-->(n:User) " "RETURN n.id", {"id": self._get_random_vertex()})
"MATCH (s:User {id: $id})-->(n:User) " "RETURN n.id",
{"id": self._get_random_vertex()},
)
def benchmark__arango__expansion_1_with_filter(self): def benchmark__arango__expansion_1_with_filter(self):
return ( return (
@ -166,10 +160,7 @@ class Pokec(Dataset):
) )
def benchmark__arango__expansion_2(self): def benchmark__arango__expansion_2(self):
return ( return ("MATCH (s:User {id: $id})-->()-->(n:User) " "RETURN DISTINCT n.id", {"id": self._get_random_vertex()})
"MATCH (s:User {id: $id})-->()-->(n:User) " "RETURN DISTINCT n.id",
{"id": self._get_random_vertex()},
)
def benchmark__arango__expansion_2_with_filter(self): def benchmark__arango__expansion_2_with_filter(self):
return ( return (
@ -202,10 +193,7 @@ class Pokec(Dataset):
) )
def benchmark__arango__neighbours_2(self): def benchmark__arango__neighbours_2(self):
return ( return ("MATCH (s:User {id: $id})-[*1..2]->(n:User) " "RETURN DISTINCT n.id", {"id": self._get_random_vertex()})
"MATCH (s:User {id: $id})-[*1..2]->(n:User) " "RETURN DISTINCT n.id",
{"id": self._get_random_vertex()},
)
def benchmark__arango__neighbours_2_with_filter(self): def benchmark__arango__neighbours_2_with_filter(self):
return ( return (
@ -282,10 +270,7 @@ class Pokec(Dataset):
return ("MATCH (n) RETURN min(n.age), max(n.age), avg(n.age)", {}) return ("MATCH (n) RETURN min(n.age), max(n.age), avg(n.age)", {})
def benchmark__match__pattern_cycle(self): def benchmark__match__pattern_cycle(self):
return ( return ("MATCH (n:User {id: $id})-[e1]->(m)-[e2]->(n) " "RETURN e1, m, e2", {"id": self._get_random_vertex()})
"MATCH (n:User {id: $id})-[e1]->(m)-[e2]->(n) " "RETURN e1, m, e2",
{"id": self._get_random_vertex()},
)
def benchmark__match__pattern_long(self): def benchmark__match__pattern_long(self):
return ( return (
@ -294,19 +279,16 @@ class Pokec(Dataset):
) )
def benchmark__match__pattern_short(self): def benchmark__match__pattern_short(self):
return ( return ("MATCH (n:User {id: $id})-[e]->(m) " "RETURN m LIMIT 1", {"id": self._get_random_vertex()})
"MATCH (n:User {id: $id})-[e]->(m) " "RETURN m LIMIT 1",
{"id": self._get_random_vertex()},
)
def benchmark__match__vertex_on_label_property(self): def benchmark__match__vertex_on_label_property(self):
return ( return ("MATCH (n:User) WITH n WHERE n.id = $id RETURN n", {"id": self._get_random_vertex()})
"MATCH (n:User) WITH n WHERE n.id = $id RETURN n",
{"id": self._get_random_vertex()},
)
def benchmark__match__vertex_on_label_property_index(self): def benchmark__match__vertex_on_label_property_index(self):
return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()}) return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()})
def benchmark__match__vertex_on_property(self): def benchmark__match__vertex_on_property(self):
return ("MATCH (n {id: $id}) RETURN n", {"id": self._get_random_vertex()}) return ("MATCH (n {id: $id}) RETURN n", {"id": self._get_random_vertex()})
def benchmark__update__vertex_on_property(self):
return ("MATCH (n {id: $id}) SET n.property = -1", {"id": self._get_random_vertex()})

View File

@ -51,11 +51,12 @@ def _get_usage(pid):
class Memgraph: class Memgraph:
def __init__(self, memgraph_binary, temporary_dir, properties_on_edges): def __init__(self, memgraph_binary, temporary_dir, properties_on_edges, bolt_port):
self._memgraph_binary = memgraph_binary self._memgraph_binary = memgraph_binary
self._directory = tempfile.TemporaryDirectory(dir=temporary_dir) self._directory = tempfile.TemporaryDirectory(dir=temporary_dir)
self._properties_on_edges = properties_on_edges self._properties_on_edges = properties_on_edges
self._proc_mg = None self._proc_mg = None
self._bolt_port = bolt_port
atexit.register(self._cleanup) atexit.register(self._cleanup)
# Determine Memgraph version # Determine Memgraph version
@ -69,6 +70,7 @@ class Memgraph:
def _get_args(self, **kwargs): def _get_args(self, **kwargs):
data_directory = os.path.join(self._directory.name, "memgraph") data_directory = os.path.join(self._directory.name, "memgraph")
kwargs["bolt_port"] = self._bolt_port
if self._memgraph_version >= (0, 50, 0): if self._memgraph_version >= (0, 50, 0):
kwargs["data_directory"] = data_directory kwargs["data_directory"] = data_directory
else: else:
@ -88,7 +90,7 @@ class Memgraph:
if self._proc_mg.poll() is not None: if self._proc_mg.poll() is not None:
self._proc_mg = None self._proc_mg = None
raise Exception("The database process died prematurely!") raise Exception("The database process died prematurely!")
wait_for_server(7687) wait_for_server(self._bolt_port)
ret = self._proc_mg.poll() ret = self._proc_mg.poll()
assert ret is None, "The database process died prematurely " "({})!".format(ret) assert ret is None, "The database process died prematurely " "({})!".format(ret)
@ -121,9 +123,14 @@ class Memgraph:
class Client: class Client:
def __init__(self, client_binary, temporary_directory): def __init__(
self, client_binary: str, temporary_directory: str, bolt_port: int, username: str = "", password: str = ""
):
self._client_binary = client_binary self._client_binary = client_binary
self._directory = tempfile.TemporaryDirectory(dir=temporary_directory) self._directory = tempfile.TemporaryDirectory(dir=temporary_directory)
self._username = username
self._password = password
self._bolt_port = bolt_port
def _get_args(self, **kwargs): def _get_args(self, **kwargs):
return _convert_args_to_flags(self._client_binary, **kwargs) return _convert_args_to_flags(self._client_binary, **kwargs)
@ -144,8 +151,15 @@ class Client:
json.dump(query, f) json.dump(query, f)
f.write("\n") f.write("\n")
args = self._get_args(input=file_path, num_workers=num_workers, queries_json=queries_json) args = self._get_args(
input=file_path,
num_workers=num_workers,
queries_json=queries_json,
username=self._username,
password=self._password,
port=self._bolt_port,
)
ret = subprocess.run(args, stdout=subprocess.PIPE, check=True) ret = subprocess.run(args, stdout=subprocess.PIPE, check=True)
data = ret.stdout.decode("utf-8").strip().split("\n") data = ret.stdout.decode("utf-8").strip().split("\n")
data = [x for x in data if not x.startswith("[")] # data = [x for x in data if not x.startswith("[")]
return list(map(json.loads, data)) return list(map(json.loads, data))