From a3c2492672ce4cad6f723e5ee586d7867ce27eb6 Mon Sep 17 00:00:00 2001
From: niko4299 <51059248+niko4299@users.noreply.github.com>
Date: Thu, 15 Sep 2022 21:33:15 +0200
Subject: [PATCH] Add fine grained access control to mgbench (#522)

---
 tests/mgbench/benchmark.py       | 369 +++++++++++++++++++------------
 tests/mgbench/compare_results.py |  94 ++++----
 tests/mgbench/datasets.py        |  38 +---
 tests/mgbench/runners.py         |  24 +-
 4 files changed, 313 insertions(+), 212 deletions(-)

diff --git a/tests/mgbench/benchmark.py b/tests/mgbench/benchmark.py
index 5ce715571..15fb67d5f 100755
--- a/tests/mgbench/benchmark.py
+++ b/tests/mgbench/benchmark.py
@@ -26,6 +26,9 @@ import log
 import helpers
 import runners
 
+WITH_FINE_GRAINED_AUTHORIZATION = "with_fine_grained_authorization"
+WITHOUT_FINE_GRAINED_AUTHORIZATION = "without_fine_grained_authorization"
+
 
 def get_queries(gen, count):
     # Make the generator deterministic.
@@ -37,8 +40,7 @@ def get_queries(gen, count):
     return ret
 
 
-def match_patterns(dataset, variant, group, test, is_default_variant,
-                   patterns):
+def match_patterns(dataset, variant, group, test, is_default_variant, patterns):
     for pattern in patterns:
         verdict = [fnmatch.fnmatchcase(dataset, pattern[0])]
         if pattern[1] != "":
@@ -58,7 +60,7 @@ def filter_benchmarks(generators, patterns):
         pattern = patterns[i].split("/")
         if len(pattern) > 4 or len(pattern) == 0:
             raise Exception("Invalid benchmark description '" + pattern + "'!")
-        pattern.extend(["", "*", "*"][len(pattern) - 1:])
+        pattern.extend(["", "*", "*"][len(pattern) - 1 :])
         patterns[i] = pattern
     filtered = []
     for dataset in sorted(generators.keys()):
@@ -68,8 +70,7 @@ def filter_benchmarks(generators, patterns):
             current = collections.defaultdict(list)
             for group in tests:
                 for test_name, test_func in tests[group]:
-                    if match_patterns(dataset, variant, group, test_name,
-                                      is_default_variant, patterns):
+                    if match_patterns(dataset, variant, group, test_name, is_default_variant, patterns):
                         current[group].append((test_name, test_func))
             if len(current) > 0:
                 filtered.append((generator(variant), dict(current)))
@@ -79,43 +80,72 @@ def filter_benchmarks(generators, patterns):
 # Parse options.
 parser = argparse.ArgumentParser(
     description="Memgraph benchmark executor.",
-    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-parser.add_argument("benchmarks", nargs="*", default="",
-                    help="descriptions of benchmarks that should be run; "
-                    "multiple descriptions can be specified to run multiple "
-                    "benchmarks; the description is specified as "
-                    "dataset/variant/group/test; Unix shell-style wildcards "
-                    "can be used in the descriptions; variant, group and test "
-                    "are optional and they can be left out; the default "
-                    "variant is '' which selects the default dataset variant; "
-                    "the default group is '*' which selects all groups; the "
-                    "default test is '*' which selects all tests")
-parser.add_argument("--memgraph-binary",
-                    default=helpers.get_binary_path("memgraph"),
-                    help="Memgraph binary used for benchmarking")
-parser.add_argument("--client-binary",
-                    default=helpers.get_binary_path("tests/mgbench/client"),
-                    help="client binary used for benchmarking")
-parser.add_argument("--num-workers-for-import", type=int,
-                    default=multiprocessing.cpu_count() // 2,
-                    help="number of workers used to import the dataset")
-parser.add_argument("--num-workers-for-benchmark", type=int,
-                    default=1,
-                    help="number of workers used to execute the benchmark")
-parser.add_argument("--single-threaded-runtime-sec", type=int,
-                    default=10,
-                    help="single threaded duration of each test")
-parser.add_argument("--no-load-query-counts", action="store_true",
-                    help="disable loading of cached query counts")
-parser.add_argument("--no-save-query-counts", action="store_true",
-                    help="disable storing of cached query counts")
-parser.add_argument("--export-results", default="",
-                    help="file path into which results should be exported")
-parser.add_argument("--temporary-directory", default="/tmp",
-                    help="directory path where temporary data should "
-                    "be stored")
-parser.add_argument("--no-properties-on-edges", action="store_true",
-                    help="disable properties on edges")
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+)
+parser.add_argument(
+    "benchmarks",
+    nargs="*",
+    default="",
+    help="descriptions of benchmarks that should be run; "
+    "multiple descriptions can be specified to run multiple "
+    "benchmarks; the description is specified as "
+    "dataset/variant/group/test; Unix shell-style wildcards "
+    "can be used in the descriptions; variant, group and test "
+    "are optional and they can be left out; the default "
+    "variant is '' which selects the default dataset variant; "
+    "the default group is '*' which selects all groups; the "
+    "default test is '*' which selects all tests",
+)
+parser.add_argument(
+    "--memgraph-binary",
+    default=helpers.get_binary_path("memgraph"),
+    help="Memgraph binary used for benchmarking",
+)
+parser.add_argument(
+    "--client-binary",
+    default=helpers.get_binary_path("tests/mgbench/client"),
+    help="client binary used for benchmarking",
+)
+parser.add_argument(
+    "--num-workers-for-import",
+    type=int,
+    default=multiprocessing.cpu_count() // 2,
+    help="number of workers used to import the dataset",
+)
+parser.add_argument(
+    "--num-workers-for-benchmark",
+    type=int,
+    default=1,
+    help="number of workers used to execute the benchmark",
+)
+parser.add_argument(
+    "--single-threaded-runtime-sec",
+    type=int,
+    default=10,
+    help="single threaded duration of each test",
+)
+parser.add_argument(
+    "--no-load-query-counts",
+    action="store_true",
+    help="disable loading of cached query counts",
+)
+parser.add_argument(
+    "--no-save-query-counts",
+    action="store_true",
+    help="disable storing of cached query counts",
+)
+parser.add_argument(
+    "--export-results",
+    default="",
+    help="file path into which results should be exported",
+)
+parser.add_argument(
+    "--temporary-directory",
+    default="/tmp",
+    help="directory path where temporary data should " "be stored",
+)
+parser.add_argument("--no-properties-on-edges", action="store_true", help="disable properties on edges")
+parser.add_argument("--bolt-port", default=7687, help="memgraph bolt port")
 args = parser.parse_args()
 
 # Detect available datasets.
@@ -124,8 +154,7 @@ for key in dir(datasets):
     if key.startswith("_"):
         continue
     dataset = getattr(datasets, key)
-    if not inspect.isclass(dataset) or dataset == datasets.Dataset or \
-            not issubclass(dataset, datasets.Dataset):
+    if not inspect.isclass(dataset) or dataset == datasets.Dataset or not issubclass(dataset, datasets.Dataset):
         continue
     tests = collections.defaultdict(list)
     for funcname in dir(dataset):
@@ -135,8 +164,9 @@ for key in dir(datasets):
         tests[group].append((test, funcname))
     generators[dataset.NAME] = (dataset, dict(tests))
     if dataset.PROPERTIES_ON_EDGES and args.no_properties_on_edges:
-        raise Exception("The \"{}\" dataset requires properties on edges, "
-                        "but you have disabled them!".format(dataset.NAME))
+        raise Exception(
+            'The "{}" dataset requires properties on edges, ' "but you have disabled them!".format(dataset.NAME)
+        )
 
 # List datasets if there is no specified dataset.
 if len(args.benchmarks) == 0:
@@ -144,8 +174,11 @@ if len(args.benchmarks) == 0:
     for name in sorted(generators.keys()):
         print("Dataset:", name)
         dataset, tests = generators[name]
-        print("    Variants:", ", ".join(dataset.VARIANTS),
-              "(default: " + dataset.DEFAULT_VARIANT + ")")
+        print(
+            "    Variants:",
+            ", ".join(dataset.VARIANTS),
+            "(default: " + dataset.DEFAULT_VARIANT + ")",
+        )
         for group in sorted(tests.keys()):
             print("    Group:", group)
             for test_name, test_func in tests[group]:
@@ -162,34 +195,45 @@ results = helpers.RecursiveDict()
 
 # Filter out the generators.
 benchmarks = filter_benchmarks(generators, args.benchmarks)
-
 # Run all specified benchmarks.
 for dataset, tests in benchmarks:
-    log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(),
-             "dataset")
-    dataset.prepare(cache.cache_directory("datasets", dataset.NAME,
-                                          dataset.get_variant()))
+    log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(), "dataset")
+    dataset.prepare(cache.cache_directory("datasets", dataset.NAME, dataset.get_variant()))
 
     # Prepare runners and import the dataset.
