345 lines
13 KiB
Plaintext
345 lines
13 KiB
Plaintext
|
#!/usr/bin/env python3
|
||
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
'''
|
||
|
Create a Memgraph recovery snapshot file from CSV.
|
||
|
'''
|
||
|
|
||
|
import argparse
|
||
|
import csv
|
||
|
import itertools as it
|
||
|
import logging
|
||
|
import struct
|
||
|
|
||
|
log = logging.getLogger(__name__)
|
||
|
|
||
|
_CSV_TYPE_TO_PY_TYPE = {
|
||
|
'int': int,
|
||
|
'long': int,
|
||
|
'float': float,
|
||
|
'double': float,
|
||
|
'boolean': bool,
|
||
|
'byte': int,
|
||
|
'short': int,
|
||
|
'char': str,
|
||
|
'string': str,
|
||
|
}
|
||
|
|
||
|
|
||
|
def csv_to_py_val(value, csv_type, array_delimiter):
|
||
|
if not csv_type.endswith('[]'):
|
||
|
return _CSV_TYPE_TO_PY_TYPE[csv_type](value)
|
||
|
# Otherwise we have an array type, so convert it to a list.
|
||
|
csv_type = csv_type[:-2]
|
||
|
py_type = _CSV_TYPE_TO_PY_TYPE[csv_type]
|
||
|
return [py_type(val.strip()) for val in value.split(array_delimiter)]
|
||
|
|
||
|
|
||
|
class NodeId:
|
||
|
def __init__(self, id_, id_space):
|
||
|
if not id_:
|
||
|
raise ValueError('ID must not be empty')
|
||
|
self.id = id_
|
||
|
self.id_space = id_space
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
if not isinstance(other, NodeId):
|
||
|
return NotImplemented
|
||
|
return self.id == other.id and self.id_space == other.id_space
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash((self.id, self.id_space))
|
||
|
|
||
|
def __str__(self):
|
||
|
if self.id_space is None:
|
||
|
return self.id
|
||
|
return '{}({})'.format(self.id, self.id_space)
|
||
|
|
||
|
|
||
|
class Hasher:
|
||
|
'''Implementation of memgraph/src/durability/hasher.
|
||
|
The API mimics hashlib, so that it will be easier to switch to something
|
||
|
more sane (e.g. sha256).'''
|
||
|
_PRIME = 3137
|
||
|
|
||
|
def __init__(self):
|
||
|
self._hash = 0
|
||
|
|
||
|
def update(self, data):
|
||
|
if not isinstance(data, bytes):
|
||
|
raise TypeError("Expected 'bytes', but got '{}'"
|
||
|
.format(type(data).__name__))
|
||
|
for byte in data:
|
||
|
self._hash = self._hash * self._PRIME + byte + 1
|
||
|
|
||
|
def digest(self):
|
||
|
'''Return the digest value as an int (which fits in uint64_t) and
|
||
|
*not* as bytes. (This is different from hashlib objects.)'''
|
||
|
return self._hash % (2**64)
|
||
|
|
||
|
|
||
|
class BoltEncoder:
|
||
|
# Type markers
|
||
|
_NULL_MARKER = b'\xC0'
|
||
|
_FLOAT64_MARKER = b'\xC1'
|
||
|
_FALSE_MARKER = b'\xC2'
|
||
|
_TRUE_MARKER = b'\xC3'
|
||
|
_INT64_MARKER = b'\xCB'
|
||
|
_STRING32_MARKER = b'\xD2'
|
||
|
_LIST32_MARKER = b'\xD6'
|
||
|
_MAP32_MARKER = b'\xDA'
|
||
|
_NODE_MARKER = b'\xB3\x4E'
|
||
|
_RELATIONSHIP_MARKER = b'\xB5\x52'
|
||
|
# Struct formats
|
||
|
_INT64_STRUCT = struct.Struct('>q')
|
||
|
_UINT32_STRUCT = struct.Struct('>I')
|
||
|
_UINT64_STRUCT = struct.Struct('>Q')
|
||
|
_FLOAT64_STRUCT = struct.Struct('>d')
|
||
|
|
||
|
def __init__(self, file, hasher, skip_duplicate_nodes):
|
||
|
self._file = file
|
||
|
self._hasher = hasher
|
||
|
self._relationship_id = 0
|
||
|
self._node_id = 0
|
||
|
self._csv_to_mg_node_id = {}
|
||
|
self._skip_duplicate_nodes = skip_duplicate_nodes
|
||
|
|
||
|
def write(self, value):
|
||
|
if value is None:
|
||
|
return self.write_null()
|
||
|
write = getattr(self, 'write_' + type(value).__name__)
|
||
|
write(value)
|
||
|
|
||
|
def write_null(self):
|
||
|
self._write(self._NULL_MARKER)
|
||
|
|
||
|
def write_bool(self, value):
|
||
|
if value:
|
||
|
self._write(self._TRUE_MARKER)
|
||
|
else:
|
||
|
self._write(self._FALSE_MARKER)
|
||
|
|
||
|
def write_int(self, value):
|
||
|
self._write(self._INT64_MARKER)
|
||
|
self._write(self._INT64_STRUCT.pack(value))
|
||
|
|
||
|
def write_float(self, value):
|
||
|
self._write(self._FLOAT64_MARKER)
|
||
|
self._write(self._FLOAT64_STRUCT.pack(value))
|
||
|
|
||
|
def write_str(self, value):
|
||
|
self._write(self._STRING32_MARKER)
|
||
|
self._write(self._UINT32_STRUCT.pack(len(value)))
|
||
|
self._write(value.encode('utf-8'))
|
||
|
|
||
|
def write_list(self, values):
|
||
|
self._write(self._LIST32_MARKER)
|
||
|
self._write(self._UINT32_STRUCT.pack(len(values)))
|
||
|
for value in values:
|
||
|
self.write(value)
|
||
|
|
||
|
def write_dict(self, dict_value):
|
||
|
self._write(self._MAP32_MARKER)
|
||
|
self._write(self._UINT32_STRUCT.pack(len(dict_value)))
|
||
|
for key, value in dict_value.items():
|
||
|
self.write_str(key)
|
||
|
self.write(value)
|
||
|
|
||
|
def write_summary(self, node_count, relationship_count):
|
||
|
# It's a bit silly that the summary isn't considered for hashing
|
||
|
# (see: memgraph/src/durability/file_writer_buffer)
|
||
|
self._write(self._UINT64_STRUCT.pack(node_count), update_hash=False)
|
||
|
self._write(self._UINT64_STRUCT.pack(relationship_count),
|
||
|
update_hash=False)
|
||
|
self._write(self._UINT64_STRUCT.pack(self._hasher.digest()),
|
||
|
update_hash=False)
|
||
|
|
||
|
def write_node(self, node_id, labels, properties):
|
||
|
id_ = None
|
||
|
try:
|
||
|
id_ = self._add_node_id(node_id)
|
||
|
except ValueError:
|
||
|
if self._skip_duplicate_nodes:
|
||
|
return
|
||
|
else:
|
||
|
raise
|
||
|
self._write(self._NODE_MARKER)
|
||
|
self.write_int(id_)
|
||
|
self.write_list(labels)
|
||
|
self.write_dict(properties)
|
||
|
|
||
|
def write_relationship(self, start_id, end_id, type_, properties):
|
||
|
self._write(self._RELATIONSHIP_MARKER)
|
||
|
self.write_int(self._relationship_id)
|
||
|
self.write_int(self._csv_to_mg_node_id[start_id])
|
||
|
self.write_int(self._csv_to_mg_node_id[end_id])
|
||
|
self._relationship_id += 1
|
||
|
self.write_str(type_)
|
||
|
self.write_dict(properties)
|
||
|
|
||
|
def _write(self, byte_data, update_hash=True):
|
||
|
log.debug("Writing bytes '0x{}'".format(byte_data.hex()))
|
||
|
if update_hash:
|
||
|
self._hasher.update(byte_data)
|
||
|
self._file.write(byte_data)
|
||
|
|
||
|
def _add_node_id(self, node_id):
|
||
|
'''Add a new mapping from CSV node ID to Memgraph ID and
|
||
|
return the Memgraph ID.'''
|
||
|
if node_id in self._csv_to_mg_node_id:
|
||
|
raise ValueError("Node '{}' already exists".format(node_id))
|
||
|
id_ = self._node_id
|
||
|
self._csv_to_mg_node_id[node_id] = id_
|
||
|
self._node_id += 1
|
||
|
return id_
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
argp = argparse.ArgumentParser(description=__doc__)
|
||
|
argp.add_argument('-o', '--out', required=True,
|
||
|
help='Destination for the created snapshot file')
|
||
|
argp.add_argument('-n', '--nodes', action='append', required=True,
|
||
|
help='CSV file containing graph nodes (vertices)')
|
||
|
argp.add_argument('-r', '--relationships', default=[], action='append',
|
||
|
help='CSV file containing graph relationships (edges)')
|
||
|
argp.add_argument('--overwrite', action='store_true', default=False,
|
||
|
help='Overwrite the output file if it exists')
|
||
|
argp.add_argument('--log_level', default='WARNING',
|
||
|
choices=['INFO', 'WARNING', 'DEBUG'],
|
||
|
help='Log level, default is WARNING')
|
||
|
argp.add_argument('--array-delimiter', default=';',
|
||
|
help='Delimiter between elements of array values, '
|
||
|
"default is ';'")
|
||
|
argp.add_argument('--csv-delimiter', default=',',
|
||
|
help='Delimiter between each field in the CSV, '
|
||
|
"default is ','")
|
||
|
argp.add_argument('--skip-duplicate-nodes', action='store_true', default=False,
|
||
|
help='Skip duplicate nodes or raise an error (default)')
|
||
|
return argp.parse_args()
|
||
|
|
||
|
|
||
|
def get_field_name_and_type(field):
|
||
|
'''Return (field_name, field_type) from the field string.
|
||
|
If there is no type, field_type is returned as None.'''
|
||
|
field_name_and_type = field.split(':', maxsplit=1)
|
||
|
name = field_name_and_type[0]
|
||
|
if len(field_name_and_type) == 1:
|
||
|
return name, None
|
||
|
field_type = field_name_and_type[1].strip().lower()
|
||
|
return name, field_type
|
||
|
|
||
|
|
||
|
def get_id_space(field_type):
|
||
|
group_start = field_type.find('(')
|
||
|
if group_start == -1:
|
||
|
return None
|
||
|
return field_type[1 + field_type.find('('):-1]
|
||
|
|
||
|
|
||
|
def write_node_row(node_row, array_delimiter, encoder):
|
||
|
node_id = None
|
||
|
node_labels = []
|
||
|
properties = {}
|
||
|
for field, value in node_row.items():
|
||
|
value = value.strip()
|
||
|
name, field_type = get_field_name_and_type(field)
|
||
|
if field_type is not None and field_type.startswith('id'):
|
||
|
if node_id is not None:
|
||
|
raise ValueError('Only one node ID must be specified')
|
||
|
node_id = NodeId(value, get_id_space(field_type))
|
||
|
elif field_type == 'label':
|
||
|
labels = map(str.strip, value.split(array_delimiter))
|
||
|
node_labels.extend(label for label in labels if label)
|
||
|
elif field_type != 'ignore':
|
||
|
# Everything else is a property.
|
||
|
# Missing field_type defaults to string.
|
||
|
if not field_type:
|
||
|
field_type = 'string'
|
||
|
properties[name] = csv_to_py_val(value, field_type, array_delimiter)
|
||
|
if node_id is None:
|
||
|
raise ValueError('Node ID must be specified')
|
||
|
encoder.write_node(node_id, node_labels, properties)
|
||
|
|
||
|
|
||
|
def convert_nodes(node_filenames, csv_delimiter, array_delimiter, encoder):
|
||
|
node_count = 0
|
||
|
for node_filename in node_filenames:
|
||
|
with open(node_filename) as node_file:
|
||
|
nodes = csv.DictReader(node_file, delimiter=csv_delimiter)
|
||
|
for node in nodes:
|
||
|
write_node_row(node, array_delimiter, encoder)
|
||
|
node_count += 1
|
||
|
return node_count
|
||
|
|
||
|
|
||
|
def write_relationship_row(relationship_row, array_delimiter, encoder):
|
||
|
start_id = None
|
||
|
end_id = None
|
||
|
relationship_type = None
|
||
|
properties = {}
|
||
|
for field, value in relationship_row.items():
|
||
|
value = value.strip()
|
||
|
name, field_type = get_field_name_and_type(field)
|
||
|
if field_type is not None and field_type.startswith('start_id'):
|
||
|
if start_id is not None:
|
||
|
raise ValueError('Only one node ID must be specified')
|
||
|
start_id = NodeId(value, get_id_space(field_type))
|
||
|
elif field_type is not None and field_type.startswith('end_id'):
|
||
|
if end_id is not None:
|
||
|
raise ValueError('Only one node ID must be specified')
|
||
|
end_id = NodeId(value, get_id_space(field_type))
|
||
|
elif field_type == 'type':
|
||
|
if relationship_type is not None:
|
||
|
raise ValueError('Only one relationship TYPE must be specified')
|
||
|
relationship_type = value
|
||
|
elif field_type != 'ignore':
|
||
|
# Everything else is a property.
|
||
|
# Missing field_type defaults to string.
|
||
|
if not field_type:
|
||
|
field_type = 'string'
|
||
|
properties[name] = csv_to_py_val(value, field_type, array_delimiter)
|
||
|
if None in (start_id, end_id, relationship_type):
|
||
|
raise ValueError('Relationship TYPE, START_ID and END_ID must be set')
|
||
|
encoder.write_relationship(start_id, end_id, relationship_type, properties)
|
||
|
|
||
|
|
||
|
def convert_relationships(relationship_filenames, csv_delimiter,
|
||
|
array_delimiter, encoder):
|
||
|
relationship_count = 0
|
||
|
for relationship_filename in relationship_filenames:
|
||
|
with open(relationship_filename) as relationship_file:
|
||
|
relationships = csv.DictReader(relationship_file,
|
||
|
delimiter=csv_delimiter)
|
||
|
for relationship in relationships:
|
||
|
write_relationship_row(relationship, array_delimiter, encoder)
|
||
|
relationship_count += 1
|
||
|
return relationship_count
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse_args()
|
||
|
logging.basicConfig(level=args.log_level)
|
||
|
all_input_names = ', '.join(it.chain(args.nodes, args.relationships))
|
||
|
log.info("Converting {} to '{}'".format(all_input_names, args.out))
|
||
|
with open(args.out, 'wb' if args.overwrite else 'xb') as dest_file:
|
||
|
hasher = Hasher()
|
||
|
encoder = BoltEncoder(dest_file, hasher, args.skip_duplicate_nodes)
|
||
|
# Snapshot file has the following contents in order:
|
||
|
# 1) list of label+property index
|
||
|
# 2) all nodes, sequantially, but not encoded as a list
|
||
|
# 3) all relationships, sequantially, but not encoded as a list
|
||
|
# 3) summary with node count, relationship count and hash digest
|
||
|
encoder.write_list([]) # Label + property indexes.
|
||
|
node_count = convert_nodes(args.nodes, args.csv_delimiter,
|
||
|
args.array_delimiter, encoder)
|
||
|
relationship_count = convert_relationships(args.relationships,
|
||
|
args.csv_delimiter,
|
||
|
args.array_delimiter,
|
||
|
encoder)
|
||
|
encoder.write_summary(node_count, relationship_count)
|
||
|
log.info("Created '{}'".format(args.out))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|