299 lines
9.9 KiB
Python
299 lines
9.9 KiB
Python
# Copyright 2023 Memgraph Ltd.
|
|
#
|
|
# Use of this software is governed by the Business Source License
|
|
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
|
# License, and you may not use this file except in compliance with the Business Source License.
|
|
#
|
|
# As of the Change Date specified in that file, in accordance with
|
|
# the Business Source License, use of this software will be governed
|
|
# by the Apache License, Version 2.0, included in the file
|
|
# licenses/APL.txt.
|
|
|
|
import collections
|
|
import copy
|
|
import fnmatch
|
|
import importlib
|
|
import inspect
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import workloads
|
|
from benchmark_context import BenchmarkContext
|
|
from workloads import *
|
|
from workloads import base
|
|
|
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
|
|
def get_binary_path(path, base=""):
|
|
dirpath = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
|
|
if os.path.exists(os.path.join(dirpath, "build_release")):
|
|
dirpath = os.path.join(dirpath, "build_release")
|
|
else:
|
|
dirpath = os.path.join(dirpath, "build")
|
|
return os.path.join(dirpath, path)
|
|
|
|
|
|
def download_file(url, path):
|
|
ret = subprocess.run(["wget", "-nv", "--content-disposition", url], stderr=subprocess.PIPE, cwd=path, check=True)
|
|
data = ret.stderr.decode("utf-8")
|
|
tmp = data.split("->")[1]
|
|
name = tmp[tmp.index('"') + 1 : tmp.rindex('"')]
|
|
return os.path.join(path, name)
|
|
|
|
|
|
def unpack_gz_and_move_file(input_path, output_path):
|
|
if input_path.endswith(".gz"):
|
|
subprocess.run(["gunzip", input_path], stdout=subprocess.DEVNULL, check=True)
|
|
input_path = input_path[:-3]
|
|
os.rename(input_path, output_path)
|
|
|
|
|
|
def unpack_gz(input_path: Path):
|
|
if input_path.suffix == ".gz":
|
|
subprocess.run(["gzip", "-d", input_path], capture_output=True, check=True)
|
|
input_path = input_path.with_suffix("")
|
|
return input_path
|
|
|
|
|
|
def unpack_zip(input_path: Path):
|
|
if input_path.suffix == ".zip":
|
|
subprocess.run(["unzip", input_path], capture_output=True, check=True, cwd=input_path.parent)
|
|
input_path = input_path.with_suffix("")
|
|
return input_path
|
|
|
|
|
|
def unpack_tar_zst(input_path: Path):
|
|
if input_path.suffix == ".zst":
|
|
subprocess.run(
|
|
["tar", "--use-compress-program=unzstd", "-xvf", input_path],
|
|
cwd=input_path.parent,
|
|
capture_output=True,
|
|
check=True,
|
|
)
|
|
input_path = input_path.with_suffix("").with_suffix("")
|
|
return input_path
|
|
|
|
|
|
def unpack_tar_gz(input_path: Path):
|
|
if input_path.suffix == ".gz":
|
|
subprocess.run(
|
|
["tar", "-xvf", input_path],
|
|
cwd=input_path.parent,
|
|
capture_output=True,
|
|
check=True,
|
|
)
|
|
input_path = input_path.with_suffix("").with_suffix("")
|
|
return input_path
|
|
|
|
|
|
def unpack_tar_zst_and_move(input_path: Path, output_path: Path):
|
|
if input_path.suffix == ".zst":
|
|
subprocess.run(
|
|
["tar", "--use-compress-program=unzstd", "-xvf", input_path],
|
|
cwd=input_path.parent,
|
|
capture_output=True,
|
|
check=True,
|
|
)
|
|
input_path = input_path.with_suffix("").with_suffix("")
|
|
return input_path.rename(output_path)
|
|
|
|
|
|
def ensure_directory(path):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
if not os.path.isdir(path):
|
|
raise Exception("The path '{}' should be a directory!".format(path))
|
|
|
|
|
|
def get_available_workloads(customer_workloads: str = None) -> dict:
|
|
generators = {}
|
|
for module in map(workloads.__dict__.get, workloads.__all__):
|
|
for key in dir(module):
|
|
if key.startswith("_"):
|
|
continue
|
|
base_class = getattr(module, key)
|
|
if not inspect.isclass(base_class) or not issubclass(base_class, base.Workload):
|
|
continue
|
|
queries = collections.defaultdict(list)
|
|
for funcname in dir(base_class):
|
|
if not funcname.startswith("benchmark__"):
|
|
continue
|
|
group, query = funcname.split("__")[1:]
|
|
queries[group].append((query, funcname))
|
|
generators[base_class.NAME] = (base_class, dict(queries))
|
|
|
|
if customer_workloads:
|
|
head_tail = os.path.split(customer_workloads)
|
|
path_without_dataset_name = head_tail[0]
|
|
dataset_name = head_tail[1].split(".")[0]
|
|
sys.path.append(path_without_dataset_name)
|
|
dataset_to_use = importlib.import_module(dataset_name)
|
|
|
|
for key in dir(dataset_to_use):
|
|
if key.startswith("_"):
|
|
continue
|
|
base_class = getattr(dataset_to_use, key)
|
|
if not inspect.isclass(base_class) or not issubclass(base_class, base.Workload):
|
|
continue
|
|
queries = collections.defaultdict(list)
|
|
for funcname in dir(base_class):
|
|
if not funcname.startswith("benchmark__"):
|
|
continue
|
|
group, query = funcname.split("__")[1:]
|
|
queries[group].append((query, funcname))
|
|
generators[base_class.NAME] = (base_class, dict(queries))
|
|
|
|
return generators
|
|
|
|
|
|
def list_available_workloads(customer_workloads: str = None):
|
|
generators = get_available_workloads(customer_workloads)
|
|
for name in sorted(generators.keys()):
|
|
print("Dataset:", name)
|
|
dataset, queries = generators[name]
|
|
print(
|
|
" Variants:",
|
|
", ".join(dataset.VARIANTS),
|
|
"(default: " + dataset.DEFAULT_VARIANT + ")",
|
|
)
|
|
for group in sorted(queries.keys()):
|
|
print(" Group:", group)
|
|
for query_name, query_func in queries[group]:
|
|
print(" Query:", query_name)
|
|
|
|
|
|
def match_patterns(workload, variant, group, query, is_default_variant, patterns):
|
|
for pattern in patterns:
|
|
verdict = [fnmatch.fnmatchcase(workload, pattern[0])]
|
|
if pattern[1] != "":
|
|
verdict.append(fnmatch.fnmatchcase(variant, pattern[1]))
|
|
else:
|
|
verdict.append(is_default_variant)
|
|
verdict.append(fnmatch.fnmatchcase(group, pattern[2]))
|
|
verdict.append(fnmatch.fnmatchcase(query, pattern[3]))
|
|
if all(verdict):
|
|
return True
|
|
return False
|
|
|
|
|
|
def filter_workloads(available_workloads: dict, benchmark_context: BenchmarkContext) -> list:
|
|
patterns = benchmark_context.benchmark_target_workload
|
|
for i in range(len(patterns)):
|
|
pattern = patterns[i].split("/")
|
|
if len(pattern) > 5 or len(pattern) == 0:
|
|
raise Exception("Invalid benchmark description '" + pattern + "'!")
|
|
pattern.extend(["", "*", "*"][len(pattern) - 1 :])
|
|
patterns[i] = pattern
|
|
filtered = []
|
|
for workload in sorted(available_workloads.keys()):
|
|
generator, queries = available_workloads[workload]
|
|
for variant in generator.VARIANTS:
|
|
is_default_variant = variant == generator.DEFAULT_VARIANT
|
|
current = collections.defaultdict(list)
|
|
for group in queries:
|
|
for query_name, query_func in queries[group]:
|
|
if match_patterns(
|
|
workload,
|
|
variant,
|
|
group,
|
|
query_name,
|
|
is_default_variant,
|
|
patterns,
|
|
):
|
|
current[group].append((query_name, query_func))
|
|
if len(current) == 0:
|
|
continue
|
|
|
|
# Ignore benchgraph "basic" queries in standard CI/CD run
|
|
for pattern in patterns:
|
|
res = pattern.count("*")
|
|
key = "basic"
|
|
if res >= 2 and key in current.keys():
|
|
current.pop(key)
|
|
|
|
filtered.append((generator(variant=variant, benchmark_context=benchmark_context), dict(current)))
|
|
return filtered
|
|
|
|
|
|
def parse_kwargs(items):
|
|
"""
|
|
Parse a series of key-value pairs and return a dictionary
|
|
"""
|
|
d = {}
|
|
|
|
if items:
|
|
for item in items:
|
|
key, value = item.split("=")
|
|
d[key] = value
|
|
return d
|
|
|
|
|
|
class Directory:
|
|
def __init__(self, path):
|
|
self._path = path
|
|
|
|
def get_path(self):
|
|
return self._path
|
|
|
|
def get_file(self, name):
|
|
path = os.path.join(self._path, name)
|
|
if os.path.exists(path) and not os.path.isfile(path):
|
|
raise Exception("The path '{}' should be a file!".format(path))
|
|
return (path, os.path.isfile(path))
|
|
|
|
|
|
class RecursiveDict:
|
|
def __init__(self, data={}):
|
|
self._data = copy.deepcopy(data)
|
|
|
|
def _get_obj_and_key(self, *args):
|
|
key = args[-1]
|
|
obj = self._data
|
|
for item in args[:-1]:
|
|
if item not in obj:
|
|
obj[item] = {}
|
|
obj = obj[item]
|
|
return (obj, key)
|
|
|
|
def get_value(self, *args):
|
|
obj, key = self._get_obj_and_key(*args)
|
|
return obj.get(key, None)
|
|
|
|
def set_value(self, *args, value=None):
|
|
obj, key = self._get_obj_and_key(*args)
|
|
obj[key] = value
|
|
|
|
def get_data(self):
|
|
return copy.deepcopy(self._data)
|
|
|
|
|
|
class Cache:
|
|
def __init__(self):
|
|
self._directory = os.path.join(SCRIPT_DIR, ".cache")
|
|
ensure_directory(self._directory)
|
|
self._config = os.path.join(self._directory, "config.json")
|
|
|
|
def cache_directory(self, *args):
|
|
if len(args) == 0:
|
|
raise ValueError("At least one directory level must be supplied!")
|
|
path = os.path.join(self._directory, *args)
|
|
ensure_directory(path)
|
|
return Directory(path)
|
|
|
|
def get_default_cache_directory(self):
|
|
return self._directory
|
|
|
|
def load_config(self):
|
|
if not os.path.isfile(self._config):
|
|
return RecursiveDict()
|
|
with open(self._config) as f:
|
|
return RecursiveDict(json.load(f))
|
|
|
|
def save_config(self, config):
|
|
with open(self._config, "w") as f:
|
|
json.dump(config.get_data(), f)
|