memgraph/tests/integration/mg_import_csv/runner.py

159 lines
5.5 KiB
Python
Raw Normal View History

#!/usr/bin/python3 -u
import argparse
import atexit
import os
import subprocess
import sys
import tempfile
import time
import yaml
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
while subprocess.call(cmd) != 0:
time.sleep(0.01)
time.sleep(delay)
def get_build_dir():
if os.path.exists(os.path.join(BASE_DIR, "build_release")):
return os.path.join(BASE_DIR, "build_release")
if os.path.exists(os.path.join(BASE_DIR, "build_debug")):
return os.path.join(BASE_DIR, "build_debug")
if os.path.exists(os.path.join(BASE_DIR, "build_community")):
return os.path.join(BASE_DIR, "build_community")
return os.path.join(BASE_DIR, "build")
def extract_rows(data):
return list(map(lambda x: x.strip(), data.strip().split("\n")))
def list_to_string(data):
ret = "[\n"
for row in data:
ret += " " + row + "\n"
ret += "]"
return ret
def execute_test(name, test_path, test_config, memgraph_binary,
mg_import_csv_binary, tester_binary):
print("\033[1;36m~~ Executing test", name, "~~\033[0m")
storage_directory = tempfile.TemporaryDirectory()
# Verify test configuration
if ("import_should_fail" not in test_config and
"expected" not in test_config) or \
("import_should_fail" in test_config and
"expected" in test_config):
raise Exception("The test should specify either 'import_should_fail' "
"or 'expected'!")
# Load test expected queries
import_should_fail = test_config.pop("import_should_fail", False)
expected_path = test_config.pop("expected", "")
if expected_path:
with open(os.path.join(test_path, expected_path)) as f:
queries_expected = extract_rows(f.read())
else:
queries_expected = ""
# Generate common args
properties_on_edges = bool(test_config.pop("properties_on_edges", False))
common_args = ["--data-directory", storage_directory.name,
"--storage-properties-on-edges=" +
str(properties_on_edges).lower()]
# Generate mg_import_csv args using flags specified in the test
mg_import_csv_args = [mg_import_csv_binary] + common_args
for key, value in test_config.items():
flag = "--" + key.replace("_", "-")
if type(value) == list:
for item in value:
mg_import_csv_args.extend([flag, str(item)])
elif type(value) == bool:
mg_import_csv_args.append(flag + "=" + str(value).lower())
else:
mg_import_csv_args.extend([flag, str(value)])
# Execute mg_import_csv
ret = subprocess.run(mg_import_csv_args, cwd=test_path)
# Check the return code
if import_should_fail:
if ret.returncode == 0:
raise Exception("The import should have failed, but it "
"succeeded instead!")
else:
print("\033[1;32m~~ Test successful ~~\033[0m\n")
return
else:
if ret.returncode != 0:
raise Exception("The import should have succeeded, but it "
"failed instead!")
# Start the memgraph binary
memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \
common_args
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
time.sleep(0.1)
assert memgraph.poll() is None, "Memgraph process died prematurely!"
wait_for_server(7687)
# Register cleanup function
@atexit.register
def cleanup():
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
# Get the contents of the database
queries_got = extract_rows(subprocess.run(
[tester_binary], stdout=subprocess.PIPE,
check=True).stdout.decode("utf-8"))
# Shutdown the memgraph binary
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
# Verify the queries
queries_expected.sort()
queries_got.sort()
assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \
"{}".format(list_to_string(queries_got),
list_to_string(queries_expected))
print("\033[1;32m~~ Test successful ~~\033[0m\n")
if __name__ == "__main__":
memgraph_binary = os.path.join(get_build_dir(), "memgraph")
mg_import_csv_binary = os.path.join(
get_build_dir(), "src", "mg_import_csv")
tester_binary = os.path.join(
get_build_dir(), "tests", "integration", "mg_import_csv", "tester")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--mg-import-csv", default=mg_import_csv_binary)
parser.add_argument("--tester", default=tester_binary)
args = parser.parse_args()
test_dir = os.path.join(SCRIPT_DIR, "tests")
for name in sorted(os.listdir(test_dir)):
print("\033[1;34m~~ Processing tests from", name, "~~\033[0m\n")
test_path = os.path.join(test_dir, name)
with open(os.path.join(test_path, "test.yaml")) as f:
testcases = yaml.safe_load(f)
for test_config in testcases:
test_name = name + "/" + test_config.pop("name")
execute_test(test_name, test_path, test_config, args.memgraph,
args.mg_import_csv, args.tester)
sys.exit(0)