-    memgraph = runners.Memgraph(args.memgraph_binary, args.temporary_directory,
-                                not args.no_properties_on_edges)
-    client = runners.Client(args.client_binary, args.temporary_directory)
+    memgraph = runners.Memgraph(
+        args.memgraph_binary,
+        args.temporary_directory,
+        not args.no_properties_on_edges,
+        args.bolt_port,
+    )
+    client = runners.Client(args.client_binary, args.temporary_directory, args.bolt_port)
     memgraph.start_preparation()
-    ret = client.execute(file_path=dataset.get_file(),
-                         num_workers=args.num_workers_for_import)
+    ret = client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import)
     usage = memgraph.stop()
 
     # Display import statistics.
     print()
     for row in ret:
-        print("Executed", row["count"], "queries in", row["duration"],
-              "seconds using", row["num_workers"],
-              "workers with a total throughput of", row["throughput"],
-              "queries/second.")
+        print(
+            "Executed",
+            row["count"],
+            "queries in",
+            row["duration"],
+            "seconds using",
+            row["num_workers"],
+            "workers with a total throughput of",
+            row["throughput"],
+            "queries/second.",
+        )
     print()
-    print("The database used", usage["cpu"],
-          "seconds of CPU time and peaked at",
-          usage["memory"] / 1024 / 1024, "MiB of RAM.")
+    print(
+        "The database used",
+        usage["cpu"],
+        "seconds of CPU time and peaked at",
+        usage["memory"] / 1024 / 1024,
+        "MiB of RAM.",
+    )
 
     # Save import results.
     import_key = [dataset.NAME, dataset.get_variant(), "__import__"]
