memgraph/query_modules/mgp_networkx.py

283 lines
9.4 KiB
Python

import sys
import mgp
import collections
try:
import networkx as nx
except ImportError as import_error:
sys.stderr.write(
(
"\n"
"NOTE: Please install networkx to be able to use Memgraph NetworkX "
"wrappers. Using Python:\n" + sys.version + "\n"
)
)
raise import_error
class MemgraphAdjlistOuterDict(collections.abc.Mapping):
__slots__ = ("_ctx", "_succ", "_multi")
def __init__(self, ctx, succ=True, multi=True):
self._ctx = ctx
self._succ = succ
self._multi = multi
def __getitem__(self, key):
if key not in self:
raise KeyError
return MemgraphAdjlistInnerDict(key, succ=self._succ, multi=self._multi)
def __iter__(self):
return iter(self._ctx.graph.vertices)
def __len__(self):
return len(self._ctx.graph.vertices)
def __contains__(self, key):
if not isinstance(key, mgp.Vertex):
raise TypeError
return key in self._ctx.graph.vertices
class MemgraphAdjlistInnerDict(collections.abc.Mapping):
__slots__ = ("_node", "_succ", "_multi", "_neighbors")
def __init__(self, node, succ=True, multi=True):
self._node = node
self._succ = succ
self._multi = multi
self._neighbors = None
def __getitem__(self, key):
# NOTE: NetworkX 2.4, classes/coreviews.py:143. UnionAtlas expects a
# KeyError when indexing with a vertex that is not a neighbor.
if key not in self:
raise KeyError
if not self._multi:
return UnhashableProperties(self._get_edge(key).properties)
return MemgraphEdgeKeyDict(self._node, key, self._succ)
def __iter__(self):
yield from self._get_neighbors()
def __len__(self):
return len(self._get_neighbors())
def __contains__(self, key):
if not isinstance(key, mgp.Vertex):
raise TypeError
return key in self._get_neighbors()
def _get_neighbors(self):
if not self._neighbors:
if self._succ:
self._neighbors = set(e.to_vertex for e in self._node.out_edges)
else:
self._neighbors = set(e.from_vertex for e in self._node.in_edges)
return self._neighbors
def _get_edge(self, neighbor):
if self._succ:
edge = list(filter(lambda e: e.to_vertex == neighbor, self._node.out_edges))
else:
edge = list(filter(lambda e: e.from_vertex == neighbor, self._node.in_edges))
assert len(edge) >= 1
if len(edge) > 1:
raise RuntimeError("Graph contains multiedges but " "is of non-multigraph type: {}".format(edge))
return edge[0]
class MemgraphEdgeKeyDict(collections.abc.Mapping):
__slots__ = ("_node", "_neighbor", "_succ", "_edges")
def __init__(self, node, neighbor, succ=True):
self._node = node
self._neighbor = neighbor
self._succ = succ
self._edges = None
def __getitem__(self, key):
if key not in self:
raise KeyError
return UnhashableProperties(key.properties)
def __iter__(self):
yield from self._get_edges()
def __len__(self):
return len(self._get_edges())
def __contains__(self, key):
if not isinstance(key, mgp.Edge):
raise TypeError
return key in self._get_edges()
def _get_edges(self):
if not self._edges:
if self._succ:
self._edges = list(filter(lambda e: e.to_vertex == self._neighbor, self._node.out_edges))
else:
self._edges = list(filter(lambda e: e.from_vertex == self._neighbor, self._node.in_edges))
return self._edges
class UnhashableProperties(collections.abc.Mapping):
__slots__ = "_properties"
def __init__(self, properties):
self._properties = properties
def __getitem__(self, key):
return self._properties[key]
def __iter__(self):
yield from self._properties
def __len__(self):
return len(self._properties)
def __contains__(self, key):
return key in self._properties
# NOTE: Explicitly disable hashing. See the comment in MemgraphNodeDict.
__hash__ = None
class MemgraphNodeDict(collections.abc.Mapping):
__slots__ = ("_ctx",)
def __init__(self, ctx):
self._ctx = ctx
def __getitem__(self, key):
if key not in self:
raise KeyError
# NOTE: NetworkX 2.4, classes/digraph.py:484. NetworkX expects the
# tuples provided to add_nodes_from to be unhashable and cause a
# TypeError when trying to index the node dictionary. This happens
# because the data dictionary element of the tuple is unhashable. We do
# the same thing by returning an unhashable data dictionary.
return UnhashableProperties(key.properties)
def __iter__(self):
return iter(self._ctx.graph.vertices)
def __len__(self):
return len(self._ctx.graph.vertices)
def __contains__(self, key):
# NOTE: NetworkX 2.4, graph.py:425. Graph.__contains__ relies on
# self._node's (i.e. the dictionary produced by node_dict_factory)
# __contains__ to raise a TypeError when indexing with something weird,
# e.g. with sets. This is the behavior of dict.
if not isinstance(key, mgp.Vertex):
raise TypeError
return key in self._ctx.graph.vertices
class MemgraphDiGraphBase:
def __init__(self, incoming_graph_data=None, ctx=None, multi=True, **kwargs):
# NOTE: We assume that our graph will never be given any initial data
# because we already pull our data from the Memgraph database. This
# assert is triggered by certain NetworkX procedures because they
# create a new instance of our class and try to populate it with their
# own data.
assert incoming_graph_data is None
# NOTE: We allow for ctx to be None in order to allow certain NetworkX
# procedures (such as `subgraph`) to work. Such procedures directly
# modify the graph's internal attributes and don't try to populate it
# with initial data or modify it.
self.node_dict_factory = lambda: MemgraphNodeDict(ctx) if ctx else self._error
self.node_attr_dict_factory = self._error
self.adjlist_outer_dict_factory = lambda: MemgraphAdjlistOuterDict(ctx, multi=multi) if ctx else self._error
self.adjlist_inner_dict_factory = self._error
self.edge_key_dict_factory = self._error
self.edge_attr_dict_factory = self._error
# NOTE: We forbid any mutating operations because our graph is
# immutable and pulls its data from the Memgraph database.
for f in [
"add_node",
"add_nodes_from",
"remove_node",
"remove_nodes_from",
"add_edge",
"add_edges_from",
"add_weighted_edges_from",
"new_edge_key",
"remove_edge",
"remove_edges_from",
"update",
"clear",
]:
setattr(self, f, lambda *args, **kwargs: self._error())
super().__init__(None, **kwargs)
# NOTE: This is a necessary hack because NetworkX assumes that the
# customizable factory functions will only ever return *empty*
# dictionaries. In our case, the factory functions return our custom,
# already populated, dictionaries. Because self._pred and self._end are
# initialized by the same factory function, they end up storing the
# same adjacency lists which is not good. We correct that here.
self._pred = MemgraphAdjlistOuterDict(ctx, succ=False, multi=multi)
def _error(self):
raise RuntimeError("Modification operations are not supported")
class MemgraphMultiDiGraph(MemgraphDiGraphBase, nx.MultiDiGraph):
def __init__(self, incoming_graph_data=None, ctx=None, **kwargs):
super().__init__(incoming_graph_data=incoming_graph_data, ctx=ctx, multi=True, **kwargs)
def MemgraphMultiGraph(incoming_graph_data=None, ctx=None, **kwargs):
return MemgraphMultiDiGraph(incoming_graph_data=incoming_graph_data, ctx=ctx, **kwargs).to_undirected(as_view=True)
class MemgraphDiGraph(MemgraphDiGraphBase, nx.DiGraph):
def __init__(self, incoming_graph_data=None, ctx=None, **kwargs):
super().__init__(incoming_graph_data=incoming_graph_data, ctx=ctx, multi=False, **kwargs)
def MemgraphGraph(incoming_graph_data=None, ctx=None, **kwargs):
return MemgraphDiGraph(incoming_graph_data=incoming_graph_data, ctx=ctx, **kwargs).to_undirected(as_view=True)
class PropertiesDictionary(collections.abc.Mapping):
__slots__ = ("_ctx", "_prop", "_len")
def __init__(self, ctx, prop):
self._ctx = ctx
self._prop = prop
self._len = None
def __getitem__(self, vertex):
if vertex not in self:
raise KeyError
try:
return vertex.properties[self._prop]
except KeyError:
raise KeyError(("{} doesn\t have the required " + "property '{}'").format(vertex, self._prop))
def __iter__(self):
for v in self._ctx.graph.vertices:
if self._prop in v.properties:
yield v
def __len__(self):
if not self._len:
self._len = sum(1 for _ in self)
return self._len
def __contains__(self, vertex):
if not isinstance(vertex, mgp.Vertex):
raise TypeError
return self._prop in vertex.properties