#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''
Durability Stress Test
'''

# TODO (mg_durability_stress_test): extend once we add full durability mode

import atexit
import multiprocessing
import os
import random
import shutil
import subprocess
import time
import threading

from common import connection_argument_parser, SessionCache
from multiprocessing import Pool, Manager

# Constants and args
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
BUILD_DIR = os.path.join(BASE_DIR, "build")
DATA_DIR = os.path.join(BUILD_DIR, "mg_data")
if "THREADS" in os.environ:
    DB_WORKERS = os.environ["THREADS"]
else:
    DB_WORKERS = multiprocessing.cpu_count()
# The snapshot interval has to be in sync with Memgraph's parameter
# --snapshot-cycle-sec
SNAPSHOT_CYCLE_SEC = 5

# Parse command line arguments
parser = connection_argument_parser()

parser.add_argument("--memgraph", default=os.path.join(BUILD_DIR,
                                                       "memgraph"))
parser.add_argument("--log-file", default="")
parser.add_argument("--verbose", action="store_const",
                    const=True, default=False)
parser.add_argument("--data-directory", default=DATA_DIR)
parser.add_argument("--num-clients", default=multiprocessing.cpu_count())
parser.add_argument("--num-steps", type=int, default=5)
args = parser.parse_args()

# Memgraph run command construction
cwd = os.path.dirname(args.memgraph)
cmd = [args.memgraph, "--bolt-num-workers=" + str(DB_WORKERS),
       "--storage-properties-on-edges=true",
       "--storage-snapshot-on-exit=false",
       "--storage-snapshot-interval-sec=5",
       "--storage-snapshot-retention-count=2",
       "--storage-wal-enabled=true",
       "--storage-recover-on-startup=true",
       "--query-execution-timeout-sec=600"]
if not args.verbose:
    cmd += ["--log-level", "WARNING"]
if args.log_file:
    cmd += ["--log-file", args.log_file]
if args.data_directory:
    cmd += ["--data-directory", args.data_directory]

data_manager = Manager()
data = data_manager.dict()

# Pretest cleanup
if os.path.exists(DATA_DIR):
    shutil.rmtree(DATA_DIR)

atexit.register(SessionCache.cleanup)


@atexit.register
def clean_memgraph():
    global proc_mg
    if proc_mg is None:
        return
    if proc_mg.poll() is not None:
        return
    proc_mg.kill()
    proc_mg.wait()


def wait_for_server(port, delay=0.1):
    cmd = ["nc", "-z", "-w", "1", "127.0.0.1", port]
    while subprocess.call(cmd) != 0:
        time.sleep(0.01)
    time.sleep(delay)


def run_memgraph():
    global proc_mg
    proc_mg = subprocess.Popen(cmd, cwd=cwd)
    # Wait for Memgraph to finish the recovery process
    wait_for_server(args.endpoint.split(":")[1])


def run_client(id, data):
    # init
    session = SessionCache.argument_session(args)
    data.setdefault(id, 0)
    counter = data[id]

    # Check recovery for this client
    num_nodes_db = session.run(
        "MATCH (n:%s) RETURN count(n) as cnt" %
        ("Client%d" % id)).data()[0]["cnt"]
    print("Client%d DB: %d; ACTUAL: %d" % (id, num_nodes_db, data[id]))

    # Execute a new set of write queries
    while True:
        try:
            session.run("CREATE (n:%s {id:%s}) RETURN n;" %
                        ("Client%d" % id, counter)).consume()
            counter += 1
            if counter % 100000 == 0:
                print("Client %d executed %d" % (id, counter))
        except Exception:
            print("DB isn't reachable any more")
            break
    data[id] = counter
    print("Client %d executed %d" % (id, counter))


def run_step():
    with Pool(args.num_clients) as p:
        p.starmap(run_client, [(id, data)
                               for id in range(1, args.num_clients + 1)])


def main():
    for step in range(1, args.num_steps + 1):
        print("#### Step %d" % step)
        run_memgraph()
        thread = threading.Thread(target=run_step)
        thread.daemon = True
        thread.start()
        # Kill Memgraph at arbitrary point in time. It makse sense to ensure
        # that at least one snapshot has to be generated beucase we want to
        # test the entire recovery process (Snapshot + WAL).
        # Also it makes sense to run this test for a longer period of time
        # but with small --snapshot-cycle-sec to force that Memgraph is being
        # killed in the middle of generating the snapshot.
        time.sleep(
            random.randint(2 * SNAPSHOT_CYCLE_SEC, 3 * SNAPSHOT_CYCLE_SEC))
        clean_memgraph()

    # Final check
    run_memgraph()
    session = SessionCache.argument_session(args)
    num_nodes_db = \
        session.run("MATCH (n) RETURN count(n) AS cnt").data()[0]["cnt"]
    num_nodes = sum(data.values())
    # Check that more than 99.9% of data is recoverable.
    # NOTE: default WAL flush interval is 1ms.
    # Once the full sync durability mode is introduced,
    # this test has to be extended.
    assert num_nodes_db > 0.999 * num_nodes, \
        "Memgraph lost more than 0.001 of data!"


if __name__ == "__main__":
    main()