@@ -198,87 +242,128 @@ for dataset, tests in benchmarks:
     # TODO: cache import data
 
     # Run all benchmarks in all available groups.
-    for group in sorted(tests.keys()):
-        for test, funcname in tests[group]:
-            log.info("Running test:", "{}/{}".format(group, test))
-            func = getattr(dataset, funcname)
 
-            # Get number of queries to execute.
-            # TODO: implement minimum number of queries, `max(10, num_workers)`
-            config_key = [dataset.NAME, dataset.get_variant(), group, test]
-            cached_count = config.get_value(*config_key)
-            if cached_count is None:
-                print("Determining the number of queries necessary for",
-                      args.single_threaded_runtime_sec,
-                      "seconds of single-threaded runtime...")
-                # First run to prime the query caches.
+    for with_fine_grained_authorization in [False, True]:
+        if with_fine_grained_authorization:
+            memgraph.start_preparation()
+            client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import)
+            client.execute(
+                queries=[
+                    ("CREATE USER user IDENTIFIED BY 'test';", {}),
+                    ("GRANT ALL PRIVILEGES TO user;", {}),
+                    ("GRANT CREATE_DELETE ON EDGE_TYPES * TO user;", {}),
+                    ("GRANT CREATE_DELETE ON LABELS * TO user;", {}),
+                ]
+            )
+            client = runners.Client(
+                args.client_binary,
+                args.temporary_directory,
+                args.bolt_port,
+                username="user",
+                password="test",
+            )
+            memgraph.stop()
+
+        test_type = (
+            WITH_FINE_GRAINED_AUTHORIZATION if with_fine_grained_authorization else WITHOUT_FINE_GRAINED_AUTHORIZATION
+        )
+
+        for group in sorted(tests.keys()):
+            for test, funcname in tests[group]:
+                log.info("Running test:", "{}/{}/{}".format(group, test, test_type))
+                func = getattr(dataset, funcname)
+
+                # Get number of queries to execute.
+                # TODO: implement minimum number of queries, `max(10, num_workers)`
+                config_key = [dataset.NAME, dataset.get_variant(), group, test, test_type]
+                cached_count = config.get_value(*config_key)
+                if cached_count is None:
+                    print(
+                        "Determining the number of queries necessary for",
+                        args.single_threaded_runtime_sec,
+                        "seconds of single-threaded runtime...",
+                    )
+                    # First run to prime the query caches.
+                    memgraph.start_benchmark()
+                    client.execute(queries=get_queries(func, 1), num_workers=1)
+                    # Get a sense of the runtime.
+                    count = 1
+                    while True:
+                        ret = client.execute(queries=get_queries(func, count), num_workers=1)
+                        duration = ret[0]["duration"]
+                        should_execute = int(args.single_threaded_runtime_sec / (duration / count))
+                        print(
+                            "executed_queries={}, total_duration={}, "
+                            "query_duration={}, estimated_count={}".format(
+                                count, duration, duration / count, should_execute
+                            )
+                        )
+                        # We don't have to execute the next iteration when
+                        # `should_execute` becomes the same order of magnitude as
+                        # `count * 10`.
+                        if should_execute / (count * 10) < 10:
+                            count = should_execute
+                            break
+                        else:
+                            count = count * 10
+                    memgraph.stop()
+                    config.set_value(
+                        *config_key,
+                        value={
+                            "count": count,
+                            "duration": args.single_threaded_runtime_sec,
+                        },
+                    )
+                else:
+                    print(
+                        "Using cached query count of",
+                        cached_count["count"],
+                        "queries for",
+                        cached_count["duration"],
+                        "seconds of single-threaded runtime.",
+                    )
+                    count = int(cached_count["count"] * args.single_threaded_runtime_sec / cached_count["duration"])
+
+                # Benchmark run.
+                print("Sample query:", get_queries(func, 1)[0][0])
+                print(
+                    "Executing benchmark with",
+                    count,
+                    "queries that should " "yield a single-threaded runtime of",
+                    args.single_threaded_runtime_sec,
+                    "seconds.",
+                )
+                print(
+                    "Queries are executed using",
+                    args.num_workers_for_benchmark,
+                    "concurrent clients.",
+                )
                 memgraph.start_benchmark()
