73c1206e81
Reviewers: mferencevic, buda Reviewed By: buda Subscribers: mislav.bradac, pullbot Differential Revision: https://phabricator.memgraph.io/D988
240 lines
8.3 KiB
Python
Executable File
240 lines
8.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
'''
|
|
A script for transfering all data from a Neo4j database
|
|
into a Memgraph database.
|
|
'''
|
|
|
|
import logging
|
|
import json
|
|
import os
|
|
from time import time
|
|
from datetime import datetime
|
|
from argparse import ArgumentParser
|
|
|
|
from neo4j.v1 import GraphDatabase, basic_auth
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
TEMP_ID = "__memgraph_temp_id_314235423"
|
|
TEMP_LABEL = "__memgraph_temp_label_1414213562"
|
|
|
|
|
|
def parse_args():
|
|
argp = ArgumentParser(description=__doc__)
|
|
argp.add_argument("--neo-url", default = "127.0.0.1:7687",
|
|
help = "Neo4j url, default 127.0.0.1:7687.")
|
|
argp.add_argument("--neo-user", default = "neo4j",
|
|
help = "Neo4j username, default neo4j.")
|
|
argp.add_argument("--neo-password", default = "1234",
|
|
help = "Neo4j password, default 1234.")
|
|
argp.add_argument("--neo-ssl", default = False, choices = [True, False],
|
|
help = "Encryption for neo4j auth data, default False")
|
|
argp.add_argument("--memgraph-url", default = "127.0.0.1:7688",
|
|
help = "Memgraph url, default 127.0.0.1:7688.")
|
|
argp.add_argument("--logging", default = "DEBUG", choices = ["INFO", "DEBUG"],
|
|
help = "Logging level, default debug.")
|
|
argp.add_argument("--json-storage", default = "json",
|
|
help = "Storage for JSON files.")
|
|
return argp.parse_args()
|
|
|
|
|
|
def create_vertex_cypher(vertex):
|
|
"""
|
|
Helper function that generates a cypher query for creting
|
|
a vertex based on the given Bolt vertex.
|
|
"""
|
|
labels = ""
|
|
if vertex.labels:
|
|
labels += ":" + ":".join(vertex.labels)
|
|
vertex.properties[TEMP_ID] = vertex.id
|
|
properties = ", ".join('%s: %r' % kv for kv
|
|
in vertex.properties.items())
|
|
return "CREATE (%s {%s})" % (labels, properties)
|
|
|
|
|
|
def create_edge_cypher(edge, edge_num):
|
|
"""
|
|
Helper function that generates a cypher query for creting
|
|
a edge based on the given Bolt edge.
|
|
"""
|
|
properties = ", ".join('%s: %r' % kv for kv
|
|
in edge.properties.items())
|
|
return ["(from%s {%s: %r}), (to%s {%s: %r})" \
|
|
% (edge_num, TEMP_ID, edge.start, edge_num, TEMP_ID, edge.end),
|
|
"CREATE (from%s)-[:%s {%s}]->(to%s)" \
|
|
% (edge_num, edge.type, properties, edge_num)]
|
|
|
|
def vertex_to_dict(vertex):
|
|
"""
|
|
Returns dictionary which represents one vertex with labels and
|
|
properties.
|
|
|
|
:param vertex: Graph vertex
|
|
"""
|
|
return {"labels": list(vertex.labels), "properties": vertex.properties}
|
|
|
|
def edge_to_dict(edge):
|
|
"""
|
|
Returns dictionary which represents one edge with type, start vertex,
|
|
end vertex and properties.
|
|
|
|
:param edge: Graph edge
|
|
"""
|
|
return {"type": edge.type, "properties": edge.properties, \
|
|
"start": edge.start, "end": edge.end}
|
|
|
|
def create_json_file(storage, timestamp, element, batch_index, content):
|
|
"""
|
|
Creates json file with given content and path
|
|
storage/timestamp/element and file name is batch_count.json. Creates
|
|
directories where files are stored if directories don't exist.
|
|
|
|
:param storage: str, path where all json files are stored
|
|
:param timestamp: str, timestamp of the current transfer
|
|
:param element: str, expected vertex or edge, which elements are
|
|
stored in json
|
|
:param batch_index: int, index of the current batch, used in file name
|
|
:param content: list, contet which will be dumped in file
|
|
"""
|
|
json_file = os.path.join(storage, timestamp, element,
|
|
str(batch_index) + ".json")
|
|
os.makedirs(os.path.dirname(json_file), exist_ok = True)
|
|
print(content)
|
|
with open(json_file, 'w') as f:
|
|
json.dump(content, f, indent = 2)
|
|
|
|
|
|
def transfer(storage, neo_driver, memgraph_driver):
|
|
""" Copies all the data from Neo4j to Memgraph. """
|
|
|
|
# TODO add error handling
|
|
neo_session = neo_driver.session()
|
|
memgraph_session = memgraph_driver.session()
|
|
|
|
# Creating index
|
|
log.debug("Creating memgraph index on TEMP_LABEL and TEMP_ID.")
|
|
memgraph_session.run("CREATE INDEX ON :%s(%s)" % (TEMP_LABEL, TEMP_ID))
|
|
neo_session.run("MATCH(n) SET n :%s, n.%s = ID(n)" % (TEMP_LABEL, TEMP_ID))
|
|
neo_session.run("CREATE INDEX ON :%s(%s)" % (TEMP_LABEL, TEMP_ID))
|
|
|
|
read_vertex_batch = 2
|
|
write_vertex_batch = 3
|
|
read_edge_batch = 2
|
|
write_edge_batch = 3
|
|
vertex_count = 0
|
|
edge_count = 0
|
|
|
|
cypher_query = ""
|
|
batch_count = 0
|
|
timestamp = datetime.fromtimestamp(time()).strftime("%Y_%m_%d__%H_%M_%S")
|
|
|
|
def write_vertices(vertices):
|
|
nonlocal batch_count
|
|
cypher_query = ""
|
|
vertices_list = []
|
|
for vertex in vertices:
|
|
cypher_query += create_vertex_cypher(vertex)
|
|
vertices_list.append(vertex_to_dict(vertex))
|
|
create_json_file(storage, timestamp, "vertices", batch_count, vertices_list)
|
|
log.debug("Vertex create on cypher: %s" % (cypher_query))
|
|
memgraph_session.run(cypher_query).consume()
|
|
batch_count += 1
|
|
vertices[:] = []
|
|
|
|
def write_edges(edges):
|
|
nonlocal batch_count
|
|
cypher_query = ""
|
|
edge_num = 0
|
|
edges_list = []
|
|
for edge in edges:
|
|
edges_list.append(edge_to_dict(edge))
|
|
edge_queries = create_edge_cypher(edge, edge_num)
|
|
edge_num += 1
|
|
if cypher_query:
|
|
cypher_query = edge_queries[0] + ", " + cypher_query + \
|
|
" " + edge_queries[1]
|
|
else:
|
|
cypher_query = ' '.join(edge_queries)
|
|
create_json_file(storage, timestamp, "edges", batch_count, edges_list)
|
|
cypher_query = "MATCH " + cypher_query
|
|
log.debug("Edge create on cypher: %s" % (cypher_query))
|
|
memgraph_session.run(cypher_query).consume()
|
|
batch_count += 1
|
|
edges[:] = []
|
|
|
|
# Vertex transfer
|
|
start_id = 0
|
|
vertices_batch = []
|
|
while True:
|
|
read_vertices_in_batch = 0
|
|
vertices = neo_session.run("MATCH(n) WHERE n.%s>=%s RETURN n " \
|
|
"ORDER BY ID(n) LIMIT %s" % (TEMP_ID, start_id, read_vertex_batch))
|
|
for vertex in vertices:
|
|
vertex = vertex['n']
|
|
vertices_batch.append(vertex)
|
|
if len(vertices_batch) >= write_vertex_batch:
|
|
write_vertices(vertices_batch)
|
|
start_id = vertex.id
|
|
read_vertices_in_batch += 1
|
|
vertex_count += 1
|
|
start_id += 1
|
|
if read_vertices_in_batch != read_vertex_batch:
|
|
break
|
|
if len(vertices_batch) > 0:
|
|
write_vertices(vertices_batch)
|
|
|
|
max_id = neo_session.run("MATCH(n) RETURN MAX(ID(n)) AS id").peek()['id']
|
|
|
|
start_id = 0
|
|
batch_count = 0
|
|
edges_batch = []
|
|
while start_id <= max_id:
|
|
edges = neo_session.run("MATCH (n)-[r]->() WHERE n.%s>=%s AND " \
|
|
"n.%s<%s RETURN r" % (TEMP_ID, start_id, TEMP_ID, start_id + read_edge_batch))
|
|
start_id = start_id + read_edge_batch
|
|
for edge in edges:
|
|
edge_count += 1
|
|
edge = edge['r']
|
|
edges_batch.append(edge)
|
|
if len(edges_batch) >= write_edge_batch:
|
|
write_edges(edges_batch)
|
|
if len(edges_batch) > 0:
|
|
write_edges(edges_batch)
|
|
|
|
# TODO Drop index in memgraph when it will be supported
|
|
log.debug("Removing TEMP_LABEL and TEMP_ID")
|
|
memgraph_session.run("MATCH (n) REMOVE n:%s, n.%s" % (TEMP_LABEL, TEMP_ID))
|
|
neo_session.run("MATCH (n) REMOVE n:%s, n.%s" % (TEMP_LABEL, TEMP_ID))
|
|
neo_session.run("DROP INDEX ON :%s(%s)" % (TEMP_LABEL, TEMP_ID))
|
|
log.info("Created %d vertiecs and %d edges", vertex_count, edge_count)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
if args.logging:
|
|
logging.basicConfig(level=args.logging)
|
|
logging.getLogger("neo4j").setLevel(logging.WARNING)
|
|
|
|
log.info("Memgraph from Neo4j data import tool")
|
|
|
|
neo_driver = GraphDatabase.driver(
|
|
"bolt://" + args.neo_url,
|
|
auth=basic_auth(args.neo_user, args.neo_password),
|
|
encrypted=args.neo_ssl)
|
|
memgraph_driver = GraphDatabase.driver(
|
|
"bolt://" + args.memgraph_url,
|
|
auth=basic_auth("", ""),
|
|
encrypted=False)
|
|
|
|
start_time = time()
|
|
transfer(args.json_storage, neo_driver, memgraph_driver)
|
|
log.info("Import complete in %.2f seconds", time() - start_time)
|
|
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
main()
|