#!/usr/bin/env python3
import concurrent.futures
import itertools
import json
import logging
import multiprocessing
import os
import subprocess
from argparse import ArgumentParser
from collections import namedtuple
from pathlib import Path

"""
Script provides helper actions for strata card fraud demo:
    1) Copy memgraph to cluster
    2) Copy durability directories to cluster
    3) Start master and workers
    4) Stop memgraph cluster
    5) Clean memgraph directories
    6) Clean durability directories
    7) Collect memgraph logs
    8) Start and stop tcpdump

Cluster config should be provided in separate json file.

Example usage:
    ```./manage_distributed_card_fraud memgraph-copy --config config.json```

Example json config:
{
  "workload_machines":
  [
    {
      "host" : "distcardfraud2.memgraph.io",
      "type" : "master",
      "address" : "10.1.13.5",
      "port" : 10000,
      "num_workers" : 4,
      "rpc_num_workers" : 4,
      "ssh_port" : 60022
    },
    {
      "host" : "distcardfraud3.memgraph.io",
      "type" : "worker",
      "address" : "10.1.13.6",
      "port" : 10001,
      "num_workers" : 2,
      "rpc_num_workers" : 2,
      "ssh_port" : 60022
    },
    {
      "host" : "distcardfraud4.memgraph.io",
      "type" : "worker",
      "address" : "10.1.13.7",
      "port" : 10002,
      "num_workers" : 2,
      "rpc_num_workers" : 2,
      "ssh_port" : 60022
    }
  ],
  "benchmark_machines":["distcardfraud1.memgraph.io"],
  "statsd_address": "10.1.13.4",
  "statsd_port" : "2500",
  "remote_user": "memgraph",
  "ssh_public_key" : "~/.ssh/id_rsa",
  "memgraph_build_dir" : "~/Development/memgraph/build",
  "memgraph_remote_dir" : "/home/memgraph/memgraph_remote",
  "durability_dir" : "~/Development/memgraph/build/tests/manual/test_dir",
  "durability_remote_dir" : "/home/memgraph/memgraph_remote/durability",
  "logs_dir" : "~/Development/memgraph/logs",
  "logs_remote_dir" : "/home/memgraph/memgraph_remote/logs",
  "tcpdump_dir" : "~/Development/memgraph/tcpdumps",
  "tcpdump_remote_dir" : "/home/memgraph/tcpdumps"
}
"""

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("RemoteRunner")


def parse_args():
    """
    Parse command line arguments
    """
    parser = ArgumentParser(description=__doc__)
    actions = [func for func in dir(RemoteRunner)
               if callable(getattr(RemoteRunner, func))]
    actions = [action.replace('_', '-') for action in actions
               if not action.startswith('_')]

    parser.add_argument("action", metavar="action", choices=actions,
                        help=", ".join(actions))
    parser.add_argument("--config", default="config.json",
                        help="Config for cluster.")
    return parser.parse_args()


class Machine(namedtuple('Machine', ['host', 'type', 'address',
                                     'port', 'num_workers',
                                     'rpc_num_workers', 'ssh_port'])):
    __slots__ = ()  # prevent creation of instance dictionaries

    def __init__(self, **kwargs):
        assert isinstance(self.port, int), "port must be an integer"
        assert isinstance(self.num_workers, int), "num_workers must be \
                                                   an integer"
        assert isinstance(self.rpc_num_workers, int), "rpc_num_workers must be \
                                                   an integer"
        assert isinstance(self.ssh_port, int), "ssh_port must be an integer"


class Config:

    def __init__(self, config_file):
        data = json.load(open(config_file))
        self.workload_machines = [Machine(**config)
                                  for config in data["workload_machines"]]
        self.benchmark_machines = data["benchmark_machines"]
        self.statsd_address = data["statsd_address"]
        self.statsd_port = data["statsd_port"]
        self.remote_user = data["remote_user"]
        self.ssh_public_key = data["ssh_public_key"]
        self.ssh_public_key = str(Path(self.ssh_public_key).expanduser())

        self.memgraph_build_dir = data["memgraph_build_dir"]
        self.memgraph_build_dir = str(
            Path(self.memgraph_build_dir).expanduser())
        self.memgraph_remote_dir = data["memgraph_remote_dir"]

        self.durability_dir = data["durability_dir"]
        self.durability_dir = str(Path(self.durability_dir).expanduser())
        self.durability_remote_dir = data["durability_remote_dir"]

        self.logs_dir = data["logs_dir"]
        self.logs_dir = str(Path(self.logs_dir).expanduser())
        self.logs_remote_dir = data["logs_remote_dir"]

        self.tcpdump_dir = data["tcpdump_dir"]
        self.tcpdump_dir = str(Path(self.tcpdump_dir).expanduser())
        self.tcpdump_remote_dir = data["tcpdump_remote_dir"]

        log.info("Initializing with config\n{}".format(
            json.dumps(data, indent=4, sort_keys=True)))

    def master(self):
        return next(filter(lambda m: m.type == "master",
                           self.workload_machines), None)

    def workers(self):
        return list(filter(lambda m: m.type == "worker",
                           self.workload_machines))