-                client.execute(queries=get_queries(func, 1), num_workers=1)
-                # Get a sense of the runtime.
-                count = 1
-                while True:
-                    ret = client.execute(queries=get_queries(func, count),
-                                         num_workers=1)
-                    duration = ret[0]["duration"]
-                    should_execute = int(args.single_threaded_runtime_sec /
-                                         (duration / count))
-                    print("executed_queries={}, total_duration={}, "
-                          "query_duration={}, estimated_count={}".format(
-                              count, duration, duration / count,
-                              should_execute))
-                    # We don't have to execute the next iteration when
-                    # `should_execute` becomes the same order of magnitude as
-                    # `count * 10`.
-                    if should_execute / (count * 10) < 10:
-                        count = should_execute
-                        break
-                    else:
-                        count = count * 10
-                memgraph.stop()
-                config.set_value(*config_key, value={
-                    "count": count,
-                    "duration": args.single_threaded_runtime_sec})
-            else:
-                print("Using cached query count of", cached_count["count"],
-                      "queries for", cached_count["duration"],
-                      "seconds of single-threaded runtime.")
-                count = int(cached_count["count"] *
-                            args.single_threaded_runtime_sec /
-                            cached_count["duration"])
+                ret = client.execute(
+                    queries=get_queries(func, count),
+                    num_workers=args.num_workers_for_benchmark,
+                )[0]
+                usage = memgraph.stop()
+                ret["database"] = usage
 
