Memgraph magic functions (#345)

* Extend mgp_module with include adding functions

* Add return type to the function API

* Change Cypher grammar

* Add Python support for functions

* Implement error handling

* E2e tests for functions

* Write cpp e2e functions

* Create mg.functions() procedure

* Implement case insensitivity for user-defined Magic Functions.
This commit is contained in:
Josip Matak 2022-04-21 15:45:31 +02:00 committed by GitHub
parent ea2806bd57
commit 4abaf27765
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1523 additions and 246 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2021 Memgraph Ltd. // Copyright 2022 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -549,6 +549,8 @@ enum mgp_error mgp_path_equal(struct mgp_path *p1, struct mgp_path *p2, int *res
struct mgp_result; struct mgp_result;
/// Represents a record of resulting field values. /// Represents a record of resulting field values.
struct mgp_result_record; struct mgp_result_record;
/// Represents a return type for magic functions
struct mgp_func_result;
/// Set the error as the result of the procedure. /// Set the error as the result of the procedure.
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE ff there's no memory for copying the error message. /// Return MGP_ERROR_UNABLE_TO_ALLOCATE ff there's no memory for copying the error message.
@ -1290,6 +1292,9 @@ struct mgp_module;
/// Describes a procedure of a query module. /// Describes a procedure of a query module.
struct mgp_proc; struct mgp_proc;
/// Describes a Memgraph magic function.
struct mgp_func;
/// Entry-point for a query module read procedure, invoked through openCypher. /// Entry-point for a query module read procedure, invoked through openCypher.
/// ///
/// Passed in arguments will not live longer than the callback's execution. /// Passed in arguments will not live longer than the callback's execution.
@ -1502,6 +1507,84 @@ typedef void (*mgp_trans_cb)(struct mgp_messages *, struct mgp_graph *, struct m
enum mgp_error mgp_module_add_transformation(struct mgp_module *module, const char *name, mgp_trans_cb cb); enum mgp_error mgp_module_add_transformation(struct mgp_module *module, const char *name, mgp_trans_cb cb);
/// @} /// @}
/// @name Memgraph Magic Functions API
///
/// API for creating the Memgraph magic functions. It is used to create external-source stateless methods which can
/// be called by using openCypher query language. These methods should not modify the original graph and should use only
/// the values provided as arguments to the method.
///
///@{
/// Add a required argument to a function.
///
/// The order of the added arguments corresponds to the signature of the openCypher function.
/// Note, that required arguments are followed by optional arguments.
///
/// The `name` must be a valid identifier, following the same rules as the
/// function `name` in mgp_module_add_function.
///
/// Passed in `type` describes what kind of values can be used as the argument.
///
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for an argument.
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid argument name.
/// Return MGP_ERROR_LOGIC_ERROR if the function already has any optional argument.
enum mgp_error mgp_func_add_arg(struct mgp_func *func, const char *name, struct mgp_type *type);
/// Add an optional argument with a default value to a function.
///
/// The order of the added arguments corresponds to the signature of the openCypher function.
/// Note, that required arguments are followed by optional arguments.
///
/// The `name` must be a valid identifier, following the same rules as the
/// function `name` in mgp_module_add_function.
///
/// Passed in `type` describes what kind of values can be used as the argument.
///
/// `default_value` is copied and set as the default value for the argument.
/// Don't forget to call mgp_value_destroy when you are done using
/// `default_value`. When the function is called, if this argument is not
/// provided, `default_value` will be used instead. `default_value` must not be
/// a graph element (node, relationship, path) and it must satisfy the given
/// `type`.
///
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for an argument.
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid argument name.
/// Return MGP_ERROR_VALUE_CONVERSION if `default_value` is a graph element (vertex, edge or path).
/// Return MGP_ERROR_LOGIC_ERROR if `default_value` does not satisfy `type`.
enum mgp_error mgp_func_add_opt_arg(struct mgp_func *func, const char *name, struct mgp_type *type,
struct mgp_value *default_value);
/// Entry-point for a custom Memgraph awesome function.
///
/// Passed in arguments will not live longer than the callback's execution.
/// Therefore, you must not store them globally or use the passed in mgp_memory
/// to allocate global resources.
typedef void (*mgp_func_cb)(struct mgp_list *, struct mgp_func_context *, struct mgp_func_result *,
struct mgp_memory *);
/// Register a Memgraph magic function
///
/// The `name` must be a sequence of digits, underscores, lowercase and
/// uppercase Latin letters. The name must begin with a non-digit character.
/// Note that Unicode characters are not allowed.
///
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for mgp_func.
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid function name.
/// RETURN MGP_ERROR_LOGIC_ERROR if a function with the same name was already registered.
enum mgp_error mgp_module_add_function(struct mgp_module *module, const char *name, mgp_func_cb cb,
struct mgp_func **result);
/// Set an error message as an output to the Magic function
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if there's no memory for copying the error message.
enum mgp_error mgp_func_result_set_error_msg(struct mgp_func_result *result, const char *error_msg,
struct mgp_memory *memory);
/// Set an output value for the Magic function
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory to copy the mgp_value to mgp_func_result.
enum mgp_error mgp_func_result_set_value(struct mgp_func_result *result, struct mgp_value *value,
struct mgp_memory *memory);
/// @}
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif

View File

@ -40,6 +40,7 @@ class InvalidContextError(Exception):
""" """
Signals using a graph element instance outside of the registered procedure. Signals using a graph element instance outside of the registered procedure.
""" """
pass pass
@ -47,6 +48,7 @@ class UnknownError(_mgp.UnknownError):
""" """
Signals unspecified failure. Signals unspecified failure.
""" """
pass pass
@ -54,6 +56,7 @@ class UnableToAllocateError(_mgp.UnableToAllocateError):
""" """
Signals failed memory allocation. Signals failed memory allocation.
""" """
pass pass
@ -61,6 +64,7 @@ class InsufficientBufferError(_mgp.InsufficientBufferError):
""" """
Signals that some buffer is not big enough. Signals that some buffer is not big enough.
""" """
pass pass
@ -69,6 +73,7 @@ class OutOfRangeError(_mgp.OutOfRangeError):
Signals that an index-like parameter has a value that is outside its Signals that an index-like parameter has a value that is outside its
possible values. possible values.
""" """
pass pass
@ -77,6 +82,7 @@ class LogicErrorError(_mgp.LogicErrorError):
Signals faulty logic within the program such as violating logical Signals faulty logic within the program such as violating logical
preconditions or class invariants and may be preventable. preconditions or class invariants and may be preventable.
""" """
pass pass
@ -84,6 +90,7 @@ class DeletedObjectError(_mgp.DeletedObjectError):
""" """
Signals accessing an already deleted object. Signals accessing an already deleted object.
""" """
pass pass
@ -91,6 +98,7 @@ class InvalidArgumentError(_mgp.InvalidArgumentError):
""" """
Signals that some of the arguments have invalid values. Signals that some of the arguments have invalid values.
""" """
pass pass
@ -98,6 +106,7 @@ class KeyAlreadyExistsError(_mgp.KeyAlreadyExistsError):
""" """
Signals that a key already exists in a container-like object. Signals that a key already exists in a container-like object.
""" """
pass pass
@ -105,6 +114,7 @@ class ImmutableObjectError(_mgp.ImmutableObjectError):
""" """
Signals modification of an immutable object. Signals modification of an immutable object.
""" """
pass pass
@ -112,6 +122,7 @@ class ValueConversionError(_mgp.ValueConversionError):
""" """
Signals that the conversion failed between python and cypher values. Signals that the conversion failed between python and cypher values.
""" """
pass pass
@ -120,12 +131,14 @@ class SerializationError(_mgp.SerializationError):
Signals serialization error caused by concurrent modifications from Signals serialization error caused by concurrent modifications from
different transactions. different transactions.
""" """
pass pass
class Label: class Label:
"""Label of a Vertex.""" """Label of a Vertex."""
__slots__ = ('_name',)
__slots__ = ("_name",)
def __init__(self, name: str): def __init__(self, name: str):
self._name = name self._name = name
@ -145,19 +158,22 @@ class Label:
# Named property value of a Vertex or an Edge. # Named property value of a Vertex or an Edge.
# It would be better to use typing.NamedTuple with typed fields, but that is # It would be better to use typing.NamedTuple with typed fields, but that is
# not available in Python 3.5. # not available in Python 3.5.
Property = namedtuple('Property', ('name', 'value')) Property = namedtuple("Property", ("name", "value"))
class Properties: class Properties:
""" """
A collection of properties either on a Vertex or an Edge. A collection of properties either on a Vertex or an Edge.
""" """
__slots__ = ('_vertex_or_edge', '_len',)
__slots__ = (
"_vertex_or_edge",
"_len",
)
def __init__(self, vertex_or_edge): def __init__(self, vertex_or_edge):
if not isinstance(vertex_or_edge, (_mgp.Vertex, _mgp.Edge)): if not isinstance(vertex_or_edge, (_mgp.Vertex, _mgp.Edge)):
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Edge', \ raise TypeError("Expected '_mgp.Vertex' or '_mgp.Edge', got {}".format(type(vertex_or_edge)))
got {}".format(type(vertex_or_edge)))
self._len = None self._len = None
self._vertex_or_edge = vertex_or_edge self._vertex_or_edge = vertex_or_edge
@ -330,7 +346,8 @@ class Properties:
class EdgeType: class EdgeType:
"""Type of an Edge.""" """Type of an Edge."""
__slots__ = ('_name',)
__slots__ = ("_name",)
def __init__(self, name): def __init__(self, name):
self._name = name self._name = name
@ -348,7 +365,7 @@ class EdgeType:
if sys.version_info >= (3, 5, 2): if sys.version_info >= (3, 5, 2):
EdgeId = typing.NewType('EdgeId', int) EdgeId = typing.NewType("EdgeId", int)
else: else:
EdgeId = int EdgeId = int
@ -360,12 +377,12 @@ class Edge:
a query. You should not globally store an instance of an Edge. Using an a query. You should not globally store an instance of an Edge. Using an
invalid Edge instance will raise InvalidContextError. invalid Edge instance will raise InvalidContextError.
""" """
__slots__ = ('_edge',)
__slots__ = ("_edge",)
def __init__(self, edge): def __init__(self, edge):
if not isinstance(edge, _mgp.Edge): if not isinstance(edge, _mgp.Edge):
raise TypeError( raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
self._edge = edge self._edge = edge
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -408,7 +425,7 @@ class Edge:
return EdgeType(self._edge.get_type_name()) return EdgeType(self._edge.get_type_name())
@property @property
def from_vertex(self) -> 'Vertex': def from_vertex(self) -> "Vertex":
""" """
Get the source vertex. Get the source vertex.
@ -419,7 +436,7 @@ class Edge:
return Vertex(self._edge.from_vertex()) return Vertex(self._edge.from_vertex())
@property @property
def to_vertex(self) -> 'Vertex': def to_vertex(self) -> "Vertex":
""" """
Get the destination vertex. Get the destination vertex.
@ -453,7 +470,7 @@ class Edge:
if sys.version_info >= (3, 5, 2): if sys.version_info >= (3, 5, 2):
VertexId = typing.NewType('VertexId', int) VertexId = typing.NewType("VertexId", int)
else: else:
VertexId = int VertexId = int
@ -465,12 +482,12 @@ class Vertex:
in a query. You should not globally store an instance of a Vertex. Using an in a query. You should not globally store an instance of a Vertex. Using an
invalid Vertex instance will raise InvalidContextError. invalid Vertex instance will raise InvalidContextError.
""" """
__slots__ = ('_vertex',)
__slots__ = ("_vertex",)
def __init__(self, vertex): def __init__(self, vertex):
if not isinstance(vertex, _mgp.Vertex): if not isinstance(vertex, _mgp.Vertex):
raise TypeError( raise TypeError("Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
"Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
self._vertex = vertex self._vertex = vertex
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -513,8 +530,7 @@ class Vertex:
""" """
if not self.is_valid(): if not self.is_valid():
raise InvalidContextError() raise InvalidContextError()
return tuple(Label(self._vertex.label_at(i)) return tuple(Label(self._vertex.label_at(i)) for i in range(self._vertex.labels_count()))
for i in range(self._vertex.labels_count()))
def add_label(self, label: str) -> None: def add_label(self, label: str) -> None:
""" """
@ -615,7 +631,8 @@ class Vertex:
class Path: class Path:
"""Path containing Vertex and Edge instances.""" """Path containing Vertex and Edge instances."""
__slots__ = ('_path', '_vertices', '_edges')
__slots__ = ("_path", "_vertices", "_edges")
def __init__(self, starting_vertex_or_path: typing.Union[_mgp.Path, Vertex]): def __init__(self, starting_vertex_or_path: typing.Union[_mgp.Path, Vertex]):
"""Initialize with a starting Vertex. """Initialize with a starting Vertex.
@ -636,8 +653,7 @@ class Path:
raise InvalidContextError() raise InvalidContextError()
self._path = _mgp.Path.make_with_start(vertex) self._path = _mgp.Path.make_with_start(vertex)
else: else:
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Path', got '{}'" raise TypeError("Expected '_mgp.Vertex' or '_mgp.Path', got '{}'".format(type(starting_vertex_or_path)))
.format(type(starting_vertex_or_path)))
def __copy__(self): def __copy__(self):
if not self.is_valid(): if not self.is_valid():
@ -678,8 +694,7 @@ class Path:
extension. extension.
""" """
if not isinstance(edge, Edge): if not isinstance(edge, Edge):
raise TypeError( raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
if not self.is_valid() or not edge.is_valid(): if not self.is_valid() or not edge.is_valid():
raise InvalidContextError() raise InvalidContextError()
self._path.expand(edge._edge) self._path.expand(edge._edge)
@ -698,8 +713,7 @@ class Path:
raise InvalidContextError() raise InvalidContextError()
if self._vertices is None: if self._vertices is None:
num_vertices = self._path.size() + 1 num_vertices = self._path.size() + 1
self._vertices = tuple(Vertex(self._path.vertex_at(i)) self._vertices = tuple(Vertex(self._path.vertex_at(i)) for i in range(num_vertices))
for i in range(num_vertices))
return self._vertices return self._vertices
@property @property
@ -713,14 +727,14 @@ class Path:
raise InvalidContextError() raise InvalidContextError()
if self._edges is None: if self._edges is None:
num_edges = self._path.size() num_edges = self._path.size()
self._edges = tuple(Edge(self._path.edge_at(i)) self._edges = tuple(Edge(self._path.edge_at(i)) for i in range(num_edges))
for i in range(num_edges))
return self._edges return self._edges
class Record: class Record:
"""Represents a record of resulting field values.""" """Represents a record of resulting field values."""
__slots__ = ('fields',)
__slots__ = ("fields",)
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Initialize with name=value fields in kwargs.""" """Initialize with name=value fields in kwargs."""
@ -729,12 +743,12 @@ class Record:
class Vertices: class Vertices:
"""Iterable over vertices in a graph.""" """Iterable over vertices in a graph."""
__slots__ = ('_graph', '_len')
__slots__ = ("_graph", "_len")
def __init__(self, graph): def __init__(self, graph):
if not isinstance(graph, _mgp.Graph): if not isinstance(graph, _mgp.Graph):
raise TypeError( raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = graph self._graph = graph
self._len = None self._len = None
@ -791,12 +805,12 @@ class Vertices:
class Graph: class Graph:
"""State of the graph database in current ProcCtx.""" """State of the graph database in current ProcCtx."""
__slots__ = ('_graph',)
__slots__ = ("_graph",)
def __init__(self, graph): def __init__(self, graph):
if not isinstance(graph, _mgp.Graph): if not isinstance(graph, _mgp.Graph):
raise TypeError( raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = graph self._graph = graph
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -885,8 +899,7 @@ class Graph:
raise InvalidContextError() raise InvalidContextError()
self._graph.detach_delete_vertex(vertex._vertex) self._graph.detach_delete_vertex(vertex._vertex)
def create_edge(self, from_vertex: Vertex, to_vertex: Vertex, def create_edge(self, from_vertex: Vertex, to_vertex: Vertex, edge_type: EdgeType) -> None:
edge_type: EdgeType) -> None:
""" """
Create an edge. Create an edge.
@ -899,8 +912,7 @@ class Graph:
""" """
if not self.is_valid(): if not self.is_valid():
raise InvalidContextError() raise InvalidContextError()
return Edge(self._graph.create_edge(from_vertex._vertex, return Edge(self._graph.create_edge(from_vertex._vertex, to_vertex._vertex, edge_type.name))
to_vertex._vertex, edge_type.name))
def delete_edge(self, edge: Edge) -> None: def delete_edge(self, edge: Edge) -> None:
""" """
@ -918,6 +930,7 @@ class Graph:
class AbortError(Exception): class AbortError(Exception):
"""Signals that the procedure was asked to abort its execution.""" """Signals that the procedure was asked to abort its execution."""
pass pass
@ -927,12 +940,12 @@ class ProcCtx:
Access to a ProcCtx is only valid during a single execution of a procedure Access to a ProcCtx is only valid during a single execution of a procedure
in a query. You should not globally store a ProcCtx instance. in a query. You should not globally store a ProcCtx instance.
""" """
__slots__ = ('_graph',)
__slots__ = ("_graph",)
def __init__(self, graph): def __init__(self, graph):
if not isinstance(graph, _mgp.Graph): if not isinstance(graph, _mgp.Graph):
raise TypeError( raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph) self._graph = Graph(graph)
def is_valid(self) -> bool: def is_valid(self) -> bool:
@ -969,8 +982,7 @@ LocalDateTime = datetime.datetime
Duration = datetime.timedelta Duration = datetime.timedelta
Any = typing.Union[bool, str, Number, Map, Path, Any = typing.Union[bool, str, Number, Map, Path, list, Date, LocalTime, LocalDateTime, Duration]
list, Date, LocalTime, LocalDateTime, Duration]
List = typing.List List = typing.List
@ -1003,7 +1015,7 @@ def _typing_to_cypher_type(type_):
Date: _mgp.type_date(), Date: _mgp.type_date(),
LocalTime: _mgp.type_local_time(), LocalTime: _mgp.type_local_time(),
LocalDateTime: _mgp.type_local_date_time(), LocalDateTime: _mgp.type_local_date_time(),
Duration: _mgp.type_duration() Duration: _mgp.type_duration(),
} }
try: try:
return simple_types[type_] return simple_types[type_]
@ -1021,14 +1033,14 @@ def _typing_to_cypher_type(type_):
if type(None) in type_args: if type(None) in type_args:
types = tuple(t for t in type_args if t is not type(None)) # noqa E721 types = tuple(t for t in type_args if t is not type(None)) # noqa E721
if len(types) == 1: if len(types) == 1:
type_arg, = types (type_arg,) = types
else: else:
# We cannot do typing.Union[*types], so do the equivalent # We cannot do typing.Union[*types], so do the equivalent
# with __getitem__ which does not even need arg unpacking. # with __getitem__ which does not even need arg unpacking.
type_arg = typing.Union.__getitem__(types) type_arg = typing.Union.__getitem__(types)
return _mgp.type_nullable(_typing_to_cypher_type(type_arg)) return _mgp.type_nullable(_typing_to_cypher_type(type_arg))
elif complex_type == list: elif complex_type == list:
type_arg, = type_args (type_arg,) = type_args
return _mgp.type_list(_typing_to_cypher_type(type_arg)) return _mgp.type_list(_typing_to_cypher_type(type_arg))
raise UnsupportedTypingError(type_) raise UnsupportedTypingError(type_)
else: else:
@ -1038,13 +1050,17 @@ def _typing_to_cypher_type(type_):
# printed the same way. `typing.List[type]` is printed as such, while # printed the same way. `typing.List[type]` is printed as such, while
# `typing.Optional[type]` is printed as 'typing.Union[type, NoneType]' # `typing.Optional[type]` is printed as 'typing.Union[type, NoneType]'
def parse_type_args(type_as_str): def parse_type_args(type_as_str):
return tuple(map(str.strip, return tuple(
type_as_str[type_as_str.index('[') + 1: -1].split(','))) map(
str.strip,
type_as_str[type_as_str.index("[") + 1 : -1].split(","),
)
)
def fully_qualified_name(cls): def fully_qualified_name(cls):
if cls.__module__ is None or cls.__module__ == 'builtins': if cls.__module__ is None or cls.__module__ == "builtins":
return cls.__name__ return cls.__name__
return cls.__module__ + '.' + cls.__name__ return cls.__module__ + "." + cls.__name__
def get_simple_type(type_as_str): def get_simple_type(type_as_str):
for simple_type, cypher_type in simple_types.items(): for simple_type, cypher_type in simple_types.items():
@ -1060,28 +1076,26 @@ def _typing_to_cypher_type(type_):
pass pass
def parse_typing(type_as_str): def parse_typing(type_as_str):
if type_as_str.startswith('typing.Union'): if type_as_str.startswith("typing.Union"):
type_args_as_str = parse_type_args(type_as_str) type_args_as_str = parse_type_args(type_as_str)
none_type_as_str = type(None).__name__ none_type_as_str = type(None).__name__
if none_type_as_str in type_args_as_str: if none_type_as_str in type_args_as_str:
types = tuple( types = tuple(t for t in type_args_as_str if t != none_type_as_str)
t for t in type_args_as_str if t != none_type_as_str)
if len(types) == 1: if len(types) == 1:
type_arg_as_str, = types (type_arg_as_str,) = types
else: else:
type_arg_as_str = 'typing.Union[' + \ type_arg_as_str = "typing.Union[" + ", ".join(types) + "]"
', '.join(types) + ']'
simple_type = get_simple_type(type_arg_as_str) simple_type = get_simple_type(type_arg_as_str)
if simple_type is not None: if simple_type is not None:
return _mgp.type_nullable(simple_type) return _mgp.type_nullable(simple_type)
return _mgp.type_nullable(parse_typing(type_arg_as_str)) return _mgp.type_nullable(parse_typing(type_arg_as_str))
elif type_as_str.startswith('typing.List'): elif type_as_str.startswith("typing.List"):
type_arg_as_str = parse_type_args(type_as_str) type_arg_as_str = parse_type_args(type_as_str)
if len(type_arg_as_str) > 1: if len(type_arg_as_str) > 1:
# Nested object could be a type consisting of a list of types (e.g. mgp.Map) # Nested object could be a type consisting of a list of types (e.g. mgp.Map)
# so we need to join the parts. # so we need to join the parts.
type_arg_as_str = ', '.join(type_arg_as_str) type_arg_as_str = ", ".join(type_arg_as_str)
else: else:
type_arg_as_str = type_arg_as_str[0] type_arg_as_str = type_arg_as_str[0]
@ -1096,9 +1110,11 @@ def _typing_to_cypher_type(type_):
# Procedure registration # Procedure registration
class Deprecated: class Deprecated:
"""Annotate a resulting Record's field as deprecated.""" """Annotate a resulting Record's field as deprecated."""
__slots__ = ('field_type',)
__slots__ = ("field_type",)
def __init__(self, type_): def __init__(self, type_):
self.field_type = type_ self.field_type = type_
@ -1106,8 +1122,7 @@ class Deprecated:
def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]): def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
if not callable(func): if not callable(func):
raise TypeError("Expected a callable object, got an instance of '{}'" raise TypeError("Expected a callable object, got an instance of '{}'".format(type(func)))
.format(type(func)))
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
raise TypeError("Callable must not be 'async def' function") raise TypeError("Callable must not be 'async def' function")
if sys.version_info >= (3, 6): if sys.version_info >= (3, 6):
@ -1117,24 +1132,25 @@ def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
raise NotImplementedError("Generator functions are not supported") raise NotImplementedError("Generator functions are not supported")
def _register_proc(func: typing.Callable[..., Record], def _register_proc(func: typing.Callable[..., Record], is_write: bool):
is_write: bool):
raise_if_does_not_meet_requirements(func) raise_if_does_not_meet_requirements(func)
register_func = ( register_func = _mgp.Module.add_write_procedure if is_write else _mgp.Module.add_read_procedure
_mgp.Module.add_write_procedure if is_write
else _mgp.Module.add_read_procedure)
sig = inspect.signature(func) sig = inspect.signature(func)
params = tuple(sig.parameters.values()) params = tuple(sig.parameters.values())
if params and params[0].annotation is ProcCtx: if params and params[0].annotation is ProcCtx:
@wraps(func) @wraps(func)
def wrapper(graph, args): def wrapper(graph, args):
return func(ProcCtx(graph), *args) return func(ProcCtx(graph), *args)
params = params[1:] params = params[1:]
mgp_proc = register_func(_mgp._MODULE, wrapper) mgp_proc = register_func(_mgp._MODULE, wrapper)
else: else:
@wraps(func) @wraps(func)
def wrapper(graph, args): def wrapper(graph, args):
return func(*args) return func(*args)
mgp_proc = register_func(_mgp._MODULE, wrapper) mgp_proc = register_func(_mgp._MODULE, wrapper)
for param in params: for param in params:
name = param.name name = param.name
@ -1149,8 +1165,7 @@ def _register_proc(func: typing.Callable[..., Record],
if sig.return_annotation is not sig.empty: if sig.return_annotation is not sig.empty:
record = sig.return_annotation record = sig.return_annotation
if not isinstance(record, Record): if not isinstance(record, Record):
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'" raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'".format(func.__name__, type(record)))
.format(func.__name__, type(record)))
for name, type_ in record.fields.items(): for name, type_ in record.fields.items():
if isinstance(type_, Deprecated): if isinstance(type_, Deprecated):
cypher_type = _typing_to_cypher_type(type_.field_type) cypher_type = _typing_to_cypher_type(type_.field_type)
@ -1257,20 +1272,22 @@ class InvalidMessageError(Exception):
""" """
Signals using a message instance outside of the registered transformation. Signals using a message instance outside of the registered transformation.
""" """
pass pass
SOURCE_TYPE_KAFKA = _mgp.SOURCE_TYPE_KAFKA SOURCE_TYPE_KAFKA = _mgp.SOURCE_TYPE_KAFKA
SOURCE_TYPE_PULSAR = _mgp.SOURCE_TYPE_PULSAR SOURCE_TYPE_PULSAR = _mgp.SOURCE_TYPE_PULSAR
class Message: class Message:
"""Represents a message from a stream.""" """Represents a message from a stream."""
__slots__ = ('_message',)
__slots__ = ("_message",)
def __init__(self, message): def __init__(self, message):
if not isinstance(message, _mgp.Message): if not isinstance(message, _mgp.Message):
raise TypeError( raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message)))
"Expected '_mgp.Message', got '{}'".format(type(message)))
self._message = message self._message = message
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -1353,17 +1370,18 @@ class Message:
class InvalidMessagesError(Exception): class InvalidMessagesError(Exception):
"""Signals using a messages instance outside of the registered transformation.""" """Signals using a messages instance outside of the registered transformation."""
pass pass
class Messages: class Messages:
"""Represents a list of messages from a stream.""" """Represents a list of messages from a stream."""
__slots__ = ('_messages',)
__slots__ = ("_messages",)
def __init__(self, messages): def __init__(self, messages):
if not isinstance(messages, _mgp.Messages): if not isinstance(messages, _mgp.Messages):
raise TypeError( raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages)))
"Expected '_mgp.Messages', got '{}'".format(type(messages)))
self._messages = messages self._messages = messages
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -1395,12 +1413,12 @@ class TransCtx:
Access to a TransCtx is only valid during a single execution of a transformation. Access to a TransCtx is only valid during a single execution of a transformation.
You should not globally store a TransCtx instance. You should not globally store a TransCtx instance.
""" """
__slots__ = ('_graph')
__slots__ = "_graph"
def __init__(self, graph): def __init__(self, graph):
if not isinstance(graph, _mgp.Graph): if not isinstance(graph, _mgp.Graph):
raise TypeError( raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph) self._graph = Graph(graph)
def is_valid(self) -> bool: def is_valid(self) -> bool:
@ -1420,21 +1438,76 @@ def transformation(func: typing.Callable[..., Record]):
params = tuple(sig.parameters.values()) params = tuple(sig.parameters.values())
if not params or not params[0].annotation is Messages: if not params or not params[0].annotation is Messages:
if not len(params) == 2 or not params[1].annotation is Messages: if not len(params) == 2 or not params[1].annotation is Messages:
raise NotImplementedError( raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
"Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
if params[0].annotation is TransCtx: if params[0].annotation is TransCtx:
@wraps(func) @wraps(func)
def wrapper(graph, messages): def wrapper(graph, messages):
return func(TransCtx(graph), messages) return func(TransCtx(graph), messages)
_mgp._MODULE.add_transformation(wrapper) _mgp._MODULE.add_transformation(wrapper)
else: else:
@wraps(func) @wraps(func)
def wrapper(graph, messages): def wrapper(graph, messages):
return func(messages) return func(messages)
_mgp._MODULE.add_transformation(wrapper) _mgp._MODULE.add_transformation(wrapper)
return func return func
class FuncCtx:
"""Context of a function being executed.
Access to a FuncCtx is only valid during a single execution of a transformation.
You should not globally store a FuncCtx instance.
"""
__slots__ = "_graph"
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph)
def is_valid(self) -> bool:
return self._graph.is_valid()
def function(func: typing.Callable):
raise_if_does_not_meet_requirements(func)
register_func = _mgp.Module.add_function
sig = inspect.signature(func)
params = tuple(sig.parameters.values())
if params and params[0].annotation is FuncCtx:
@wraps(func)
def wrapper(graph, args):
return func(FuncCtx(graph), *args)
params = params[1:]
mgp_func = register_func(_mgp._MODULE, wrapper)
else:
@wraps(func)
def wrapper(graph, args):
return func(*args)
mgp_func = register_func(_mgp._MODULE, wrapper)
for param in params:
name = param.name
type_ = param.annotation
if type_ is param.empty:
type_ = object
cypher_type = _typing_to_cypher_type(type_)
if param.default is param.empty:
mgp_func.add_arg(name, cypher_type)
else:
mgp_func.add_opt_arg(name, cypher_type, param.default)
return func
def _wrap_exceptions(): def _wrap_exceptions():
def wrap_function(func): def wrap_function(func):
@wraps(func) @wraps(func)
@ -1463,6 +1536,7 @@ def _wrap_exceptions():
raise ValueConversionError(e) raise ValueConversionError(e)
except _mgp.SerializationError as e: except _mgp.SerializationError as e:
raise SerializationError(e) raise SerializationError(e)
return wrapped_func return wrapped_func
def wrap_prop_func(func): def wrap_prop_func(func):
@ -1473,11 +1547,16 @@ def _wrap_exceptions():
if inspect.isfunction(obj): if inspect.isfunction(obj):
setattr(cls, name, wrap_function(obj)) setattr(cls, name, wrap_function(obj))
elif isinstance(obj, property): elif isinstance(obj, property):
setattr(cls, name, property( setattr(
wrap_prop_func(obj.fget), cls,
wrap_prop_func(obj.fset), name,
wrap_prop_func(obj.fdel), property(
obj.__doc__)) wrap_prop_func(obj.fget),
wrap_prop_func(obj.fset),
wrap_prop_func(obj.fdel),
obj.__doc__,
),
)
def defined_in_this_module(obj: object): def defined_in_this_module(obj: object):
return getattr(obj, "__module__", "") == __name__ return getattr(obj, "__module__", "") == __name__

View File

@ -852,7 +852,9 @@ cpp<#
: arguments_(arguments), : arguments_(arguments),
function_name_(function_name), function_name_(function_name),
function_(NameToFunction(function_name_)) { function_(NameToFunction(function_name_)) {
DMG_ASSERT(function_, "Unexpected missing function: {}", function_name_); if (!function_) {
throw SemanticException("Function '{}' doesn't exist.", function_name);
}
} }
cpp<#) cpp<#)
(:private (:private

View File

@ -2109,13 +2109,30 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::Functio
storage_->Create<Aggregation>(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP)); storage_->Create<Aggregation>(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP));
} }
auto function = NameToFunction(function_name); auto is_user_defined_function = [](const std::string &function_name) {
if (!function) throw SemanticException("Function '{}' doesn't exist.", function_name); // Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined
// functions. Builtin functions should be case insensitive.
return function_name.find('.') != std::string::npos;
};
// Don't cache queries which call user-defined functions. User-defined function's return
// types can vary depending on whether the module is reloaded, therefore the cache would
// be invalid.
if (is_user_defined_function(function_name)) {
query_info_.is_cacheable = false;
}
return static_cast<Expression *>(storage_->Create<Function>(function_name, expressions)); return static_cast<Expression *>(storage_->Create<Function>(function_name, expressions));
} }
antlrcpp::Any CypherMainVisitor::visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) { antlrcpp::Any CypherMainVisitor::visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) {
return utils::ToUpperCase(ctx->getText()); auto function_name = ctx->getText();
// Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined
// functions. Builtin functions should be case insensitive.
if (function_name.find('.') != std::string::npos) {
return function_name;
}
return utils::ToUpperCase(function_name);
} }
antlrcpp::Any CypherMainVisitor::visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) { antlrcpp::Any CypherMainVisitor::visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) {

View File

@ -279,7 +279,7 @@ idInColl : variable IN expression ;
functionInvocation : functionName '(' ( DISTINCT )? ( expression ( ',' expression )* )? ')' ; functionInvocation : functionName '(' ( DISTINCT )? ( expression ( ',' expression )* )? ')' ;
functionName : symbolicName ; functionName : symbolicName ( '.' symbolicName )* ;
listComprehension : '[' filterExpression ( '|' expression )? ']' ; listComprehension : '[' filterExpression ( '|' expression )? ']' ;

View File

@ -22,6 +22,9 @@
#include "query/db_accessor.hpp" #include "query/db_accessor.hpp"
#include "query/exceptions.hpp" #include "query/exceptions.hpp"
#include "query/procedure/cypher_types.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/procedure/module.hpp"
#include "query/typed_value.hpp" #include "query/typed_value.hpp"
#include "utils/string.hpp" #include "utils/string.hpp"
#include "utils/temporal.hpp" #include "utils/temporal.hpp"
@ -1174,6 +1177,53 @@ TypedValue Duration(const TypedValue *args, int64_t nargs, const FunctionContext
MapNumericParameters<Number>(parameter_mappings, args[0].ValueMap()); MapNumericParameters<Number>(parameter_mappings, args[0].ValueMap());
return TypedValue(utils::Duration(duration_parameters), ctx.memory); return TypedValue(utils::Duration(duration_parameters), ctx.memory);
} }
std::function<TypedValue(const TypedValue *, const int64_t, const FunctionContext &)> UserFunction(
const mgp_func &func, const std::string &fully_qualified_name) {
return [func, fully_qualified_name](const TypedValue *args, int64_t nargs, const FunctionContext &ctx) -> TypedValue {
/// Find function is called to aquire the lock on Module pointer while user-defined function is executed
const auto &maybe_found =
procedure::FindFunction(procedure::gModuleRegistry, fully_qualified_name, utils::NewDeleteResource());
if (!maybe_found) {
throw QueryRuntimeException(
"Function '{}' has been unloaded. Please check query modules to confirm that function is loaded in Memgraph.",
fully_qualified_name);
}
/// Explicit extraction of module pointer, to clearly state that the lock is aquired.
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
const auto &module_ptr = (*maybe_found).first;
const auto &func_cb = func.cb;
mgp_memory memory{ctx.memory};
mgp_func_context functx{ctx.db_accessor, ctx.view};
auto graph = mgp_graph::NonWritableGraph(*ctx.db_accessor, ctx.view);
std::vector<TypedValue> args_list;
args_list.reserve(nargs);
for (std::size_t i = 0; i < nargs; ++i) {
args_list.emplace_back(args[i]);
}
auto function_argument_list = mgp_list(ctx.memory);
procedure::ConstructArguments(args_list, func, fully_qualified_name, function_argument_list, graph);
mgp_func_result maybe_res;
func_cb(&function_argument_list, &functx, &maybe_res, &memory);
if (maybe_res.error_msg) {
throw QueryRuntimeException(*maybe_res.error_msg);
}
if (!maybe_res.value) {
throw QueryRuntimeException(
"Function '{}' didn't set the result nor the error message. Please either set the result by using "
"mgp_func_result_set_value or the error by using mgp_func_result_set_error_msg.",
fully_qualified_name);
}
return {*(maybe_res.value), ctx.memory};
};
}
} // namespace } // namespace
std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx)> NameToFunction( std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx)> NameToFunction(
@ -1259,6 +1309,14 @@ std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx
if (function_name == "LOCALDATETIME") return LocalDateTime; if (function_name == "LOCALDATETIME") return LocalDateTime;
if (function_name == "DURATION") return Duration; if (function_name == "DURATION") return Duration;
const auto &maybe_found =
procedure::FindFunction(procedure::gModuleRegistry, function_name, utils::NewDeleteResource());
if (maybe_found) {
const auto *func = (*maybe_found).second;
return UserFunction(*func, function_name);
}
return nullptr; return nullptr;
} }

View File

@ -3705,46 +3705,12 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
"containers aware of that"); "containers aware of that");
// Build and type check procedure arguments. // Build and type check procedure arguments.
mgp_list proc_args(memory); mgp_list proc_args(memory);
proc_args.elems.reserve(args.size()); std::vector<TypedValue> args_list;
if (args.size() < proc.args.size() || args_list.reserve(args.size());
// Rely on `||` short circuit so we can avoid potential overflow of for (auto *expression : args) {
// proc.args.size() + proc.opt_args.size() by subtracting. args_list.emplace_back(expression->Accept(*evaluator));
(args.size() - proc.args.size() > proc.opt_args.size())) {
if (proc.args.empty() && proc.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_procedure_name);
} else if (proc.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_procedure_name, proc.args.size(),
proc.args.size() == 1U ? "argument" : "arguments");
} else {
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_procedure_name,
proc.args.size(), proc.args.size() + proc.opt_args.size());
}
}
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args[i]->Accept(*evaluator);
std::string_view name;
const query::procedure::CypherType *type{nullptr};
if (proc.args.size() > i) {
name = proc.args[i].first;
type = proc.args[i].second;
} else {
MG_ASSERT(proc.opt_args.size() > i - proc.args.size());
name = std::get<0>(proc.opt_args[i - proc.args.size()]);
type = std::get<1>(proc.opt_args[i - proc.args.size()]);
}
if (!type->SatisfiesType(arg)) {
throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.",
fully_qualified_procedure_name, name, i, type->GetPresentableName());
}
proc_args.elems.emplace_back(std::move(arg), &graph);
}
// Fill missing optional arguments with their default values.
MG_ASSERT(args.size() >= proc.args.size());
size_t passed_in_opt_args = args.size() - proc.args.size();
MG_ASSERT(passed_in_opt_args <= proc.opt_args.size());
for (size_t i = passed_in_opt_args; i < proc.opt_args.size(); ++i) {
proc_args.elems.emplace_back(std::get<2>(proc.opt_args[i]), &graph);
} }
procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph);
if (memory_limit) { if (memory_limit) {
SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name, SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name,
utils::GetReadableSize(*memory_limit)); utils::GetReadableSize(*memory_limit));
@ -3832,7 +3798,7 @@ class CallProcedureCursor : public Cursor {
// generator like procedures which yield a new result on each invocation. // generator like procedures which yield a new result on each invocation.
auto *memory = context.evaluation_context.memory; auto *memory = context.evaluation_context.memory;
auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_); auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_);
mgp_graph graph{context.db_accessor, graph_view, &context}; auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context);
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit, CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
&result_); &result_);

View File

@ -187,7 +187,10 @@ template <typename TFunc, typename... Args>
return MGP_ERROR_NO_ERROR; return MGP_ERROR_NO_ERROR;
} }
bool MgpGraphIsMutable(const mgp_graph &graph) noexcept { return graph.view == memgraph::storage::View::NEW; } // Graph mutations
bool MgpGraphIsMutable(const mgp_graph &graph) noexcept {
return graph.view == memgraph::storage::View::NEW && graph.ctx != nullptr;
}
bool MgpVertexIsMutable(const mgp_vertex &vertex) { return MgpGraphIsMutable(*vertex.graph); } bool MgpVertexIsMutable(const mgp_vertex &vertex) { return MgpGraphIsMutable(*vertex.graph); }
@ -289,6 +292,7 @@ mgp_value_type FromTypedValueType(memgraph::query::TypedValue::Type type) {
return MGP_VALUE_TYPE_DURATION; return MGP_VALUE_TYPE_DURATION;
} }
} }
} // namespace
memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory) { memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory) {
switch (val.type) { switch (val.type) {
@ -345,8 +349,6 @@ memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::
} }
} }
} // namespace
mgp_value::mgp_value(memgraph::utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_NULL), memory(m) {} mgp_value::mgp_value(memgraph::utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_NULL), memory(m) {}
mgp_value::mgp_value(bool val, memgraph::utils::MemoryResource *m) noexcept mgp_value::mgp_value(bool val, memgraph::utils::MemoryResource *m) noexcept
@ -1451,6 +1453,14 @@ mgp_error mgp_result_record_insert(mgp_result_record *record, const char *field_
}); });
} }
mgp_error mgp_func_result_set_error_msg(mgp_func_result *res, const char *msg, mgp_memory *memory) {
return WrapExceptions([=] { res->error_msg.emplace(msg, memory->impl); });
}
mgp_error mgp_func_result_set_value(mgp_func_result *res, mgp_value *value, mgp_memory *memory) {
return WrapExceptions([=] { res->value = ToTypedValue(*value, memory->impl); });
}
/// Graph Constructs /// Graph Constructs
void mgp_properties_iterator_destroy(mgp_properties_iterator *it) { DeleteRawMgpObject(it); } void mgp_properties_iterator_destroy(mgp_properties_iterator *it) { DeleteRawMgpObject(it); }
@ -2382,31 +2392,53 @@ mgp_error mgp_module_add_write_procedure(mgp_module *module, const char *name, m
return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result); return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result);
} }
mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) { namespace {
return WrapExceptions([=] { template <typename T>
if (!IsValidIdentifierName(name)) { concept IsCallable = memgraph::utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
throw std::invalid_argument{fmt::format("Invalid argument name for procedure '{}': {}", proc->name, name)};
template <IsCallable TCall>
mgp_error MgpAddArg(TCall &callable, const std::string &name, mgp_type &type) {
return WrapExceptions([&]() mutable {
static constexpr std::string_view type_name = std::invoke([]() constexpr {
if constexpr (std::is_same_v<TCall, mgp_proc>) {
return "procedure";
} else if constexpr (std::is_same_v<TCall, mgp_func>) {
return "function";
}
});
if (!IsValidIdentifierName(name.c_str())) {
throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)};
} }
if (!proc->opt_args.empty()) { if (!callable.opt_args.empty()) {
throw std::logic_error{fmt::format( throw std::logic_error{fmt::format("Cannot add required argument '{}' to {} '{}' after adding any optional one",
"Cannot add required argument '{}' to procedure '{}' after adding any optional one", name, proc->name)}; name, type_name, callable.name)};
} }
proc->args.emplace_back(name, type->impl.get()); callable.args.emplace_back(name, type.impl.get());
}); });
} }
mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) { template <IsCallable TCall>
return WrapExceptions([=] { mgp_error MgpAddOptArg(TCall &callable, const std::string name, mgp_type &type, mgp_value &default_value) {
if (!IsValidIdentifierName(name)) { return WrapExceptions([&]() mutable {
throw std::invalid_argument{fmt::format("Invalid argument name for procedure '{}': {}", proc->name, name)}; static constexpr std::string_view type_name = std::invoke([]() constexpr {
if constexpr (std::is_same_v<TCall, mgp_proc>) {
return "procedure";
} else if constexpr (std::is_same_v<TCall, mgp_func>) {
return "function";
}
});
if (!IsValidIdentifierName(name.c_str())) {
throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)};
} }
switch (MgpValueGetType(*default_value)) { switch (MgpValueGetType(default_value)) {
case MGP_VALUE_TYPE_VERTEX: case MGP_VALUE_TYPE_VERTEX:
case MGP_VALUE_TYPE_EDGE: case MGP_VALUE_TYPE_EDGE:
case MGP_VALUE_TYPE_PATH: case MGP_VALUE_TYPE_PATH:
// default_value must not be a graph element. // default_value must not be a graph element.
throw ValueConversionException{ throw ValueConversionException{"Default value of argument '{}' of {} '{}' name must not be a graph element!",
"Default value of argument '{}' of procedure '{}' name must not be a graph element!", name, proc->name}; name, type_name, callable.name};
case MGP_VALUE_TYPE_NULL: case MGP_VALUE_TYPE_NULL:
case MGP_VALUE_TYPE_BOOL: case MGP_VALUE_TYPE_BOOL:
case MGP_VALUE_TYPE_INT: case MGP_VALUE_TYPE_INT:
@ -2421,16 +2453,32 @@ mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type,
break; break;
} }
// Default value must be of required `type`. // Default value must be of required `type`.
if (!type->impl->SatisfiesType(*default_value)) { if (!type.impl->SatisfiesType(default_value)) {
throw std::logic_error{ throw std::logic_error{fmt::format("The default value of argument '{}' for {} '{}' doesn't satisfy type '{}'",
fmt::format("The default value of argument '{}' for procedure '{}' doesn't satisfy type '{}'", name, name, type_name, callable.name, type.impl->GetPresentableName())};
proc->name, type->impl->GetPresentableName())};
} }
auto *memory = proc->opt_args.get_allocator().GetMemoryResource(); auto *memory = callable.opt_args.get_allocator().GetMemoryResource();
proc->opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type->impl.get(), callable.opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type.impl.get(),
ToTypedValue(*default_value, memory)); ToTypedValue(default_value, memory));
}); });
} }
} // namespace
mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) {
return MgpAddArg(*proc, std::string(name), *type);
}
mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) {
return MgpAddOptArg(*proc, std::string(name), *type, *default_value);
}
mgp_error mgp_func_add_arg(mgp_func *func, const char *name, mgp_type *type) {
return MgpAddArg(*func, std::string(name), *type);
}
mgp_error mgp_func_add_opt_arg(mgp_func *func, const char *name, mgp_type *type, mgp_value *default_value) {
return MgpAddOptArg(*func, std::string(name), *type, *default_value);
}
namespace { namespace {
@ -2545,6 +2593,22 @@ void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) {
(*stream) << ")"; (*stream) << ")";
} }
void PrintFuncSignature(const mgp_func &func, std::ostream &stream) {
stream << func.name << "(";
utils::PrintIterable(stream, func.args, ", ", [](auto &stream, const auto &arg) {
stream << arg.first << " :: " << arg.second->GetPresentableName();
});
if (!func.args.empty() && !func.opt_args.empty()) {
stream << ", ";
}
utils::PrintIterable(stream, func.opt_args, ", ", [](auto &stream, const auto &arg) {
const auto &[name, type, default_val] = arg;
stream << name << " = ";
PrintValue(default_val, &stream) << " :: " << type->GetPresentableName();
});
stream << ")";
}
bool IsValidIdentifierName(const char *name) { bool IsValidIdentifierName(const char *name) {
if (!name) return false; if (!name) return false;
std::regex regex("[_[:alpha:]][_[:alnum:]]*"); std::regex regex("[_[:alpha:]][_[:alnum:]]*");
@ -2716,3 +2780,19 @@ mgp_error mgp_module_add_transformation(mgp_module *module, const char *name, mg
module->transformations.emplace(name, mgp_trans(name, cb, memory)); module->transformations.emplace(name, mgp_trans(name, cb, memory));
}); });
} }
mgp_error mgp_module_add_function(mgp_module *module, const char *name, mgp_func_cb cb, mgp_func **result) {
return WrapExceptions(
[=] {
if (!IsValidIdentifierName(name)) {
throw std::invalid_argument{fmt::format("Invalid function name: {}", name)};
}
if (module->functions.find(name) != module->functions.end()) {
throw std::logic_error{fmt::format("Function with similar name already exists '{}'", name)};
};
auto *memory = module->functions.get_allocator().GetMemoryResource();
return &module->functions.emplace(name, mgp_func(name, cb, memory)).first->second;
},
result);
}

View File

@ -562,14 +562,36 @@ struct mgp_result {
std::optional<memgraph::utils::pmr::string> error_msg; std::optional<memgraph::utils::pmr::string> error_msg;
}; };
struct mgp_func_result {
mgp_func_result() {}
/// Return Magic function result. If user forgets it, the error is raised
std::optional<memgraph::query::TypedValue> value;
/// Return Magic function result with potential error
std::optional<memgraph::utils::pmr::string> error_msg;
};
struct mgp_graph { struct mgp_graph {
memgraph::query::DbAccessor *impl; memgraph::query::DbAccessor *impl;
memgraph::storage::View view; memgraph::storage::View view;
// TODO: Merge `mgp_graph` and `mgp_memory` into a single `mgp_context`. The // TODO: Merge `mgp_graph` and `mgp_memory` into a single `mgp_context`. The
// `ctx` field is out of place here. // `ctx` field is out of place here.
memgraph::query::ExecutionContext *ctx; memgraph::query::ExecutionContext *ctx;
static mgp_graph WritableGraph(memgraph::query::DbAccessor &acc, memgraph::storage::View view,
memgraph::query::ExecutionContext &ctx) {
return mgp_graph{&acc, view, &ctx};
}
static mgp_graph NonWritableGraph(memgraph::query::DbAccessor &acc, memgraph::storage::View view) {
return mgp_graph{&acc, view, nullptr};
}
}; };
// Prevents user to use ExecutionContext in writable callables
struct mgp_func_context {
memgraph::query::DbAccessor *impl;
memgraph::storage::View view;
};
struct mgp_properties_iterator { struct mgp_properties_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_properties_iterator>; using allocator_type = memgraph::utils::Allocator<mgp_properties_iterator>;
@ -779,18 +801,69 @@ struct mgp_trans {
results; results;
}; };
struct mgp_func {
using allocator_type = memgraph::utils::Allocator<mgp_func>;
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const char *name, mgp_func_cb cb, memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const char *name, std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const mgp_func &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory), cb(other.cb), args(other.args, memory), opt_args(other.opt_args, memory) {}
mgp_func(mgp_func &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory),
cb(std::move(other.cb)),
args(std::move(other.args), memory),
opt_args(std::move(other.opt_args), memory) {}
mgp_func(const mgp_func &other) = default;
mgp_func(mgp_func &&other) = default;
mgp_func &operator=(const mgp_func &) = delete;
mgp_func &operator=(mgp_func &&) = delete;
~mgp_func() = default;
/// Name of the function.
memgraph::utils::pmr::string name;
/// Entry-point for the function.
std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb;
/// Required, positional arguments as a (name, type) pair.
memgraph::utils::pmr::vector<std::pair<memgraph::utils::pmr::string, const memgraph::query::procedure::CypherType *>>
args;
/// Optional positional arguments as a (name, type, default_value) tuple.
memgraph::utils::pmr::vector<std::tuple<memgraph::utils::pmr::string, const memgraph::query::procedure::CypherType *,
memgraph::query::TypedValue>>
opt_args;
};
mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept; mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept;
struct mgp_module { struct mgp_module {
using allocator_type = memgraph::utils::Allocator<mgp_module>; using allocator_type = memgraph::utils::Allocator<mgp_module>;
explicit mgp_module(memgraph::utils::MemoryResource *memory) : procedures(memory), transformations(memory) {} explicit mgp_module(memgraph::utils::MemoryResource *memory)
: procedures(memory), transformations(memory), functions(memory) {}
mgp_module(const mgp_module &other, memgraph::utils::MemoryResource *memory) mgp_module(const mgp_module &other, memgraph::utils::MemoryResource *memory)
: procedures(other.procedures, memory), transformations(other.transformations, memory) {} : procedures(other.procedures, memory),
transformations(other.transformations, memory),
functions(other.functions, memory) {}
mgp_module(mgp_module &&other, memgraph::utils::MemoryResource *memory) mgp_module(mgp_module &&other, memgraph::utils::MemoryResource *memory)
: procedures(std::move(other.procedures), memory), transformations(std::move(other.transformations), memory) {} : procedures(std::move(other.procedures), memory),
transformations(std::move(other.transformations), memory),
functions(std::move(other.functions), memory) {}
mgp_module(const mgp_module &) = default; mgp_module(const mgp_module &) = default;
mgp_module(mgp_module &&) = default; mgp_module(mgp_module &&) = default;
@ -802,6 +875,7 @@ struct mgp_module {
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_proc> procedures; memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_proc> procedures;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_trans> transformations; memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_trans> transformations;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_func> functions;
}; };
namespace memgraph::query::procedure { namespace memgraph::query::procedure {
@ -811,6 +885,11 @@ namespace memgraph::query::procedure {
/// @throw anything std::ostream::operator<< may throw. /// @throw anything std::ostream::operator<< may throw.
void PrintProcSignature(const mgp_proc &, std::ostream *); void PrintProcSignature(const mgp_proc &, std::ostream *);
/// @throw std::bad_alloc
/// @throw std::length_error
/// @throw anything std::ostream::operator<< may throw.
void PrintFuncSignature(const mgp_func &, std::ostream &);
bool IsValidIdentifierName(const char *name); bool IsValidIdentifierName(const char *name);
} // namespace memgraph::query::procedure } // namespace memgraph::query::procedure
@ -839,3 +918,5 @@ struct mgp_messages {
storage_type messages; storage_type messages;
}; };
memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory);

View File

@ -52,6 +52,8 @@ class BuiltinModule final : public Module {
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override; const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
void AddProcedure(std::string_view name, mgp_proc proc); void AddProcedure(std::string_view name, mgp_proc proc);
void AddTransformation(std::string_view name, mgp_trans trans); void AddTransformation(std::string_view name, mgp_trans trans);
@ -62,6 +64,7 @@ class BuiltinModule final : public Module {
/// Registered procedures /// Registered procedures
std::map<std::string, mgp_proc, std::less<>> procedures_; std::map<std::string, mgp_proc, std::less<>> procedures_;
std::map<std::string, mgp_trans, std::less<>> transformations_; std::map<std::string, mgp_trans, std::less<>> transformations_;
std::map<std::string, mgp_func, std::less<>> functions_;
}; };
BuiltinModule::BuiltinModule() {} BuiltinModule::BuiltinModule() {}
@ -75,6 +78,7 @@ const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures()
const std::map<std::string, mgp_trans, std::less<>> *BuiltinModule::Transformations() const { const std::map<std::string, mgp_trans, std::less<>> *BuiltinModule::Transformations() const {
return &transformations_; return &transformations_;
} }
const std::map<std::string, mgp_func, std::less<>> *BuiltinModule::Functions() const { return &functions_; }
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) { procedures_.emplace(name, std::move(proc)); } void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) { procedures_.emplace(name, std::move(proc)); }
@ -300,6 +304,82 @@ void RegisterMgTransformations(const std::map<std::string, std::unique_ptr<Modul
module->AddProcedure("transformations", std::move(procedures)); module->AddProcedure("transformations", std::move(procedures));
} }
void RegisterMgFunctions(
// We expect modules to be sorted by name.
const std::map<std::string, std::unique_ptr<Module>, std::less<>> *all_modules, BuiltinModule *module) {
auto functions_cb = [all_modules](mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result *result,
mgp_memory *memory) {
// Iterating over all_modules assumes that the standard mechanism of magic
// functions invocations takes the ModuleRegistry::lock_ with READ access.
for (const auto &[module_name, module] : *all_modules) {
// Return the results in sorted order by module and by function_name.
static_assert(std::is_same_v<decltype(module->Functions()), const std::map<std::string, mgp_func, std::less<>> *>,
"Expected module magic functions to be sorted by name");
const auto path = module->Path();
const auto path_string = GetPathString(path);
const auto is_editable = IsFileEditable(path);
for (const auto &[func_name, func] : *module->Functions()) {
mgp_result_record *record{nullptr};
if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) {
return;
}
const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result);
if (!path_value) {
return;
}
MgpUniquePtr<mgp_value> is_editable_value{nullptr, mgp_value_destroy};
if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); },
result)) {
return;
}
utils::pmr::string full_name(module_name, memory->impl);
full_name.append(1, '.');
full_name.append(func_name);
const auto name_value = GetStringValueOrSetError(full_name.c_str(), memory, result);
if (!name_value) {
return;
}
std::stringstream ss;
ss << module_name << ".";
PrintFuncSignature(func, ss);
const auto signature = ss.str();
const auto signature_value = GetStringValueOrSetError(signature.c_str(), memory, result);
if (!signature_value) {
return;
}
if (!InsertResultOrSetError(result, record, "name", name_value.get())) {
return;
}
if (!InsertResultOrSetError(result, record, "signature", signature_value.get())) {
return;
}
if (!InsertResultOrSetError(result, record, "path", path_value.get())) {
return;
}
if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) {
return;
}
}
}
};
mgp_proc functions("functions", functions_cb, utils::NewDeleteResource());
MG_ASSERT(mgp_proc_add_result(&functions, "name", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&functions, "signature", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&functions, "path", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&functions, "is_editable", Call<mgp_type *>(mgp_type_bool)) == MGP_ERROR_NO_ERROR);
module->AddProcedure("functions", std::move(functions));
}
namespace { namespace {
bool IsAllowedExtension(const auto &extension) { bool IsAllowedExtension(const auto &extension) {
static constexpr std::array<std::string_view, 1> allowed_extensions{".py"}; static constexpr std::array<std::string_view, 1> allowed_extensions{".py"};
@ -650,8 +730,8 @@ void RegisterMgDeleteModuleFile(ModuleRegistry *module_registry, utils::RWLock *
// `mgp_module::transformations into `proc_map`. The return value of WithModuleRegistration // `mgp_module::transformations into `proc_map`. The return value of WithModuleRegistration
// is the same as that of `fun`. Note, the return value need only be convertible to `bool`, // is the same as that of `fun`. Note, the return value need only be convertible to `bool`,
// it does not have to be `bool` itself. // it does not have to be `bool` itself.
template <class TProcMap, class TTransMap, class TFun> template <class TProcMap, class TTransMap, class TFuncMap, class TFun>
auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, const TFun &fun) { auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, TFuncMap *func_map, const TFun &fun) {
// We probably don't need more than 256KB for module initialization. // We probably don't need more than 256KB for module initialization.
static constexpr size_t stack_bytes = 256UL * 1024UL; static constexpr size_t stack_bytes = 256UL * 1024UL;
unsigned char stack_memory[stack_bytes]; unsigned char stack_memory[stack_bytes];
@ -664,6 +744,8 @@ auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, const TFun
for (const auto &proc : module_def.procedures) proc_map->emplace(proc); for (const auto &proc : module_def.procedures) proc_map->emplace(proc);
// Copy transformations into resulting trans_map. // Copy transformations into resulting trans_map.
for (const auto &trans : module_def.transformations) trans_map->emplace(trans); for (const auto &trans : module_def.transformations) trans_map->emplace(trans);
// Copy functions into resulting func_map.
for (const auto &func : module_def.functions) func_map->emplace(func);
} }
return res; return res;
} }
@ -687,6 +769,8 @@ class SharedLibraryModule final : public Module {
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override; const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
std::optional<std::filesystem::path> Path() const override { return file_path_; } std::optional<std::filesystem::path> Path() const override { return file_path_; }
private: private:
@ -702,6 +786,8 @@ class SharedLibraryModule final : public Module {
std::map<std::string, mgp_proc, std::less<>> procedures_; std::map<std::string, mgp_proc, std::less<>> procedures_;
/// Registered transformations /// Registered transformations
std::map<std::string, mgp_trans, std::less<>> transformations_; std::map<std::string, mgp_trans, std::less<>> transformations_;
/// Registered functions
std::map<std::string, mgp_func, std::less<>> functions_;
}; };
SharedLibraryModule::SharedLibraryModule() : handle_(nullptr) {} SharedLibraryModule::SharedLibraryModule() : handle_(nullptr) {}
@ -755,7 +841,7 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
} }
return true; return true;
}; };
if (!WithModuleRegistration(&procedures_, &transformations_, module_cb)) { if (!WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb)) {
return false; return false;
} }
// Get optional mgp_shutdown_module // Get optional mgp_shutdown_module
@ -801,6 +887,13 @@ const std::map<std::string, mgp_trans, std::less<>> *SharedLibraryModule::Transf
return &transformations_; return &transformations_;
} }
const std::map<std::string, mgp_func, std::less<>> *SharedLibraryModule::Functions() const {
MG_ASSERT(handle_,
"Attempting to access functions of a module that has not "
"been loaded...");
return &functions_;
}
class PythonModule final : public Module { class PythonModule final : public Module {
public: public:
PythonModule(); PythonModule();
@ -816,6 +909,7 @@ class PythonModule final : public Module {
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override; const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override; const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
std::optional<std::filesystem::path> Path() const override { return file_path_; } std::optional<std::filesystem::path> Path() const override { return file_path_; }
private: private:
@ -823,6 +917,7 @@ class PythonModule final : public Module {
py::Object py_module_; py::Object py_module_;
std::map<std::string, mgp_proc, std::less<>> procedures_; std::map<std::string, mgp_proc, std::less<>> procedures_;
std::map<std::string, mgp_trans, std::less<>> transformations_; std::map<std::string, mgp_trans, std::less<>> transformations_;
std::map<std::string, mgp_func, std::less<>> functions_;
}; };
PythonModule::PythonModule() {} PythonModule::PythonModule() {}
@ -853,7 +948,7 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
}; };
return result; return result;
}; };
py_module_ = WithModuleRegistration(&procedures_, &transformations_, module_cb); py_module_ = WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb);
if (py_module_) { if (py_module_) {
spdlog::info("Loaded module {}", file_path); spdlog::info("Loaded module {}", file_path);
@ -877,6 +972,7 @@ bool PythonModule::Close() {
auto gil = py::EnsureGIL(); auto gil = py::EnsureGIL();
procedures_.clear(); procedures_.clear();
transformations_.clear(); transformations_.clear();
functions_.clear();
// Delete the module from the `sys.modules` directory so that the module will // Delete the module from the `sys.modules` directory so that the module will
// be properly imported if imported again. // be properly imported if imported again.
py::Object sys(PyImport_ImportModule("sys")); py::Object sys(PyImport_ImportModule("sys"));
@ -906,6 +1002,13 @@ const std::map<std::string, mgp_trans, std::less<>> *PythonModule::Transformatio
"not been loaded..."); "not been loaded...");
return &transformations_; return &transformations_;
} }
const std::map<std::string, mgp_func, std::less<>> *PythonModule::Functions() const {
MG_ASSERT(py_module_,
"Attempting to access functions of a module that has "
"not been loaded...");
return &functions_;
}
namespace { namespace {
std::unique_ptr<Module> LoadModuleFromFile(const std::filesystem::path &path) { std::unique_ptr<Module> LoadModuleFromFile(const std::filesystem::path &path) {
@ -954,6 +1057,7 @@ ModuleRegistry::ModuleRegistry() {
auto module = std::make_unique<BuiltinModule>(); auto module = std::make_unique<BuiltinModule>();
RegisterMgProcedures(&modules_, module.get()); RegisterMgProcedures(&modules_, module.get());
RegisterMgTransformations(&modules_, module.get()); RegisterMgTransformations(&modules_, module.get());
RegisterMgFunctions(&modules_, module.get());
RegisterMgLoad(this, &lock_, module.get()); RegisterMgLoad(this, &lock_, module.get());
RegisterMgGetModuleFiles(this, module.get()); RegisterMgGetModuleFiles(this, module.get());
RegisterMgGetModuleFile(this, module.get()); RegisterMgGetModuleFile(this, module.get());
@ -1083,7 +1187,7 @@ std::optional<std::pair<std::string_view, std::string_view>> FindModuleNameAndPr
} }
template <typename T> template <typename T>
concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans>; concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans, mgp_func>;
template <ModuleProperties T> template <ModuleProperties T>
std::optional<std::pair<ModulePtr, const T *>> MakePairIfPropFound(const ModuleRegistry &module_registry, std::optional<std::pair<ModulePtr, const T *>> MakePairIfPropFound(const ModuleRegistry &module_registry,
@ -1092,8 +1196,10 @@ std::optional<std::pair<ModulePtr, const T *>> MakePairIfPropFound(const ModuleR
auto prop_fun = [](auto &module) { auto prop_fun = [](auto &module) {
if constexpr (std::is_same_v<T, mgp_proc>) { if constexpr (std::is_same_v<T, mgp_proc>) {
return module->Procedures(); return module->Procedures();
} else { } else if constexpr (std::is_same_v<T, mgp_trans>) {
return module->Transformations(); return module->Transformations();
} else if constexpr (std::is_same_v<T, mgp_func>) {
return module->Functions();
} }
}; };
auto result = FindModuleNameAndProp(module_registry, fully_qualified_name, memory); auto result = FindModuleNameAndProp(module_registry, fully_qualified_name, memory);
@ -1121,4 +1227,10 @@ std::optional<std::pair<ModulePtr, const mgp_trans *>> FindTransformation(
return MakePairIfPropFound<mgp_trans>(module_registry, fully_qualified_transformation_name, memory); return MakePairIfPropFound<mgp_trans>(module_registry, fully_qualified_transformation_name, memory);
} }
std::optional<std::pair<ModulePtr, const mgp_func *>> FindFunction(const ModuleRegistry &module_registry,
std::string_view fully_qualified_function_name,
utils::MemoryResource *memory) {
return MakePairIfPropFound<mgp_func>(module_registry, fully_qualified_function_name, memory);
}
} // namespace memgraph::query::procedure } // namespace memgraph::query::procedure

View File

@ -21,6 +21,7 @@
#include <string_view> #include <string_view>
#include <unordered_map> #include <unordered_map>
#include "query/procedure/cypher_types.hpp"
#include "query/procedure/mg_procedure_impl.hpp" #include "query/procedure/mg_procedure_impl.hpp"
#include "utils/memory.hpp" #include "utils/memory.hpp"
#include "utils/rw_lock.hpp" #include "utils/rw_lock.hpp"
@ -45,6 +46,8 @@ class Module {
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0; virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0;
/// Returns registered transformations of this module /// Returns registered transformations of this module
virtual const std::map<std::string, mgp_trans, std::less<>> *Transformations() const = 0; virtual const std::map<std::string, mgp_trans, std::less<>> *Transformations() const = 0;
// /// Returns registered functions of this module
virtual const std::map<std::string, mgp_func, std::less<>> *Functions() const = 0;
virtual std::optional<std::filesystem::path> Path() const = 0; virtual std::optional<std::filesystem::path> Path() const = 0;
}; };
@ -147,4 +150,62 @@ std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
std::optional<std::pair<procedure::ModulePtr, const mgp_trans *>> FindTransformation( std::optional<std::pair<procedure::ModulePtr, const mgp_trans *>> FindTransformation(
const ModuleRegistry &module_registry, const std::string_view fully_qualified_transformation_name, const ModuleRegistry &module_registry, const std::string_view fully_qualified_transformation_name,
utils::MemoryResource *memory); utils::MemoryResource *memory);
/// Return the ModulePtr and `mgp_func *` of the found function after resolving
/// `fully_qualified_function_name` if found. If there is no such function
/// std::nullopt is returned. `memory` is used for temporary allocations
/// inside this function. ModulePtr must be kept alive to make sure it won't be unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_func *>> FindFunction(
const ModuleRegistry &module_registry, const std::string_view fully_qualified_function_name,
utils::MemoryResource *memory);
template <typename T>
concept IsCallable = utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
template <IsCallable TCall>
void ConstructArguments(const std::vector<TypedValue> &args, const TCall &callable,
const std::string_view fully_qualified_name, mgp_list &args_list, mgp_graph &graph) {
const auto n_args = args.size();
const auto c_args_sz = callable.args.size();
const auto c_opt_args_sz = callable.opt_args.size();
if (n_args < c_args_sz || (n_args - c_args_sz > c_opt_args_sz)) {
if (callable.args.empty() && callable.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_name);
}
if (callable.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_name, c_args_sz,
c_args_sz == 1U ? "argument" : "arguments");
}
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_name, c_args_sz,
c_args_sz + c_opt_args_sz);
}
args_list.elems.reserve(n_args);
auto is_not_optional_arg = [c_args_sz](int i) { return c_args_sz > i; };
for (size_t i = 0; i < n_args; ++i) {
auto arg = args[i];
std::string_view name;
const query::procedure::CypherType *type;
if (is_not_optional_arg(i)) {
name = callable.args[i].first;
type = callable.args[i].second;
} else {
name = std::get<0>(callable.opt_args[i - c_args_sz]);
type = std::get<1>(callable.opt_args[i - c_args_sz]);
}
if (!type->SatisfiesType(arg)) {
throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.", fully_qualified_name,
name, i, type->GetPresentableName());
}
args_list.elems.emplace_back(std::move(arg), &graph);
}
// Fill missing optional arguments with their default values.
const size_t passed_in_opt_args = n_args - c_args_sz;
for (size_t i = passed_in_opt_args; i < c_opt_args_sz; ++i) {
args_list.elems.emplace_back(std::get<2>(callable.opt_args[i]), &graph);
}
}
} // namespace memgraph::query::procedure } // namespace memgraph::query::procedure

View File

@ -447,62 +447,94 @@ PyObject *MakePyCypherType(mgp_type *type) {
// clang-format off // clang-format off
struct PyQueryProc { struct PyQueryProc {
PyObject_HEAD PyObject_HEAD
mgp_proc *proc; mgp_proc *callable;
}; };
// clang-format on // clang-format on
PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) { // clang-format off
MG_ASSERT(self->proc); struct PyMagicFunc{
PyObject_HEAD
mgp_func *callable;
};
// clang-format on
template <typename T>
concept IsCallable = utils::SameAsAnyOf<T, PyQueryProc, PyMagicFunc>;
template <IsCallable TCall>
PyObject *PyCallableAddArg(TCall *self, PyObject *args) {
MG_ASSERT(self->callable);
const char *name = nullptr; const char *name = nullptr;
PyCypherType *py_type = nullptr; PyCypherType *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
auto *type = py_type->type; auto *type = py_type->type;
if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->proc, name, type))) {
return nullptr; if constexpr (std::is_same_v<TCall, PyQueryProc>) {
if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->callable, name, type))) {
return nullptr;
}
} else if constexpr (std::is_same_v<TCall, PyMagicFunc>) {
if (RaiseExceptionFromErrorCode(mgp_func_add_arg(self->callable, name, type))) {
return nullptr;
}
} }
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) { template <IsCallable TCall>
MG_ASSERT(self->proc); PyObject *PyCallableAddOptArg(TCall *self, PyObject *args) {
MG_ASSERT(self->callable);
const char *name = nullptr; const char *name = nullptr;
PyCypherType *py_type = nullptr; PyCypherType *py_type = nullptr;
PyObject *py_value = nullptr; PyObject *py_value = nullptr;
if (!PyArg_ParseTuple(args, "sO!O", &name, &PyCypherTypeType, &py_type, &py_value)) return nullptr; if (!PyArg_ParseTuple(args, "sO!O", &name, &PyCypherTypeType, &py_type, &py_value)) return nullptr;
auto *type = py_type->type; auto *type = py_type->type;
mgp_memory memory{self->proc->opt_args.get_allocator().GetMemoryResource()}; mgp_memory memory{self->callable->opt_args.get_allocator().GetMemoryResource()};
mgp_value *value = PyObjectToMgpValueWithPythonExceptions(py_value, &memory); mgp_value *value = PyObjectToMgpValueWithPythonExceptions(py_value, &memory);
if (value == nullptr) { if (value == nullptr) {
return nullptr; return nullptr;
} }
if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->proc, name, type, value))) { if constexpr (std::is_same_v<TCall, PyQueryProc>) {
mgp_value_destroy(value); if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->callable, name, type, value))) {
return nullptr; mgp_value_destroy(value);
return nullptr;
}
} else if constexpr (std::is_same_v<TCall, PyMagicFunc>) {
if (RaiseExceptionFromErrorCode(mgp_func_add_opt_arg(self->callable, name, type, value))) {
mgp_value_destroy(value);
return nullptr;
}
} }
mgp_value_destroy(value); mgp_value_destroy(value);
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) { return PyCallableAddArg(self, args); }
PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) { return PyCallableAddOptArg(self, args); }
PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) { PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) {
MG_ASSERT(self->proc); MG_ASSERT(self->callable);
const char *name = nullptr; const char *name = nullptr;
PyCypherType *py_type = nullptr; PyCypherType *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
auto *type = reinterpret_cast<PyCypherType *>(py_type)->type; auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->proc, name, type))) { if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->callable, name, type))) {
return nullptr; return nullptr;
} }
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) { PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) {
MG_ASSERT(self->proc); MG_ASSERT(self->callable);
const char *name = nullptr; const char *name = nullptr;
PyCypherType *py_type = nullptr; PyCypherType *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
auto *type = reinterpret_cast<PyCypherType *>(py_type)->type; auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->proc, name, type))) { if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->callable, name, type))) {
return nullptr; return nullptr;
} }
Py_RETURN_NONE; Py_RETURN_NONE;
@ -532,6 +564,33 @@ static PyTypeObject PyQueryProcType = {
}; };
// clang-format on // clang-format on
PyObject *PyMagicFuncAddArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddArg(self, args); }
PyObject *PyMagicFuncAddOptArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddOptArg(self, args); }
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef PyMagicFuncMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"add_arg", reinterpret_cast<PyCFunction>(PyMagicFuncAddArg), METH_VARARGS,
"Add a required argument to a function."},
{"add_opt_arg", reinterpret_cast<PyCFunction>(PyMagicFuncAddOptArg), METH_VARARGS,
"Add an optional argument with a default value to a function."},
{nullptr},
};
// clang-format off
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static PyTypeObject PyMagicFuncType = {
PyVarObject_HEAD_INIT(nullptr, 0)
.tp_name = "_mgp.Func",
.tp_basicsize = sizeof(PyMagicFunc),
// NOLINTNEXTLINE(hicpp-signed-bitwise)
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_doc = "Wraps struct mgp_func.",
.tp_methods = PyMagicFuncMethods,
};
// clang-format on
// clang-format off // clang-format off
struct PyQueryModule { struct PyQueryModule {
PyObject_HEAD PyObject_HEAD
@ -796,7 +855,6 @@ py::Object MgpListToPyTuple(mgp_list *list, PyObject *py_graph) {
} }
namespace { namespace {
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result, py::Object py_record) { std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result, py::Object py_record) {
py::Object py_mgp(PyImport_ImportModule("mgp")); py::Object py_mgp(PyImport_ImportModule("mgp"));
if (!py_mgp) return py::FetchError(); if (!py_mgp) return py::FetchError();
@ -870,6 +928,33 @@ std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(mgp_result *result
return std::nullopt; return std::nullopt;
} }
std::function<void()> PyObjectCleanup(py::Object &py_object) {
return [py_object]() {
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If the
// user stored a reference to one of our `_mgp` instances then the
// internally used `mgp_*` structs will stay unfreed and a memory leak
// will be reported at the end of the query execution.
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
}
// After making sure all references from our side have been cleared,
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
// of our `_mgp` instances then this will prevent them from using those
// objects (whose internal `mgp_*` pointers are now invalid and would cause
// a crash).
if (!py_object.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
}
};
}
void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_result *result, void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_result *result,
mgp_memory *memory) { mgp_memory *memory) {
auto gil = py::EnsureGIL(); auto gil = py::EnsureGIL();
@ -895,31 +980,6 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
} }
}; };
auto cleanup = [](py::Object py_graph) {
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If the
// user stored a reference to one of our `_mgp` instances then the
// internally used `mgp_*` structs will stay unfreed and a memory leak
// will be reported at the end of the query execution.
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
}
// After making sure all references from our side have been cleared,
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
// of our `_mgp` instances then this will prevent them from using those
// objects (whose internal `mgp_*` pointers are now invalid and would cause
// a crash).
if (!py_graph.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
}
};
// It is *VERY IMPORTANT* to note that this code takes great care not to keep // It is *VERY IMPORTANT* to note that this code takes great care not to keep
// any extra references to any `_mgp` instances (except for `_mgp.Graph`), so // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
// as not to introduce extra reference counts and prevent their deallocation. // as not to introduce extra reference counts and prevent their deallocation.
@ -932,14 +992,9 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
std::optional<std::string> maybe_msg; std::optional<std::string> maybe_msg;
{ {
py::Object py_graph(MakePyGraph(graph, memory)); py::Object py_graph(MakePyGraph(graph, memory));
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
if (py_graph) { if (py_graph) {
try { maybe_msg = error_to_msg(call(py_graph));
maybe_msg = error_to_msg(call(py_graph));
cleanup(py_graph);
} catch (...) {
cleanup(py_graph);
throw;
}
} else { } else {
maybe_msg = error_to_msg(py::FetchError()); maybe_msg = error_to_msg(py::FetchError());
} }
@ -972,32 +1027,58 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g
return AddRecordFromPython(result, py_res); return AddRecordFromPython(result, py_res);
}; };
auto cleanup = [](py::Object py_graph, py::Object py_messages) { // It is *VERY IMPORTANT* to note that this code takes great care not to keep
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
// sure the procedure cleaned up everything it held references to. If the // as not to introduce extra reference counts and prevent their deallocation.
// user stored a reference to one of our `_mgp` instances then the // In particular, the `ExceptionInfo` object has a `traceback` field that
// internally used `mgp_*` structs will stay unfreed and a memory leak // contains references to the Python frames and their arguments, and therefore
// will be reported at the end of the query execution. // our `_mgp` instances as well. Within this code we ensure not to keep the
py::Object gc(PyImport_ImportModule("gc")); // `ExceptionInfo` object alive so that no extra reference counts are
if (!gc) { // introduced. We only fetch the error message and immediately destroy the
LOG_FATAL(py::FetchError().value()); // object.
} std::optional<std::string> maybe_msg;
{
py::Object py_graph(MakePyGraph(graph, memory));
py::Object py_messages(MakePyMessages(msgs, memory));
if (!gc.CallMethod("collect")) { utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph));
LOG_FATAL(py::FetchError().value()); utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages));
}
// After making sure all references from our side have been cleared, if (py_graph && py_messages) {
// invalidate the `_mgp.Graph` object. If the user kept a reference to one maybe_msg = error_to_msg(call(py_graph, py_messages));
// of our `_mgp` instances then this will prevent them from using those } else {
// objects (whose internal `mgp_*` pointers are now invalid and would cause maybe_msg = error_to_msg(py::FetchError());
// a crash).
if (!py_graph.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
} }
if (!py_messages.CallMethod("invalidate")) { }
LOG_FATAL(py::FetchError().value());
if (maybe_msg) {
static_cast<void>(mgp_result_set_error_msg(result, maybe_msg->c_str()));
}
}
void CallPythonFunction(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_func_result *result,
mgp_memory *memory) {
auto gil = py::EnsureGIL();
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
if (!exc_info) return std::nullopt;
// Here we tell the traceback formatter to skip the first line of the
// traceback because that line will always be our wrapper function in our
// internal `mgp.py` file. With that line skipped, the user will always
// get only the relevant traceback that happened in his Python code.
return py::FormatException(*exc_info, /* skip_first_line = */ true);
};
auto call = [&](py::Object py_graph) -> utils::BasicResult<std::optional<py::ExceptionInfo>, mgp_value *> {
py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr()));
if (!py_args) return {py::FetchError()};
auto py_res = py_cb.Call(py_graph, py_args);
if (!py_res) return {py::FetchError()};
mgp_value *ret_val = PyObjectToMgpValueWithPythonExceptions(py_res.Ptr(), memory);
if (ret_val == nullptr) {
return {py::FetchError()};
} }
return ret_val;
}; };
// It is *VERY IMPORTANT* to note that this code takes great care not to keep // It is *VERY IMPORTANT* to note that this code takes great care not to keep
@ -1012,22 +1093,22 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g
std::optional<std::string> maybe_msg; std::optional<std::string> maybe_msg;
{ {
py::Object py_graph(MakePyGraph(graph, memory)); py::Object py_graph(MakePyGraph(graph, memory));
py::Object py_messages(MakePyMessages(msgs, memory)); utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
if (py_graph && py_messages) { if (py_graph) {
try { auto maybe_result = call(py_graph);
maybe_msg = error_to_msg(call(py_graph, py_messages)); if (!maybe_result.HasError()) {
cleanup(py_graph, py_messages); static_cast<void>(mgp_func_result_set_value(result, maybe_result.GetValue(), memory));
} catch (...) { return;
cleanup(py_graph, py_messages);
throw;
} }
maybe_msg = error_to_msg(maybe_result.GetError());
} else { } else {
maybe_msg = error_to_msg(py::FetchError()); maybe_msg = error_to_msg(py::FetchError());
} }
} }
if (maybe_msg) { if (maybe_msg) {
static_cast<void>(mgp_result_set_error_msg(result, maybe_msg->c_str())); static_cast<void>(
mgp_func_result_set_error_msg(result, maybe_msg->c_str(), memory)); // No error fetching if this fails
} }
} }
@ -1056,9 +1137,9 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name."); PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
return nullptr; return nullptr;
} }
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
if (!py_proc) return nullptr; if (!py_proc) return nullptr;
py_proc->proc = &proc_it->second; py_proc->callable = &proc_it->second;
return reinterpret_cast<PyObject *>(py_proc); return reinterpret_cast<PyObject *>(py_proc);
} }
} // namespace } // namespace
@ -1100,6 +1181,39 @@ PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyObject *PyQueryModuleAddFunction(PyQueryModule *self, PyObject *cb) {
MG_ASSERT(self->module);
if (!PyCallable_Check(cb)) {
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
return nullptr;
}
auto py_cb = py::Object::FromBorrow(cb);
py::Object py_name(py_cb.GetAttr("__name__"));
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
if (!name) return nullptr;
if (!IsValidIdentifierName(name)) {
PyErr_SetString(PyExc_ValueError, "Function name is not a valid identifier");
return nullptr;
}
auto *memory = self->module->functions.get_allocator().GetMemoryResource();
mgp_func func(
name,
[py_cb](mgp_list *args, mgp_func_context *func_ctx, mgp_func_result *result, mgp_memory *memory) {
auto graph = mgp_graph::NonWritableGraph(*(func_ctx->impl), func_ctx->view);
return CallPythonFunction(py_cb, args, &graph, result, memory);
},
memory);
const auto [func_it, did_insert] = self->module->functions.emplace(name, std::move(func));
if (!did_insert) {
PyErr_SetString(PyExc_ValueError, "Already registered a function with the same name.");
return nullptr;
}
auto *py_func = PyObject_New(PyMagicFunc, &PyMagicFuncType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
if (!py_func) return nullptr;
py_func->callable = &func_it->second;
return reinterpret_cast<PyObject *>(py_func);
}
static PyMethodDef PyQueryModuleMethods[] = { static PyMethodDef PyQueryModuleMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O, {"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
@ -1108,6 +1222,8 @@ static PyMethodDef PyQueryModuleMethods[] = {
"Register a writeable procedure with this module."}, "Register a writeable procedure with this module."},
{"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O, {"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O,
"Register a transformation with this module."}, "Register a transformation with this module."},
{"add_function", reinterpret_cast<PyCFunction>(PyQueryModuleAddFunction), METH_O,
"Register a function with this module."},
{nullptr}, {nullptr},
}; };
@ -1980,6 +2096,7 @@ PyObject *PyInitMgpModule() {
if (!register_type(&PyGraphType, "Graph")) return nullptr; if (!register_type(&PyGraphType, "Graph")) return nullptr;
if (!register_type(&PyEdgeType, "Edge")) return nullptr; if (!register_type(&PyEdgeType, "Edge")) return nullptr;
if (!register_type(&PyQueryProcType, "Proc")) return nullptr; if (!register_type(&PyQueryProcType, "Proc")) return nullptr;
if (!register_type(&PyMagicFuncType, "Func")) return nullptr;
if (!register_type(&PyQueryModuleType, "Module")) return nullptr; if (!register_type(&PyQueryModuleType, "Module")) return nullptr;
if (!register_type(&PyVertexType, "Vertex")) return nullptr; if (!register_type(&PyVertexType, "Vertex")) return nullptr;
if (!register_type(&PyPathType, "Path")) return nullptr; if (!register_type(&PyPathType, "Path")) return nullptr;

View File

@ -6,6 +6,14 @@ add_custom_target(memgraph__e2e__${TARGET_PREFIX}__${FILE_NAME} ALL
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME}) DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME})
endfunction() endfunction()
function(copy_e2e_cpp_files TARGET_PREFIX FILE_NAME)
add_custom_target(memgraph__e2e__${TARGET_PREFIX}__${FILE_NAME} ALL
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME}
${CMAKE_CURRENT_BINARY_DIR}/${FILE_NAME}
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME})
endfunction()
add_subdirectory(replication) add_subdirectory(replication)
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(triggers) add_subdirectory(triggers)
@ -13,6 +21,7 @@ add_subdirectory(isolation_levels)
add_subdirectory(streams) add_subdirectory(streams)
add_subdirectory(temporal_types) add_subdirectory(temporal_types)
add_subdirectory(write_procedures) add_subdirectory(write_procedures)
add_subdirectory(magic_functions)
add_subdirectory(module_file_manager) add_subdirectory(module_file_manager)
add_subdirectory(websocket) add_subdirectory(websocket)

View File

@ -0,0 +1,17 @@
# Set up C++ functions for e2e tests
function(add_query_module target_name src)
add_library(${target_name} SHARED ${src})
SET_TARGET_PROPERTIES(${target_name} PROPERTIES PREFIX "")
target_include_directories(${target_name} PRIVATE ${CMAKE_SOURCE_DIR}/include)
endfunction()
# Set up Python functions for e2e tests
function(copy_magic_functions_e2e_python_files FILE_NAME)
copy_e2e_python_files(functions ${FILE_NAME})
endfunction()
copy_magic_functions_e2e_python_files(common.py)
copy_magic_functions_e2e_python_files(conftest.py)
copy_magic_functions_e2e_python_files(function_example.py)
add_subdirectory(functions)

View File

@ -0,0 +1,35 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgclient
import typing
def execute_and_fetch_all(
cursor: mgclient.Cursor, query: str, params: dict = {}
) -> typing.List[tuple]:
cursor.execute(query, params)
return cursor.fetchall()
def connect(**kwargs) -> mgclient.Connection:
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
connection.autocommit = True
return connection
def has_n_result_row(cursor: mgclient.Cursor, query: str, n: int):
results = execute_and_fetch_all(cursor, query)
return len(results) == n
def has_one_result_row(cursor: mgclient.Cursor, query: str):
return has_n_result_row(cursor, query, 1)

View File

@ -0,0 +1,22 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import pytest
from common import execute_and_fetch_all, connect
@pytest.fixture(autouse=True)
def connection():
connection = connect()
yield connection
cursor = connection.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")

View File

@ -0,0 +1,122 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import typing
import mgclient
import sys
import pytest
from common import execute_and_fetch_all, has_n_result_row
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_argument(connection, function_type):
cursor = connection.cursor()
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
result = execute_and_fetch_all(
cursor,
f"MATCH (n) RETURN {function_type}_read.return_function_argument(n) AS argument;",
)
vertex = result[0][0]
assert isinstance(vertex, mgclient.Node)
assert vertex.labels == set(["Label"])
assert vertex.properties == {"id": 1}
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_optional_argument(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.return_optional_argument(42) AS argument;",
)
result = result[0][0]
assert isinstance(result, int)
assert result == 42
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_optional_argument_no_arg(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.return_optional_argument() AS argument;",
)
result = result[0][0]
assert isinstance(result, int)
assert result == 42
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_add_two_numbers(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.add_two_numbers(1, 5) AS total;",
)
result_sum = result[0][0]
assert isinstance(result_sum, (float, int))
assert result_sum == 6
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_null(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.return_null() AS null;",
)
result_null = result[0][0]
assert result_null is None
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_too_many_arguments(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
# Should raise too many arguments
with pytest.raises(mgclient.DatabaseError):
execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.return_null('parameter') AS null;",
)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_try_to_write(connection, function_type):
cursor = connection.cursor()
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
# Should raise non mutable
with pytest.raises(mgclient.DatabaseError):
execute_and_fetch_all(
cursor,
f"MATCH (n) RETURN {function_type}_write.try_to_write(n, 'property', 1);",
)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_case_sensitivity(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
# Should raise function does not exist
with pytest.raises(mgclient.DatabaseError):
execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.ReTuRn_nUlL('parameter') AS null;",
)
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -0,0 +1,5 @@
copy_magic_functions_e2e_python_files(py_write.py)
copy_magic_functions_e2e_python_files(py_read.py)
add_query_module(c_read c_read.cpp)
add_query_module(c_write c_write.cpp)

View File

@ -0,0 +1,178 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <functional>
#include <stdexcept>
#include "mg_procedure.h"
#include "utils/on_scope_exit.hpp"
namespace {
static void ReturnFunctionArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
auto err_code = mgp_list_at(args, 0, &value);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
return;
}
}
static void ReturnOptionalArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
auto err_code = mgp_list_at(args, 0, &value);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
return;
}
}
double GetElementFromArg(struct mgp_list *args, int index) {
mgp_value *value{nullptr};
if (mgp_list_at(args, index, &value) != MGP_ERROR_NO_ERROR) {
throw std::runtime_error("Error while argument fetching.");
}
double result;
int is_int;
mgp_value_is_int(value, &is_int);
if (is_int) {
int64_t result_int;
mgp_value_get_int(value, &result_int);
result = static_cast<double>(result_int);
} else {
mgp_value_get_double(value, &result);
}
return result;
}
static void AddTwoNumbers(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
double first = 0;
double second = 0;
try {
first = GetElementFromArg(args, 0);
second = GetElementFromArg(args, 1);
} catch (...) {
mgp_func_result_set_error_msg(result, "Unable to fetch the result!", memory);
return;
}
mgp_value *value{nullptr};
auto summation = first + second;
mgp_value_make_double(summation, memory, &value);
memgraph::utils::OnScopeExit delete_summation_value([&value] { mgp_value_destroy(value); });
auto err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
}
}
static void ReturnNull(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
mgp_value_make_null(memory, &value);
memgraph::utils::OnScopeExit delete_null([&value] { mgp_value_destroy(value); });
auto err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
}
}
} // namespace
// Each module needs to define mgp_init_module function.
// Here you can register multiple functions/procedures your module supports.
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "return_function_argument", ReturnFunctionArgument, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_any{nullptr};
mgp_type_any(&type_any);
err_code = mgp_func_add_arg(func, "argument", type_any);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "return_optional_argument", ReturnOptionalArgument, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_value *default_value{nullptr};
mgp_value_make_int(42, memory, &default_value);
memgraph::utils::OnScopeExit delete_summation_value([&default_value] { mgp_value_destroy(default_value); });
mgp_type *type_int{nullptr};
mgp_type_int(&type_int);
err_code = mgp_func_add_opt_arg(func, "opt_argument", type_int, default_value);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "add_two_numbers", AddTwoNumbers, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_number{nullptr};
mgp_type_number(&type_number);
err_code = mgp_func_add_arg(func, "first", type_number);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
err_code = mgp_func_add_arg(func, "second", type_number);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "return_null", ReturnNull, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
return 0;
}
// This is an optional function if you need to release any resources before the
// module is unloaded. You will probably need this if you acquired some
// resources in mgp_init_module.
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,80 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "mg_procedure.h"
static void TryToWrite(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
mgp_vertex *vertex{nullptr};
mgp_list_at(args, 0, &value);
mgp_value_get_vertex(value, &vertex);
const char *name;
mgp_list_at(args, 1, &value);
mgp_value_get_string(value, &name);
mgp_list_at(args, 2, &value);
// Setting a property should set an error
auto err_code = mgp_vertex_set_property(vertex, name, value);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Cannot set property in the function!", memory);
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
return;
}
}
// Each module needs to define mgp_init_module function.
// Here you can register multiple functions/procedures your module supports.
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "try_to_write", TryToWrite, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_vertex{nullptr};
mgp_type_node(&type_vertex);
err_code = mgp_func_add_arg(func, "argument", type_vertex);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_string{nullptr};
mgp_type_string(&type_string);
err_code = mgp_func_add_arg(func, "name", type_string);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *any_type{nullptr};
mgp_type_any(&any_type);
mgp_type *nullable_type{nullptr};
mgp_type_nullable(any_type, &nullable_type);
err_code = mgp_func_add_arg(func, "value", nullable_type);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
return 0;
}
// This is an optional function if you need to release any resources before the
// module is unloaded. You will probably need this if you acquired some
// resources in mgp_init_module.
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,32 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgp
@mgp.function
def return_function_argument(ctx: mgp.FuncCtx, argument: mgp.Any):
return argument
@mgp.function
def return_optional_argument(ctx: mgp.FuncCtx, opt_argument: mgp.Number = 42):
return opt_argument
@mgp.function
def add_two_numbers(ctx: mgp.FuncCtx, first: mgp.Number, second: mgp.Number):
return first + second
@mgp.function
def return_null(ctx: mgp.FuncCtx):
return None

View File

@ -0,0 +1,17 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgp
@mgp.function
def try_to_write(ctx: mgp.FuncCtx, argument: mgp.Vertex, name: str, value: mgp.Nullable[mgp.Any]):
argument.properties.set(name, value)

View File

@ -0,0 +1,14 @@
template_cluster: &template_cluster
cluster:
main:
args: ["--bolt-port", "7687", "--log-level=TRACE"]
log_file: "magic-functions-e2e.log"
setup_queries: []
validation_queries: []
workloads:
- name: "Magic functions runner"
binary: "tests/e2e/pytest_runner.sh"
proc: "tests/e2e/magic_functions/functions/"
args: ["magic_functions/function_example.py"]
<<: *template_cluster

View File

@ -129,6 +129,12 @@ target_link_libraries(${test_prefix}query_serialization_property_value mg-query)
add_unit_test(query_streams.cpp) add_unit_test(query_streams.cpp)
target_link_libraries(${test_prefix}query_streams mg-query kafka-mock) target_link_libraries(${test_prefix}query_streams mg-query kafka-mock)
# Test query functions
add_unit_test(query_function_mgp_module.cpp)
target_link_libraries(${test_prefix}query_function_mgp_module mg-query)
target_include_directories(${test_prefix}query_function_mgp_module PRIVATE ${CMAKE_SOURCE_DIR}/include)
# Test query/procedure # Test query/procedure
add_unit_test(query_procedure_mgp_type.cpp) add_unit_test(query_procedure_mgp_type.cpp)
target_link_libraries(${test_prefix}query_procedure_mgp_type mg-query) target_link_libraries(${test_prefix}query_procedure_mgp_type mg-query)

View File

@ -211,13 +211,19 @@ class MockModule : public procedure::Module {
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override { return &procedures; } const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override { return &procedures; }
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override { return &transformations; } const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override { return &transformations; }
std::optional<std::filesystem::path> Path() const override { return std::nullopt; }
const std::map<std::string, mgp_func, std::less<>> *Functions() const override { return &functions; }
std::optional<std::filesystem::path> Path() const override { return std::nullopt; };
std::map<std::string, mgp_proc, std::less<>> procedures{}; std::map<std::string, mgp_proc, std::less<>> procedures{};
std::map<std::string, mgp_trans, std::less<>> transformations{}; std::map<std::string, mgp_trans, std::less<>> transformations{};
std::map<std::string, mgp_func, std::less<>> functions{};
}; };
void DummyProcCallback(mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, mgp_memory * /*memory*/){}; void DummyProcCallback(mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, mgp_memory * /*memory*/){};
void DummyFuncCallback(mgp_list * /*args*/, mgp_func_context * /*func_ctx*/, mgp_func_result * /*result*/,
mgp_memory * /*memory*/){};
enum class ProcedureType { WRITE, READ }; enum class ProcedureType { WRITE, READ };
@ -258,6 +264,15 @@ class CypherMainVisitorTest : public ::testing::TestWithParam<std::shared_ptr<Ba
module.procedures.emplace(name, std::move(proc)); module.procedures.emplace(name, std::move(proc));
} }
static void AddFunc(MockModule &module, const char *name, const std::vector<std::string_view> &args) {
memgraph::utils::MemoryResource *memory = memgraph::utils::NewDeleteResource();
mgp_func func(name, DummyFuncCallback, memory);
for (const auto arg : args) {
func.args.emplace_back(memgraph::utils::pmr::string{arg, memory}, &any_type);
}
module.functions.emplace(name, std::move(func));
}
std::string CreateProcByType(const ProcedureType type, const std::vector<std::string_view> &args) { std::string CreateProcByType(const ProcedureType type, const std::vector<std::string_view> &args) {
const auto proc_name = std::string{"proc_"} + ToString(type); const auto proc_name = std::string{"proc_"} + ToString(type);
SCOPED_TRACE(proc_name); SCOPED_TRACE(proc_name);
@ -858,6 +873,12 @@ TEST_P(CypherMainVisitorTest, UndefinedFunction) {
SemanticException); SemanticException);
} }
TEST_P(CypherMainVisitorTest, MissingFunction) {
AddFunc(*mock_module, "get", {});
auto &ast_generator = *GetParam();
ASSERT_THROW(ast_generator.ParseQuery("RETURN missing_function.get()"), SemanticException);
}
TEST_P(CypherMainVisitorTest, Function) { TEST_P(CypherMainVisitorTest, Function) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN abs(n, 2)")); auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN abs(n, 2)"));
@ -871,6 +892,20 @@ TEST_P(CypherMainVisitorTest, Function) {
ASSERT_TRUE(function->function_); ASSERT_TRUE(function->function_);
} }
TEST_P(CypherMainVisitorTest, MagicFunction) {
AddFunc(*mock_module, "get", {});
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN mock_module.get()"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 1);
auto *function = dynamic_cast<Function *>(return_clause->body_.named_expressions[0]->expression_);
ASSERT_TRUE(function);
ASSERT_TRUE(function->function_);
}
TEST_P(CypherMainVisitorTest, StringLiteralDoubleQuotes) { TEST_P(CypherMainVisitorTest, StringLiteralDoubleQuotes) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN \"mi'rko\"")); auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN \"mi'rko\""));

View File

@ -0,0 +1,49 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gtest/gtest.h>
#include <functional>
#include <sstream>
#include <string_view>
#include "query/procedure/mg_procedure_impl.hpp"
#include "test_utils.hpp"
static void DummyCallback(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *){};
TEST(Module, InvalidFunctionRegistration) {
mgp_module module(memgraph::utils::NewDeleteResource());
mgp_func *func{nullptr};
// Other test cases are covered within the procedure API. This is only sanity check
EXPECT_EQ(mgp_module_add_function(&module, "dashes-not-supported", DummyCallback, &func), MGP_ERROR_INVALID_ARGUMENT);
}
TEST(Module, RegisterSameFunctionMultipleTimes) {
mgp_module module(memgraph::utils::NewDeleteResource());
mgp_func *func{nullptr};
EXPECT_EQ(module.functions.find("same_name"), module.functions.end());
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_NO_ERROR);
EXPECT_NE(module.functions.find("same_name"), module.functions.end());
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_LOGIC_ERROR);
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_LOGIC_ERROR);
EXPECT_NE(module.functions.find("same_name"), module.functions.end());
}
TEST(Module, CaseSensitiveFunctionNames) {
mgp_module module(memgraph::utils::NewDeleteResource());
mgp_func *func{nullptr};
EXPECT_EQ(mgp_module_add_function(&module, "not_same", DummyCallback, &func), MGP_ERROR_NO_ERROR);
EXPECT_EQ(mgp_module_add_function(&module, "NoT_saME", DummyCallback, &func), MGP_ERROR_NO_ERROR);
EXPECT_EQ(mgp_module_add_function(&module, "NOT_SAME", DummyCallback, &func), MGP_ERROR_NO_ERROR);
EXPECT_EQ(module.functions.size(), 3U);
}