#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''
A script for transfering all data from a Neo4j database
into a Memgraph database.
'''

import logging
import json
import os
from time import time
from datetime import datetime
from argparse import ArgumentParser

from neo4j.v1 import GraphDatabase, basic_auth

log = logging.getLogger(__name__)


TEMP_ID = "__memgraph_temp_id_314235423"
TEMP_LABEL = "__memgraph_temp_label_1414213562"


def parse_args():
    argp = ArgumentParser(description=__doc__)
    argp.add_argument("--neo-url", default = "127.0.0.1:7687",
                      help = "Neo4j url, default 127.0.0.1:7687.")
    argp.add_argument("--neo-user", default = "neo4j",
                      help = "Neo4j username, default neo4j.")
    argp.add_argument("--neo-password", default = "1234",
                      help = "Neo4j password, default 1234.")
    argp.add_argument("--neo-ssl", default = False, choices = [True, False],
                      help = "Encryption for neo4j auth data, default False")
    argp.add_argument("--memgraph-url", default = "127.0.0.1:7688",
                      help = "Memgraph url, default 127.0.0.1:7688.")
    argp.add_argument("--logging", default = "DEBUG", choices = ["INFO", "DEBUG"],
                      help = "Logging level, default debug.") 
    argp.add_argument("--json-storage", default = "json",
                      help = "Storage for JSON files.")
    return argp.parse_args()


def create_vertex_cypher(vertex):
    """
    Helper function that generates a cypher query for creting
    a vertex based on the given Bolt vertex.
    """
    labels = ""
    if vertex.labels:
        labels += ":" + ":".join(vertex.labels)
    vertex.properties[TEMP_ID] = vertex.id
    properties = ", ".join('%s: %r' % kv for kv
                           in vertex.properties.items())
    return "CREATE (%s {%s})" % (labels, properties)


def create_edge_cypher(edge, edge_num):
    """
    Helper function that generates a cypher query for creting
    a edge based on the given Bolt edge.
    """
    properties = ", ".join('%s: %r' % kv for kv
                           in edge.properties.items())
    return ["(from%s {%s: %r}), (to%s {%s: %r})" \
                % (edge_num, TEMP_ID, edge.start, edge_num, TEMP_ID, edge.end),
            "CREATE (from%s)-[:%s {%s}]->(to%s)" \
                % (edge_num, edge.type, properties, edge_num)]

def vertex_to_dict(vertex):
    """
    Returns dictionary which represents one vertex with labels and
    properties.

    :param vertex: Graph vertex
    """
    return {"labels": list(vertex.labels), "properties": vertex.properties}

def edge_to_dict(edge):
    """
    Returns dictionary which represents one edge with type, start vertex,
    end vertex and properties.

    :param edge: Graph edge
    """
    return {"type": edge.type, "properties": edge.properties, \
            "start": edge.start, "end": edge.end}

def create_json_file(storage, timestamp, element, batch_index, content):
    """
    Creates json file with given content and path
    storage/timestamp/element and file name is batch_count.json. Creates
    directories where files are stored if directories don't exist.

    :param storage: str, path where all json files are stored
    :param timestamp: str, timestamp of the current transfer
    :param element: str, expected vertex or edge, which elements are
                    stored in json
    :param batch_index: int, index of the current batch, used in file name
    :param content: list, contet which will be dumped in file
    """
    json_file = os.path.join(storage, timestamp, element,
                        str(batch_index) + ".json")
    os.makedirs(os.path.dirname(json_file), exist_ok = True)
    print(content)
    with open(json_file, 'w') as f:
        json.dump(content, f, indent = 2)


def transfer(storage, neo_driver, memgraph_driver):
    """ Copies all the data from Neo4j to Memgraph. """

    # TODO add error handling
    neo_session = neo_driver.session()
    memgraph_session = memgraph_driver.session()

    # Creating index
    log.debug("Creating memgraph index on TEMP_LABEL and TEMP_ID.")
    memgraph_session.run("CREATE INDEX ON :%s(%s)" % (TEMP_LABEL, TEMP_ID))
    neo_session.run("MATCH(n) SET n :%s, n.%s = ID(n)" % (TEMP_LABEL, TEMP_ID))
    neo_session.run("CREATE INDEX ON :%s(%s)" % (TEMP_LABEL, TEMP_ID))

    read_vertex_batch = 2
    write_vertex_batch = 3
    read_edge_batch = 2
    write_edge_batch = 3
    vertex_count = 0
    edge_count = 0

    cypher_query = ""
    batch_count = 0
    timestamp = datetime.fromtimestamp(time()).strftime("%Y_%m_%d__%H_%M_%S")

    def write_vertices(vertices):
        nonlocal batch_count
        cypher_query = ""
        vertices_list = []
        for vertex in vertices:
            cypher_query += create_vertex_cypher(vertex)
            vertices_list.append(vertex_to_dict(vertex))
        create_json_file(storage, timestamp, "vertices", batch_count, vertices_list)
        log.debug("Vertex create on cypher: %s" % (cypher_query))
        memgraph_session.run(cypher_query).consume()
        batch_count += 1
        vertices[:] = []

    def write_edges(edges):
        nonlocal batch_count
        cypher_query = ""
        edge_num = 0
        edges_list = []
        for edge in edges:
            edges_list.append(edge_to_dict(edge))
            edge_queries = create_edge_cypher(edge, edge_num)
            edge_num += 1
            if cypher_query:
                cypher_query = edge_queries[0] + ", " + cypher_query + \
                        " " + edge_queries[1]
            else:
                cypher_query = ' '.join(edge_queries)
        create_json_file(storage, timestamp, "edges", batch_count, edges_list)
        cypher_query = "MATCH " + cypher_query
        log.debug("Edge create on cypher: %s" % (cypher_query))
        memgraph_session.run(cypher_query).consume()
        batch_count += 1
        edges[:] = []

    # Vertex transfer
    start_id = 0
    vertices_batch = []
    while True:
        read_vertices_in_batch = 0
        vertices = neo_session.run("MATCH(n) WHERE n.%s>=%s RETURN n " \
                      "ORDER BY ID(n) LIMIT %s" % (TEMP_ID, start_id, read_vertex_batch))
        for vertex in vertices:
            vertex = vertex['n']
            vertices_batch.append(vertex)
            if len(vertices_batch) >= write_vertex_batch:
                write_vertices(vertices_batch)
            start_id = vertex.id
            read_vertices_in_batch += 1
            vertex_count += 1
        start_id += 1
        if read_vertices_in_batch != read_vertex_batch:
            break
    if len(vertices_batch) > 0:
        write_vertices(vertices_batch)

    max_id = neo_session.run("MATCH(n) RETURN MAX(ID(n)) AS id").peek()['id']

    start_id = 0
    batch_count = 0
    edges_batch = []
    while start_id <= max_id:
        edges = neo_session.run("MATCH (n)-[r]->() WHERE n.%s>=%s AND " \
              "n.%s<%s RETURN r" % (TEMP_ID, start_id, TEMP_ID, start_id + read_edge_batch))
        start_id = start_id + read_edge_batch
        for edge in edges:
            edge_count += 1
            edge = edge['r']
            edges_batch.append(edge)
            if len(edges_batch) >= write_edge_batch:
                write_edges(edges_batch)
    if len(edges_batch) > 0:
        write_edges(edges_batch)

    # TODO Drop index in memgraph when it will be supported
    log.debug("Removing TEMP_LABEL and TEMP_ID")
    memgraph_session.run("MATCH (n) REMOVE n:%s, n.%s" % (TEMP_LABEL, TEMP_ID))
    neo_session.run("MATCH (n) REMOVE n:%s, n.%s" % (TEMP_LABEL, TEMP_ID))
    neo_session.run("DROP INDEX ON :%s(%s)" % (TEMP_LABEL, TEMP_ID))
    log.info("Created %d vertiecs and %d edges", vertex_count, edge_count)


def main():
    args = parse_args()
    if args.logging:
        logging.basicConfig(level=args.logging)
        logging.getLogger("neo4j").setLevel(logging.WARNING)

    log.info("Memgraph from Neo4j data import tool")

    neo_driver = GraphDatabase.driver(
        "bolt://" + args.neo_url,
        auth=basic_auth(args.neo_user, args.neo_password),
        encrypted=args.neo_ssl)
    memgraph_driver = GraphDatabase.driver(
        "bolt://" + args.memgraph_url,
        auth=basic_auth("", ""),
        encrypted=False)

    start_time = time()
    transfer(args.json_storage, neo_driver, memgraph_driver)
    log.info("Import complete in %.2f seconds", time() - start_time)

    pass

if __name__ == '__main__':
    main()