-            # Benchmark run.
-            print("Sample query:", get_queries(func, 1)[0][0])
-            print("Executing benchmark with", count, "queries that should "
-                  "yield a single-threaded runtime of",
-                  args.single_threaded_runtime_sec, "seconds.")
-            print("Queries are executed using", args.num_workers_for_benchmark,
-                  "concurrent clients.")
-            memgraph.start_benchmark()
-            ret = client.execute(queries=get_queries(func, count),
-                                 num_workers=args.num_workers_for_benchmark)[0]
-            usage = memgraph.stop()
-            ret["database"] = usage
+                # Output summary.
+                print()
+                print("Executed", ret["count"], "queries in", ret["duration"], "seconds.")
+                print("Queries have been retried", ret["retries"], "times.")
+                print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"]))
+                print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0))
+                print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", "avg", "max"))
+                metadata = ret["metadata"]
+                for key in sorted(metadata.keys()):
+                    print(
+                        "{name:>30}: {minimum:>20.06f} {average:>20.06f} "
+                        "{maximum:>20.06f}".format(name=key, **metadata[key])
+                    )
+                log.success("Throughput: {:02f} QPS".format(ret["throughput"]))
 
-            # Output summary.
-            print()
-            print("Executed", ret["count"], "queries in",
-                  ret["duration"], "seconds.")
-            print("Queries have been retried", ret["retries"], "times.")
-            print("Database used {:.3f} seconds of CPU time.".format(
-                  usage["cpu"]))
-            print("Database peaked at {:.3f} MiB of memory.".format(
-                  usage["memory"] / 1024.0 / 1024.0))
-            print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min",
-                                                       "avg", "max"))
-            metadata = ret["metadata"]
-            for key in sorted(metadata.keys()):
-                print("{name:>30}: {minimum:>20.06f} {average:>20.06f} "
-                      "{maximum:>20.06f}".format(name=key, **metadata[key]))
-            log.success("Throughput: {:02f} QPS".format(ret["throughput"]))
-
-            # Save results.
-            results_key = [dataset.NAME, dataset.get_variant(), group, test]
-            results.set_value(*results_key, value=ret)
+                # Save results.
+                results_key = [dataset.NAME, dataset.get_variant(), group, test, test_type]
+                results.set_value(*results_key, value=ret)
 
 # Save configuration.
 if not args.no_save_query_counts:
diff --git a/tests/mgbench/compare_results.py b/tests/mgbench/compare_results.py
index 2179bb408..17703b9dd 100755
--- a/tests/mgbench/compare_results.py
+++ b/tests/mgbench/compare_results.py
@@ -77,7 +77,7 @@ def recursive_get(data, *args, value=None):
     return data
 
 
-def compare_results(results_from, results_to, fields):
+def compare_results(results_from, results_to, fields, ignored):
     ret = {}
     for dataset, variants in results_to.items():
         for variant, groups in variants.items():
@@ -85,39 +85,44 @@ def compare_results(results_from, results_to, fields):
                 if group == "__import__":
                     continue
                 for scenario, summary_to in scenarios.items():
-                    summary_from = recursive_get(
-                        results_from, dataset, variant, group, scenario,
-                        value={})
-                    if len(summary_from) > 0 and \
-                            summary_to["count"] != summary_from["count"] or \
-                            summary_to["num_workers"] != \
-                            summary_from["num_workers"]:
+                    if scenario in ignored:
+                        continue
+
+                    summary_from = recursive_get(results_from, dataset, variant, group, scenario, value={})
+                    if (
+                        len(summary_from) > 0
+                        and summary_to["count"] != summary_from["count"]
+                        or summary_to["num_workers"] != summary_from["num_workers"]
+                    ):
                         raise Exception("Incompatible results!")