class RemoteRunner:

    def __init__(self, config):
        self._config = config
        self._ssh_args = ["-i", self._config.ssh_public_key]
        self._scp_args = ["-i", self._config.ssh_public_key]
        self._ssh_nohost_cmd = ["ssh"] + self._ssh_args + ["-p"]
        self.executor = concurrent.futures.ThreadPoolExecutor(
            max_workers=multiprocessing.cpu_count())

    def _run_cmd(self, cmd, stdout=None, stderr=subprocess.PIPE,
                 host=None, user=None, ssh_port=None, **kwargs):
        if not stdout:
            stdout = open("/dev/null", "w")

        if cmd[0] != "scp":
            assert host is not None, "Remote host not specified"
            assert ssh_port is not None, "ssh port not specified"
            where = "{}@{}".format(user, host)
            cmd = self._ssh_nohost_cmd + [str(ssh_port)] + [where] + cmd

        log.info("Command: {}".format(cmd))
        ret = subprocess.run(cmd, stdout=stdout, stderr=stderr, **kwargs)
        err = ret.stderr.decode("utf-8")
        if err != "":
            log.error("Command: {} - ERROR: {}".format(cmd, err))
        if stdout == subprocess.PIPE:
            return (ret.returncode, ret.stdout.decode("utf-8"))
        return ret.returncode

    def _mkdir_machine(self, machine, dir_name):
        log.info("mkdir {} on {}".format(dir_name, machine.host))
        ret = self._run_cmd(["mkdir", "-p", dir_name],
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)

    def _mkdir(self, dir_name="tmp"):
        log.info("mkdir {}".format(dir_name))
        fs = self.executor.map(self._mkdir_machine,
                               self._config.workload_machines,
                               itertools.repeat(dir_name))
        for _ in fs: pass # noqa

    def _listdir(self):
        # for testing purposes only
        log.info("listdir")
        machine = self._config.workload_machines[0]
        ret = self._run_cmd(["ls", "-la"],
                            user=self._config.remote_user, host=machine.host,
                            stdout=subprocess.PIPE, ssh_port=machine.ssh_port)
        log.info("\n{}".format(ret[1]))

    def _memgraph_symlink(self, memgraph_version, machine):
        log.info("memgraph-symlink on {}".format(machine.host))
        # create or update a symlink if already exists
        ret = self._run_cmd(["ln", "-sf",
                             os.path.join(self._config.memgraph_remote_dir,
                                          memgraph_version),
                             os.path.join(self._config.memgraph_remote_dir,
                                          "memgraph")],
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)

    def _memgraph_copy_machine(self, machine, memgraph_binary):
        log.info("memgraph-copy on {}".format(machine.host))
        args = ["scp"] + self._scp_args + ["-P", str(machine.ssh_port)]
        args += [memgraph_binary,
                 self._config.remote_user + "@" + machine.host + ":" +
                 self._config.memgraph_remote_dir]
        ret = self._run_cmd(args)
        log.info(ret)
        self._memgraph_symlink(os.path.basename(memgraph_binary), machine)

    def memgraph_copy(self):
        log.info("memgraph-copy")
        self._mkdir(self._config.memgraph_remote_dir)
        self._mkdir(self._config.durability_remote_dir)
        self._mkdir(self._config.logs_remote_dir)
        memgraph_binary = os.path.realpath(
            self._config.memgraph_build_dir + "/memgraph")
        fs = self.executor.map(self._memgraph_copy_machine,
                               self._config.workload_machines,
                               itertools.repeat(memgraph_binary))
        for _ in fs: pass # noqa

    def _memgraph_clean_machine(self, machine):
        log.info("memgraph-clean on {}".format(machine.host))
        ret = self._run_cmd(["rm", "-rf",
                            self._config.memgraph_remote_dir],
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)

    def memgraph_clean(self):
        log.info("memgraph-clean")
        fs = self.executor.map(self._memgraph_clean_machine,
                               self._config.workload_machines)
        for _ in fs: pass # noqa

    def _durability_copy_machine(self, machine, i):
        log.info("durability-copy on {}".format(machine.host))
        # copy durability dir
        args = ["scp", "-r"] + self._scp_args + ["-P",
                                                 str(machine.ssh_port)]
        args += [os.path.join(self._config.durability_dir,
                              "worker_{}/snapshots".format(i)),
                 self._config.remote_user + "@" + machine.host + ":" +
                 self._config.durability_remote_dir]
        ret = self._run_cmd(args)
        log.info(ret)

    def durability_copy(self):
        log.info("durability-copy")
        self._mkdir(self._config.durability_remote_dir)
        fs = self.executor.map(self._durability_copy_machine,
                               self._config.workload_machines,
                               itertools.count(0))
        for _ in fs: pass # noqa

    def _durability_clean_machine(self, machine):
        log.info("durability-clean on {}".format(machine.host))
        ret = self._run_cmd(["rm", "-rf",
                            self._config.durability_remote_dir],
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)

    def durability_clean(self):
        log.info("durability-clean")
        fs = self.executor.map(self._durability_clean_machine,
                               self._config.workload_machines)
        for _ in fs: pass # noqa

    def _is_process_running(self, process_name, machine):
        ret = self._run_cmd(["pgrep", process_name],
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)
        return ret == 0

    def memgraph_single_node(self):
        log.info("memgraph-single-node")
        machine = self._config.master()
        assert machine is not None, "Unable to fetch memgraph config"
        is_running = self._is_process_running("memgraph", machine)
        assert not is_running, "Memgraph already running on machine: "\
            " {}".format(machine)
        dir_cmd = ["cd", self._config.memgraph_remote_dir]
        memgraph_cmd = ["nohup", "./memgraph",
                        "--durability-directory={}".format(
                            self._config.durability_remote_dir),
                        "--db-recover-on-startup=true",
                        "--num-workers", str(machine.num_workers),
                        "--rpc-num-workers", str(machine.rpc_num_workers),
                        "--log-file",
                        os.path.join(self._config.logs_remote_dir,
                                     "log_worker_0"),
                        "--query-vertex-count-to-expand-existing", "-1"]
        cmd = ["eval", "\""] + dir_cmd + ["&&"] + memgraph_cmd + \
              ["\"", "&> /dev/null < /dev/null &"]
        ret = self._run_cmd(cmd,
                            user=self._config.remote_user, host=machine.host,
                            ssh_port=machine.ssh_port)
        log.info(ret)
        is_running = self._is_process_running("memgraph", machine)
        assert is_running, "Memgraph failed to start on machine: "\
            " {}".format(machine)

    def memgraph_master(self):
        log.info("memgraph-master")
        machine = self._config.master()
        assert machine is not None, "Unable to fetch master config"
        is_running = self._is_process_running("memgraph", machine)
        assert not is_running, "Memgraph already running on machine: "\
            " {}".format(machine)
        dir_cmd = ["cd", self._config.memgraph_remote_dir]
        memgraph_cmd = ["nohup", "./memgraph",
                        "--master",
                        "--master-host", machine.address,
                        "--master-port", str(machine.port),
                        "--durability-directory={}".format(
                            self._config.durability_remote_dir),
                        "--db-recover-on-startup=true",
                        "--num-workers", str(machine.num_workers),
                        "--rpc-num-workers", str(machine.rpc_num_workers),
                        "--statsd-address", str(self._config.statsd_address),
                        "--statsd-port", str(self._config.statsd_port),
                        "--log-file",
                        os.path.join(self._config.logs_remote_dir,
                                     "log_worker_0"),
                        "--query-vertex-count-to-expand-existing", "-1"]
        cmd = ["eval", "\""] + dir_cmd + ["&&"] + memgraph_cmd + \
              ["\"", "&> /dev/null < /dev/null &"]
        ret = self._run_cmd(cmd,
                            user=self._config.remote_user, host=machine.host,
                            ssh_port=machine.ssh_port)
        log.info(ret)
        is_running = self._is_process_running("memgraph", machine)
        assert is_running, "Memgraph failed to start on machine: "\
            " {}".format(machine)

    def _memgraph_worker(self, machine, master, i):
        is_running = self._is_process_running("memgraph", machine)
        assert not is_running, "Memgraph already running on machine: "\
            " {}".format(machine)
        dir_cmd = ["cd", self._config.memgraph_remote_dir]
        worker_cmd = ["nohup", "./memgraph",
                      "--worker",
                      "--master-host", master.address,
                      "--master-port", str(master.port),
                      "--worker-id", str(i),
                      "--worker-port", str(machine.port),
                      "--worker-host", machine.address,
                      "--durability-directory={}".format(
                            self._config.durability_remote_dir),
                      "--db-recover-on-startup=true",
                      "--num-workers", str(machine.num_workers),
                      "--rpc-num-workers", str(machine.rpc_num_workers),
                      "--statsd-address", str(self._config.statsd_address),
                      "--statsd-port", str(self._config.statsd_port),
                      "--log-file",
                      os.path.join(self._config.logs_remote_dir,
                                   "log_worker_{}".format(i))]
        cmd = ["eval", "\""] + dir_cmd + ["&&"] + worker_cmd + \
              ["\"", "&> /dev/null < /dev/null &"]
        ret = self._run_cmd(cmd,
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)
        is_running = self._is_process_running("memgraph", machine)
        assert is_running, "Memgraph failed to start on machine: "\
            " {}".format(machine)

    def memgraph_workers(self):
        log.info("memgraph-workers")
        master = self._config.master()
        assert master is not None, "Unable to fetch master config"
        workers = self._config.workers()
        assert workers, "Unable to fetch workers config"
        fs = self.executor.map(self._memgraph_worker,
                               self._config.workload_machines[1:],
                               itertools.repeat(master),
                               itertools.count(1))
        for _ in fs: pass # noqa

    def _kill_process(self, machine, process, signal):
        ret = self._run_cmd(["sudo", "kill", "-" + str(signal),
                             "$(pgrep {})".format(process)],
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)

    def memgraph_terminate(self):
        log.info("memgraph-terminate")
        # only kill master - master stops workers
        machine = self._config.master()
        self._kill_process(machine, "memgraph", 15)

    def memgraph_kill(self):
        log.info("memgraph-kill")
        fs = self.executor.map(self._kill_process,
                               self._config.workload_machines,
                               itertools.repeat("memgraph"),
                               itertools.repeat(9))
        for _ in fs: pass # noqa

    def _collect_log(self, machine, i):
        log.info("collect-log from {}".format(machine.host))
        args = ["scp"] + self._scp_args + ["-P", str(machine.ssh_port)]
        args += [self._config.remote_user + "@" + machine.host + ":" +
                 os.path.join(self._config.logs_remote_dir,
                              "log_worker_{}".format(i)),
                 self._config.logs_dir]
        ret = self._run_cmd(args)
        log.info(ret)

    def collect_logs(self):
        log.info("collect-logs")
        fs = self.executor.map(self._collect_log,
                               self._config.workload_machines,
                               itertools.count(0))
        for _ in fs: pass # noqa

    def _tcpdump_machine(self, machine, i):
        cmd = ["sudo", "tcpdump", "-i", "eth0", "-w", os.path.join(
                                 self._config.tcpdump_remote_dir,
                                 "dump_worker_{}".format(i))]
        cmd = ["eval", "\""] + cmd + ["\"", "&> /dev/null < /dev/null &"]
        ret = self._run_cmd(cmd,
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)

    def tcpdump_start(self):
        log.info("tcpdump-start")
        self._mkdir(self._config.tcpdump_remote_dir)
        fs = self.executor.map(self._tcpdump_machine,
                               self._config.workload_machines,
                               itertools.count(0))
        for _ in fs: pass # noqa

    def _tcpdump_collect_machine(self, machine, i):
        log.info("collect-tcpdump from {}".format(machine.host))
        # first kill tcpdump process
        ret = self._run_cmd(["nohup", "sudo", "kill", "-15",
                             "$(pgrep tcpdump)"],
                            user=self._config.remote_user,
                            host=machine.host, ssh_port=machine.ssh_port)
        log.info(ret)

        args = ["scp"] + self._scp_args + ["-P", str(machine.ssh_port)]
        args += [self._config.remote_user + "@" + machine.host + ":" +
                 os.path.join(self._config.tcpdump_remote_dir,
                              "dump_worker_{}".format(i)),
                 self._config.tcpdump_dir]
        ret = self._run_cmd(args)
        log.info(ret)

    def tcpdump_collect(self):
        log.info("tcpdump-collect")
        fs = self.executor.map(self._tcpdump_collect_machine,
                               self._config.workload_machines,
                               itertools.count(0))
        for _ in fs: pass # noqa

    def tcpdump_clean(self):
        log.info("tcpdump-clean")
        for machine in self._config.workload_machines:
            log.info("tcpdump-clean on {}".format(machine.host))
            ret = self._run_cmd(["rm", "-rf",
                                self._config.tcpdump_remote_dir],
                                user=self._config.remote_user,
                                host=machine.host, ssh_port=machine.ssh_port)
            log.info(ret)


def main():
    args = parse_args()
    config = Config(args.config)
    action = args.action.replace("-", "_")
    runner = RemoteRunner(config)
    action = getattr(runner, action)
    action()


if __name__ == "__main__":
    main()