memgraph/include/_mgp_mock.py
2023-09-19 12:37:55 +02:00

351 lines
10 KiB
Python

import typing
from enum import Enum
import networkx as nx
NX_LABEL_ATTR = "labels"
NX_TYPE_ATTR = "type"
SOURCE_TYPE_KAFKA = "SOURCE_TYPE_KAFKA"
SOURCE_TYPE_PULSAR = "SOURCE_TYPE_PULSAR"
"""
This module provides helpers for the mock Python API, much like _mgp.py does for mgp.py.
"""
class InvalidArgumentError(Exception):
"""
Signals that some of the arguments have invalid values.
"""
pass
class ImmutableObjectError(Exception):
pass
class LogicErrorError(Exception):
pass
class DeletedObjectError(Exception):
pass
class EdgeConstants(Enum):
I_START = 0
I_END = 1
I_KEY = 2
class Graph:
"""Wrapper around a NetworkX MultiDiGraph instance."""
__slots__ = ("nx", "_highest_vertex_id", "_highest_edge_id", "_valid")
def __init__(self, graph: nx.MultiDiGraph) -> None:
if not isinstance(graph, nx.MultiDiGraph):
raise TypeError(f"Expected 'networkx.classes.multidigraph.MultiDiGraph', got '{type(graph)}'")
self.nx = graph
self._highest_vertex_id = None
self._highest_edge_id = None
self._valid = True
@property
def vertex_ids(self):
return self.nx.nodes
def vertex_is_isolate(self, vertex_id: int) -> bool:
return nx.is_isolate(self.nx, vertex_id)
@property
def vertices(self):
return (Vertex(node_id, self) for node_id in self.nx.nodes)
def has_node(self, node_id):
return self.nx.has_node(node_id)
@property
def edges(self):
return self.nx.edges
def is_valid(self) -> bool:
return self._valid
def get_vertex_by_id(self, vertex_id: int) -> "Vertex":
return Vertex(vertex_id, self)
def invalidate(self):
self._valid = False
def is_immutable(self) -> bool:
return nx.is_frozen(self.nx)
def make_immutable(self):
self.nx = nx.freeze(self.nx)
def _new_vertex_id(self):
if self._highest_vertex_id is None:
self._highest_vertex_id = max(vertex_id for vertex_id in self.nx.nodes)
return self._highest_vertex_id + 1
def _new_edge_id(self):
if self._highest_edge_id is None:
self._highest_edge_id = max(edge[EdgeConstants.I_KEY.value] for edge in self.nx.edges(keys=True))
return self._highest_edge_id + 1
def create_vertex(self) -> "Vertex":
vertex_id = self._new_vertex_id()
self.nx.add_node(vertex_id)
self._highest_vertex_id = vertex_id
return Vertex(vertex_id, self)
def create_edge(self, from_vertex: "Vertex", to_vertex: "Vertex", edge_type: str) -> "Edge":
if from_vertex.is_deleted() or to_vertex.is_deleted():
raise DeletedObjectError("Accessing deleted object.")
edge_id = self._new_edge_id()
from_id = from_vertex.id
to_id = to_vertex.id
self.nx.add_edge(from_id, to_id, key=edge_id, type=edge_type)
self._highest_edge_id = edge_id
return Edge((from_id, to_id, edge_id), self)
def delete_vertex(self, vertex_id: int):
self.nx.remove_node(vertex_id)
def delete_edge(self, from_vertex_id: int, to_vertex_id: int, edge_id: int):
self.nx.remove_edge(from_vertex_id, to_vertex_id, edge_id)
@property
def highest_vertex_id(self) -> int:
if self._highest_vertex_id is None:
self._highest_vertex_id = max(vertex_id for vertex_id in self.nx.nodes) + 1
return self._highest_vertex_id
@property
def highest_edge_id(self) -> int:
if self._highest_edge_id is None:
self._highest_edge_id = max(edge[EdgeConstants.I_KEY.value] for edge in self.nx.edges(keys=True))
return self._highest_edge_id + 1
class Vertex:
"""Represents a graph vertex."""
__slots__ = ("_id", "_graph")
def __init__(self, id: int, graph: Graph) -> None:
if not isinstance(id, int):
raise TypeError(f"Expected 'int', got '{type(id)}'")
if not isinstance(graph, Graph):
raise TypeError(f"Expected '_mgp_mock.Graph', got '{type(graph)}'")
if not graph.nx.has_node(id):
raise IndexError(f"Unable to find vertex with ID {id}.")
self._id = id
self._graph = graph
def is_valid(self) -> bool:
return self._graph.is_valid()
def is_deleted(self) -> bool:
return not self._graph.nx.has_node(self._id) and self._id <= self._graph.highest_vertex_id
@property
def underlying_graph(self) -> Graph:
return self._graph
def underlying_graph_is_mutable(self) -> bool:
return not nx.is_frozen(self._graph.nx)
@property
def labels(self) -> typing.List[int]:
return self._graph.nx.nodes[self._id][NX_LABEL_ATTR].split(":")
def add_label(self, label: str) -> None:
if nx.is_frozen(self._graph.nx):
raise ImmutableObjectError("Cannot modify immutable object.")
self._graph.nx.nodes[self._id][NX_LABEL_ATTR] += f":{label}"
def remove_label(self, label: str) -> None:
if nx.is_frozen(self._graph.nx):
raise ImmutableObjectError("Cannot modify immutable object.")
labels = self._graph.nx.nodes[self._id][NX_LABEL_ATTR]
if labels.startswith(f"{label}:"):
labels = "\n" + labels # pseudo-string starter
self._graph.nx.nodes[self._id][NX_LABEL_ATTR] = labels.replace(f"\n{label}:", "")
elif labels.endswith(f":{label}"):
labels += "\n" # pseudo-string terminator
self._graph.nx.nodes[self._id][NX_LABEL_ATTR] = labels.replace(f":{label}\n", "")
else:
self._graph.nx.nodes[self._id][NX_LABEL_ATTR] = labels.replace(f":{label}:", ":")
@property
def id(self) -> int:
return self._id
@property
def properties(self):
return (
(key, value)
for key, value in self._graph.nx.nodes[self._id].items()
if key not in (NX_LABEL_ATTR, NX_TYPE_ATTR)
)
def get_property(self, property_name: str):
return self._graph.nx.nodes[self._id][property_name]
def set_property(self, property_name: str, value: object):
self._graph.nx.nodes[self._id][property_name] = value
@property
def in_edges(self) -> typing.Iterable["Edge"]:
return [Edge(edge, self._graph) for edge in self._graph.nx.in_edges(self._id, keys=True)]
@property
def out_edges(self) -> typing.Iterable["Edge"]:
return [Edge(edge, self._graph) for edge in self._graph.nx.out_edges(self._id, keys=True)]
class Edge:
"""Represents a graph edge."""
__slots__ = ("_edge", "_graph")
def __init__(self, edge: typing.Tuple[int, int, int], graph: Graph) -> None:
if not isinstance(edge, typing.Tuple):
raise TypeError(f"Expected 'Tuple', got '{type(edge)}'")
if not isinstance(graph, Graph):
raise TypeError(f"Expected '_mgp_mock.Graph', got '{type(graph)}'")
if not graph.nx.has_edge(*edge):
raise IndexError(f"Unable to find edge with ID {edge[EdgeConstants.I_KEY.value]}.")
self._edge = edge
self._graph = graph
def is_valid(self) -> bool:
return self._graph.is_valid()
def is_deleted(self) -> bool:
return (
not self._graph.nx.has_edge(*self._edge)
and self._edge[EdgeConstants.I_KEY.value] <= self._graph.highest_edge_id
)
def underlying_graph_is_mutable(self) -> bool:
return not nx.is_frozen(self._graph.nx)
@property
def id(self) -> int:
return self._edge[EdgeConstants.I_KEY.value]
@property
def edge(self) -> typing.Tuple[int, int, int]:
return self._edge
@property
def start_id(self) -> int:
return self._edge[EdgeConstants.I_START.value]
@property
def end_id(self) -> int:
return self._edge[EdgeConstants.I_END.value]
def get_type_name(self):
return self._graph.nx.get_edge_data(*self._edge)[NX_TYPE_ATTR]
def from_vertex(self) -> Vertex:
return Vertex(self.start_id, self._graph)
def to_vertex(self) -> Vertex:
return Vertex(self.end_id, self._graph)
@property
def properties(self):
return (
(key, value)
for key, value in self._graph.nx.edges[self._edge].items()
if key not in (NX_LABEL_ATTR, NX_TYPE_ATTR)
)
def get_property(self, property_name: str):
return self._graph.nx.edges[self._edge][property_name]
def set_property(self, property_name: str, value: object):
self._graph.nx.edges[self._edge][property_name] = value
class Path:
"""Represents a path comprised of `Vertex` and `Edge` instances."""
__slots__ = ("_vertices", "_edges", "_graph")
__create_key = object()
def __init__(self, create_key, vertex_id: int, graph: Graph) -> None:
assert create_key == Path.__create_key, "Path objects must be created using Path.make_with_start"
self._vertices = [vertex_id]
self._edges = []
self._graph = graph
@classmethod
def make_with_start(cls, vertex: Vertex) -> "Path":
if not isinstance(vertex, Vertex):
raise TypeError(f"Expected 'Vertex', got '{type(vertex)}'")
if not isinstance(vertex.underlying_graph, Graph):
raise TypeError(f"Expected '_mgp_mock.Graph', got '{type(vertex.underlying_graph)}'")
if not vertex.underlying_graph.nx.has_node(vertex._id):
raise IndexError(f"Unable to find vertex with ID {vertex._id}.")
return Path(cls.__create_key, vertex._id, vertex.underlying_graph)
def is_valid(self) -> bool:
return self._graph.is_valid()
def underlying_graph_is_mutable(self) -> bool:
return not nx.is_frozen(self._graph.nx)
def expand(self, edge: Edge):
if edge.start_id != self._vertices[-1]:
raise LogicErrorError("Logic error.")
self._vertices.append(edge.end_id)
self._edges.append((edge.start_id, edge.end_id, edge.id))
def pop(self):
if not self._edges:
raise IndexError("Path contains no relationships.")
self._vertices.pop()
self._edges.pop()
def vertex_at(self, index: int) -> Vertex:
return Vertex(self._vertices[index], self._graph)
def edge_at(self, index: int) -> Edge:
return Edge(self._edges[index], self._graph)
def size(self) -> int:
return len(self._edges)