-                    testcode = "/".join([dataset, variant, group, scenario,
-                                         "{:02d}".format(
-                                             summary_to["num_workers"])])
+                    testcode = "/".join(
+                        [
+                            dataset,
+                            variant,
+                            group,
+                            scenario,
+                            "{:02d}".format(summary_to["num_workers"]),
+                        ]
+                    )
                     row = {}
                     performance_changed = False
                     for field in fields:
                         key = field["name"]
                         if key in summary_to:
-                            row[key] = compute_diff(
-                                summary_from.get(key, None),
-                                summary_to[key])
+                            row[key] = compute_diff(summary_from.get(key, None), summary_to[key])
                         elif key in summary_to["database"]:
                             row[key] = compute_diff(
-                                recursive_get(summary_from, "database", key,
-                                              value=None),
-                                summary_to["database"][key])
+                                recursive_get(summary_from, "database", key, value=None),
+                                summary_to["database"][key],
+                            )
                         else:
                             row[key] = compute_diff(
-                                recursive_get(summary_from, "metadata", key,
-                                              "average", value=None),
-                                summary_to["metadata"][key]["average"])
-                        if "diff" not in row[key] or \
-                                ("diff_treshold" in field and
-                                 abs(row[key]["diff"]) >=
-                                 field["diff_treshold"]):
+                                recursive_get(summary_from, "metadata", key, "average", value=None),
+                                summary_to["metadata"][key]["average"],
+                            )
+                        if "diff" not in row[key] or (
+                            "diff_treshold" in field and abs(row[key]["diff"]) >= field["diff_treshold"]
+                        ):
                             performance_changed = True
                     if performance_changed:
                         ret[testcode] = row
@@ -130,8 +135,15 @@ def generate_remarkup(fields, data):
         ret += "<table>\n"
         ret += "  <tr>\n"
         ret += "    <th>Testcode</th>\n"
-        ret += "\n".join(map(lambda x: "    <th>{}</th>".format(
-            x["name"].replace("_", " ").capitalize()), fields)) + "\n"
+        ret += (
+            "\n".join(
+                map(
+                    lambda x: "    <th>{}</th>".format(x["name"].replace("_", " ").capitalize()),
+                    fields,
+                )
+            )
+            + "\n"
+        )
         ret += "  </tr>\n"
         for testcode in sorted(data.keys()):
             ret += "  <tr>\n"
@@ -147,12 +159,9 @@ def generate_remarkup(fields, data):
                     else:
                         color = "red"
                     sign = "{{icon {} color={}}}".format(arrow, color)
-                    ret += "    <td>{:.3f}{} //({:+.2%})// {}</td>\n".format(
-                        value, field["unit"], diff, sign)
+                    ret += '    <td bgcolor="{}">{:.3f}{} ({:+.2%})</td>\n'.format(color, value, field["unit"], diff)
                 else:
-                    ret += "    <td>{:.3f}{} //(new)// " \
-                           "{{icon plus color=blue}}</td>\n".format(
-                               value, field["unit"])
+                    ret += '<td bgcolor="blue">{:.3f}{} //(new)// </td>\n'.format(value, field["unit"])
             ret += "  </tr>\n"
         ret += "</table>\n"
     else:
@@ -161,22 +170,33 @@ def generate_remarkup(fields, data):
 
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(
-        description="Compare results of multiple benchmark runs.")
-    parser.add_argument("--compare", action="append", nargs=2,
-                        metavar=("from", "to"),
-                        help="compare results between `from` and `to` files")
+    parser = argparse.ArgumentParser(description="Compare results of multiple benchmark runs.")
+    parser.add_argument(
+        "--compare",
+        action="append",
+        nargs=2,
+        metavar=("from", "to"),
+        help="compare results between `from` and `to` files",
+    )
     parser.add_argument("--output", default="", help="output file name")
