#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Durability Stress Test
"""

# TODO (mg_durability_stress_test): extend once we add full durability mode

import atexit
import multiprocessing
import os
import random
import shutil
import subprocess
import threading
import time
from multiprocessing import Manager, Pool

from common import SessionCache, connection_argument_parser

# Constants and args
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
BUILD_DIR = os.path.join(BASE_DIR, "build")
DATA_DIR = os.path.join(BUILD_DIR, "mg_data")
if "THREADS" in os.environ:
    DB_WORKERS = os.environ["THREADS"]
else:
    DB_WORKERS = multiprocessing.cpu_count()
# The snapshot interval has to be in sync with Memgraph's parameter
# --snapshot-cycle-sec
SNAPSHOT_CYCLE_SEC = 5

# Parse command line arguments
parser = connection_argument_parser()

parser.add_argument("--memgraph", default=os.path.join(BUILD_DIR, "memgraph"))
parser.add_argument("--log-file", default="")
parser.add_argument("--verbose", action="store_const", const=True, default=False)
parser.add_argument("--data-directory", default=DATA_DIR)
parser.add_argument("--num-clients", default=multiprocessing.cpu_count())
parser.add_argument("--num-steps", type=int, default=5)
args = parser.parse_args()

# Memgraph run command construction
cwd = os.path.dirname(args.memgraph)
cmd = [
    args.memgraph,
    "--bolt-num-workers=" + str(DB_WORKERS),
    "--storage-properties-on-edges=true",
    "--storage-snapshot-on-exit=false",
    "--storage-snapshot-interval-sec=5",
    "--storage-snapshot-retention-count=2",
    "--storage-wal-enabled=true",
    "--storage-recover-on-startup=true",
    "--query-execution-timeout-sec=600",
    "--bolt-server-name-for-init=Neo4j/v5.11.0 compatible graph database server - Memgraph",
]
if not args.verbose:
    cmd += ["--log-level", "WARNING"]
if args.log_file:
    cmd += ["--log-file", args.log_file]
if args.data_directory:
    cmd += ["--data-directory", args.data_directory]

data_manager = Manager()
data = data_manager.dict()

# Pretest cleanup
if os.path.exists(DATA_DIR):
    shutil.rmtree(DATA_DIR)

atexit.register(SessionCache.cleanup)


@atexit.register
def clean_memgraph():
    global proc_mg
    if proc_mg is None:
        return
    if proc_mg.poll() is not None:
        return
    proc_mg.kill()
    proc_mg.wait()


def wait_for_server(port, delay=0.1):
    cmd = ["nc", "-z", "-w", "1", "127.0.0.1", port]
    while subprocess.call(cmd) != 0:
        time.sleep(0.01)
    time.sleep(delay)


def run_memgraph():
    global proc_mg
    proc_mg = subprocess.Popen(cmd, cwd=cwd)
    # Wait for Memgraph to finish the recovery process
    wait_for_server(args.endpoint.split(":")[1])


def run_client(id, data):
    # init
    session = SessionCache.argument_session(args)
    data.setdefault(id, 0)
    counter = data[id]

    # Check recovery for this client
    num_nodes_db = session.run("MATCH (n:%s) RETURN count(n) as cnt" % ("Client%d" % id)).data()[0]["cnt"]
    print("Client%d DB: %d; ACTUAL: %d" % (id, num_nodes_db, data[id]))

    # Execute a new set of write queries
    while True:
        try:
            session.run("CREATE (n:%s {id:%s}) RETURN n;" % ("Client%d" % id, counter)).consume()
            counter += 1
            if counter % 100000 == 0:
                print("Client %d executed %d" % (id, counter))
        except Exception:
            print("DB isn't reachable anymore")
            break
    data[id] = counter
    print("Client %d executed %d" % (id, counter))


def run_step():
    with Pool(args.num_clients) as p:
        p.starmap(run_client, [(id, data) for id in range(1, args.num_clients + 1)])


def main():
    for step in range(1, args.num_steps + 1):
        print("#### Step %d" % step)
        run_memgraph()
        thread = threading.Thread(target=run_step)
        thread.daemon = True
        thread.start()
        # Kill Memgraph at arbitrary point in time. It makse sense to ensure
        # that at least one snapshot has to be generated beucase we want to
        # test the entire recovery process (Snapshot + WAL).
        # Also it makes sense to run this test for a longer period of time
        # but with small --snapshot-cycle-sec to force that Memgraph is being
        # killed in the middle of generating the snapshot.
        time.sleep(random.randint(2 * SNAPSHOT_CYCLE_SEC, 3 * SNAPSHOT_CYCLE_SEC))
        clean_memgraph()

    # Final check
    run_memgraph()
    session = SessionCache.argument_session(args)
    num_nodes_db = session.run("MATCH (n) RETURN count(n) AS cnt").data()[0]["cnt"]
    num_nodes = sum(data.values())
    # Check that more than 99.9% of data is recoverable.
    # NOTE: default WAL flush interval is 1ms.
    # Once the full sync durability mode is introduced,
    # this test has to be extended.
    assert num_nodes_db > 0.999 * num_nodes, "Memgraph lost more than 0.001 of data!"


if __name__ == "__main__":
    main()