memgraph/tools/csv_to_snapshot

349 lines
13 KiB
Plaintext
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
Create a Memgraph recovery snapshot file from CSV.
'''
import argparse
import csv
import itertools as it
import logging
import struct
log = logging.getLogger(__name__)
_CSV_TYPE_TO_PY_TYPE = {
'int': int,
'long': int,
'float': float,
'double': float,
'boolean': bool,
'byte': int,
'short': int,
'char': str,
'string': str,
}
def csv_to_py_val(value, csv_type, array_delimiter):
if not csv_type.endswith('[]'):
return _CSV_TYPE_TO_PY_TYPE[csv_type](value)
# Otherwise we have an array type, so convert it to a list.
csv_type = csv_type[:-2]
py_type = _CSV_TYPE_TO_PY_TYPE[csv_type]
return [py_type(val.strip()) for val in value.split(array_delimiter)]
class NodeId:
def __init__(self, id_, id_space):
if not id_:
raise ValueError('ID must not be empty')
self.id = id_
self.id_space = id_space
def __eq__(self, other):
if not isinstance(other, NodeId):
return NotImplemented
return self.id == other.id and self.id_space == other.id_space
def __hash__(self):
return hash((self.id, self.id_space))
def __str__(self):
if self.id_space is None:
return self.id
return '{}({})'.format(self.id, self.id_space)
class Hasher:
'''Implementation of memgraph/src/durability/hasher.
The API mimics hashlib, so that it will be easier to switch to something
more sane (e.g. sha256).'''
_PRIME = 3137
def __init__(self):
self._hash = 0
def update(self, data):
if not isinstance(data, bytes):
raise TypeError("Expected 'bytes', but got '{}'"
.format(type(data).__name__))
for byte in data:
self._hash = self._hash * self._PRIME + byte + 1
self._hash %= 2**64 # Make hash fit in uint64_t
def digest(self):
'''Return the digest value as an int (which fits in uint64_t) and
*not* as bytes. (This is different from hashlib objects.)'''
return self._hash
class BoltEncoder:
# Type markers
_NULL_MARKER = b'\xC0'
_FLOAT64_MARKER = b'\xC1'
_FALSE_MARKER = b'\xC2'
_TRUE_MARKER = b'\xC3'
_INT64_MARKER = b'\xCB'
_STRING32_MARKER = b'\xD2'
_LIST32_MARKER = b'\xD6'
_MAP32_MARKER = b'\xDA'
_NODE_MARKER = b'\xB3\x4E'
_RELATIONSHIP_MARKER = b'\xB5\x52'
# Struct formats
_INT64_STRUCT = struct.Struct('>q')
_UINT32_STRUCT = struct.Struct('>I')
_UINT64_STRUCT = struct.Struct('>Q')
_FLOAT64_STRUCT = struct.Struct('>d')
def __init__(self, file, hasher, skip_duplicate_nodes):
self._file = file
self._hasher = hasher
self._relationship_id = 0
self._node_id = 0
self._csv_to_mg_node_id = {}
self._skip_duplicate_nodes = skip_duplicate_nodes
def write(self, value):
if value is None:
return self.write_null()
write = getattr(self, 'write_' + type(value).__name__)
write(value)
def write_null(self):
self._write(self._NULL_MARKER)
def write_bool(self, value):
if value:
self._write(self._TRUE_MARKER)
else:
self._write(self._FALSE_MARKER)
def write_int(self, value):
self._write(self._INT64_MARKER)
self._write(self._INT64_STRUCT.pack(value))
def write_float(self, value):
self._write(self._FLOAT64_MARKER)
self._write(self._FLOAT64_STRUCT.pack(value))
def write_str(self, value):
self._write(self._STRING32_MARKER)
data = value.encode('utf-8')
self._write(self._UINT32_STRUCT.pack(len(data)))
self._write(data)
def write_list(self, values):
self._write(self._LIST32_MARKER)
self._write(self._UINT32_STRUCT.pack(len(values)))
for value in values:
self.write(value)
def write_dict(self, dict_value):
self._write(self._MAP32_MARKER)
self._write(self._UINT32_STRUCT.pack(len(dict_value)))
for key, value in dict_value.items():
self.write_str(key)
self.write(value)
def write_summary(self, node_count, relationship_count):
# It's a bit silly that the summary isn't considered for hashing
# (see: memgraph/src/durability/file_writer_buffer)
self._write(self._UINT64_STRUCT.pack(node_count), update_hash=False)
self._write(self._UINT64_STRUCT.pack(relationship_count),
update_hash=False)
self._write(self._UINT64_STRUCT.pack(self._hasher.digest()),
update_hash=False)
def write_node(self, node_id, labels, properties):
id_ = None
try:
id_ = self._add_node_id(node_id)
except ValueError:
if self._skip_duplicate_nodes:
return
else:
raise
properties['id'] = node_id.id
self._write(self._NODE_MARKER)
self.write_int(id_)
self.write_list(labels)
self.write_dict(properties)
def write_relationship(self, start_id, end_id, type_, properties):
self._write(self._RELATIONSHIP_MARKER)
self.write_int(self._relationship_id)
self.write_int(self._csv_to_mg_node_id[start_id])
self.write_int(self._csv_to_mg_node_id[end_id])
self._relationship_id += 1
self.write_str(type_)
self.write_dict(properties)
def _write(self, byte_data, update_hash=True):
log.debug("Writing bytes '0x{}'".format(byte_data.hex()))
if update_hash:
self._hasher.update(byte_data)
self._file.write(byte_data)
def _add_node_id(self, node_id):
'''Add a new mapping from CSV node ID to Memgraph ID and
return the Memgraph ID.'''
if node_id in self._csv_to_mg_node_id:
raise ValueError("Node '{}' already exists".format(node_id))
id_ = self._node_id
self._csv_to_mg_node_id[node_id] = id_
self._node_id += 1
return id_
def parse_args():
argp = argparse.ArgumentParser(description=__doc__)
argp.add_argument('-o', '--out', required=True,
help='Destination for the created snapshot file')
argp.add_argument('-n', '--nodes', action='append', required=True,
help='CSV file containing graph nodes (vertices)')
argp.add_argument('-r', '--relationships', default=[], action='append',
help='CSV file containing graph relationships (edges)')
argp.add_argument('--overwrite', action='store_true', default=False,
help='Overwrite the output file if it exists')
argp.add_argument('--log_level', default='WARNING',
choices=['INFO', 'WARNING', 'DEBUG'],
help='Log level, default is WARNING')
argp.add_argument('--array-delimiter', default=';',
help='Delimiter between elements of array values, '
"default is ';'")
argp.add_argument('--csv-delimiter', default=',',
help='Delimiter between each field in the CSV, '
"default is ','")
argp.add_argument('--skip-duplicate-nodes', action='store_true', default=False,
help='Skip duplicate nodes or raise an error (default)')
return argp.parse_args()
def get_field_name_and_type(field):
'''Return (field_name, field_type) from the field string.
If there is no type, field_type is returned as None.'''
field_name_and_type = field.split(':', maxsplit=1)
name = field_name_and_type[0]
if len(field_name_and_type) == 1:
return name, None
field_type = field_name_and_type[1].strip().lower()
return name, field_type
def get_id_space(field_type):
group_start = field_type.find('(')
if group_start == -1:
return None
return field_type[1 + field_type.find('('):-1]
def write_node_row(node_row, array_delimiter, encoder):
node_id = None
node_labels = []
properties = {}
for field, value in node_row.items():
value = value.strip()
name, field_type = get_field_name_and_type(field)
if field_type is not None and field_type.startswith('id'):
if node_id is not None:
raise ValueError('Only one node ID must be specified')
node_id = NodeId(value, get_id_space(field_type))
elif field_type == 'label':
labels = map(str.strip, value.split(array_delimiter))
node_labels.extend(label for label in labels if label)
elif field_type != 'ignore':
# Everything else is a property.
# Missing field_type defaults to string.
if not field_type:
field_type = 'string'
properties[name] = csv_to_py_val(value, field_type, array_delimiter)
if node_id is None:
raise ValueError('Node ID must be specified')
encoder.write_node(node_id, node_labels, properties)
def convert_nodes(node_filenames, csv_delimiter, array_delimiter, encoder):
node_count = 0
for node_filename in node_filenames:
with open(node_filename, newline='', encoding='utf-8') as node_file:
nodes = csv.DictReader(node_file, delimiter=csv_delimiter)
for node in nodes:
write_node_row(node, array_delimiter, encoder)
node_count += 1
return node_count
def write_relationship_row(relationship_row, array_delimiter, encoder):
start_id = None
end_id = None
relationship_type = None
properties = {}
for field, value in relationship_row.items():
value = value.strip()
name, field_type = get_field_name_and_type(field)
if field_type is not None and field_type.startswith('start_id'):
if start_id is not None:
raise ValueError('Only one node ID must be specified')
start_id = NodeId(value, get_id_space(field_type))
elif field_type is not None and field_type.startswith('end_id'):
if end_id is not None:
raise ValueError('Only one node ID must be specified')
end_id = NodeId(value, get_id_space(field_type))
elif field_type == 'type':
if relationship_type is not None:
raise ValueError('Only one relationship TYPE must be specified')
relationship_type = value
elif field_type != 'ignore':
# Everything else is a property.
# Missing field_type defaults to string.
if not field_type:
field_type = 'string'
properties[name] = csv_to_py_val(value, field_type, array_delimiter)
if None in (start_id, end_id, relationship_type):
raise ValueError('Relationship TYPE, START_ID and END_ID must be set')
encoder.write_relationship(start_id, end_id, relationship_type, properties)
def convert_relationships(relationship_filenames, csv_delimiter,
array_delimiter, encoder):
relationship_count = 0
for relationship_filename in relationship_filenames:
with open(relationship_filename, newline='', encoding='utf-8') as \
relationship_file:
relationships = csv.DictReader(relationship_file,
delimiter=csv_delimiter)
for relationship in relationships:
write_relationship_row(relationship, array_delimiter, encoder)
relationship_count += 1
return relationship_count
def main():
args = parse_args()
logging.basicConfig(level=args.log_level)
all_input_names = ', '.join(it.chain(args.nodes, args.relationships))
log.info("Converting {} to '{}'".format(all_input_names, args.out))
with open(args.out, 'wb' if args.overwrite else 'xb') as dest_file:
hasher = Hasher()
encoder = BoltEncoder(dest_file, hasher, args.skip_duplicate_nodes)
# Snapshot file has the following contents in order:
# 1) list of label+property index
# 2) all nodes, sequantially, but not encoded as a list
# 3) all relationships, sequantially, but not encoded as a list
# 3) summary with node count, relationship count and hash digest
encoder.write_list([]) # Label + property indexes.
node_count = convert_nodes(args.nodes, args.csv_delimiter,
args.array_delimiter, encoder)
relationship_count = convert_relationships(args.relationships,
args.csv_delimiter,
args.array_delimiter,
encoder)
encoder.write_summary(node_count, relationship_count)
log.info("Created '{}'".format(args.out))
if __name__ == '__main__':
main()