Memgraph magic functions (#345)
* Extend mgp_module with include adding functions * Add return type to the function API * Change Cypher grammar * Add Python support for functions * Implement error handling * E2e tests for functions * Write cpp e2e functions * Create mg.functions() procedure * Implement case insensitivity for user-defined Magic Functions.
This commit is contained in:
parent
ea2806bd57
commit
4abaf27765
@ -1,4 +1,4 @@
|
||||
// Copyright 2021 Memgraph Ltd.
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -549,6 +549,8 @@ enum mgp_error mgp_path_equal(struct mgp_path *p1, struct mgp_path *p2, int *res
|
||||
struct mgp_result;
|
||||
/// Represents a record of resulting field values.
|
||||
struct mgp_result_record;
|
||||
/// Represents a return type for magic functions
|
||||
struct mgp_func_result;
|
||||
|
||||
/// Set the error as the result of the procedure.
|
||||
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE ff there's no memory for copying the error message.
|
||||
@ -1290,6 +1292,9 @@ struct mgp_module;
|
||||
/// Describes a procedure of a query module.
|
||||
struct mgp_proc;
|
||||
|
||||
/// Describes a Memgraph magic function.
|
||||
struct mgp_func;
|
||||
|
||||
/// Entry-point for a query module read procedure, invoked through openCypher.
|
||||
///
|
||||
/// Passed in arguments will not live longer than the callback's execution.
|
||||
@ -1502,6 +1507,84 @@ typedef void (*mgp_trans_cb)(struct mgp_messages *, struct mgp_graph *, struct m
|
||||
enum mgp_error mgp_module_add_transformation(struct mgp_module *module, const char *name, mgp_trans_cb cb);
|
||||
/// @}
|
||||
|
||||
/// @name Memgraph Magic Functions API
|
||||
///
|
||||
/// API for creating the Memgraph magic functions. It is used to create external-source stateless methods which can
|
||||
/// be called by using openCypher query language. These methods should not modify the original graph and should use only
|
||||
/// the values provided as arguments to the method.
|
||||
///
|
||||
///@{
|
||||
|
||||
/// Add a required argument to a function.
|
||||
///
|
||||
/// The order of the added arguments corresponds to the signature of the openCypher function.
|
||||
/// Note, that required arguments are followed by optional arguments.
|
||||
///
|
||||
/// The `name` must be a valid identifier, following the same rules as the
|
||||
/// function `name` in mgp_module_add_function.
|
||||
///
|
||||
/// Passed in `type` describes what kind of values can be used as the argument.
|
||||
///
|
||||
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for an argument.
|
||||
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid argument name.
|
||||
/// Return MGP_ERROR_LOGIC_ERROR if the function already has any optional argument.
|
||||
enum mgp_error mgp_func_add_arg(struct mgp_func *func, const char *name, struct mgp_type *type);
|
||||
|
||||
/// Add an optional argument with a default value to a function.
|
||||
///
|
||||
/// The order of the added arguments corresponds to the signature of the openCypher function.
|
||||
/// Note, that required arguments are followed by optional arguments.
|
||||
///
|
||||
/// The `name` must be a valid identifier, following the same rules as the
|
||||
/// function `name` in mgp_module_add_function.
|
||||
///
|
||||
/// Passed in `type` describes what kind of values can be used as the argument.
|
||||
///
|
||||
/// `default_value` is copied and set as the default value for the argument.
|
||||
/// Don't forget to call mgp_value_destroy when you are done using
|
||||
/// `default_value`. When the function is called, if this argument is not
|
||||
/// provided, `default_value` will be used instead. `default_value` must not be
|
||||
/// a graph element (node, relationship, path) and it must satisfy the given
|
||||
/// `type`.
|
||||
///
|
||||
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for an argument.
|
||||
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid argument name.
|
||||
/// Return MGP_ERROR_VALUE_CONVERSION if `default_value` is a graph element (vertex, edge or path).
|
||||
/// Return MGP_ERROR_LOGIC_ERROR if `default_value` does not satisfy `type`.
|
||||
enum mgp_error mgp_func_add_opt_arg(struct mgp_func *func, const char *name, struct mgp_type *type,
|
||||
struct mgp_value *default_value);
|
||||
|
||||
/// Entry-point for a custom Memgraph awesome function.
|
||||
///
|
||||
/// Passed in arguments will not live longer than the callback's execution.
|
||||
/// Therefore, you must not store them globally or use the passed in mgp_memory
|
||||
/// to allocate global resources.
|
||||
typedef void (*mgp_func_cb)(struct mgp_list *, struct mgp_func_context *, struct mgp_func_result *,
|
||||
struct mgp_memory *);
|
||||
|
||||
/// Register a Memgraph magic function
|
||||
///
|
||||
/// The `name` must be a sequence of digits, underscores, lowercase and
|
||||
/// uppercase Latin letters. The name must begin with a non-digit character.
|
||||
/// Note that Unicode characters are not allowed.
|
||||
///
|
||||
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for mgp_func.
|
||||
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid function name.
|
||||
/// RETURN MGP_ERROR_LOGIC_ERROR if a function with the same name was already registered.
|
||||
enum mgp_error mgp_module_add_function(struct mgp_module *module, const char *name, mgp_func_cb cb,
|
||||
struct mgp_func **result);
|
||||
|
||||
/// Set an error message as an output to the Magic function
|
||||
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if there's no memory for copying the error message.
|
||||
enum mgp_error mgp_func_result_set_error_msg(struct mgp_func_result *result, const char *error_msg,
|
||||
struct mgp_memory *memory);
|
||||
|
||||
/// Set an output value for the Magic function
|
||||
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory to copy the mgp_value to mgp_func_result.
|
||||
enum mgp_error mgp_func_result_set_value(struct mgp_func_result *result, struct mgp_value *value,
|
||||
struct mgp_memory *memory);
|
||||
/// @}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
247
include/mgp.py
247
include/mgp.py
@ -40,6 +40,7 @@ class InvalidContextError(Exception):
|
||||
"""
|
||||
Signals using a graph element instance outside of the registered procedure.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -47,6 +48,7 @@ class UnknownError(_mgp.UnknownError):
|
||||
"""
|
||||
Signals unspecified failure.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -54,6 +56,7 @@ class UnableToAllocateError(_mgp.UnableToAllocateError):
|
||||
"""
|
||||
Signals failed memory allocation.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -61,6 +64,7 @@ class InsufficientBufferError(_mgp.InsufficientBufferError):
|
||||
"""
|
||||
Signals that some buffer is not big enough.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -69,6 +73,7 @@ class OutOfRangeError(_mgp.OutOfRangeError):
|
||||
Signals that an index-like parameter has a value that is outside its
|
||||
possible values.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -77,6 +82,7 @@ class LogicErrorError(_mgp.LogicErrorError):
|
||||
Signals faulty logic within the program such as violating logical
|
||||
preconditions or class invariants and may be preventable.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -84,6 +90,7 @@ class DeletedObjectError(_mgp.DeletedObjectError):
|
||||
"""
|
||||
Signals accessing an already deleted object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -91,6 +98,7 @@ class InvalidArgumentError(_mgp.InvalidArgumentError):
|
||||
"""
|
||||
Signals that some of the arguments have invalid values.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -98,6 +106,7 @@ class KeyAlreadyExistsError(_mgp.KeyAlreadyExistsError):
|
||||
"""
|
||||
Signals that a key already exists in a container-like object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -105,6 +114,7 @@ class ImmutableObjectError(_mgp.ImmutableObjectError):
|
||||
"""
|
||||
Signals modification of an immutable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -112,6 +122,7 @@ class ValueConversionError(_mgp.ValueConversionError):
|
||||
"""
|
||||
Signals that the conversion failed between python and cypher values.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -120,12 +131,14 @@ class SerializationError(_mgp.SerializationError):
|
||||
Signals serialization error caused by concurrent modifications from
|
||||
different transactions.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Label:
|
||||
"""Label of a Vertex."""
|
||||
__slots__ = ('_name',)
|
||||
|
||||
__slots__ = ("_name",)
|
||||
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
@ -145,19 +158,22 @@ class Label:
|
||||
# Named property value of a Vertex or an Edge.
|
||||
# It would be better to use typing.NamedTuple with typed fields, but that is
|
||||
# not available in Python 3.5.
|
||||
Property = namedtuple('Property', ('name', 'value'))
|
||||
Property = namedtuple("Property", ("name", "value"))
|
||||
|
||||
|
||||
class Properties:
|
||||
"""
|
||||
A collection of properties either on a Vertex or an Edge.
|
||||
"""
|
||||
__slots__ = ('_vertex_or_edge', '_len',)
|
||||
|
||||
__slots__ = (
|
||||
"_vertex_or_edge",
|
||||
"_len",
|
||||
)
|
||||
|
||||
def __init__(self, vertex_or_edge):
|
||||
if not isinstance(vertex_or_edge, (_mgp.Vertex, _mgp.Edge)):
|
||||
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Edge', \
|
||||
got {}".format(type(vertex_or_edge)))
|
||||
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Edge', got {}".format(type(vertex_or_edge)))
|
||||
self._len = None
|
||||
self._vertex_or_edge = vertex_or_edge
|
||||
|
||||
@ -330,7 +346,8 @@ class Properties:
|
||||
|
||||
class EdgeType:
|
||||
"""Type of an Edge."""
|
||||
__slots__ = ('_name',)
|
||||
|
||||
__slots__ = ("_name",)
|
||||
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
@ -348,7 +365,7 @@ class EdgeType:
|
||||
|
||||
|
||||
if sys.version_info >= (3, 5, 2):
|
||||
EdgeId = typing.NewType('EdgeId', int)
|
||||
EdgeId = typing.NewType("EdgeId", int)
|
||||
else:
|
||||
EdgeId = int
|
||||
|
||||
@ -360,12 +377,12 @@ class Edge:
|
||||
a query. You should not globally store an instance of an Edge. Using an
|
||||
invalid Edge instance will raise InvalidContextError.
|
||||
"""
|
||||
__slots__ = ('_edge',)
|
||||
|
||||
__slots__ = ("_edge",)
|
||||
|
||||
def __init__(self, edge):
|
||||
if not isinstance(edge, _mgp.Edge):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
|
||||
raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
|
||||
self._edge = edge
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -408,7 +425,7 @@ class Edge:
|
||||
return EdgeType(self._edge.get_type_name())
|
||||
|
||||
@property
|
||||
def from_vertex(self) -> 'Vertex':
|
||||
def from_vertex(self) -> "Vertex":
|
||||
"""
|
||||
Get the source vertex.
|
||||
|
||||
@ -419,7 +436,7 @@ class Edge:
|
||||
return Vertex(self._edge.from_vertex())
|
||||
|
||||
@property
|
||||
def to_vertex(self) -> 'Vertex':
|
||||
def to_vertex(self) -> "Vertex":
|
||||
"""
|
||||
Get the destination vertex.
|
||||
|
||||
@ -453,7 +470,7 @@ class Edge:
|
||||
|
||||
|
||||
if sys.version_info >= (3, 5, 2):
|
||||
VertexId = typing.NewType('VertexId', int)
|
||||
VertexId = typing.NewType("VertexId", int)
|
||||
else:
|
||||
VertexId = int
|
||||
|
||||
@ -465,12 +482,12 @@ class Vertex:
|
||||
in a query. You should not globally store an instance of a Vertex. Using an
|
||||
invalid Vertex instance will raise InvalidContextError.
|
||||
"""
|
||||
__slots__ = ('_vertex',)
|
||||
|
||||
__slots__ = ("_vertex",)
|
||||
|
||||
def __init__(self, vertex):
|
||||
if not isinstance(vertex, _mgp.Vertex):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
|
||||
raise TypeError("Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
|
||||
self._vertex = vertex
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -513,8 +530,7 @@ class Vertex:
|
||||
"""
|
||||
if not self.is_valid():
|
||||
raise InvalidContextError()
|
||||
return tuple(Label(self._vertex.label_at(i))
|
||||
for i in range(self._vertex.labels_count()))
|
||||
return tuple(Label(self._vertex.label_at(i)) for i in range(self._vertex.labels_count()))
|
||||
|
||||
def add_label(self, label: str) -> None:
|
||||
"""
|
||||
@ -615,7 +631,8 @@ class Vertex:
|
||||
|
||||
class Path:
|
||||
"""Path containing Vertex and Edge instances."""
|
||||
__slots__ = ('_path', '_vertices', '_edges')
|
||||
|
||||
__slots__ = ("_path", "_vertices", "_edges")
|
||||
|
||||
def __init__(self, starting_vertex_or_path: typing.Union[_mgp.Path, Vertex]):
|
||||
"""Initialize with a starting Vertex.
|
||||
@ -636,8 +653,7 @@ class Path:
|
||||
raise InvalidContextError()
|
||||
self._path = _mgp.Path.make_with_start(vertex)
|
||||
else:
|
||||
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Path', got '{}'"
|
||||
.format(type(starting_vertex_or_path)))
|
||||
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Path', got '{}'".format(type(starting_vertex_or_path)))
|
||||
|
||||
def __copy__(self):
|
||||
if not self.is_valid():
|
||||
@ -678,8 +694,7 @@ class Path:
|
||||
extension.
|
||||
"""
|
||||
if not isinstance(edge, Edge):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
|
||||
raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
|
||||
if not self.is_valid() or not edge.is_valid():
|
||||
raise InvalidContextError()
|
||||
self._path.expand(edge._edge)
|
||||
@ -698,8 +713,7 @@ class Path:
|
||||
raise InvalidContextError()
|
||||
if self._vertices is None:
|
||||
num_vertices = self._path.size() + 1
|
||||
self._vertices = tuple(Vertex(self._path.vertex_at(i))
|
||||
for i in range(num_vertices))
|
||||
self._vertices = tuple(Vertex(self._path.vertex_at(i)) for i in range(num_vertices))
|
||||
return self._vertices
|
||||
|
||||
@property
|
||||
@ -713,14 +727,14 @@ class Path:
|
||||
raise InvalidContextError()
|
||||
if self._edges is None:
|
||||
num_edges = self._path.size()
|
||||
self._edges = tuple(Edge(self._path.edge_at(i))
|
||||
for i in range(num_edges))
|
||||
self._edges = tuple(Edge(self._path.edge_at(i)) for i in range(num_edges))
|
||||
return self._edges
|
||||
|
||||
|
||||
class Record:
|
||||
"""Represents a record of resulting field values."""
|
||||
__slots__ = ('fields',)
|
||||
|
||||
__slots__ = ("fields",)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize with name=value fields in kwargs."""
|
||||
@ -729,12 +743,12 @@ class Record:
|
||||
|
||||
class Vertices:
|
||||
"""Iterable over vertices in a graph."""
|
||||
__slots__ = ('_graph', '_len')
|
||||
|
||||
__slots__ = ("_graph", "_len")
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = graph
|
||||
self._len = None
|
||||
|
||||
@ -791,12 +805,12 @@ class Vertices:
|
||||
|
||||
class Graph:
|
||||
"""State of the graph database in current ProcCtx."""
|
||||
__slots__ = ('_graph',)
|
||||
|
||||
__slots__ = ("_graph",)
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = graph
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -885,8 +899,7 @@ class Graph:
|
||||
raise InvalidContextError()
|
||||
self._graph.detach_delete_vertex(vertex._vertex)
|
||||
|
||||
def create_edge(self, from_vertex: Vertex, to_vertex: Vertex,
|
||||
edge_type: EdgeType) -> None:
|
||||
def create_edge(self, from_vertex: Vertex, to_vertex: Vertex, edge_type: EdgeType) -> None:
|
||||
"""
|
||||
Create an edge.
|
||||
|
||||
@ -899,8 +912,7 @@ class Graph:
|
||||
"""
|
||||
if not self.is_valid():
|
||||
raise InvalidContextError()
|
||||
return Edge(self._graph.create_edge(from_vertex._vertex,
|
||||
to_vertex._vertex, edge_type.name))
|
||||
return Edge(self._graph.create_edge(from_vertex._vertex, to_vertex._vertex, edge_type.name))
|
||||
|
||||
def delete_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
@ -918,6 +930,7 @@ class Graph:
|
||||
|
||||
class AbortError(Exception):
|
||||
"""Signals that the procedure was asked to abort its execution."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -927,12 +940,12 @@ class ProcCtx:
|
||||
Access to a ProcCtx is only valid during a single execution of a procedure
|
||||
in a query. You should not globally store a ProcCtx instance.
|
||||
"""
|
||||
__slots__ = ('_graph',)
|
||||
|
||||
__slots__ = ("_graph",)
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = Graph(graph)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
@ -969,8 +982,7 @@ LocalDateTime = datetime.datetime
|
||||
|
||||
Duration = datetime.timedelta
|
||||
|
||||
Any = typing.Union[bool, str, Number, Map, Path,
|
||||
list, Date, LocalTime, LocalDateTime, Duration]
|
||||
Any = typing.Union[bool, str, Number, Map, Path, list, Date, LocalTime, LocalDateTime, Duration]
|
||||
|
||||
List = typing.List
|
||||
|
||||
@ -1003,7 +1015,7 @@ def _typing_to_cypher_type(type_):
|
||||
Date: _mgp.type_date(),
|
||||
LocalTime: _mgp.type_local_time(),
|
||||
LocalDateTime: _mgp.type_local_date_time(),
|
||||
Duration: _mgp.type_duration()
|
||||
Duration: _mgp.type_duration(),
|
||||
}
|
||||
try:
|
||||
return simple_types[type_]
|
||||
@ -1021,14 +1033,14 @@ def _typing_to_cypher_type(type_):
|
||||
if type(None) in type_args:
|
||||
types = tuple(t for t in type_args if t is not type(None)) # noqa E721
|
||||
if len(types) == 1:
|
||||
type_arg, = types
|
||||
(type_arg,) = types
|
||||
else:
|
||||
# We cannot do typing.Union[*types], so do the equivalent
|
||||
# with __getitem__ which does not even need arg unpacking.
|
||||
type_arg = typing.Union.__getitem__(types)
|
||||
return _mgp.type_nullable(_typing_to_cypher_type(type_arg))
|
||||
elif complex_type == list:
|
||||
type_arg, = type_args
|
||||
(type_arg,) = type_args
|
||||
return _mgp.type_list(_typing_to_cypher_type(type_arg))
|
||||
raise UnsupportedTypingError(type_)
|
||||
else:
|
||||
@ -1038,13 +1050,17 @@ def _typing_to_cypher_type(type_):
|
||||
# printed the same way. `typing.List[type]` is printed as such, while
|
||||
# `typing.Optional[type]` is printed as 'typing.Union[type, NoneType]'
|
||||
def parse_type_args(type_as_str):
|
||||
return tuple(map(str.strip,
|
||||
type_as_str[type_as_str.index('[') + 1: -1].split(',')))
|
||||
return tuple(
|
||||
map(
|
||||
str.strip,
|
||||
type_as_str[type_as_str.index("[") + 1 : -1].split(","),
|
||||
)
|
||||
)
|
||||
|
||||
def fully_qualified_name(cls):
|
||||
if cls.__module__ is None or cls.__module__ == 'builtins':
|
||||
if cls.__module__ is None or cls.__module__ == "builtins":
|
||||
return cls.__name__
|
||||
return cls.__module__ + '.' + cls.__name__
|
||||
return cls.__module__ + "." + cls.__name__
|
||||
|
||||
def get_simple_type(type_as_str):
|
||||
for simple_type, cypher_type in simple_types.items():
|
||||
@ -1060,28 +1076,26 @@ def _typing_to_cypher_type(type_):
|
||||
pass
|
||||
|
||||
def parse_typing(type_as_str):
|
||||
if type_as_str.startswith('typing.Union'):
|
||||
if type_as_str.startswith("typing.Union"):
|
||||
type_args_as_str = parse_type_args(type_as_str)
|
||||
none_type_as_str = type(None).__name__
|
||||
if none_type_as_str in type_args_as_str:
|
||||
types = tuple(
|
||||
t for t in type_args_as_str if t != none_type_as_str)
|
||||
types = tuple(t for t in type_args_as_str if t != none_type_as_str)
|
||||
if len(types) == 1:
|
||||
type_arg_as_str, = types
|
||||
(type_arg_as_str,) = types
|
||||
else:
|
||||
type_arg_as_str = 'typing.Union[' + \
|
||||
', '.join(types) + ']'
|
||||
type_arg_as_str = "typing.Union[" + ", ".join(types) + "]"
|
||||
simple_type = get_simple_type(type_arg_as_str)
|
||||
if simple_type is not None:
|
||||
return _mgp.type_nullable(simple_type)
|
||||
return _mgp.type_nullable(parse_typing(type_arg_as_str))
|
||||
elif type_as_str.startswith('typing.List'):
|
||||
elif type_as_str.startswith("typing.List"):
|
||||
type_arg_as_str = parse_type_args(type_as_str)
|
||||
|
||||
if len(type_arg_as_str) > 1:
|
||||
# Nested object could be a type consisting of a list of types (e.g. mgp.Map)
|
||||
# so we need to join the parts.
|
||||
type_arg_as_str = ', '.join(type_arg_as_str)
|
||||
type_arg_as_str = ", ".join(type_arg_as_str)
|
||||
else:
|
||||
type_arg_as_str = type_arg_as_str[0]
|
||||
|
||||
@ -1096,9 +1110,11 @@ def _typing_to_cypher_type(type_):
|
||||
|
||||
# Procedure registration
|
||||
|
||||
|
||||
class Deprecated:
|
||||
"""Annotate a resulting Record's field as deprecated."""
|
||||
__slots__ = ('field_type',)
|
||||
|
||||
__slots__ = ("field_type",)
|
||||
|
||||
def __init__(self, type_):
|
||||
self.field_type = type_
|
||||
@ -1106,8 +1122,7 @@ class Deprecated:
|
||||
|
||||
def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
|
||||
if not callable(func):
|
||||
raise TypeError("Expected a callable object, got an instance of '{}'"
|
||||
.format(type(func)))
|
||||
raise TypeError("Expected a callable object, got an instance of '{}'".format(type(func)))
|
||||
if inspect.iscoroutinefunction(func):
|
||||
raise TypeError("Callable must not be 'async def' function")
|
||||
if sys.version_info >= (3, 6):
|
||||
@ -1117,24 +1132,25 @@ def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
|
||||
raise NotImplementedError("Generator functions are not supported")
|
||||
|
||||
|
||||
def _register_proc(func: typing.Callable[..., Record],
|
||||
is_write: bool):
|
||||
def _register_proc(func: typing.Callable[..., Record], is_write: bool):
|
||||
raise_if_does_not_meet_requirements(func)
|
||||
register_func = (
|
||||
_mgp.Module.add_write_procedure if is_write
|
||||
else _mgp.Module.add_read_procedure)
|
||||
register_func = _mgp.Module.add_write_procedure if is_write else _mgp.Module.add_read_procedure
|
||||
sig = inspect.signature(func)
|
||||
params = tuple(sig.parameters.values())
|
||||
if params and params[0].annotation is ProcCtx:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(graph, args):
|
||||
return func(ProcCtx(graph), *args)
|
||||
|
||||
params = params[1:]
|
||||
mgp_proc = register_func(_mgp._MODULE, wrapper)
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(graph, args):
|
||||
return func(*args)
|
||||
|
||||
mgp_proc = register_func(_mgp._MODULE, wrapper)
|
||||
for param in params:
|
||||
name = param.name
|
||||
@ -1149,8 +1165,7 @@ def _register_proc(func: typing.Callable[..., Record],
|
||||
if sig.return_annotation is not sig.empty:
|
||||
record = sig.return_annotation
|
||||
if not isinstance(record, Record):
|
||||
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'"
|
||||
.format(func.__name__, type(record)))
|
||||
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'".format(func.__name__, type(record)))
|
||||
for name, type_ in record.fields.items():
|
||||
if isinstance(type_, Deprecated):
|
||||
cypher_type = _typing_to_cypher_type(type_.field_type)
|
||||
@ -1257,20 +1272,22 @@ class InvalidMessageError(Exception):
|
||||
"""
|
||||
Signals using a message instance outside of the registered transformation.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
SOURCE_TYPE_KAFKA = _mgp.SOURCE_TYPE_KAFKA
|
||||
SOURCE_TYPE_PULSAR = _mgp.SOURCE_TYPE_PULSAR
|
||||
|
||||
|
||||
class Message:
|
||||
"""Represents a message from a stream."""
|
||||
__slots__ = ('_message',)
|
||||
|
||||
__slots__ = ("_message",)
|
||||
|
||||
def __init__(self, message):
|
||||
if not isinstance(message, _mgp.Message):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Message', got '{}'".format(type(message)))
|
||||
raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message)))
|
||||
self._message = message
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -1353,17 +1370,18 @@ class Message:
|
||||
|
||||
class InvalidMessagesError(Exception):
|
||||
"""Signals using a messages instance outside of the registered transformation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Messages:
|
||||
"""Represents a list of messages from a stream."""
|
||||
__slots__ = ('_messages',)
|
||||
|
||||
__slots__ = ("_messages",)
|
||||
|
||||
def __init__(self, messages):
|
||||
if not isinstance(messages, _mgp.Messages):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Messages', got '{}'".format(type(messages)))
|
||||
raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages)))
|
||||
self._messages = messages
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -1395,12 +1413,12 @@ class TransCtx:
|
||||
Access to a TransCtx is only valid during a single execution of a transformation.
|
||||
You should not globally store a TransCtx instance.
|
||||
"""
|
||||
__slots__ = ('_graph')
|
||||
|
||||
__slots__ = "_graph"
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = Graph(graph)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
@ -1420,21 +1438,76 @@ def transformation(func: typing.Callable[..., Record]):
|
||||
params = tuple(sig.parameters.values())
|
||||
if not params or not params[0].annotation is Messages:
|
||||
if not len(params) == 2 or not params[1].annotation is Messages:
|
||||
raise NotImplementedError(
|
||||
"Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
|
||||
raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
|
||||
if params[0].annotation is TransCtx:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(graph, messages):
|
||||
return func(TransCtx(graph), messages)
|
||||
|
||||
_mgp._MODULE.add_transformation(wrapper)
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(graph, messages):
|
||||
return func(messages)
|
||||
|
||||
_mgp._MODULE.add_transformation(wrapper)
|
||||
return func
|
||||
|
||||
|
||||
class FuncCtx:
|
||||
"""Context of a function being executed.
|
||||
|
||||
Access to a FuncCtx is only valid during a single execution of a transformation.
|
||||
You should not globally store a FuncCtx instance.
|
||||
"""
|
||||
|
||||
__slots__ = "_graph"
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = Graph(graph)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self._graph.is_valid()
|
||||
|
||||
|
||||
def function(func: typing.Callable):
|
||||
raise_if_does_not_meet_requirements(func)
|
||||
register_func = _mgp.Module.add_function
|
||||
sig = inspect.signature(func)
|
||||
params = tuple(sig.parameters.values())
|
||||
if params and params[0].annotation is FuncCtx:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(graph, args):
|
||||
return func(FuncCtx(graph), *args)
|
||||
|
||||
params = params[1:]
|
||||
mgp_func = register_func(_mgp._MODULE, wrapper)
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(graph, args):
|
||||
return func(*args)
|
||||
|
||||
mgp_func = register_func(_mgp._MODULE, wrapper)
|
||||
|
||||
for param in params:
|
||||
name = param.name
|
||||
type_ = param.annotation
|
||||
if type_ is param.empty:
|
||||
type_ = object
|
||||
cypher_type = _typing_to_cypher_type(type_)
|
||||
if param.default is param.empty:
|
||||
mgp_func.add_arg(name, cypher_type)
|
||||
else:
|
||||
mgp_func.add_opt_arg(name, cypher_type, param.default)
|
||||
return func
|
||||
|
||||
|
||||
def _wrap_exceptions():
|
||||
def wrap_function(func):
|
||||
@wraps(func)
|
||||
@ -1463,6 +1536,7 @@ def _wrap_exceptions():
|
||||
raise ValueConversionError(e)
|
||||
except _mgp.SerializationError as e:
|
||||
raise SerializationError(e)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
def wrap_prop_func(func):
|
||||
@ -1473,11 +1547,16 @@ def _wrap_exceptions():
|
||||
if inspect.isfunction(obj):
|
||||
setattr(cls, name, wrap_function(obj))
|
||||
elif isinstance(obj, property):
|
||||
setattr(cls, name, property(
|
||||
wrap_prop_func(obj.fget),
|
||||
wrap_prop_func(obj.fset),
|
||||
wrap_prop_func(obj.fdel),
|
||||
obj.__doc__))
|
||||
setattr(
|
||||
cls,
|
||||
name,
|
||||
property(
|
||||
wrap_prop_func(obj.fget),
|
||||
wrap_prop_func(obj.fset),
|
||||
wrap_prop_func(obj.fdel),
|
||||
obj.__doc__,
|
||||
),
|
||||
)
|
||||
|
||||
def defined_in_this_module(obj: object):
|
||||
return getattr(obj, "__module__", "") == __name__
|
||||
|
@ -852,7 +852,9 @@ cpp<#
|
||||
: arguments_(arguments),
|
||||
function_name_(function_name),
|
||||
function_(NameToFunction(function_name_)) {
|
||||
DMG_ASSERT(function_, "Unexpected missing function: {}", function_name_);
|
||||
if (!function_) {
|
||||
throw SemanticException("Function '{}' doesn't exist.", function_name);
|
||||
}
|
||||
}
|
||||
cpp<#)
|
||||
(:private
|
||||
|
@ -2109,13 +2109,30 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::Functio
|
||||
storage_->Create<Aggregation>(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP));
|
||||
}
|
||||
|
||||
auto function = NameToFunction(function_name);
|
||||
if (!function) throw SemanticException("Function '{}' doesn't exist.", function_name);
|
||||
auto is_user_defined_function = [](const std::string &function_name) {
|
||||
// Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined
|
||||
// functions. Builtin functions should be case insensitive.
|
||||
return function_name.find('.') != std::string::npos;
|
||||
};
|
||||
|
||||
// Don't cache queries which call user-defined functions. User-defined function's return
|
||||
// types can vary depending on whether the module is reloaded, therefore the cache would
|
||||
// be invalid.
|
||||
if (is_user_defined_function(function_name)) {
|
||||
query_info_.is_cacheable = false;
|
||||
}
|
||||
|
||||
return static_cast<Expression *>(storage_->Create<Function>(function_name, expressions));
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) {
|
||||
return utils::ToUpperCase(ctx->getText());
|
||||
auto function_name = ctx->getText();
|
||||
// Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined
|
||||
// functions. Builtin functions should be case insensitive.
|
||||
if (function_name.find('.') != std::string::npos) {
|
||||
return function_name;
|
||||
}
|
||||
return utils::ToUpperCase(function_name);
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) {
|
||||
|
@ -279,7 +279,7 @@ idInColl : variable IN expression ;
|
||||
|
||||
functionInvocation : functionName '(' ( DISTINCT )? ( expression ( ',' expression )* )? ')' ;
|
||||
|
||||
functionName : symbolicName ;
|
||||
functionName : symbolicName ( '.' symbolicName )* ;
|
||||
|
||||
listComprehension : '[' filterExpression ( '|' expression )? ']' ;
|
||||
|
||||
|
@ -22,6 +22,9 @@
|
||||
|
||||
#include "query/db_accessor.hpp"
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/procedure/cypher_types.hpp"
|
||||
#include "query/procedure/mg_procedure_impl.hpp"
|
||||
#include "query/procedure/module.hpp"
|
||||
#include "query/typed_value.hpp"
|
||||
#include "utils/string.hpp"
|
||||
#include "utils/temporal.hpp"
|
||||
@ -1174,6 +1177,53 @@ TypedValue Duration(const TypedValue *args, int64_t nargs, const FunctionContext
|
||||
MapNumericParameters<Number>(parameter_mappings, args[0].ValueMap());
|
||||
return TypedValue(utils::Duration(duration_parameters), ctx.memory);
|
||||
}
|
||||
|
||||
std::function<TypedValue(const TypedValue *, const int64_t, const FunctionContext &)> UserFunction(
|
||||
const mgp_func &func, const std::string &fully_qualified_name) {
|
||||
return [func, fully_qualified_name](const TypedValue *args, int64_t nargs, const FunctionContext &ctx) -> TypedValue {
|
||||
/// Find function is called to aquire the lock on Module pointer while user-defined function is executed
|
||||
const auto &maybe_found =
|
||||
procedure::FindFunction(procedure::gModuleRegistry, fully_qualified_name, utils::NewDeleteResource());
|
||||
if (!maybe_found) {
|
||||
throw QueryRuntimeException(
|
||||
"Function '{}' has been unloaded. Please check query modules to confirm that function is loaded in Memgraph.",
|
||||
fully_qualified_name);
|
||||
}
|
||||
/// Explicit extraction of module pointer, to clearly state that the lock is aquired.
|
||||
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
|
||||
const auto &module_ptr = (*maybe_found).first;
|
||||
|
||||
const auto &func_cb = func.cb;
|
||||
mgp_memory memory{ctx.memory};
|
||||
mgp_func_context functx{ctx.db_accessor, ctx.view};
|
||||
auto graph = mgp_graph::NonWritableGraph(*ctx.db_accessor, ctx.view);
|
||||
|
||||
std::vector<TypedValue> args_list;
|
||||
args_list.reserve(nargs);
|
||||
for (std::size_t i = 0; i < nargs; ++i) {
|
||||
args_list.emplace_back(args[i]);
|
||||
}
|
||||
|
||||
auto function_argument_list = mgp_list(ctx.memory);
|
||||
procedure::ConstructArguments(args_list, func, fully_qualified_name, function_argument_list, graph);
|
||||
|
||||
mgp_func_result maybe_res;
|
||||
func_cb(&function_argument_list, &functx, &maybe_res, &memory);
|
||||
if (maybe_res.error_msg) {
|
||||
throw QueryRuntimeException(*maybe_res.error_msg);
|
||||
}
|
||||
|
||||
if (!maybe_res.value) {
|
||||
throw QueryRuntimeException(
|
||||
"Function '{}' didn't set the result nor the error message. Please either set the result by using "
|
||||
"mgp_func_result_set_value or the error by using mgp_func_result_set_error_msg.",
|
||||
fully_qualified_name);
|
||||
}
|
||||
|
||||
return {*(maybe_res.value), ctx.memory};
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx)> NameToFunction(
|
||||
@ -1259,6 +1309,14 @@ std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx
|
||||
if (function_name == "LOCALDATETIME") return LocalDateTime;
|
||||
if (function_name == "DURATION") return Duration;
|
||||
|
||||
const auto &maybe_found =
|
||||
procedure::FindFunction(procedure::gModuleRegistry, function_name, utils::NewDeleteResource());
|
||||
|
||||
if (maybe_found) {
|
||||
const auto *func = (*maybe_found).second;
|
||||
return UserFunction(*func, function_name);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -3705,46 +3705,12 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
|
||||
"containers aware of that");
|
||||
// Build and type check procedure arguments.
|
||||
mgp_list proc_args(memory);
|
||||
proc_args.elems.reserve(args.size());
|
||||
if (args.size() < proc.args.size() ||
|
||||
// Rely on `||` short circuit so we can avoid potential overflow of
|
||||
// proc.args.size() + proc.opt_args.size() by subtracting.
|
||||
(args.size() - proc.args.size() > proc.opt_args.size())) {
|
||||
if (proc.args.empty() && proc.opt_args.empty()) {
|
||||
throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_procedure_name);
|
||||
} else if (proc.opt_args.empty()) {
|
||||
throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_procedure_name, proc.args.size(),
|
||||
proc.args.size() == 1U ? "argument" : "arguments");
|
||||
} else {
|
||||
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_procedure_name,
|
||||
proc.args.size(), proc.args.size() + proc.opt_args.size());
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
auto arg = args[i]->Accept(*evaluator);
|
||||
std::string_view name;
|
||||
const query::procedure::CypherType *type{nullptr};
|
||||
if (proc.args.size() > i) {
|
||||
name = proc.args[i].first;
|
||||
type = proc.args[i].second;
|
||||
} else {
|
||||
MG_ASSERT(proc.opt_args.size() > i - proc.args.size());
|
||||
name = std::get<0>(proc.opt_args[i - proc.args.size()]);
|
||||
type = std::get<1>(proc.opt_args[i - proc.args.size()]);
|
||||
}
|
||||
if (!type->SatisfiesType(arg)) {
|
||||
throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.",
|
||||
fully_qualified_procedure_name, name, i, type->GetPresentableName());
|
||||
}
|
||||
proc_args.elems.emplace_back(std::move(arg), &graph);
|
||||
}
|
||||
// Fill missing optional arguments with their default values.
|
||||
MG_ASSERT(args.size() >= proc.args.size());
|
||||
size_t passed_in_opt_args = args.size() - proc.args.size();
|
||||
MG_ASSERT(passed_in_opt_args <= proc.opt_args.size());
|
||||
for (size_t i = passed_in_opt_args; i < proc.opt_args.size(); ++i) {
|
||||
proc_args.elems.emplace_back(std::get<2>(proc.opt_args[i]), &graph);
|
||||
std::vector<TypedValue> args_list;
|
||||
args_list.reserve(args.size());
|
||||
for (auto *expression : args) {
|
||||
args_list.emplace_back(expression->Accept(*evaluator));
|
||||
}
|
||||
procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph);
|
||||
if (memory_limit) {
|
||||
SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name,
|
||||
utils::GetReadableSize(*memory_limit));
|
||||
@ -3832,7 +3798,7 @@ class CallProcedureCursor : public Cursor {
|
||||
// generator like procedures which yield a new result on each invocation.
|
||||
auto *memory = context.evaluation_context.memory;
|
||||
auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_);
|
||||
mgp_graph graph{context.db_accessor, graph_view, &context};
|
||||
auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context);
|
||||
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
|
||||
&result_);
|
||||
|
||||
|
@ -187,7 +187,10 @@ template <typename TFunc, typename... Args>
|
||||
return MGP_ERROR_NO_ERROR;
|
||||
}
|
||||
|
||||
bool MgpGraphIsMutable(const mgp_graph &graph) noexcept { return graph.view == memgraph::storage::View::NEW; }
|
||||
// Graph mutations
|
||||
bool MgpGraphIsMutable(const mgp_graph &graph) noexcept {
|
||||
return graph.view == memgraph::storage::View::NEW && graph.ctx != nullptr;
|
||||
}
|
||||
|
||||
bool MgpVertexIsMutable(const mgp_vertex &vertex) { return MgpGraphIsMutable(*vertex.graph); }
|
||||
|
||||
@ -289,6 +292,7 @@ mgp_value_type FromTypedValueType(memgraph::query::TypedValue::Type type) {
|
||||
return MGP_VALUE_TYPE_DURATION;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory) {
|
||||
switch (val.type) {
|
||||
@ -345,8 +349,6 @@ memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
mgp_value::mgp_value(memgraph::utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_NULL), memory(m) {}
|
||||
|
||||
mgp_value::mgp_value(bool val, memgraph::utils::MemoryResource *m) noexcept
|
||||
@ -1451,6 +1453,14 @@ mgp_error mgp_result_record_insert(mgp_result_record *record, const char *field_
|
||||
});
|
||||
}
|
||||
|
||||
mgp_error mgp_func_result_set_error_msg(mgp_func_result *res, const char *msg, mgp_memory *memory) {
|
||||
return WrapExceptions([=] { res->error_msg.emplace(msg, memory->impl); });
|
||||
}
|
||||
|
||||
mgp_error mgp_func_result_set_value(mgp_func_result *res, mgp_value *value, mgp_memory *memory) {
|
||||
return WrapExceptions([=] { res->value = ToTypedValue(*value, memory->impl); });
|
||||
}
|
||||
|
||||
/// Graph Constructs
|
||||
|
||||
void mgp_properties_iterator_destroy(mgp_properties_iterator *it) { DeleteRawMgpObject(it); }
|
||||
@ -2382,31 +2392,53 @@ mgp_error mgp_module_add_write_procedure(mgp_module *module, const char *name, m
|
||||
return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result);
|
||||
}
|
||||
|
||||
mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) {
|
||||
return WrapExceptions([=] {
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
throw std::invalid_argument{fmt::format("Invalid argument name for procedure '{}': {}", proc->name, name)};
|
||||
namespace {
|
||||
template <typename T>
|
||||
concept IsCallable = memgraph::utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
|
||||
|
||||
template <IsCallable TCall>
|
||||
mgp_error MgpAddArg(TCall &callable, const std::string &name, mgp_type &type) {
|
||||
return WrapExceptions([&]() mutable {
|
||||
static constexpr std::string_view type_name = std::invoke([]() constexpr {
|
||||
if constexpr (std::is_same_v<TCall, mgp_proc>) {
|
||||
return "procedure";
|
||||
} else if constexpr (std::is_same_v<TCall, mgp_func>) {
|
||||
return "function";
|
||||
}
|
||||
});
|
||||
|
||||
if (!IsValidIdentifierName(name.c_str())) {
|
||||
throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)};
|
||||
}
|
||||
if (!proc->opt_args.empty()) {
|
||||
throw std::logic_error{fmt::format(
|
||||
"Cannot add required argument '{}' to procedure '{}' after adding any optional one", name, proc->name)};
|
||||
if (!callable.opt_args.empty()) {
|
||||
throw std::logic_error{fmt::format("Cannot add required argument '{}' to {} '{}' after adding any optional one",
|
||||
name, type_name, callable.name)};
|
||||
}
|
||||
proc->args.emplace_back(name, type->impl.get());
|
||||
callable.args.emplace_back(name, type.impl.get());
|
||||
});
|
||||
}
|
||||
|
||||
mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) {
|
||||
return WrapExceptions([=] {
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
throw std::invalid_argument{fmt::format("Invalid argument name for procedure '{}': {}", proc->name, name)};
|
||||
template <IsCallable TCall>
|
||||
mgp_error MgpAddOptArg(TCall &callable, const std::string name, mgp_type &type, mgp_value &default_value) {
|
||||
return WrapExceptions([&]() mutable {
|
||||
static constexpr std::string_view type_name = std::invoke([]() constexpr {
|
||||
if constexpr (std::is_same_v<TCall, mgp_proc>) {
|
||||
return "procedure";
|
||||
} else if constexpr (std::is_same_v<TCall, mgp_func>) {
|
||||
return "function";
|
||||
}
|
||||
});
|
||||
|
||||
if (!IsValidIdentifierName(name.c_str())) {
|
||||
throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)};
|
||||
}
|
||||
switch (MgpValueGetType(*default_value)) {
|
||||
switch (MgpValueGetType(default_value)) {
|
||||
case MGP_VALUE_TYPE_VERTEX:
|
||||
case MGP_VALUE_TYPE_EDGE:
|
||||
case MGP_VALUE_TYPE_PATH:
|
||||
// default_value must not be a graph element.
|
||||
throw ValueConversionException{
|
||||
"Default value of argument '{}' of procedure '{}' name must not be a graph element!", name, proc->name};
|
||||
throw ValueConversionException{"Default value of argument '{}' of {} '{}' name must not be a graph element!",
|
||||
name, type_name, callable.name};
|
||||
case MGP_VALUE_TYPE_NULL:
|
||||
case MGP_VALUE_TYPE_BOOL:
|
||||
case MGP_VALUE_TYPE_INT:
|
||||
@ -2421,16 +2453,32 @@ mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type,
|
||||
break;
|
||||
}
|
||||
// Default value must be of required `type`.
|
||||
if (!type->impl->SatisfiesType(*default_value)) {
|
||||
throw std::logic_error{
|
||||
fmt::format("The default value of argument '{}' for procedure '{}' doesn't satisfy type '{}'", name,
|
||||
proc->name, type->impl->GetPresentableName())};
|
||||
if (!type.impl->SatisfiesType(default_value)) {
|
||||
throw std::logic_error{fmt::format("The default value of argument '{}' for {} '{}' doesn't satisfy type '{}'",
|
||||
name, type_name, callable.name, type.impl->GetPresentableName())};
|
||||
}
|
||||
auto *memory = proc->opt_args.get_allocator().GetMemoryResource();
|
||||
proc->opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type->impl.get(),
|
||||
ToTypedValue(*default_value, memory));
|
||||
auto *memory = callable.opt_args.get_allocator().GetMemoryResource();
|
||||
callable.opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type.impl.get(),
|
||||
ToTypedValue(default_value, memory));
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) {
|
||||
return MgpAddArg(*proc, std::string(name), *type);
|
||||
}
|
||||
|
||||
mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) {
|
||||
return MgpAddOptArg(*proc, std::string(name), *type, *default_value);
|
||||
}
|
||||
|
||||
mgp_error mgp_func_add_arg(mgp_func *func, const char *name, mgp_type *type) {
|
||||
return MgpAddArg(*func, std::string(name), *type);
|
||||
}
|
||||
|
||||
mgp_error mgp_func_add_opt_arg(mgp_func *func, const char *name, mgp_type *type, mgp_value *default_value) {
|
||||
return MgpAddOptArg(*func, std::string(name), *type, *default_value);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@ -2545,6 +2593,22 @@ void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) {
|
||||
(*stream) << ")";
|
||||
}
|
||||
|
||||
void PrintFuncSignature(const mgp_func &func, std::ostream &stream) {
|
||||
stream << func.name << "(";
|
||||
utils::PrintIterable(stream, func.args, ", ", [](auto &stream, const auto &arg) {
|
||||
stream << arg.first << " :: " << arg.second->GetPresentableName();
|
||||
});
|
||||
if (!func.args.empty() && !func.opt_args.empty()) {
|
||||
stream << ", ";
|
||||
}
|
||||
utils::PrintIterable(stream, func.opt_args, ", ", [](auto &stream, const auto &arg) {
|
||||
const auto &[name, type, default_val] = arg;
|
||||
stream << name << " = ";
|
||||
PrintValue(default_val, &stream) << " :: " << type->GetPresentableName();
|
||||
});
|
||||
stream << ")";
|
||||
}
|
||||
|
||||
bool IsValidIdentifierName(const char *name) {
|
||||
if (!name) return false;
|
||||
std::regex regex("[_[:alpha:]][_[:alnum:]]*");
|
||||
@ -2716,3 +2780,19 @@ mgp_error mgp_module_add_transformation(mgp_module *module, const char *name, mg
|
||||
module->transformations.emplace(name, mgp_trans(name, cb, memory));
|
||||
});
|
||||
}
|
||||
|
||||
mgp_error mgp_module_add_function(mgp_module *module, const char *name, mgp_func_cb cb, mgp_func **result) {
|
||||
return WrapExceptions(
|
||||
[=] {
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
throw std::invalid_argument{fmt::format("Invalid function name: {}", name)};
|
||||
}
|
||||
if (module->functions.find(name) != module->functions.end()) {
|
||||
throw std::logic_error{fmt::format("Function with similar name already exists '{}'", name)};
|
||||
};
|
||||
auto *memory = module->functions.get_allocator().GetMemoryResource();
|
||||
|
||||
return &module->functions.emplace(name, mgp_func(name, cb, memory)).first->second;
|
||||
},
|
||||
result);
|
||||
}
|
||||
|
@ -562,14 +562,36 @@ struct mgp_result {
|
||||
std::optional<memgraph::utils::pmr::string> error_msg;
|
||||
};
|
||||
|
||||
struct mgp_func_result {
|
||||
mgp_func_result() {}
|
||||
/// Return Magic function result. If user forgets it, the error is raised
|
||||
std::optional<memgraph::query::TypedValue> value;
|
||||
/// Return Magic function result with potential error
|
||||
std::optional<memgraph::utils::pmr::string> error_msg;
|
||||
};
|
||||
|
||||
struct mgp_graph {
|
||||
memgraph::query::DbAccessor *impl;
|
||||
memgraph::storage::View view;
|
||||
// TODO: Merge `mgp_graph` and `mgp_memory` into a single `mgp_context`. The
|
||||
// `ctx` field is out of place here.
|
||||
memgraph::query::ExecutionContext *ctx;
|
||||
|
||||
static mgp_graph WritableGraph(memgraph::query::DbAccessor &acc, memgraph::storage::View view,
|
||||
memgraph::query::ExecutionContext &ctx) {
|
||||
return mgp_graph{&acc, view, &ctx};
|
||||
}
|
||||
|
||||
static mgp_graph NonWritableGraph(memgraph::query::DbAccessor &acc, memgraph::storage::View view) {
|
||||
return mgp_graph{&acc, view, nullptr};
|
||||
}
|
||||
};
|
||||
|
||||
// Prevents user to use ExecutionContext in writable callables
|
||||
struct mgp_func_context {
|
||||
memgraph::query::DbAccessor *impl;
|
||||
memgraph::storage::View view;
|
||||
};
|
||||
struct mgp_properties_iterator {
|
||||
using allocator_type = memgraph::utils::Allocator<mgp_properties_iterator>;
|
||||
|
||||
@ -779,18 +801,69 @@ struct mgp_trans {
|
||||
results;
|
||||
};
|
||||
|
||||
struct mgp_func {
|
||||
using allocator_type = memgraph::utils::Allocator<mgp_func>;
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_func(const char *name, mgp_func_cb cb, memgraph::utils::MemoryResource *memory)
|
||||
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_func(const char *name, std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb,
|
||||
memgraph::utils::MemoryResource *memory)
|
||||
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_func(const mgp_func &other, memgraph::utils::MemoryResource *memory)
|
||||
: name(other.name, memory), cb(other.cb), args(other.args, memory), opt_args(other.opt_args, memory) {}
|
||||
|
||||
mgp_func(mgp_func &&other, memgraph::utils::MemoryResource *memory)
|
||||
: name(std::move(other.name), memory),
|
||||
cb(std::move(other.cb)),
|
||||
args(std::move(other.args), memory),
|
||||
opt_args(std::move(other.opt_args), memory) {}
|
||||
|
||||
mgp_func(const mgp_func &other) = default;
|
||||
mgp_func(mgp_func &&other) = default;
|
||||
|
||||
mgp_func &operator=(const mgp_func &) = delete;
|
||||
mgp_func &operator=(mgp_func &&) = delete;
|
||||
|
||||
~mgp_func() = default;
|
||||
|
||||
/// Name of the function.
|
||||
memgraph::utils::pmr::string name;
|
||||
/// Entry-point for the function.
|
||||
std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb;
|
||||
/// Required, positional arguments as a (name, type) pair.
|
||||
memgraph::utils::pmr::vector<std::pair<memgraph::utils::pmr::string, const memgraph::query::procedure::CypherType *>>
|
||||
args;
|
||||
/// Optional positional arguments as a (name, type, default_value) tuple.
|
||||
memgraph::utils::pmr::vector<std::tuple<memgraph::utils::pmr::string, const memgraph::query::procedure::CypherType *,
|
||||
memgraph::query::TypedValue>>
|
||||
opt_args;
|
||||
};
|
||||
|
||||
mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept;
|
||||
|
||||
struct mgp_module {
|
||||
using allocator_type = memgraph::utils::Allocator<mgp_module>;
|
||||
|
||||
explicit mgp_module(memgraph::utils::MemoryResource *memory) : procedures(memory), transformations(memory) {}
|
||||
explicit mgp_module(memgraph::utils::MemoryResource *memory)
|
||||
: procedures(memory), transformations(memory), functions(memory) {}
|
||||
|
||||
mgp_module(const mgp_module &other, memgraph::utils::MemoryResource *memory)
|
||||
: procedures(other.procedures, memory), transformations(other.transformations, memory) {}
|
||||
: procedures(other.procedures, memory),
|
||||
transformations(other.transformations, memory),
|
||||
functions(other.functions, memory) {}
|
||||
|
||||
mgp_module(mgp_module &&other, memgraph::utils::MemoryResource *memory)
|
||||
: procedures(std::move(other.procedures), memory), transformations(std::move(other.transformations), memory) {}
|
||||
: procedures(std::move(other.procedures), memory),
|
||||
transformations(std::move(other.transformations), memory),
|
||||
functions(std::move(other.functions), memory) {}
|
||||
|
||||
mgp_module(const mgp_module &) = default;
|
||||
mgp_module(mgp_module &&) = default;
|
||||
@ -802,6 +875,7 @@ struct mgp_module {
|
||||
|
||||
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_proc> procedures;
|
||||
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_trans> transformations;
|
||||
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_func> functions;
|
||||
};
|
||||
|
||||
namespace memgraph::query::procedure {
|
||||
@ -811,6 +885,11 @@ namespace memgraph::query::procedure {
|
||||
/// @throw anything std::ostream::operator<< may throw.
|
||||
void PrintProcSignature(const mgp_proc &, std::ostream *);
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
/// @throw anything std::ostream::operator<< may throw.
|
||||
void PrintFuncSignature(const mgp_func &, std::ostream &);
|
||||
|
||||
bool IsValidIdentifierName(const char *name);
|
||||
|
||||
} // namespace memgraph::query::procedure
|
||||
@ -839,3 +918,5 @@ struct mgp_messages {
|
||||
|
||||
storage_type messages;
|
||||
};
|
||||
|
||||
memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory);
|
||||
|
@ -52,6 +52,8 @@ class BuiltinModule final : public Module {
|
||||
|
||||
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
|
||||
|
||||
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
|
||||
|
||||
void AddProcedure(std::string_view name, mgp_proc proc);
|
||||
|
||||
void AddTransformation(std::string_view name, mgp_trans trans);
|
||||
@ -62,6 +64,7 @@ class BuiltinModule final : public Module {
|
||||
/// Registered procedures
|
||||
std::map<std::string, mgp_proc, std::less<>> procedures_;
|
||||
std::map<std::string, mgp_trans, std::less<>> transformations_;
|
||||
std::map<std::string, mgp_func, std::less<>> functions_;
|
||||
};
|
||||
|
||||
BuiltinModule::BuiltinModule() {}
|
||||
@ -75,6 +78,7 @@ const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures()
|
||||
const std::map<std::string, mgp_trans, std::less<>> *BuiltinModule::Transformations() const {
|
||||
return &transformations_;
|
||||
}
|
||||
const std::map<std::string, mgp_func, std::less<>> *BuiltinModule::Functions() const { return &functions_; }
|
||||
|
||||
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) { procedures_.emplace(name, std::move(proc)); }
|
||||
|
||||
@ -300,6 +304,82 @@ void RegisterMgTransformations(const std::map<std::string, std::unique_ptr<Modul
|
||||
module->AddProcedure("transformations", std::move(procedures));
|
||||
}
|
||||
|
||||
void RegisterMgFunctions(
|
||||
// We expect modules to be sorted by name.
|
||||
const std::map<std::string, std::unique_ptr<Module>, std::less<>> *all_modules, BuiltinModule *module) {
|
||||
auto functions_cb = [all_modules](mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result *result,
|
||||
mgp_memory *memory) {
|
||||
// Iterating over all_modules assumes that the standard mechanism of magic
|
||||
// functions invocations takes the ModuleRegistry::lock_ with READ access.
|
||||
for (const auto &[module_name, module] : *all_modules) {
|
||||
// Return the results in sorted order by module and by function_name.
|
||||
static_assert(std::is_same_v<decltype(module->Functions()), const std::map<std::string, mgp_func, std::less<>> *>,
|
||||
"Expected module magic functions to be sorted by name");
|
||||
|
||||
const auto path = module->Path();
|
||||
const auto path_string = GetPathString(path);
|
||||
const auto is_editable = IsFileEditable(path);
|
||||
|
||||
for (const auto &[func_name, func] : *module->Functions()) {
|
||||
mgp_result_record *record{nullptr};
|
||||
|
||||
if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result);
|
||||
if (!path_value) {
|
||||
return;
|
||||
}
|
||||
|
||||
MgpUniquePtr<mgp_value> is_editable_value{nullptr, mgp_value_destroy};
|
||||
if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); },
|
||||
result)) {
|
||||
return;
|
||||
}
|
||||
|
||||
utils::pmr::string full_name(module_name, memory->impl);
|
||||
full_name.append(1, '.');
|
||||
full_name.append(func_name);
|
||||
const auto name_value = GetStringValueOrSetError(full_name.c_str(), memory, result);
|
||||
if (!name_value) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << module_name << ".";
|
||||
PrintFuncSignature(func, ss);
|
||||
const auto signature = ss.str();
|
||||
const auto signature_value = GetStringValueOrSetError(signature.c_str(), memory, result);
|
||||
if (!signature_value) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!InsertResultOrSetError(result, record, "name", name_value.get())) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!InsertResultOrSetError(result, record, "signature", signature_value.get())) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!InsertResultOrSetError(result, record, "path", path_value.get())) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
mgp_proc functions("functions", functions_cb, utils::NewDeleteResource());
|
||||
MG_ASSERT(mgp_proc_add_result(&functions, "name", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
|
||||
MG_ASSERT(mgp_proc_add_result(&functions, "signature", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
|
||||
MG_ASSERT(mgp_proc_add_result(&functions, "path", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
|
||||
MG_ASSERT(mgp_proc_add_result(&functions, "is_editable", Call<mgp_type *>(mgp_type_bool)) == MGP_ERROR_NO_ERROR);
|
||||
module->AddProcedure("functions", std::move(functions));
|
||||
}
|
||||
namespace {
|
||||
bool IsAllowedExtension(const auto &extension) {
|
||||
static constexpr std::array<std::string_view, 1> allowed_extensions{".py"};
|
||||
@ -650,8 +730,8 @@ void RegisterMgDeleteModuleFile(ModuleRegistry *module_registry, utils::RWLock *
|
||||
// `mgp_module::transformations into `proc_map`. The return value of WithModuleRegistration
|
||||
// is the same as that of `fun`. Note, the return value need only be convertible to `bool`,
|
||||
// it does not have to be `bool` itself.
|
||||
template <class TProcMap, class TTransMap, class TFun>
|
||||
auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, const TFun &fun) {
|
||||
template <class TProcMap, class TTransMap, class TFuncMap, class TFun>
|
||||
auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, TFuncMap *func_map, const TFun &fun) {
|
||||
// We probably don't need more than 256KB for module initialization.
|
||||
static constexpr size_t stack_bytes = 256UL * 1024UL;
|
||||
unsigned char stack_memory[stack_bytes];
|
||||
@ -664,6 +744,8 @@ auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, const TFun
|
||||
for (const auto &proc : module_def.procedures) proc_map->emplace(proc);
|
||||
// Copy transformations into resulting trans_map.
|
||||
for (const auto &trans : module_def.transformations) trans_map->emplace(trans);
|
||||
// Copy functions into resulting func_map.
|
||||
for (const auto &func : module_def.functions) func_map->emplace(func);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -687,6 +769,8 @@ class SharedLibraryModule final : public Module {
|
||||
|
||||
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
|
||||
|
||||
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
|
||||
|
||||
std::optional<std::filesystem::path> Path() const override { return file_path_; }
|
||||
|
||||
private:
|
||||
@ -702,6 +786,8 @@ class SharedLibraryModule final : public Module {
|
||||
std::map<std::string, mgp_proc, std::less<>> procedures_;
|
||||
/// Registered transformations
|
||||
std::map<std::string, mgp_trans, std::less<>> transformations_;
|
||||
/// Registered functions
|
||||
std::map<std::string, mgp_func, std::less<>> functions_;
|
||||
};
|
||||
|
||||
SharedLibraryModule::SharedLibraryModule() : handle_(nullptr) {}
|
||||
@ -755,7 +841,7 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (!WithModuleRegistration(&procedures_, &transformations_, module_cb)) {
|
||||
if (!WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb)) {
|
||||
return false;
|
||||
}
|
||||
// Get optional mgp_shutdown_module
|
||||
@ -801,6 +887,13 @@ const std::map<std::string, mgp_trans, std::less<>> *SharedLibraryModule::Transf
|
||||
return &transformations_;
|
||||
}
|
||||
|
||||
const std::map<std::string, mgp_func, std::less<>> *SharedLibraryModule::Functions() const {
|
||||
MG_ASSERT(handle_,
|
||||
"Attempting to access functions of a module that has not "
|
||||
"been loaded...");
|
||||
return &functions_;
|
||||
}
|
||||
|
||||
class PythonModule final : public Module {
|
||||
public:
|
||||
PythonModule();
|
||||
@ -816,6 +909,7 @@ class PythonModule final : public Module {
|
||||
|
||||
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
|
||||
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
|
||||
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
|
||||
std::optional<std::filesystem::path> Path() const override { return file_path_; }
|
||||
|
||||
private:
|
||||
@ -823,6 +917,7 @@ class PythonModule final : public Module {
|
||||
py::Object py_module_;
|
||||
std::map<std::string, mgp_proc, std::less<>> procedures_;
|
||||
std::map<std::string, mgp_trans, std::less<>> transformations_;
|
||||
std::map<std::string, mgp_func, std::less<>> functions_;
|
||||
};
|
||||
|
||||
PythonModule::PythonModule() {}
|
||||
@ -853,7 +948,7 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
|
||||
};
|
||||
return result;
|
||||
};
|
||||
py_module_ = WithModuleRegistration(&procedures_, &transformations_, module_cb);
|
||||
py_module_ = WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb);
|
||||
if (py_module_) {
|
||||
spdlog::info("Loaded module {}", file_path);
|
||||
|
||||
@ -877,6 +972,7 @@ bool PythonModule::Close() {
|
||||
auto gil = py::EnsureGIL();
|
||||
procedures_.clear();
|
||||
transformations_.clear();
|
||||
functions_.clear();
|
||||
// Delete the module from the `sys.modules` directory so that the module will
|
||||
// be properly imported if imported again.
|
||||
py::Object sys(PyImport_ImportModule("sys"));
|
||||
@ -906,6 +1002,13 @@ const std::map<std::string, mgp_trans, std::less<>> *PythonModule::Transformatio
|
||||
"not been loaded...");
|
||||
return &transformations_;
|
||||
}
|
||||
|
||||
const std::map<std::string, mgp_func, std::less<>> *PythonModule::Functions() const {
|
||||
MG_ASSERT(py_module_,
|
||||
"Attempting to access functions of a module that has "
|
||||
"not been loaded...");
|
||||
return &functions_;
|
||||
}
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<Module> LoadModuleFromFile(const std::filesystem::path &path) {
|
||||
@ -954,6 +1057,7 @@ ModuleRegistry::ModuleRegistry() {
|
||||
auto module = std::make_unique<BuiltinModule>();
|
||||
RegisterMgProcedures(&modules_, module.get());
|
||||
RegisterMgTransformations(&modules_, module.get());
|
||||
RegisterMgFunctions(&modules_, module.get());
|
||||
RegisterMgLoad(this, &lock_, module.get());
|
||||
RegisterMgGetModuleFiles(this, module.get());
|
||||
RegisterMgGetModuleFile(this, module.get());
|
||||
@ -1083,7 +1187,7 @@ std::optional<std::pair<std::string_view, std::string_view>> FindModuleNameAndPr
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans>;
|
||||
concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans, mgp_func>;
|
||||
|
||||
template <ModuleProperties T>
|
||||
std::optional<std::pair<ModulePtr, const T *>> MakePairIfPropFound(const ModuleRegistry &module_registry,
|
||||
@ -1092,8 +1196,10 @@ std::optional<std::pair<ModulePtr, const T *>> MakePairIfPropFound(const ModuleR
|
||||
auto prop_fun = [](auto &module) {
|
||||
if constexpr (std::is_same_v<T, mgp_proc>) {
|
||||
return module->Procedures();
|
||||
} else {
|
||||
} else if constexpr (std::is_same_v<T, mgp_trans>) {
|
||||
return module->Transformations();
|
||||
} else if constexpr (std::is_same_v<T, mgp_func>) {
|
||||
return module->Functions();
|
||||
}
|
||||
};
|
||||
auto result = FindModuleNameAndProp(module_registry, fully_qualified_name, memory);
|
||||
@ -1121,4 +1227,10 @@ std::optional<std::pair<ModulePtr, const mgp_trans *>> FindTransformation(
|
||||
return MakePairIfPropFound<mgp_trans>(module_registry, fully_qualified_transformation_name, memory);
|
||||
}
|
||||
|
||||
std::optional<std::pair<ModulePtr, const mgp_func *>> FindFunction(const ModuleRegistry &module_registry,
|
||||
std::string_view fully_qualified_function_name,
|
||||
utils::MemoryResource *memory) {
|
||||
return MakePairIfPropFound<mgp_func>(module_registry, fully_qualified_function_name, memory);
|
||||
}
|
||||
|
||||
} // namespace memgraph::query::procedure
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "query/procedure/cypher_types.hpp"
|
||||
#include "query/procedure/mg_procedure_impl.hpp"
|
||||
#include "utils/memory.hpp"
|
||||
#include "utils/rw_lock.hpp"
|
||||
@ -45,6 +46,8 @@ class Module {
|
||||
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0;
|
||||
/// Returns registered transformations of this module
|
||||
virtual const std::map<std::string, mgp_trans, std::less<>> *Transformations() const = 0;
|
||||
// /// Returns registered functions of this module
|
||||
virtual const std::map<std::string, mgp_func, std::less<>> *Functions() const = 0;
|
||||
|
||||
virtual std::optional<std::filesystem::path> Path() const = 0;
|
||||
};
|
||||
@ -147,4 +150,62 @@ std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
|
||||
std::optional<std::pair<procedure::ModulePtr, const mgp_trans *>> FindTransformation(
|
||||
const ModuleRegistry &module_registry, const std::string_view fully_qualified_transformation_name,
|
||||
utils::MemoryResource *memory);
|
||||
|
||||
/// Return the ModulePtr and `mgp_func *` of the found function after resolving
|
||||
/// `fully_qualified_function_name` if found. If there is no such function
|
||||
/// std::nullopt is returned. `memory` is used for temporary allocations
|
||||
/// inside this function. ModulePtr must be kept alive to make sure it won't be unloaded.
|
||||
std::optional<std::pair<procedure::ModulePtr, const mgp_func *>> FindFunction(
|
||||
const ModuleRegistry &module_registry, const std::string_view fully_qualified_function_name,
|
||||
utils::MemoryResource *memory);
|
||||
|
||||
template <typename T>
|
||||
concept IsCallable = utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
|
||||
|
||||
template <IsCallable TCall>
|
||||
void ConstructArguments(const std::vector<TypedValue> &args, const TCall &callable,
|
||||
const std::string_view fully_qualified_name, mgp_list &args_list, mgp_graph &graph) {
|
||||
const auto n_args = args.size();
|
||||
const auto c_args_sz = callable.args.size();
|
||||
const auto c_opt_args_sz = callable.opt_args.size();
|
||||
|
||||
if (n_args < c_args_sz || (n_args - c_args_sz > c_opt_args_sz)) {
|
||||
if (callable.args.empty() && callable.opt_args.empty()) {
|
||||
throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_name);
|
||||
}
|
||||
|
||||
if (callable.opt_args.empty()) {
|
||||
throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_name, c_args_sz,
|
||||
c_args_sz == 1U ? "argument" : "arguments");
|
||||
}
|
||||
|
||||
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_name, c_args_sz,
|
||||
c_args_sz + c_opt_args_sz);
|
||||
}
|
||||
args_list.elems.reserve(n_args);
|
||||
|
||||
auto is_not_optional_arg = [c_args_sz](int i) { return c_args_sz > i; };
|
||||
for (size_t i = 0; i < n_args; ++i) {
|
||||
auto arg = args[i];
|
||||
std::string_view name;
|
||||
const query::procedure::CypherType *type;
|
||||
if (is_not_optional_arg(i)) {
|
||||
name = callable.args[i].first;
|
||||
type = callable.args[i].second;
|
||||
} else {
|
||||
name = std::get<0>(callable.opt_args[i - c_args_sz]);
|
||||
type = std::get<1>(callable.opt_args[i - c_args_sz]);
|
||||
}
|
||||
if (!type->SatisfiesType(arg)) {
|
||||
throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.", fully_qualified_name,
|
||||
name, i, type->GetPresentableName());
|
||||
}
|
||||
args_list.elems.emplace_back(std::move(arg), &graph);
|
||||
}
|
||||
// Fill missing optional arguments with their default values.
|
||||
const size_t passed_in_opt_args = n_args - c_args_sz;
|
||||
for (size_t i = passed_in_opt_args; i < c_opt_args_sz; ++i) {
|
||||
args_list.elems.emplace_back(std::get<2>(callable.opt_args[i]), &graph);
|
||||
}
|
||||
}
|
||||
} // namespace memgraph::query::procedure
|
||||
|
@ -447,62 +447,94 @@ PyObject *MakePyCypherType(mgp_type *type) {
|
||||
// clang-format off
|
||||
struct PyQueryProc {
|
||||
PyObject_HEAD
|
||||
mgp_proc *proc;
|
||||
mgp_proc *callable;
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) {
|
||||
MG_ASSERT(self->proc);
|
||||
// clang-format off
|
||||
struct PyMagicFunc{
|
||||
PyObject_HEAD
|
||||
mgp_func *callable;
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
template <typename T>
|
||||
concept IsCallable = utils::SameAsAnyOf<T, PyQueryProc, PyMagicFunc>;
|
||||
|
||||
template <IsCallable TCall>
|
||||
PyObject *PyCallableAddArg(TCall *self, PyObject *args) {
|
||||
MG_ASSERT(self->callable);
|
||||
const char *name = nullptr;
|
||||
PyCypherType *py_type = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
|
||||
auto *type = py_type->type;
|
||||
if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->proc, name, type))) {
|
||||
return nullptr;
|
||||
|
||||
if constexpr (std::is_same_v<TCall, PyQueryProc>) {
|
||||
if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->callable, name, type))) {
|
||||
return nullptr;
|
||||
}
|
||||
} else if constexpr (std::is_same_v<TCall, PyMagicFunc>) {
|
||||
if (RaiseExceptionFromErrorCode(mgp_func_add_arg(self->callable, name, type))) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) {
|
||||
MG_ASSERT(self->proc);
|
||||
template <IsCallable TCall>
|
||||
PyObject *PyCallableAddOptArg(TCall *self, PyObject *args) {
|
||||
MG_ASSERT(self->callable);
|
||||
const char *name = nullptr;
|
||||
PyCypherType *py_type = nullptr;
|
||||
PyObject *py_value = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sO!O", &name, &PyCypherTypeType, &py_type, &py_value)) return nullptr;
|
||||
auto *type = py_type->type;
|
||||
mgp_memory memory{self->proc->opt_args.get_allocator().GetMemoryResource()};
|
||||
mgp_memory memory{self->callable->opt_args.get_allocator().GetMemoryResource()};
|
||||
mgp_value *value = PyObjectToMgpValueWithPythonExceptions(py_value, &memory);
|
||||
if (value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->proc, name, type, value))) {
|
||||
mgp_value_destroy(value);
|
||||
return nullptr;
|
||||
if constexpr (std::is_same_v<TCall, PyQueryProc>) {
|
||||
if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->callable, name, type, value))) {
|
||||
mgp_value_destroy(value);
|
||||
return nullptr;
|
||||
}
|
||||
} else if constexpr (std::is_same_v<TCall, PyMagicFunc>) {
|
||||
if (RaiseExceptionFromErrorCode(mgp_func_add_opt_arg(self->callable, name, type, value))) {
|
||||
mgp_value_destroy(value);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
mgp_value_destroy(value);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) { return PyCallableAddArg(self, args); }
|
||||
|
||||
PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) { return PyCallableAddOptArg(self, args); }
|
||||
|
||||
PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) {
|
||||
MG_ASSERT(self->proc);
|
||||
MG_ASSERT(self->callable);
|
||||
const char *name = nullptr;
|
||||
PyCypherType *py_type = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
|
||||
|
||||
auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
|
||||
if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->proc, name, type))) {
|
||||
if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->callable, name, type))) {
|
||||
return nullptr;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) {
|
||||
MG_ASSERT(self->proc);
|
||||
MG_ASSERT(self->callable);
|
||||
const char *name = nullptr;
|
||||
PyCypherType *py_type = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
|
||||
auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
|
||||
if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->proc, name, type))) {
|
||||
if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->callable, name, type))) {
|
||||
return nullptr;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
@ -532,6 +564,33 @@ static PyTypeObject PyQueryProcType = {
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
PyObject *PyMagicFuncAddArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddArg(self, args); }
|
||||
|
||||
PyObject *PyMagicFuncAddOptArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddOptArg(self, args); }
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static PyMethodDef PyMagicFuncMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"add_arg", reinterpret_cast<PyCFunction>(PyMagicFuncAddArg), METH_VARARGS,
|
||||
"Add a required argument to a function."},
|
||||
{"add_opt_arg", reinterpret_cast<PyCFunction>(PyMagicFuncAddOptArg), METH_VARARGS,
|
||||
"Add an optional argument with a default value to a function."},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static PyTypeObject PyMagicFuncType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
.tp_name = "_mgp.Func",
|
||||
.tp_basicsize = sizeof(PyMagicFunc),
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
.tp_doc = "Wraps struct mgp_func.",
|
||||
.tp_methods = PyMagicFuncMethods,
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
struct PyQueryModule {
|
||||
PyObject_HEAD
|
||||
@ -796,7 +855,6 @@ py::Object MgpListToPyTuple(mgp_list *list, PyObject *py_graph) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result, py::Object py_record) {
|
||||
py::Object py_mgp(PyImport_ImportModule("mgp"));
|
||||
if (!py_mgp) return py::FetchError();
|
||||
@ -870,6 +928,33 @@ std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(mgp_result *result
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::function<void()> PyObjectCleanup(py::Object &py_object) {
|
||||
return [py_object]() {
|
||||
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
|
||||
// sure the procedure cleaned up everything it held references to. If the
|
||||
// user stored a reference to one of our `_mgp` instances then the
|
||||
// internally used `mgp_*` structs will stay unfreed and a memory leak
|
||||
// will be reported at the end of the query execution.
|
||||
py::Object gc(PyImport_ImportModule("gc"));
|
||||
if (!gc) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
if (!gc.CallMethod("collect")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
// After making sure all references from our side have been cleared,
|
||||
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
|
||||
// of our `_mgp` instances then this will prevent them from using those
|
||||
// objects (whose internal `mgp_*` pointers are now invalid and would cause
|
||||
// a crash).
|
||||
if (!py_object.CallMethod("invalidate")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_result *result,
|
||||
mgp_memory *memory) {
|
||||
auto gil = py::EnsureGIL();
|
||||
@ -895,31 +980,6 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
|
||||
}
|
||||
};
|
||||
|
||||
auto cleanup = [](py::Object py_graph) {
|
||||
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
|
||||
// sure the procedure cleaned up everything it held references to. If the
|
||||
// user stored a reference to one of our `_mgp` instances then the
|
||||
// internally used `mgp_*` structs will stay unfreed and a memory leak
|
||||
// will be reported at the end of the query execution.
|
||||
py::Object gc(PyImport_ImportModule("gc"));
|
||||
if (!gc) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
if (!gc.CallMethod("collect")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
// After making sure all references from our side have been cleared,
|
||||
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
|
||||
// of our `_mgp` instances then this will prevent them from using those
|
||||
// objects (whose internal `mgp_*` pointers are now invalid and would cause
|
||||
// a crash).
|
||||
if (!py_graph.CallMethod("invalidate")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
};
|
||||
|
||||
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
|
||||
// any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
|
||||
// as not to introduce extra reference counts and prevent their deallocation.
|
||||
@ -932,14 +992,9 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
|
||||
std::optional<std::string> maybe_msg;
|
||||
{
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
|
||||
if (py_graph) {
|
||||
try {
|
||||
maybe_msg = error_to_msg(call(py_graph));
|
||||
cleanup(py_graph);
|
||||
} catch (...) {
|
||||
cleanup(py_graph);
|
||||
throw;
|
||||
}
|
||||
maybe_msg = error_to_msg(call(py_graph));
|
||||
} else {
|
||||
maybe_msg = error_to_msg(py::FetchError());
|
||||
}
|
||||
@ -972,32 +1027,58 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g
|
||||
return AddRecordFromPython(result, py_res);
|
||||
};
|
||||
|
||||
auto cleanup = [](py::Object py_graph, py::Object py_messages) {
|
||||
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
|
||||
// sure the procedure cleaned up everything it held references to. If the
|
||||
// user stored a reference to one of our `_mgp` instances then the
|
||||
// internally used `mgp_*` structs will stay unfreed and a memory leak
|
||||
// will be reported at the end of the query execution.
|
||||
py::Object gc(PyImport_ImportModule("gc"));
|
||||
if (!gc) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
|
||||
// any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
|
||||
// as not to introduce extra reference counts and prevent their deallocation.
|
||||
// In particular, the `ExceptionInfo` object has a `traceback` field that
|
||||
// contains references to the Python frames and their arguments, and therefore
|
||||
// our `_mgp` instances as well. Within this code we ensure not to keep the
|
||||
// `ExceptionInfo` object alive so that no extra reference counts are
|
||||
// introduced. We only fetch the error message and immediately destroy the
|
||||
// object.
|
||||
std::optional<std::string> maybe_msg;
|
||||
{
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
py::Object py_messages(MakePyMessages(msgs, memory));
|
||||
|
||||
if (!gc.CallMethod("collect")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph));
|
||||
utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages));
|
||||
|
||||
// After making sure all references from our side have been cleared,
|
||||
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
|
||||
// of our `_mgp` instances then this will prevent them from using those
|
||||
// objects (whose internal `mgp_*` pointers are now invalid and would cause
|
||||
// a crash).
|
||||
if (!py_graph.CallMethod("invalidate")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
if (py_graph && py_messages) {
|
||||
maybe_msg = error_to_msg(call(py_graph, py_messages));
|
||||
} else {
|
||||
maybe_msg = error_to_msg(py::FetchError());
|
||||
}
|
||||
if (!py_messages.CallMethod("invalidate")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
if (maybe_msg) {
|
||||
static_cast<void>(mgp_result_set_error_msg(result, maybe_msg->c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
void CallPythonFunction(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_func_result *result,
|
||||
mgp_memory *memory) {
|
||||
auto gil = py::EnsureGIL();
|
||||
|
||||
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
|
||||
if (!exc_info) return std::nullopt;
|
||||
// Here we tell the traceback formatter to skip the first line of the
|
||||
// traceback because that line will always be our wrapper function in our
|
||||
// internal `mgp.py` file. With that line skipped, the user will always
|
||||
// get only the relevant traceback that happened in his Python code.
|
||||
return py::FormatException(*exc_info, /* skip_first_line = */ true);
|
||||
};
|
||||
|
||||
auto call = [&](py::Object py_graph) -> utils::BasicResult<std::optional<py::ExceptionInfo>, mgp_value *> {
|
||||
py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr()));
|
||||
if (!py_args) return {py::FetchError()};
|
||||
auto py_res = py_cb.Call(py_graph, py_args);
|
||||
if (!py_res) return {py::FetchError()};
|
||||
mgp_value *ret_val = PyObjectToMgpValueWithPythonExceptions(py_res.Ptr(), memory);
|
||||
if (ret_val == nullptr) {
|
||||
return {py::FetchError()};
|
||||
}
|
||||
return ret_val;
|
||||
};
|
||||
|
||||
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
|
||||
@ -1012,22 +1093,22 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g
|
||||
std::optional<std::string> maybe_msg;
|
||||
{
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
py::Object py_messages(MakePyMessages(msgs, memory));
|
||||
if (py_graph && py_messages) {
|
||||
try {
|
||||
maybe_msg = error_to_msg(call(py_graph, py_messages));
|
||||
cleanup(py_graph, py_messages);
|
||||
} catch (...) {
|
||||
cleanup(py_graph, py_messages);
|
||||
throw;
|
||||
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
|
||||
if (py_graph) {
|
||||
auto maybe_result = call(py_graph);
|
||||
if (!maybe_result.HasError()) {
|
||||
static_cast<void>(mgp_func_result_set_value(result, maybe_result.GetValue(), memory));
|
||||
return;
|
||||
}
|
||||
maybe_msg = error_to_msg(maybe_result.GetError());
|
||||
} else {
|
||||
maybe_msg = error_to_msg(py::FetchError());
|
||||
}
|
||||
}
|
||||
|
||||
if (maybe_msg) {
|
||||
static_cast<void>(mgp_result_set_error_msg(result, maybe_msg->c_str()));
|
||||
static_cast<void>(
|
||||
mgp_func_result_set_error_msg(result, maybe_msg->c_str(), memory)); // No error fetching if this fails
|
||||
}
|
||||
}
|
||||
|
||||
@ -1056,9 +1137,9 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w
|
||||
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType);
|
||||
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
if (!py_proc) return nullptr;
|
||||
py_proc->proc = &proc_it->second;
|
||||
py_proc->callable = &proc_it->second;
|
||||
return reinterpret_cast<PyObject *>(py_proc);
|
||||
}
|
||||
} // namespace
|
||||
@ -1100,6 +1181,39 @@ PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *PyQueryModuleAddFunction(PyQueryModule *self, PyObject *cb) {
|
||||
MG_ASSERT(self->module);
|
||||
if (!PyCallable_Check(cb)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
|
||||
return nullptr;
|
||||
}
|
||||
auto py_cb = py::Object::FromBorrow(cb);
|
||||
py::Object py_name(py_cb.GetAttr("__name__"));
|
||||
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
|
||||
if (!name) return nullptr;
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Function name is not a valid identifier");
|
||||
return nullptr;
|
||||
}
|
||||
auto *memory = self->module->functions.get_allocator().GetMemoryResource();
|
||||
mgp_func func(
|
||||
name,
|
||||
[py_cb](mgp_list *args, mgp_func_context *func_ctx, mgp_func_result *result, mgp_memory *memory) {
|
||||
auto graph = mgp_graph::NonWritableGraph(*(func_ctx->impl), func_ctx->view);
|
||||
return CallPythonFunction(py_cb, args, &graph, result, memory);
|
||||
},
|
||||
memory);
|
||||
const auto [func_it, did_insert] = self->module->functions.emplace(name, std::move(func));
|
||||
if (!did_insert) {
|
||||
PyErr_SetString(PyExc_ValueError, "Already registered a function with the same name.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_func = PyObject_New(PyMagicFunc, &PyMagicFuncType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
if (!py_func) return nullptr;
|
||||
py_func->callable = &func_it->second;
|
||||
return reinterpret_cast<PyObject *>(py_func);
|
||||
}
|
||||
|
||||
static PyMethodDef PyQueryModuleMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
|
||||
@ -1108,6 +1222,8 @@ static PyMethodDef PyQueryModuleMethods[] = {
|
||||
"Register a writeable procedure with this module."},
|
||||
{"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O,
|
||||
"Register a transformation with this module."},
|
||||
{"add_function", reinterpret_cast<PyCFunction>(PyQueryModuleAddFunction), METH_O,
|
||||
"Register a function with this module."},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
@ -1980,6 +2096,7 @@ PyObject *PyInitMgpModule() {
|
||||
if (!register_type(&PyGraphType, "Graph")) return nullptr;
|
||||
if (!register_type(&PyEdgeType, "Edge")) return nullptr;
|
||||
if (!register_type(&PyQueryProcType, "Proc")) return nullptr;
|
||||
if (!register_type(&PyMagicFuncType, "Func")) return nullptr;
|
||||
if (!register_type(&PyQueryModuleType, "Module")) return nullptr;
|
||||
if (!register_type(&PyVertexType, "Vertex")) return nullptr;
|
||||
if (!register_type(&PyPathType, "Path")) return nullptr;
|
||||
|
@ -6,6 +6,14 @@ add_custom_target(memgraph__e2e__${TARGET_PREFIX}__${FILE_NAME} ALL
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME})
|
||||
endfunction()
|
||||
|
||||
function(copy_e2e_cpp_files TARGET_PREFIX FILE_NAME)
|
||||
add_custom_target(memgraph__e2e__${TARGET_PREFIX}__${FILE_NAME} ALL
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${FILE_NAME}
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME})
|
||||
endfunction()
|
||||
|
||||
add_subdirectory(replication)
|
||||
add_subdirectory(memory)
|
||||
add_subdirectory(triggers)
|
||||
@ -13,6 +21,7 @@ add_subdirectory(isolation_levels)
|
||||
add_subdirectory(streams)
|
||||
add_subdirectory(temporal_types)
|
||||
add_subdirectory(write_procedures)
|
||||
add_subdirectory(magic_functions)
|
||||
add_subdirectory(module_file_manager)
|
||||
add_subdirectory(websocket)
|
||||
|
||||
|
17
tests/e2e/magic_functions/CMakeLists.txt
Normal file
17
tests/e2e/magic_functions/CMakeLists.txt
Normal file
@ -0,0 +1,17 @@
|
||||
# Set up C++ functions for e2e tests
|
||||
function(add_query_module target_name src)
|
||||
add_library(${target_name} SHARED ${src})
|
||||
SET_TARGET_PROPERTIES(${target_name} PROPERTIES PREFIX "")
|
||||
target_include_directories(${target_name} PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
||||
endfunction()
|
||||
|
||||
# Set up Python functions for e2e tests
|
||||
function(copy_magic_functions_e2e_python_files FILE_NAME)
|
||||
copy_e2e_python_files(functions ${FILE_NAME})
|
||||
endfunction()
|
||||
|
||||
copy_magic_functions_e2e_python_files(common.py)
|
||||
copy_magic_functions_e2e_python_files(conftest.py)
|
||||
copy_magic_functions_e2e_python_files(function_example.py)
|
||||
|
||||
add_subdirectory(functions)
|
35
tests/e2e/magic_functions/common.py
Normal file
35
tests/e2e/magic_functions/common.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import mgclient
|
||||
import typing
|
||||
|
||||
|
||||
def execute_and_fetch_all(
|
||||
cursor: mgclient.Cursor, query: str, params: dict = {}
|
||||
) -> typing.List[tuple]:
|
||||
cursor.execute(query, params)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def connect(**kwargs) -> mgclient.Connection:
|
||||
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
|
||||
connection.autocommit = True
|
||||
return connection
|
||||
|
||||
|
||||
def has_n_result_row(cursor: mgclient.Cursor, query: str, n: int):
|
||||
results = execute_and_fetch_all(cursor, query)
|
||||
return len(results) == n
|
||||
|
||||
|
||||
def has_one_result_row(cursor: mgclient.Cursor, query: str):
|
||||
return has_n_result_row(cursor, query, 1)
|
22
tests/e2e/magic_functions/conftest.py
Normal file
22
tests/e2e/magic_functions/conftest.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import pytest
|
||||
|
||||
from common import execute_and_fetch_all, connect
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def connection():
|
||||
connection = connect()
|
||||
yield connection
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
122
tests/e2e/magic_functions/function_example.py
Normal file
122
tests/e2e/magic_functions/function_example.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import typing
|
||||
import mgclient
|
||||
import sys
|
||||
import pytest
|
||||
from common import execute_and_fetch_all, has_n_result_row
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_argument(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
f"MATCH (n) RETURN {function_type}_read.return_function_argument(n) AS argument;",
|
||||
)
|
||||
vertex = result[0][0]
|
||||
assert isinstance(vertex, mgclient.Node)
|
||||
assert vertex.labels == set(["Label"])
|
||||
assert vertex.properties == {"id": 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_optional_argument(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
f"RETURN {function_type}_read.return_optional_argument(42) AS argument;",
|
||||
)
|
||||
result = result[0][0]
|
||||
assert isinstance(result, int)
|
||||
assert result == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_optional_argument_no_arg(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
f"RETURN {function_type}_read.return_optional_argument() AS argument;",
|
||||
)
|
||||
result = result[0][0]
|
||||
assert isinstance(result, int)
|
||||
assert result == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_add_two_numbers(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
f"RETURN {function_type}_read.add_two_numbers(1, 5) AS total;",
|
||||
)
|
||||
result_sum = result[0][0]
|
||||
assert isinstance(result_sum, (float, int))
|
||||
assert result_sum == 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_return_null(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
result = execute_and_fetch_all(
|
||||
cursor,
|
||||
f"RETURN {function_type}_read.return_null() AS null;",
|
||||
)
|
||||
result_null = result[0][0]
|
||||
assert result_null is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_too_many_arguments(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
# Should raise too many arguments
|
||||
with pytest.raises(mgclient.DatabaseError):
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
f"RETURN {function_type}_read.return_null('parameter') AS null;",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_try_to_write(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
|
||||
# Should raise non mutable
|
||||
with pytest.raises(mgclient.DatabaseError):
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
f"MATCH (n) RETURN {function_type}_write.try_to_write(n, 'property', 1);",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("function_type", ["py", "c"])
|
||||
def test_case_sensitivity(connection, function_type):
|
||||
cursor = connection.cursor()
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
# Should raise function does not exist
|
||||
with pytest.raises(mgclient.DatabaseError):
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
f"RETURN {function_type}_read.ReTuRn_nUlL('parameter') AS null;",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-rA"]))
|
5
tests/e2e/magic_functions/functions/CMakeLists.txt
Normal file
5
tests/e2e/magic_functions/functions/CMakeLists.txt
Normal file
@ -0,0 +1,5 @@
|
||||
copy_magic_functions_e2e_python_files(py_write.py)
|
||||
copy_magic_functions_e2e_python_files(py_read.py)
|
||||
|
||||
add_query_module(c_read c_read.cpp)
|
||||
add_query_module(c_write c_write.cpp)
|
178
tests/e2e/magic_functions/functions/c_read.cpp
Normal file
178
tests/e2e/magic_functions/functions/c_read.cpp
Normal file
@ -0,0 +1,178 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <functional>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mg_procedure.h"
|
||||
|
||||
#include "utils/on_scope_exit.hpp"
|
||||
|
||||
namespace {
|
||||
static void ReturnFunctionArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
|
||||
struct mgp_memory *memory) {
|
||||
mgp_value *value{nullptr};
|
||||
auto err_code = mgp_list_at(args, 0, &value);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
|
||||
return;
|
||||
}
|
||||
|
||||
err_code = mgp_func_result_set_value(result, value, memory);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static void ReturnOptionalArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
|
||||
struct mgp_memory *memory) {
|
||||
mgp_value *value{nullptr};
|
||||
auto err_code = mgp_list_at(args, 0, &value);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
|
||||
return;
|
||||
}
|
||||
|
||||
err_code = mgp_func_result_set_value(result, value, memory);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
double GetElementFromArg(struct mgp_list *args, int index) {
|
||||
mgp_value *value{nullptr};
|
||||
if (mgp_list_at(args, index, &value) != MGP_ERROR_NO_ERROR) {
|
||||
throw std::runtime_error("Error while argument fetching.");
|
||||
}
|
||||
|
||||
double result;
|
||||
int is_int;
|
||||
mgp_value_is_int(value, &is_int);
|
||||
|
||||
if (is_int) {
|
||||
int64_t result_int;
|
||||
mgp_value_get_int(value, &result_int);
|
||||
result = static_cast<double>(result_int);
|
||||
} else {
|
||||
mgp_value_get_double(value, &result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void AddTwoNumbers(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
|
||||
struct mgp_memory *memory) {
|
||||
double first = 0;
|
||||
double second = 0;
|
||||
try {
|
||||
first = GetElementFromArg(args, 0);
|
||||
second = GetElementFromArg(args, 1);
|
||||
} catch (...) {
|
||||
mgp_func_result_set_error_msg(result, "Unable to fetch the result!", memory);
|
||||
return;
|
||||
}
|
||||
|
||||
mgp_value *value{nullptr};
|
||||
auto summation = first + second;
|
||||
mgp_value_make_double(summation, memory, &value);
|
||||
memgraph::utils::OnScopeExit delete_summation_value([&value] { mgp_value_destroy(value); });
|
||||
|
||||
auto err_code = mgp_func_result_set_value(result, value, memory);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
|
||||
}
|
||||
}
|
||||
|
||||
static void ReturnNull(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
|
||||
struct mgp_memory *memory) {
|
||||
mgp_value *value{nullptr};
|
||||
mgp_value_make_null(memory, &value);
|
||||
memgraph::utils::OnScopeExit delete_null([&value] { mgp_value_destroy(value); });
|
||||
|
||||
auto err_code = mgp_func_result_set_value(result, value, memory);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Each module needs to define mgp_init_module function.
|
||||
// Here you can register multiple functions/procedures your module supports.
|
||||
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
|
||||
{
|
||||
mgp_func *func{nullptr};
|
||||
auto err_code = mgp_module_add_function(module, "return_function_argument", ReturnFunctionArgument, &func);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mgp_type *type_any{nullptr};
|
||||
mgp_type_any(&type_any);
|
||||
err_code = mgp_func_add_arg(func, "argument", type_any);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
mgp_func *func{nullptr};
|
||||
auto err_code = mgp_module_add_function(module, "return_optional_argument", ReturnOptionalArgument, &func);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mgp_value *default_value{nullptr};
|
||||
mgp_value_make_int(42, memory, &default_value);
|
||||
memgraph::utils::OnScopeExit delete_summation_value([&default_value] { mgp_value_destroy(default_value); });
|
||||
|
||||
mgp_type *type_int{nullptr};
|
||||
mgp_type_int(&type_int);
|
||||
err_code = mgp_func_add_opt_arg(func, "opt_argument", type_int, default_value);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
mgp_func *func{nullptr};
|
||||
auto err_code = mgp_module_add_function(module, "add_two_numbers", AddTwoNumbers, &func);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mgp_type *type_number{nullptr};
|
||||
mgp_type_number(&type_number);
|
||||
err_code = mgp_func_add_arg(func, "first", type_number);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
err_code = mgp_func_add_arg(func, "second", type_number);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
mgp_func *func{nullptr};
|
||||
auto err_code = mgp_module_add_function(module, "return_null", ReturnNull, &func);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// This is an optional function if you need to release any resources before the
|
||||
// module is unloaded. You will probably need this if you acquired some
|
||||
// resources in mgp_init_module.
|
||||
extern "C" int mgp_shutdown_module() { return 0; }
|
80
tests/e2e/magic_functions/functions/c_write.cpp
Normal file
80
tests/e2e/magic_functions/functions/c_write.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include "mg_procedure.h"
|
||||
|
||||
static void TryToWrite(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
|
||||
struct mgp_memory *memory) {
|
||||
mgp_value *value{nullptr};
|
||||
mgp_vertex *vertex{nullptr};
|
||||
mgp_list_at(args, 0, &value);
|
||||
mgp_value_get_vertex(value, &vertex);
|
||||
|
||||
const char *name;
|
||||
mgp_list_at(args, 1, &value);
|
||||
mgp_value_get_string(value, &name);
|
||||
|
||||
mgp_list_at(args, 2, &value);
|
||||
|
||||
// Setting a property should set an error
|
||||
auto err_code = mgp_vertex_set_property(vertex, name, value);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
mgp_func_result_set_error_msg(result, "Cannot set property in the function!", memory);
|
||||
return;
|
||||
}
|
||||
|
||||
err_code = mgp_func_result_set_value(result, value, memory);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Each module needs to define mgp_init_module function.
|
||||
// Here you can register multiple functions/procedures your module supports.
|
||||
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
|
||||
{
|
||||
mgp_func *func{nullptr};
|
||||
auto err_code = mgp_module_add_function(module, "try_to_write", TryToWrite, &func);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mgp_type *type_vertex{nullptr};
|
||||
mgp_type_node(&type_vertex);
|
||||
err_code = mgp_func_add_arg(func, "argument", type_vertex);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mgp_type *type_string{nullptr};
|
||||
mgp_type_string(&type_string);
|
||||
err_code = mgp_func_add_arg(func, "name", type_string);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mgp_type *any_type{nullptr};
|
||||
mgp_type_any(&any_type);
|
||||
mgp_type *nullable_type{nullptr};
|
||||
mgp_type_nullable(any_type, &nullable_type);
|
||||
err_code = mgp_func_add_arg(func, "value", nullable_type);
|
||||
if (err_code != MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// This is an optional function if you need to release any resources before the
|
||||
// module is unloaded. You will probably need this if you acquired some
|
||||
// resources in mgp_init_module.
|
||||
extern "C" int mgp_shutdown_module() { return 0; }
|
32
tests/e2e/magic_functions/functions/py_read.py
Normal file
32
tests/e2e/magic_functions/functions/py_read.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import mgp
|
||||
|
||||
|
||||
@mgp.function
|
||||
def return_function_argument(ctx: mgp.FuncCtx, argument: mgp.Any):
|
||||
return argument
|
||||
|
||||
|
||||
@mgp.function
|
||||
def return_optional_argument(ctx: mgp.FuncCtx, opt_argument: mgp.Number = 42):
|
||||
return opt_argument
|
||||
|
||||
|
||||
@mgp.function
|
||||
def add_two_numbers(ctx: mgp.FuncCtx, first: mgp.Number, second: mgp.Number):
|
||||
return first + second
|
||||
|
||||
|
||||
@mgp.function
|
||||
def return_null(ctx: mgp.FuncCtx):
|
||||
return None
|
17
tests/e2e/magic_functions/functions/py_write.py
Normal file
17
tests/e2e/magic_functions/functions/py_write.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import mgp
|
||||
|
||||
|
||||
@mgp.function
|
||||
def try_to_write(ctx: mgp.FuncCtx, argument: mgp.Vertex, name: str, value: mgp.Nullable[mgp.Any]):
|
||||
argument.properties.set(name, value)
|
14
tests/e2e/magic_functions/workloads.yaml
Normal file
14
tests/e2e/magic_functions/workloads.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
template_cluster: &template_cluster
|
||||
cluster:
|
||||
main:
|
||||
args: ["--bolt-port", "7687", "--log-level=TRACE"]
|
||||
log_file: "magic-functions-e2e.log"
|
||||
setup_queries: []
|
||||
validation_queries: []
|
||||
|
||||
workloads:
|
||||
- name: "Magic functions runner"
|
||||
binary: "tests/e2e/pytest_runner.sh"
|
||||
proc: "tests/e2e/magic_functions/functions/"
|
||||
args: ["magic_functions/function_example.py"]
|
||||
<<: *template_cluster
|
@ -129,6 +129,12 @@ target_link_libraries(${test_prefix}query_serialization_property_value mg-query)
|
||||
add_unit_test(query_streams.cpp)
|
||||
target_link_libraries(${test_prefix}query_streams mg-query kafka-mock)
|
||||
|
||||
# Test query functions
|
||||
add_unit_test(query_function_mgp_module.cpp)
|
||||
target_link_libraries(${test_prefix}query_function_mgp_module mg-query)
|
||||
target_include_directories(${test_prefix}query_function_mgp_module PRIVATE ${CMAKE_SOURCE_DIR}/include)
|
||||
|
||||
|
||||
# Test query/procedure
|
||||
add_unit_test(query_procedure_mgp_type.cpp)
|
||||
target_link_libraries(${test_prefix}query_procedure_mgp_type mg-query)
|
||||
|
@ -211,13 +211,19 @@ class MockModule : public procedure::Module {
|
||||
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override { return &procedures; }
|
||||
|
||||
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override { return &transformations; }
|
||||
std::optional<std::filesystem::path> Path() const override { return std::nullopt; }
|
||||
|
||||
const std::map<std::string, mgp_func, std::less<>> *Functions() const override { return &functions; }
|
||||
|
||||
std::optional<std::filesystem::path> Path() const override { return std::nullopt; };
|
||||
|
||||
std::map<std::string, mgp_proc, std::less<>> procedures{};
|
||||
std::map<std::string, mgp_trans, std::less<>> transformations{};
|
||||
std::map<std::string, mgp_func, std::less<>> functions{};
|
||||
};
|
||||
|
||||
void DummyProcCallback(mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, mgp_memory * /*memory*/){};
|
||||
void DummyFuncCallback(mgp_list * /*args*/, mgp_func_context * /*func_ctx*/, mgp_func_result * /*result*/,
|
||||
mgp_memory * /*memory*/){};
|
||||
|
||||
enum class ProcedureType { WRITE, READ };
|
||||
|
||||
@ -258,6 +264,15 @@ class CypherMainVisitorTest : public ::testing::TestWithParam<std::shared_ptr<Ba
|
||||
module.procedures.emplace(name, std::move(proc));
|
||||
}
|
||||
|
||||
static void AddFunc(MockModule &module, const char *name, const std::vector<std::string_view> &args) {
|
||||
memgraph::utils::MemoryResource *memory = memgraph::utils::NewDeleteResource();
|
||||
mgp_func func(name, DummyFuncCallback, memory);
|
||||
for (const auto arg : args) {
|
||||
func.args.emplace_back(memgraph::utils::pmr::string{arg, memory}, &any_type);
|
||||
}
|
||||
module.functions.emplace(name, std::move(func));
|
||||
}
|
||||
|
||||
std::string CreateProcByType(const ProcedureType type, const std::vector<std::string_view> &args) {
|
||||
const auto proc_name = std::string{"proc_"} + ToString(type);
|
||||
SCOPED_TRACE(proc_name);
|
||||
@ -858,6 +873,12 @@ TEST_P(CypherMainVisitorTest, UndefinedFunction) {
|
||||
SemanticException);
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, MissingFunction) {
|
||||
AddFunc(*mock_module, "get", {});
|
||||
auto &ast_generator = *GetParam();
|
||||
ASSERT_THROW(ast_generator.ParseQuery("RETURN missing_function.get()"), SemanticException);
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, Function) {
|
||||
auto &ast_generator = *GetParam();
|
||||
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN abs(n, 2)"));
|
||||
@ -871,6 +892,20 @@ TEST_P(CypherMainVisitorTest, Function) {
|
||||
ASSERT_TRUE(function->function_);
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, MagicFunction) {
|
||||
AddFunc(*mock_module, "get", {});
|
||||
auto &ast_generator = *GetParam();
|
||||
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN mock_module.get()"));
|
||||
ASSERT_TRUE(query);
|
||||
ASSERT_TRUE(query->single_query_);
|
||||
auto *single_query = query->single_query_;
|
||||
auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]);
|
||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 1);
|
||||
auto *function = dynamic_cast<Function *>(return_clause->body_.named_expressions[0]->expression_);
|
||||
ASSERT_TRUE(function);
|
||||
ASSERT_TRUE(function->function_);
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, StringLiteralDoubleQuotes) {
|
||||
auto &ast_generator = *GetParam();
|
||||
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN \"mi'rko\""));
|
||||
|
49
tests/unit/query_function_mgp_module.cpp
Normal file
49
tests/unit/query_function_mgp_module.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <string_view>
|
||||
|
||||
#include "query/procedure/mg_procedure_impl.hpp"
|
||||
|
||||
#include "test_utils.hpp"
|
||||
|
||||
static void DummyCallback(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *){};
|
||||
|
||||
TEST(Module, InvalidFunctionRegistration) {
|
||||
mgp_module module(memgraph::utils::NewDeleteResource());
|
||||
mgp_func *func{nullptr};
|
||||
// Other test cases are covered within the procedure API. This is only sanity check
|
||||
EXPECT_EQ(mgp_module_add_function(&module, "dashes-not-supported", DummyCallback, &func), MGP_ERROR_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
TEST(Module, RegisterSameFunctionMultipleTimes) {
|
||||
mgp_module module(memgraph::utils::NewDeleteResource());
|
||||
mgp_func *func{nullptr};
|
||||
EXPECT_EQ(module.functions.find("same_name"), module.functions.end());
|
||||
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_NO_ERROR);
|
||||
EXPECT_NE(module.functions.find("same_name"), module.functions.end());
|
||||
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_LOGIC_ERROR);
|
||||
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_LOGIC_ERROR);
|
||||
EXPECT_NE(module.functions.find("same_name"), module.functions.end());
|
||||
}
|
||||
|
||||
TEST(Module, CaseSensitiveFunctionNames) {
|
||||
mgp_module module(memgraph::utils::NewDeleteResource());
|
||||
mgp_func *func{nullptr};
|
||||
EXPECT_EQ(mgp_module_add_function(&module, "not_same", DummyCallback, &func), MGP_ERROR_NO_ERROR);
|
||||
EXPECT_EQ(mgp_module_add_function(&module, "NoT_saME", DummyCallback, &func), MGP_ERROR_NO_ERROR);
|
||||
EXPECT_EQ(mgp_module_add_function(&module, "NOT_SAME", DummyCallback, &func), MGP_ERROR_NO_ERROR);
|
||||
EXPECT_EQ(module.functions.size(), 3U);
|
||||
}
|
Loading…
Reference in New Issue
Block a user