#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''
Run the LDBC SNB interactive workload / benchmark.
The benchmark is executed with:
    * ldbc_driver -> workload executor
    * ldbc-snb-impls/snb-interactive-neo4j -> workload implementation
'''

import argparse
import os
import shutil
import subprocess
import tempfile
import time

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))


def wait_for_server(port, delay=1.0):
    cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
    while subprocess.call(cmd) != 0:
        time.sleep(0.5)
    time.sleep(delay)


class Memgraph:
    def __init__(self, dataset, port, num_workers):
        self.proc = None
        self.dataset = dataset
        self.port = str(port)
        self.num_workers = str(num_workers)

    def start(self):
        # find executable path
        binary = os.path.join(BASE_DIR, "build", "memgraph")

        # database args
        database_args = [binary,
                         "--bolt-num-workers", self.num_workers,
                         "--data-directory", os.path.join(self.dataset,
                                                          "memgraph"),
                         "--storage-recover-on-startup", "true",
                         "--bolt-port", self.port]

        # database env
        env = {"MEMGRAPH_CONFIG": os.path.join(SCRIPT_DIR, "config",
                                               "memgraph.conf")}

        # start memgraph
        self.proc = subprocess.Popen(database_args, env=env)
        wait_for_server(self.port)

    def stop(self):
        self.proc.terminate()
        if self.proc.wait() != 0:
            raise Exception("Database exited with non-zero exit code!")


class Neo:
    def __init__(self, dataset, port):
        self.proc = None
        self.dataset = dataset
        self.port = str(port)
        self.http_port = str(int(port) + 7474)
        self.home_dir = None

    def start(self):
        # create home directory
        self.home_dir = tempfile.mkdtemp(dir=self.dataset)

        neo4j_dir = os.path.join(BASE_DIR, "libs", "neo4j")

        try:
            os.symlink(os.path.join(neo4j_dir, "lib"),
                       os.path.join(self.home_dir, "lib"))
            shutil.copytree(os.path.join(self.dataset, "neo4j"),
                            os.path.join(self.home_dir, "data"))
            conf_dir = os.path.join(self.home_dir, "conf")
            conf_file = os.path.join(conf_dir, "neo4j.conf")
            os.mkdir(conf_dir)
            shutil.copyfile(os.path.join(SCRIPT_DIR, "config", "neo4j.conf"),
                            conf_file)
            with open(conf_file, "a") as f:
                f.write("\ndbms.connector.bolt.listen_address=:" +
                        self.port + "\n")
                f.write("\ndbms.connector.http.listen_address=:" +
                        self.http_port + "\n")

            # environment
            env = {"NEO4J_HOME": self.home_dir}

            self.proc = subprocess.Popen([os.path.join(neo4j_dir, "bin",
                                                       "neo4j"),
                                          "console"], env=env, cwd=neo4j_dir)
        except:
            shutil.rmtree(self.home_dir)
            raise Exception("Couldn't run Neo4j!")

        wait_for_server(self.http_port, 2.0)

    def stop(self):
        self.proc.terminate()
        ret = self.proc.wait()
        if os.path.exists(self.home_dir):
            shutil.rmtree(self.home_dir)
        if ret != 0:
            raise Exception("Database exited with non-zero exit code!")


def parse_args():
    argp = argparse.ArgumentParser(
            description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    argp.add_argument('--scale', type=int, default=1,
                      help='Dataset scale to use for benchmarking.')
    argp.add_argument('--host', default='127.0.0.1', help='Database host.')
    argp.add_argument('--port', default='7687', help='Database port.')
    argp.add_argument('--time-compression-ratio', type=float, default=0.001,
                      help='Compress/stretch durations between operation start '
                           'times to increase/decrease benchmark load. '
                           'E.g. 2.0 = run benchmark 2x slower, 0.1 = run '
                           'benchmark 10x faster. Default is 0.001.')
    argp.add_argument('--operation-count', type=int, default=1000,
                      help='Number of operations to generate during benchmark '
                           'execution.')
    argp.add_argument('--thread-count', type=int, default=8,
                      help='Thread pool size to use for executing operation '
                           'handlers.')
    argp.add_argument('--time-unit', default='microseconds',
                      choices=('nanoseconds', 'microseconds', 'milliseconds',
                               'seconds', 'minutes'),
                      help='Time unit to use for measuring performance metrics')
    argp.add_argument('--result-file-prefix', default='',
                      help='Result file name prefix')
    argp.add_argument('--test-type', choices=('reads', 'updates'),
                      default='reads', help='Test queries of type')
    argp.add_argument('--run-db', choices=('memgraph', 'neo4j'),
                      help='Run the database before starting LDBC')
    argp.add_argument('--create-index', action='store_true', default=False,
                      help='Create index in the running database.')
    return argp.parse_args()


LDBC_INTERACTIVE_NEO4J = \
    os.path.join(SCRIPT_DIR,
                 'ldbc-snb-impls', 'snb-interactive-neo4j', 'target',
                 'snb-interactive-neo4j-1.0.0-jar-with-dependencies.jar')
LDBC_DEFAULT_PROPERTIES = \
    os.path.join(SCRIPT_DIR, 'ldbc_driver', 'configuration',
                 'ldbc_driver_default.properties')


def create_index(port, database):
    index_file = os.path.join(SCRIPT_DIR, 'ldbc-snb-impls',
                              'snb-interactive-neo4j', 'scripts', 'indexCreation.neo4j')
    subprocess.check_call(('ve3/bin/python3', 'index_creation', '--port',
                           port, '--database', database, index_file),
                          cwd=SCRIPT_DIR)
    time.sleep(1.0)


def main():
    args = parse_args()
    dataset = os.path.join(SCRIPT_DIR, "datasets", "scale_" + str(args.scale))

    db = None
    if args.run_db:
        if args.host != "127.0.0.1":
            raise Exception("Host parameter must point to localhost when "
                            "this script starts the database!")
        if args.run_db.lower() == 'memgraph':
            db = Memgraph(dataset, args.port, args.thread_count)
        elif args.run_db.lower() == 'neo4j':
            db = Neo(dataset, args.port)

    try:
        if db:
            db.start()
            time.sleep(10.0)

        if args.create_index:
            create_index(args.port, args.run_db.lower())
            time.sleep(10.0)

        # Run LDBC driver.
        cp = 'target/jeeves-0.3-SNAPSHOT.jar:{}'.format(LDBC_INTERACTIVE_NEO4J)
        updates_dir = os.path.join(dataset, 'social_network')
        parameters_dir = os.path.join(dataset, 'substitution_parameters')
        java_cmd = ('java', '-cp', cp, 'com.ldbc.driver.Client',
                    '-P', LDBC_DEFAULT_PROPERTIES,
                    '-P', os.path.join(SCRIPT_DIR, "ldbc-snb-impls-{}."
                                       "properties".format(args.test_type)),
                    '-p', 'ldbc.snb.interactive.updates_dir', updates_dir,
                    '-p', 'host', args.host, '-p', 'port', args.port,
                    '-db', 'net.ellitron.ldbcsnbimpls.interactive.neo4j.Neo4jDb',
                    '-p', 'ldbc.snb.interactive.parameters_dir', parameters_dir,
                    '--time_compression_ratio', str(args.time_compression_ratio),
                    '--operation_count', str(args.operation_count),
                    '--thread_count', str(args.thread_count),
                    '--time_unit', args.time_unit.upper())
        subprocess.check_call(java_cmd, cwd=os.path.join(SCRIPT_DIR, 'ldbc_driver'))

        # Copy the results to results dir.
        ldbc_results = os.path.join(SCRIPT_DIR, 'ldbc_driver', 'results',
                                    'LDBC-results.json')
        results_dir = os.path.join(SCRIPT_DIR, 'results')
        results_name = []
        if args.result_file_prefix:
            results_name.append(args.result_file_prefix)
        if args.run_db:
            results_name.append(args.run_db)
        else:
            results_name.append("external")
        results_name.append("scale_" + str(args.scale))
        results_name = "-".join(results_name + ["LDBC", "results.json"])
        results_copy = os.path.join(results_dir, results_name)
        shutil.copyfile(ldbc_results, results_copy)

        print("Results saved to:", results_copy)

    finally:
        if db:
            db.stop()

    print("Done!")


if __name__ == '__main__':
    main()