+    # file is read line by line, each representing one test name
+    parser.add_argument("--exclude_tests_file", help="file listing test names to be excluded")
     args = parser.parse_args()
 
     if args.compare is None or len(args.compare) == 0:
         raise Exception("You must specify at least one pair of files!")
 
+    if args.exclude_tests_file:
+        with open(args.exclude_tests_file, "r") as f:
+            ignored = [line.rstrip("\n") for line in f]
+    else:
+        ignored = []
+
     data = {}
     for file_from, file_to in args.compare:
         results_from = load_results(file_from)
         results_to = load_results(file_to)
-        data.update(compare_results(results_from, results_to, FIELDS))
+        data.update(compare_results(results_from, results_to, FIELDS, ignored))
 
     remarkup = generate_remarkup(FIELDS, data)
     if args.output:
diff --git a/tests/mgbench/datasets.py b/tests/mgbench/datasets.py
index c68fcac34..c0508e2d2 100644
--- a/tests/mgbench/datasets.py
+++ b/tests/mgbench/datasets.py
@@ -135,10 +135,7 @@ class Pokec(Dataset):
         return ("MATCH (n:User {id : $id}) RETURN n", {"id": self._get_random_vertex()})
 
     def benchmark__arango__single_vertex_write(self):
-        return (
-            "CREATE (n:UserTemp {id : $id}) RETURN n",
-            {"id": random.randint(1, self._num_vertices * 10)},
-        )
+        return ("CREATE (n:UserTemp {id : $id}) RETURN n", {"id": random.randint(1, self._num_vertices * 10)})
 
     def benchmark__arango__single_edge_write(self):
         vertex_from, vertex_to = self._get_random_from_to()
@@ -154,10 +151,7 @@ class Pokec(Dataset):
         return ("MATCH (n:User) WHERE n.age >= 18 RETURN n.age, COUNT(*)", {})
 
     def benchmark__arango__expansion_1(self):
-        return (
-            "MATCH (s:User {id: $id})-->(n:User) " "RETURN n.id",
-            {"id": self._get_random_vertex()},
-        )
+        return ("MATCH (s:User {id: $id})-->(n:User) " "RETURN n.id", {"id": self._get_random_vertex()})
 
     def benchmark__arango__expansion_1_with_filter(self):
         return (
@@ -166,10 +160,7 @@ class Pokec(Dataset):
         )
 
     def benchmark__arango__expansion_2(self):
-        return (
-            "MATCH (s:User {id: $id})-->()-->(n:User) " "RETURN DISTINCT n.id",
-            {"id": self._get_random_vertex()},
-        )
+        return ("MATCH (s:User {id: $id})-->()-->(n:User) " "RETURN DISTINCT n.id", {"id": self._get_random_vertex()})
 
     def benchmark__arango__expansion_2_with_filter(self):
         return (
@@ -202,10 +193,7 @@ class Pokec(Dataset):
         )
 
     def benchmark__arango__neighbours_2(self):
-        return (
-            "MATCH (s:User {id: $id})-[*1..2]->(n:User) " "RETURN DISTINCT n.id",
-            {"id": self._get_random_vertex()},
-        )
+        return ("MATCH (s:User {id: $id})-[*1..2]->(n:User) " "RETURN DISTINCT n.id", {"id": self._get_random_vertex()})
 
     def benchmark__arango__neighbours_2_with_filter(self):
         return (
@@ -282,10 +270,7 @@ class Pokec(Dataset):
         return ("MATCH (n) RETURN min(n.age), max(n.age), avg(n.age)", {})
 
     def benchmark__match__pattern_cycle(self):
-        return (
-            "MATCH (n:User {id: $id})-[e1]->(m)-[e2]->(n) " "RETURN e1, m, e2",
-            {"id": self._get_random_vertex()},
-        )
+        return ("MATCH (n:User {id: $id})-[e1]->(m)-[e2]->(n) " "RETURN e1, m, e2", {"id": self._get_random_vertex()})
 
     def benchmark__match__pattern_long(self):
         return (
@@ -294,19 +279,16 @@ class Pokec(Dataset):
         )
 
     def benchmark__match__pattern_short(self):
-        return (
-            "MATCH (n:User {id: $id})-[e]->(m) " "RETURN m LIMIT 1",
-            {"id": self._get_random_vertex()},
-        )
+        return ("MATCH (n:User {id: $id})-[e]->(m) " "RETURN m LIMIT 1", {"id": self._get_random_vertex()})
 
     def benchmark__match__vertex_on_label_property(self):
-        return (
-            "MATCH (n:User) WITH n WHERE n.id = $id RETURN n",
-            {"id": self._get_random_vertex()},
-        )
+        return ("MATCH (n:User) WITH n WHERE n.id = $id RETURN n", {"id": self._get_random_vertex()})
 
     def benchmark__match__vertex_on_label_property_index(self):
         return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()})
 
     def benchmark__match__vertex_on_property(self):
         return ("MATCH (n {id: $id}) RETURN n", {"id": self._get_random_vertex()})
