#!/usr/bin/env python3 # -*- coding: utf-8 -*- ''' Create a Memgraph recovery snapshot file from CSV. ''' import argparse import csv import itertools as it import logging import struct log = logging.getLogger(__name__) _CSV_TYPE_TO_PY_TYPE = { 'int': int, 'long': int, 'float': float, 'double': float, 'boolean': bool, 'byte': int, 'short': int, 'char': str, 'string': str, } def csv_to_py_val(value, csv_type, array_delimiter): if not csv_type.endswith('[]'): return _CSV_TYPE_TO_PY_TYPE[csv_type](value) # Otherwise we have an array type, so convert it to a list. csv_type = csv_type[:-2] py_type = _CSV_TYPE_TO_PY_TYPE[csv_type] return [py_type(val.strip()) for val in value.split(array_delimiter)] class NodeId: def __init__(self, id_, id_space): if not id_: raise ValueError('ID must not be empty') self.id = id_ self.id_space = id_space def __eq__(self, other): if not isinstance(other, NodeId): return NotImplemented return self.id == other.id and self.id_space == other.id_space def __hash__(self): return hash((self.id, self.id_space)) def __str__(self): if self.id_space is None: return self.id return '{}({})'.format(self.id, self.id_space) class Hasher: '''Implementation of memgraph/src/durability/hasher. The API mimics hashlib, so that it will be easier to switch to something more sane (e.g. sha256).''' _PRIME = 3137 def __init__(self): self._hash = 0 def update(self, data): if not isinstance(data, bytes): raise TypeError("Expected 'bytes', but got '{}'" .format(type(data).__name__)) for byte in data: self._hash = self._hash * self._PRIME + byte + 1 self._hash %= 2**64 # Make hash fit in uint64_t def digest(self): '''Return the digest value as an int (which fits in uint64_t) and *not* as bytes. (This is different from hashlib objects.)''' return self._hash class BoltEncoder: # Type markers _NULL_MARKER = b'\xC0' _FLOAT64_MARKER = b'\xC1' _FALSE_MARKER = b'\xC2' _TRUE_MARKER = b'\xC3' _INT64_MARKER = b'\xCB' _STRING32_MARKER = b'\xD2' _LIST32_MARKER = b'\xD6' _MAP32_MARKER = b'\xDA' _NODE_MARKER = b'\xB3\x4E' _RELATIONSHIP_MARKER = b'\xB5\x52' # Struct formats _INT64_STRUCT = struct.Struct('>q') _UINT32_STRUCT = struct.Struct('>I') _UINT64_STRUCT = struct.Struct('>Q') _FLOAT64_STRUCT = struct.Struct('>d') def __init__(self, file, hasher, skip_duplicate_nodes): self._file = file self._hasher = hasher self._relationship_id = 0 self._node_id = 0 self._csv_to_mg_node_id = {} self._skip_duplicate_nodes = skip_duplicate_nodes def write(self, value): if value is None: return self.write_null() write = getattr(self, 'write_' + type(value).__name__) write(value) def write_null(self): self._write(self._NULL_MARKER) def write_bool(self, value): if value: self._write(self._TRUE_MARKER) else: self._write(self._FALSE_MARKER) def write_int(self, value): self._write(self._INT64_MARKER) self._write(self._INT64_STRUCT.pack(value)) def write_float(self, value): self._write(self._FLOAT64_MARKER) self._write(self._FLOAT64_STRUCT.pack(value)) def write_str(self, value): self._write(self._STRING32_MARKER) data = value.encode('utf-8') self._write(self._UINT32_STRUCT.pack(len(data))) self._write(data) def write_list(self, values): self._write(self._LIST32_MARKER) self._write(self._UINT32_STRUCT.pack(len(values))) for value in values: self.write(value) def write_dict(self, dict_value): self._write(self._MAP32_MARKER) self._write(self._UINT32_STRUCT.pack(len(dict_value))) for key, value in dict_value.items(): self.write_str(key) self.write(value) def write_summary(self, node_count, relationship_count): # It's a bit silly that the summary isn't considered for hashing # (see: memgraph/src/durability/file_writer_buffer) self._write(self._UINT64_STRUCT.pack(node_count), update_hash=False) self._write(self._UINT64_STRUCT.pack(relationship_count), update_hash=False) self._write(self._UINT64_STRUCT.pack(self._hasher.digest()), update_hash=False) def write_node(self, node_id, labels, properties): id_ = None try: id_ = self._add_node_id(node_id) except ValueError: if self._skip_duplicate_nodes: return else: raise self._write(self._NODE_MARKER) self.write_int(id_) self.write_list(labels) self.write_dict(properties) def write_relationship(self, start_id, end_id, type_, properties): self._write(self._RELATIONSHIP_MARKER) self.write_int(self._relationship_id) self.write_int(self._csv_to_mg_node_id[start_id]) self.write_int(self._csv_to_mg_node_id[end_id]) self._relationship_id += 1 self.write_str(type_) self.write_dict(properties) def _write(self, byte_data, update_hash=True): log.debug("Writing bytes '0x{}'".format(byte_data.hex())) if update_hash: self._hasher.update(byte_data) self._file.write(byte_data) def _add_node_id(self, node_id): '''Add a new mapping from CSV node ID to Memgraph ID and return the Memgraph ID.''' if node_id in self._csv_to_mg_node_id: raise ValueError("Node '{}' already exists".format(node_id)) id_ = self._node_id self._csv_to_mg_node_id[node_id] = id_ self._node_id += 1 return id_ def parse_args(): argp = argparse.ArgumentParser(description=__doc__) argp.add_argument('-o', '--out', required=True, help='Destination for the created snapshot file') argp.add_argument('-n', '--nodes', action='append', required=True, help='CSV file containing graph nodes (vertices)') argp.add_argument('-r', '--relationships', default=[], action='append', help='CSV file containing graph relationships (edges)') argp.add_argument('--overwrite', action='store_true', default=False, help='Overwrite the output file if it exists') argp.add_argument('--log_level', default='WARNING', choices=['INFO', 'WARNING', 'DEBUG'], help='Log level, default is WARNING') argp.add_argument('--array-delimiter', default=';', help='Delimiter between elements of array values, ' "default is ';'") argp.add_argument('--csv-delimiter', default=',', help='Delimiter between each field in the CSV, ' "default is ','") argp.add_argument('--skip-duplicate-nodes', action='store_true', default=False, help='Skip duplicate nodes or raise an error (default)') return argp.parse_args() def get_field_name_and_type(field): '''Return (field_name, field_type) from the field string. If there is no type, field_type is returned as None.''' field_name_and_type = field.split(':', maxsplit=1) name = field_name_and_type[0] if len(field_name_and_type) == 1: return name, None field_type = field_name_and_type[1].strip().lower() return name, field_type def get_id_space(field_type): group_start = field_type.find('(') if group_start == -1: return None return field_type[1 + field_type.find('('):-1] def write_node_row(node_row, array_delimiter, encoder): node_id = None node_labels = [] properties = {} for field, value in node_row.items(): value = value.strip() name, field_type = get_field_name_and_type(field) if field_type is not None and field_type.startswith('id'): if node_id is not None: raise ValueError('Only one node ID must be specified') node_id = NodeId(value, get_id_space(field_type)) elif field_type == 'label': labels = map(str.strip, value.split(array_delimiter)) node_labels.extend(label for label in labels if label) elif field_type != 'ignore': # Everything else is a property. # Missing field_type defaults to string. if not field_type: field_type = 'string' properties[name] = csv_to_py_val(value, field_type, array_delimiter) if node_id is None: raise ValueError('Node ID must be specified') encoder.write_node(node_id, node_labels, properties) def convert_nodes(node_filenames, csv_delimiter, array_delimiter, encoder): node_count = 0 for node_filename in node_filenames: with open(node_filename, newline='', encoding='utf-8') as node_file: nodes = csv.DictReader(node_file, delimiter=csv_delimiter) for node in nodes: write_node_row(node, array_delimiter, encoder) node_count += 1 return node_count def write_relationship_row(relationship_row, array_delimiter, encoder): start_id = None end_id = None relationship_type = None properties = {} for field, value in relationship_row.items(): value = value.strip() name, field_type = get_field_name_and_type(field) if field_type is not None and field_type.startswith('start_id'): if start_id is not None: raise ValueError('Only one node ID must be specified') start_id = NodeId(value, get_id_space(field_type)) elif field_type is not None and field_type.startswith('end_id'): if end_id is not None: raise ValueError('Only one node ID must be specified') end_id = NodeId(value, get_id_space(field_type)) elif field_type == 'type': if relationship_type is not None: raise ValueError('Only one relationship TYPE must be specified') relationship_type = value elif field_type != 'ignore': # Everything else is a property. # Missing field_type defaults to string. if not field_type: field_type = 'string' properties[name] = csv_to_py_val(value, field_type, array_delimiter) if None in (start_id, end_id, relationship_type): raise ValueError('Relationship TYPE, START_ID and END_ID must be set') encoder.write_relationship(start_id, end_id, relationship_type, properties) def convert_relationships(relationship_filenames, csv_delimiter, array_delimiter, encoder): relationship_count = 0 for relationship_filename in relationship_filenames: with open(relationship_filename, newline='', encoding='utf-8') as \ relationship_file: relationships = csv.DictReader(relationship_file, delimiter=csv_delimiter) for relationship in relationships: write_relationship_row(relationship, array_delimiter, encoder) relationship_count += 1 return relationship_count def main(): args = parse_args() logging.basicConfig(level=args.log_level) all_input_names = ', '.join(it.chain(args.nodes, args.relationships)) log.info("Converting {} to '{}'".format(all_input_names, args.out)) with open(args.out, 'wb' if args.overwrite else 'xb') as dest_file: hasher = Hasher() encoder = BoltEncoder(dest_file, hasher, args.skip_duplicate_nodes) # Snapshot file has the following contents in order: # 1) list of label+property index # 2) all nodes, sequantially, but not encoded as a list # 3) all relationships, sequantially, but not encoded as a list # 3) summary with node count, relationship count and hash digest encoder.write_list([]) # Label + property indexes. node_count = convert_nodes(args.nodes, args.csv_delimiter, args.array_delimiter, encoder) relationship_count = convert_relationships(args.relationships, args.csv_delimiter, args.array_delimiter, encoder) encoder.write_summary(node_count, relationship_count) log.info("Created '{}'".format(args.out)) if __name__ == '__main__': main()