memgraph/query_modules/mgp_networkx.py
Marko Budiselić 48587d6d5e
Improve NetworkX module import (#21)
* Improve NetworkX module import
* Add Networkx dependencies to Dockerfiles
2020-10-15 09:14:50 +02:00

290 lines
9.8 KiB
Python

import sys
import mgp
import collections
try:
import networkx as nx
except ImportError as import_error:
sys.stderr.write((
'\n'
'NOTE: Please install networkx to be able to use Memgraph NetworkX '
'wrappers. Using Python:\n'
+ sys.version +
'\n'))
raise import_error
class MemgraphAdjlistOuterDict(collections.abc.Mapping):
__slots__ = ('_ctx', '_succ', '_multi')
def __init__(self, ctx, succ=True, multi=True):
self._ctx = ctx
self._succ = succ
self._multi = multi
def __getitem__(self, key):
if key not in self:
raise KeyError
return MemgraphAdjlistInnerDict(key, succ=self._succ,
multi=self._multi)
def __iter__(self):
return iter(self._ctx.graph.vertices)
def __len__(self):
return len(self._ctx.graph.vertices)
def __contains__(self, key):
if not isinstance(key, mgp.Vertex):
raise TypeError
return key in self._ctx.graph.vertices
class MemgraphAdjlistInnerDict(collections.abc.Mapping):
__slots__ = ('_node', '_succ', '_multi', '_neighbors')
def __init__(self, node, succ=True, multi=True):
self._node = node
self._succ = succ
self._multi = multi
self._neighbors = None
def __getitem__(self, key):
# NOTE: NetworkX 2.4, classes/coreviews.py:143. UnionAtlas expects a
# KeyError when indexing with a vertex that is not a neighbor.
if key not in self:
raise KeyError
if not self._multi:
return UnhashableProperties(self._get_edge(key).properties)
return MemgraphEdgeKeyDict(self._node, key, self._succ)
def __iter__(self):
yield from self._get_neighbors()
def __len__(self):
return len(self._get_neighbors())
def __contains__(self, key):
if not isinstance(key, mgp.Vertex):
raise TypeError
return key in self._get_neighbors()
def _get_neighbors(self):
if not self._neighbors:
if self._succ:
self._neighbors = set(
e.to_vertex for e in self._node.out_edges)
else:
self._neighbors = set(
e.from_vertex for e in self._node.in_edges)
return self._neighbors
def _get_edge(self, neighbor):
if self._succ:
edge = list(filter(lambda e: e.to_vertex == neighbor,
self._node.out_edges))
else:
edge = list(filter(lambda e: e.from_vertex == neighbor,
self._node.in_edges))
assert len(edge) >= 1
if len(edge) > 1:
raise RuntimeError('Graph contains multiedges but '
'is of non-multigraph type: {}'.format(edge))
return edge[0]
class MemgraphEdgeKeyDict(collections.abc.Mapping):
__slots__ = ('_node', '_neighbor', '_succ', '_edges')
def __init__(self, node, neighbor, succ=True):
self._node = node
self._neighbor = neighbor
self._succ = succ
self._edges = None
def __getitem__(self, key):
if key not in self:
raise KeyError
return UnhashableProperties(key.properties)
def __iter__(self):
yield from self._get_edges()
def __len__(self):
return len(self._get_edges())
def __contains__(self, key):
if not isinstance(key, mgp.Edge):
raise TypeError
return key in self._get_edges()
def _get_edges(self):
if not self._edges:
if self._succ:
self._edges = list(filter(
lambda e: e.to_vertex == self._neighbor,
self._node.out_edges))
else:
self._edges = list(filter(
lambda e: e.from_vertex == self._neighbor,
self._node.in_edges))
return self._edges
class UnhashableProperties(collections.abc.Mapping):
__slots__ = ('_properties')
def __init__(self, properties):
self._properties = properties
def __getitem__(self, key):
return self._properties[key]
def __iter__(self):
yield from self._properties
def __len__(self):
return len(self._properties)
def __contains__(self, key):
return key in self._properties
# NOTE: Explicitly disable hashing. See the comment in MemgraphNodeDict.
__hash__ = None
class MemgraphNodeDict(collections.abc.Mapping):
__slots__ = ('_ctx',)
def __init__(self, ctx):
self._ctx = ctx
def __getitem__(self, key):
if key not in self:
raise KeyError
# NOTE: NetworkX 2.4, classes/digraph.py:484. NetworkX expects the
# tuples provided to add_nodes_from to be unhashable and cause a
# TypeError when trying to index the node dictionary. This happens
# because the data dictionary element of the tuple is unhashable. We do
# the same thing by returning an unhashable data dictionary.
return UnhashableProperties(key.properties)
def __iter__(self):
return iter(self._ctx.graph.vertices)
def __len__(self):
return len(self._ctx.graph.vertices)
def __contains__(self, key):
# NOTE: NetworkX 2.4, graph.py:425. Graph.__contains__ relies on
# self._node's (i.e. the dictionary produced by node_dict_factory)
# __contains__ to raise a TypeError when indexing with something weird,
# e.g. with sets. This is the behavior of dict.
if not isinstance(key, mgp.Vertex):
raise TypeError
return key in self._ctx.graph.vertices
class MemgraphDiGraphBase:
def __init__(self, incoming_graph_data=None, ctx=None, multi=True,
**kwargs):
# NOTE: We assume that our graph will never be given any initial data
# because we already pull our data from the Memgraph database. This
# assert is triggered by certain NetworkX procedures because they
# create a new instance of our class and try to populate it with their
# own data.
assert incoming_graph_data is None
# NOTE: We allow for ctx to be None in order to allow certain NetworkX
# procedures (such as `subgraph`) to work. Such procedures directly
# modify the graph's internal attributes and don't try to populate it
# with initial data or modify it.
self.node_dict_factory = lambda: MemgraphNodeDict(ctx) \
if ctx else self._error
self.node_attr_dict_factory = self._error
self.adjlist_outer_dict_factory = \
lambda: MemgraphAdjlistOuterDict(ctx, multi=multi) \
if ctx else self._error
self.adjlist_inner_dict_factory = self._error
self.edge_key_dict_factory = self._error
self.edge_attr_dict_factory = self._error
# NOTE: We forbid any mutating operations because our graph is
# immutable and pulls its data from the Memgraph database.
for f in ['add_node', 'add_nodes_from', 'remove_node',
'remove_nodes_from', 'add_edge', 'add_edges_from',
'add_weighted_edges_from', 'new_edge_key', 'remove_edge',
'remove_edges_from', 'update', 'clear']:
setattr(self, f, lambda *args, **kwargs: self._error())
super().__init__(None, **kwargs)
# NOTE: This is a necessary hack because NetworkX assumes that the
# customizable factory functions will only ever return *empty*
# dictionaries. In our case, the factory functions return our custom,
# already populated, dictionaries. Because self._pred and self._end are
# initialized by the same factory function, they end up storing the
# same adjacency lists which is not good. We correct that here.
self._pred = MemgraphAdjlistOuterDict(ctx, succ=False, multi=multi)
def _error(self):
raise RuntimeError('Modification operations are not supported')
class MemgraphMultiDiGraph(MemgraphDiGraphBase, nx.MultiDiGraph):
def __init__(self, incoming_graph_data=None, ctx=None, **kwargs):
super().__init__(incoming_graph_data=incoming_graph_data,
ctx=ctx, multi=True, **kwargs)
def MemgraphMultiGraph(incoming_graph_data=None, ctx=None, **kwargs):
return MemgraphMultiDiGraph(incoming_graph_data=incoming_graph_data,
ctx=ctx, **kwargs).to_undirected(as_view=True)
class MemgraphDiGraph(MemgraphDiGraphBase, nx.DiGraph):
def __init__(self, incoming_graph_data=None, ctx=None, **kwargs):
super().__init__(incoming_graph_data=incoming_graph_data,
ctx=ctx, multi=False, **kwargs)
def MemgraphGraph(incoming_graph_data=None, ctx=None, **kwargs):
return MemgraphDiGraph(incoming_graph_data=incoming_graph_data,
ctx=ctx, **kwargs).to_undirected(as_view=True)
class PropertiesDictionary(collections.abc.Mapping):
__slots__ = ('_ctx', '_prop', '_len')
def __init__(self, ctx, prop):
self._ctx = ctx
self._prop = prop
self._len = None
def __getitem__(self, vertex):
if vertex not in self:
raise KeyError
try:
return vertex.properties[self._prop]
except KeyError:
raise KeyError(("{} doesn\t have the required " +
"property '{}'").format(vertex, self._prop))
def __iter__(self):
for v in self._ctx.graph.vertices:
if self._prop in v.properties:
yield v
def __len__(self):
if not self._len:
self._len = sum(1 for _ in self)
return self._len
def __contains__(self, vertex):
if not isinstance(vertex, mgp.Vertex):
raise TypeError
return self._prop in vertex.properties