+
+    def benchmark__update__vertex_on_property(self):
+        return ("MATCH (n {id: $id}) SET n.property = -1", {"id": self._get_random_vertex()})
diff --git a/tests/mgbench/runners.py b/tests/mgbench/runners.py
index b6b727ed8..bc2142cba 100644
--- a/tests/mgbench/runners.py
+++ b/tests/mgbench/runners.py
@@ -51,11 +51,12 @@ def _get_usage(pid):
 
 
 class Memgraph:
-    def __init__(self, memgraph_binary, temporary_dir, properties_on_edges):
+    def __init__(self, memgraph_binary, temporary_dir, properties_on_edges, bolt_port):
         self._memgraph_binary = memgraph_binary
         self._directory = tempfile.TemporaryDirectory(dir=temporary_dir)
         self._properties_on_edges = properties_on_edges
         self._proc_mg = None
+        self._bolt_port = bolt_port
         atexit.register(self._cleanup)
 
         # Determine Memgraph version
@@ -69,6 +70,7 @@ class Memgraph:
 
     def _get_args(self, **kwargs):
         data_directory = os.path.join(self._directory.name, "memgraph")
+        kwargs["bolt_port"] = self._bolt_port
         if self._memgraph_version >= (0, 50, 0):
             kwargs["data_directory"] = data_directory
         else:
@@ -88,7 +90,7 @@ class Memgraph:
         if self._proc_mg.poll() is not None:
             self._proc_mg = None
             raise Exception("The database process died prematurely!")
-        wait_for_server(7687)
+        wait_for_server(self._bolt_port)
         ret = self._proc_mg.poll()
         assert ret is None, "The database process died prematurely " "({})!".format(ret)
 
@@ -121,9 +123,14 @@ class Memgraph:
 
 
 class Client:
-    def __init__(self, client_binary, temporary_directory):
+    def __init__(
+        self, client_binary: str, temporary_directory: str, bolt_port: int, username: str = "", password: str = ""
+    ):
         self._client_binary = client_binary
         self._directory = tempfile.TemporaryDirectory(dir=temporary_directory)
+        self._username = username
+        self._password = password
+        self._bolt_port = bolt_port
 
     def _get_args(self, **kwargs):
         return _convert_args_to_flags(self._client_binary, **kwargs)
@@ -144,8 +151,15 @@ class Client:
                     json.dump(query, f)
                     f.write("\n")
 
-        args = self._get_args(input=file_path, num_workers=num_workers, queries_json=queries_json)
+        args = self._get_args(
+            input=file_path,
+            num_workers=num_workers,
+            queries_json=queries_json,
+            username=self._username,
+            password=self._password,
+            port=self._bolt_port,
+        )
         ret = subprocess.run(args, stdout=subprocess.PIPE, check=True)
         data = ret.stdout.decode("utf-8").strip().split("\n")
-        data = [x for x in data if not x.startswith("[")]
+        # data = [x for x in data if not x.startswith("[")]
         return list(map(json.loads, data))