Memgraph magic functions (#345)

* Extend mgp_module with include adding functions

* Add return type to the function API

* Change Cypher grammar

* Add Python support for functions

* Implement error handling

* E2e tests for functions

* Write cpp e2e functions

* Create mg.functions() procedure

* Implement case insensitivity for user-defined Magic Functions.
This commit is contained in:
Josip Matak 2022-04-21 15:45:31 +02:00 committed by GitHub
parent ea2806bd57
commit 4abaf27765
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1523 additions and 246 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2021 Memgraph Ltd.
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -549,6 +549,8 @@ enum mgp_error mgp_path_equal(struct mgp_path *p1, struct mgp_path *p2, int *res
struct mgp_result;
/// Represents a record of resulting field values.
struct mgp_result_record;
/// Represents a return type for magic functions
struct mgp_func_result;
/// Set the error as the result of the procedure.
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE ff there's no memory for copying the error message.
@ -1290,6 +1292,9 @@ struct mgp_module;
/// Describes a procedure of a query module.
struct mgp_proc;
/// Describes a Memgraph magic function.
struct mgp_func;
/// Entry-point for a query module read procedure, invoked through openCypher.
///
/// Passed in arguments will not live longer than the callback's execution.
@ -1502,6 +1507,84 @@ typedef void (*mgp_trans_cb)(struct mgp_messages *, struct mgp_graph *, struct m
enum mgp_error mgp_module_add_transformation(struct mgp_module *module, const char *name, mgp_trans_cb cb);
/// @}
/// @name Memgraph Magic Functions API
///
/// API for creating the Memgraph magic functions. It is used to create external-source stateless methods which can
/// be called by using openCypher query language. These methods should not modify the original graph and should use only
/// the values provided as arguments to the method.
///
///@{
/// Add a required argument to a function.
///
/// The order of the added arguments corresponds to the signature of the openCypher function.
/// Note, that required arguments are followed by optional arguments.
///
/// The `name` must be a valid identifier, following the same rules as the
/// function `name` in mgp_module_add_function.
///
/// Passed in `type` describes what kind of values can be used as the argument.
///
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for an argument.
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid argument name.
/// Return MGP_ERROR_LOGIC_ERROR if the function already has any optional argument.
enum mgp_error mgp_func_add_arg(struct mgp_func *func, const char *name, struct mgp_type *type);
/// Add an optional argument with a default value to a function.
///
/// The order of the added arguments corresponds to the signature of the openCypher function.
/// Note, that required arguments are followed by optional arguments.
///
/// The `name` must be a valid identifier, following the same rules as the
/// function `name` in mgp_module_add_function.
///
/// Passed in `type` describes what kind of values can be used as the argument.
///
/// `default_value` is copied and set as the default value for the argument.
/// Don't forget to call mgp_value_destroy when you are done using
/// `default_value`. When the function is called, if this argument is not
/// provided, `default_value` will be used instead. `default_value` must not be
/// a graph element (node, relationship, path) and it must satisfy the given
/// `type`.
///
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for an argument.
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid argument name.
/// Return MGP_ERROR_VALUE_CONVERSION if `default_value` is a graph element (vertex, edge or path).
/// Return MGP_ERROR_LOGIC_ERROR if `default_value` does not satisfy `type`.
enum mgp_error mgp_func_add_opt_arg(struct mgp_func *func, const char *name, struct mgp_type *type,
struct mgp_value *default_value);
/// Entry-point for a custom Memgraph awesome function.
///
/// Passed in arguments will not live longer than the callback's execution.
/// Therefore, you must not store them globally or use the passed in mgp_memory
/// to allocate global resources.
typedef void (*mgp_func_cb)(struct mgp_list *, struct mgp_func_context *, struct mgp_func_result *,
struct mgp_memory *);
/// Register a Memgraph magic function
///
/// The `name` must be a sequence of digits, underscores, lowercase and
/// uppercase Latin letters. The name must begin with a non-digit character.
/// Note that Unicode characters are not allowed.
///
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for mgp_func.
/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid function name.
/// RETURN MGP_ERROR_LOGIC_ERROR if a function with the same name was already registered.
enum mgp_error mgp_module_add_function(struct mgp_module *module, const char *name, mgp_func_cb cb,
struct mgp_func **result);
/// Set an error message as an output to the Magic function
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if there's no memory for copying the error message.
enum mgp_error mgp_func_result_set_error_msg(struct mgp_func_result *result, const char *error_msg,
struct mgp_memory *memory);
/// Set an output value for the Magic function
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory to copy the mgp_value to mgp_func_result.
enum mgp_error mgp_func_result_set_value(struct mgp_func_result *result, struct mgp_value *value,
struct mgp_memory *memory);
/// @}
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -40,6 +40,7 @@ class InvalidContextError(Exception):
"""
Signals using a graph element instance outside of the registered procedure.
"""
pass
@ -47,6 +48,7 @@ class UnknownError(_mgp.UnknownError):
"""
Signals unspecified failure.
"""
pass
@ -54,6 +56,7 @@ class UnableToAllocateError(_mgp.UnableToAllocateError):
"""
Signals failed memory allocation.
"""
pass
@ -61,6 +64,7 @@ class InsufficientBufferError(_mgp.InsufficientBufferError):
"""
Signals that some buffer is not big enough.
"""
pass
@ -69,6 +73,7 @@ class OutOfRangeError(_mgp.OutOfRangeError):
Signals that an index-like parameter has a value that is outside its
possible values.
"""
pass
@ -77,6 +82,7 @@ class LogicErrorError(_mgp.LogicErrorError):
Signals faulty logic within the program such as violating logical
preconditions or class invariants and may be preventable.
"""
pass
@ -84,6 +90,7 @@ class DeletedObjectError(_mgp.DeletedObjectError):
"""
Signals accessing an already deleted object.
"""
pass
@ -91,6 +98,7 @@ class InvalidArgumentError(_mgp.InvalidArgumentError):
"""
Signals that some of the arguments have invalid values.
"""
pass
@ -98,6 +106,7 @@ class KeyAlreadyExistsError(_mgp.KeyAlreadyExistsError):
"""
Signals that a key already exists in a container-like object.
"""
pass
@ -105,6 +114,7 @@ class ImmutableObjectError(_mgp.ImmutableObjectError):
"""
Signals modification of an immutable object.
"""
pass
@ -112,6 +122,7 @@ class ValueConversionError(_mgp.ValueConversionError):
"""
Signals that the conversion failed between python and cypher values.
"""
pass
@ -120,12 +131,14 @@ class SerializationError(_mgp.SerializationError):
Signals serialization error caused by concurrent modifications from
different transactions.
"""
pass
class Label:
"""Label of a Vertex."""
__slots__ = ('_name',)
__slots__ = ("_name",)
def __init__(self, name: str):
self._name = name
@ -145,19 +158,22 @@ class Label:
# Named property value of a Vertex or an Edge.
# It would be better to use typing.NamedTuple with typed fields, but that is
# not available in Python 3.5.
Property = namedtuple('Property', ('name', 'value'))
Property = namedtuple("Property", ("name", "value"))
class Properties:
"""
A collection of properties either on a Vertex or an Edge.
"""
__slots__ = ('_vertex_or_edge', '_len',)
__slots__ = (
"_vertex_or_edge",
"_len",
)
def __init__(self, vertex_or_edge):
if not isinstance(vertex_or_edge, (_mgp.Vertex, _mgp.Edge)):
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Edge', \
got {}".format(type(vertex_or_edge)))
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Edge', got {}".format(type(vertex_or_edge)))
self._len = None
self._vertex_or_edge = vertex_or_edge
@ -330,7 +346,8 @@ class Properties:
class EdgeType:
"""Type of an Edge."""
__slots__ = ('_name',)
__slots__ = ("_name",)
def __init__(self, name):
self._name = name
@ -348,7 +365,7 @@ class EdgeType:
if sys.version_info >= (3, 5, 2):
EdgeId = typing.NewType('EdgeId', int)
EdgeId = typing.NewType("EdgeId", int)
else:
EdgeId = int
@ -360,12 +377,12 @@ class Edge:
a query. You should not globally store an instance of an Edge. Using an
invalid Edge instance will raise InvalidContextError.
"""
__slots__ = ('_edge',)
__slots__ = ("_edge",)
def __init__(self, edge):
if not isinstance(edge, _mgp.Edge):
raise TypeError(
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
self._edge = edge
def __deepcopy__(self, memo):
@ -408,7 +425,7 @@ class Edge:
return EdgeType(self._edge.get_type_name())
@property
def from_vertex(self) -> 'Vertex':
def from_vertex(self) -> "Vertex":
"""
Get the source vertex.
@ -419,7 +436,7 @@ class Edge:
return Vertex(self._edge.from_vertex())
@property
def to_vertex(self) -> 'Vertex':
def to_vertex(self) -> "Vertex":
"""
Get the destination vertex.
@ -453,7 +470,7 @@ class Edge:
if sys.version_info >= (3, 5, 2):
VertexId = typing.NewType('VertexId', int)
VertexId = typing.NewType("VertexId", int)
else:
VertexId = int
@ -465,12 +482,12 @@ class Vertex:
in a query. You should not globally store an instance of a Vertex. Using an
invalid Vertex instance will raise InvalidContextError.
"""
__slots__ = ('_vertex',)
__slots__ = ("_vertex",)
def __init__(self, vertex):
if not isinstance(vertex, _mgp.Vertex):
raise TypeError(
"Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
raise TypeError("Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
self._vertex = vertex
def __deepcopy__(self, memo):
@ -513,8 +530,7 @@ class Vertex:
"""
if not self.is_valid():
raise InvalidContextError()
return tuple(Label(self._vertex.label_at(i))
for i in range(self._vertex.labels_count()))
return tuple(Label(self._vertex.label_at(i)) for i in range(self._vertex.labels_count()))
def add_label(self, label: str) -> None:
"""
@ -615,7 +631,8 @@ class Vertex:
class Path:
"""Path containing Vertex and Edge instances."""
__slots__ = ('_path', '_vertices', '_edges')
__slots__ = ("_path", "_vertices", "_edges")
def __init__(self, starting_vertex_or_path: typing.Union[_mgp.Path, Vertex]):
"""Initialize with a starting Vertex.
@ -636,8 +653,7 @@ class Path:
raise InvalidContextError()
self._path = _mgp.Path.make_with_start(vertex)
else:
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Path', got '{}'"
.format(type(starting_vertex_or_path)))
raise TypeError("Expected '_mgp.Vertex' or '_mgp.Path', got '{}'".format(type(starting_vertex_or_path)))
def __copy__(self):
if not self.is_valid():
@ -678,8 +694,7 @@ class Path:
extension.
"""
if not isinstance(edge, Edge):
raise TypeError(
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
if not self.is_valid() or not edge.is_valid():
raise InvalidContextError()
self._path.expand(edge._edge)
@ -698,8 +713,7 @@ class Path:
raise InvalidContextError()
if self._vertices is None:
num_vertices = self._path.size() + 1
self._vertices = tuple(Vertex(self._path.vertex_at(i))
for i in range(num_vertices))
self._vertices = tuple(Vertex(self._path.vertex_at(i)) for i in range(num_vertices))
return self._vertices
@property
@ -713,14 +727,14 @@ class Path:
raise InvalidContextError()
if self._edges is None:
num_edges = self._path.size()
self._edges = tuple(Edge(self._path.edge_at(i))
for i in range(num_edges))
self._edges = tuple(Edge(self._path.edge_at(i)) for i in range(num_edges))
return self._edges
class Record:
"""Represents a record of resulting field values."""
__slots__ = ('fields',)
__slots__ = ("fields",)
def __init__(self, **kwargs):
"""Initialize with name=value fields in kwargs."""
@ -729,12 +743,12 @@ class Record:
class Vertices:
"""Iterable over vertices in a graph."""
__slots__ = ('_graph', '_len')
__slots__ = ("_graph", "_len")
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError(
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = graph
self._len = None
@ -791,12 +805,12 @@ class Vertices:
class Graph:
"""State of the graph database in current ProcCtx."""
__slots__ = ('_graph',)
__slots__ = ("_graph",)
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError(
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = graph
def __deepcopy__(self, memo):
@ -885,8 +899,7 @@ class Graph:
raise InvalidContextError()
self._graph.detach_delete_vertex(vertex._vertex)
def create_edge(self, from_vertex: Vertex, to_vertex: Vertex,
edge_type: EdgeType) -> None:
def create_edge(self, from_vertex: Vertex, to_vertex: Vertex, edge_type: EdgeType) -> None:
"""
Create an edge.
@ -899,8 +912,7 @@ class Graph:
"""
if not self.is_valid():
raise InvalidContextError()
return Edge(self._graph.create_edge(from_vertex._vertex,
to_vertex._vertex, edge_type.name))
return Edge(self._graph.create_edge(from_vertex._vertex, to_vertex._vertex, edge_type.name))
def delete_edge(self, edge: Edge) -> None:
"""
@ -918,6 +930,7 @@ class Graph:
class AbortError(Exception):
"""Signals that the procedure was asked to abort its execution."""
pass
@ -927,12 +940,12 @@ class ProcCtx:
Access to a ProcCtx is only valid during a single execution of a procedure
in a query. You should not globally store a ProcCtx instance.
"""
__slots__ = ('_graph',)
__slots__ = ("_graph",)
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError(
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph)
def is_valid(self) -> bool:
@ -969,8 +982,7 @@ LocalDateTime = datetime.datetime
Duration = datetime.timedelta
Any = typing.Union[bool, str, Number, Map, Path,
list, Date, LocalTime, LocalDateTime, Duration]
Any = typing.Union[bool, str, Number, Map, Path, list, Date, LocalTime, LocalDateTime, Duration]
List = typing.List
@ -1003,7 +1015,7 @@ def _typing_to_cypher_type(type_):
Date: _mgp.type_date(),
LocalTime: _mgp.type_local_time(),
LocalDateTime: _mgp.type_local_date_time(),
Duration: _mgp.type_duration()
Duration: _mgp.type_duration(),
}
try:
return simple_types[type_]
@ -1021,14 +1033,14 @@ def _typing_to_cypher_type(type_):
if type(None) in type_args:
types = tuple(t for t in type_args if t is not type(None)) # noqa E721
if len(types) == 1:
type_arg, = types
(type_arg,) = types
else:
# We cannot do typing.Union[*types], so do the equivalent
# with __getitem__ which does not even need arg unpacking.
type_arg = typing.Union.__getitem__(types)
return _mgp.type_nullable(_typing_to_cypher_type(type_arg))
elif complex_type == list:
type_arg, = type_args
(type_arg,) = type_args
return _mgp.type_list(_typing_to_cypher_type(type_arg))
raise UnsupportedTypingError(type_)
else:
@ -1038,13 +1050,17 @@ def _typing_to_cypher_type(type_):
# printed the same way. `typing.List[type]` is printed as such, while
# `typing.Optional[type]` is printed as 'typing.Union[type, NoneType]'
def parse_type_args(type_as_str):
return tuple(map(str.strip,
type_as_str[type_as_str.index('[') + 1: -1].split(',')))
return tuple(
map(
str.strip,
type_as_str[type_as_str.index("[") + 1 : -1].split(","),
)
)
def fully_qualified_name(cls):
if cls.__module__ is None or cls.__module__ == 'builtins':
if cls.__module__ is None or cls.__module__ == "builtins":
return cls.__name__
return cls.__module__ + '.' + cls.__name__
return cls.__module__ + "." + cls.__name__
def get_simple_type(type_as_str):
for simple_type, cypher_type in simple_types.items():
@ -1060,28 +1076,26 @@ def _typing_to_cypher_type(type_):
pass
def parse_typing(type_as_str):
if type_as_str.startswith('typing.Union'):
if type_as_str.startswith("typing.Union"):
type_args_as_str = parse_type_args(type_as_str)
none_type_as_str = type(None).__name__
if none_type_as_str in type_args_as_str:
types = tuple(
t for t in type_args_as_str if t != none_type_as_str)
types = tuple(t for t in type_args_as_str if t != none_type_as_str)
if len(types) == 1:
type_arg_as_str, = types
(type_arg_as_str,) = types
else:
type_arg_as_str = 'typing.Union[' + \
', '.join(types) + ']'
type_arg_as_str = "typing.Union[" + ", ".join(types) + "]"
simple_type = get_simple_type(type_arg_as_str)
if simple_type is not None:
return _mgp.type_nullable(simple_type)
return _mgp.type_nullable(parse_typing(type_arg_as_str))
elif type_as_str.startswith('typing.List'):
elif type_as_str.startswith("typing.List"):
type_arg_as_str = parse_type_args(type_as_str)
if len(type_arg_as_str) > 1:
# Nested object could be a type consisting of a list of types (e.g. mgp.Map)
# so we need to join the parts.
type_arg_as_str = ', '.join(type_arg_as_str)
type_arg_as_str = ", ".join(type_arg_as_str)
else:
type_arg_as_str = type_arg_as_str[0]
@ -1096,9 +1110,11 @@ def _typing_to_cypher_type(type_):
# Procedure registration
class Deprecated:
"""Annotate a resulting Record's field as deprecated."""
__slots__ = ('field_type',)
__slots__ = ("field_type",)
def __init__(self, type_):
self.field_type = type_
@ -1106,8 +1122,7 @@ class Deprecated:
def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
if not callable(func):
raise TypeError("Expected a callable object, got an instance of '{}'"
.format(type(func)))
raise TypeError("Expected a callable object, got an instance of '{}'".format(type(func)))
if inspect.iscoroutinefunction(func):
raise TypeError("Callable must not be 'async def' function")
if sys.version_info >= (3, 6):
@ -1117,24 +1132,25 @@ def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
raise NotImplementedError("Generator functions are not supported")
def _register_proc(func: typing.Callable[..., Record],
is_write: bool):
def _register_proc(func: typing.Callable[..., Record], is_write: bool):
raise_if_does_not_meet_requirements(func)
register_func = (
_mgp.Module.add_write_procedure if is_write
else _mgp.Module.add_read_procedure)
register_func = _mgp.Module.add_write_procedure if is_write else _mgp.Module.add_read_procedure
sig = inspect.signature(func)
params = tuple(sig.parameters.values())
if params and params[0].annotation is ProcCtx:
@wraps(func)
def wrapper(graph, args):
return func(ProcCtx(graph), *args)
params = params[1:]
mgp_proc = register_func(_mgp._MODULE, wrapper)
else:
@wraps(func)
def wrapper(graph, args):
return func(*args)
mgp_proc = register_func(_mgp._MODULE, wrapper)
for param in params:
name = param.name
@ -1149,8 +1165,7 @@ def _register_proc(func: typing.Callable[..., Record],
if sig.return_annotation is not sig.empty:
record = sig.return_annotation
if not isinstance(record, Record):
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'"
.format(func.__name__, type(record)))
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'".format(func.__name__, type(record)))
for name, type_ in record.fields.items():
if isinstance(type_, Deprecated):
cypher_type = _typing_to_cypher_type(type_.field_type)
@ -1257,20 +1272,22 @@ class InvalidMessageError(Exception):
"""
Signals using a message instance outside of the registered transformation.
"""
pass
SOURCE_TYPE_KAFKA = _mgp.SOURCE_TYPE_KAFKA
SOURCE_TYPE_PULSAR = _mgp.SOURCE_TYPE_PULSAR
class Message:
"""Represents a message from a stream."""
__slots__ = ('_message',)
__slots__ = ("_message",)
def __init__(self, message):
if not isinstance(message, _mgp.Message):
raise TypeError(
"Expected '_mgp.Message', got '{}'".format(type(message)))
raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message)))
self._message = message
def __deepcopy__(self, memo):
@ -1353,17 +1370,18 @@ class Message:
class InvalidMessagesError(Exception):
"""Signals using a messages instance outside of the registered transformation."""
pass
class Messages:
"""Represents a list of messages from a stream."""
__slots__ = ('_messages',)
__slots__ = ("_messages",)
def __init__(self, messages):
if not isinstance(messages, _mgp.Messages):
raise TypeError(
"Expected '_mgp.Messages', got '{}'".format(type(messages)))
raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages)))
self._messages = messages
def __deepcopy__(self, memo):
@ -1395,12 +1413,12 @@ class TransCtx:
Access to a TransCtx is only valid during a single execution of a transformation.
You should not globally store a TransCtx instance.
"""
__slots__ = ('_graph')
__slots__ = "_graph"
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError(
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph)
def is_valid(self) -> bool:
@ -1420,21 +1438,76 @@ def transformation(func: typing.Callable[..., Record]):
params = tuple(sig.parameters.values())
if not params or not params[0].annotation is Messages:
if not len(params) == 2 or not params[1].annotation is Messages:
raise NotImplementedError(
"Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
if params[0].annotation is TransCtx:
@wraps(func)
def wrapper(graph, messages):
return func(TransCtx(graph), messages)
_mgp._MODULE.add_transformation(wrapper)
else:
@wraps(func)
def wrapper(graph, messages):
return func(messages)
_mgp._MODULE.add_transformation(wrapper)
return func
class FuncCtx:
"""Context of a function being executed.
Access to a FuncCtx is only valid during a single execution of a transformation.
You should not globally store a FuncCtx instance.
"""
__slots__ = "_graph"
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph)
def is_valid(self) -> bool:
return self._graph.is_valid()
def function(func: typing.Callable):
raise_if_does_not_meet_requirements(func)
register_func = _mgp.Module.add_function
sig = inspect.signature(func)
params = tuple(sig.parameters.values())
if params and params[0].annotation is FuncCtx:
@wraps(func)
def wrapper(graph, args):
return func(FuncCtx(graph), *args)
params = params[1:]
mgp_func = register_func(_mgp._MODULE, wrapper)
else:
@wraps(func)
def wrapper(graph, args):
return func(*args)
mgp_func = register_func(_mgp._MODULE, wrapper)
for param in params:
name = param.name
type_ = param.annotation
if type_ is param.empty:
type_ = object
cypher_type = _typing_to_cypher_type(type_)
if param.default is param.empty:
mgp_func.add_arg(name, cypher_type)
else:
mgp_func.add_opt_arg(name, cypher_type, param.default)
return func
def _wrap_exceptions():
def wrap_function(func):
@wraps(func)
@ -1463,6 +1536,7 @@ def _wrap_exceptions():
raise ValueConversionError(e)
except _mgp.SerializationError as e:
raise SerializationError(e)
return wrapped_func
def wrap_prop_func(func):
@ -1473,11 +1547,16 @@ def _wrap_exceptions():
if inspect.isfunction(obj):
setattr(cls, name, wrap_function(obj))
elif isinstance(obj, property):
setattr(cls, name, property(
wrap_prop_func(obj.fget),
wrap_prop_func(obj.fset),
wrap_prop_func(obj.fdel),
obj.__doc__))
setattr(
cls,
name,
property(
wrap_prop_func(obj.fget),
wrap_prop_func(obj.fset),
wrap_prop_func(obj.fdel),
obj.__doc__,
),
)
def defined_in_this_module(obj: object):
return getattr(obj, "__module__", "") == __name__

View File

@ -852,7 +852,9 @@ cpp<#
: arguments_(arguments),
function_name_(function_name),
function_(NameToFunction(function_name_)) {
DMG_ASSERT(function_, "Unexpected missing function: {}", function_name_);
if (!function_) {
throw SemanticException("Function '{}' doesn't exist.", function_name);
}
}
cpp<#)
(:private

View File

@ -2109,13 +2109,30 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::Functio
storage_->Create<Aggregation>(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP));
}
auto function = NameToFunction(function_name);
if (!function) throw SemanticException("Function '{}' doesn't exist.", function_name);
auto is_user_defined_function = [](const std::string &function_name) {
// Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined
// functions. Builtin functions should be case insensitive.
return function_name.find('.') != std::string::npos;
};
// Don't cache queries which call user-defined functions. User-defined function's return
// types can vary depending on whether the module is reloaded, therefore the cache would
// be invalid.
if (is_user_defined_function(function_name)) {
query_info_.is_cacheable = false;
}
return static_cast<Expression *>(storage_->Create<Function>(function_name, expressions));
}
antlrcpp::Any CypherMainVisitor::visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) {
return utils::ToUpperCase(ctx->getText());
auto function_name = ctx->getText();
// Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined
// functions. Builtin functions should be case insensitive.
if (function_name.find('.') != std::string::npos) {
return function_name;
}
return utils::ToUpperCase(function_name);
}
antlrcpp::Any CypherMainVisitor::visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) {

View File

@ -279,7 +279,7 @@ idInColl : variable IN expression ;
functionInvocation : functionName '(' ( DISTINCT )? ( expression ( ',' expression )* )? ')' ;
functionName : symbolicName ;
functionName : symbolicName ( '.' symbolicName )* ;
listComprehension : '[' filterExpression ( '|' expression )? ']' ;

View File

@ -22,6 +22,9 @@
#include "query/db_accessor.hpp"
#include "query/exceptions.hpp"
#include "query/procedure/cypher_types.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/procedure/module.hpp"
#include "query/typed_value.hpp"
#include "utils/string.hpp"
#include "utils/temporal.hpp"
@ -1174,6 +1177,53 @@ TypedValue Duration(const TypedValue *args, int64_t nargs, const FunctionContext
MapNumericParameters<Number>(parameter_mappings, args[0].ValueMap());
return TypedValue(utils::Duration(duration_parameters), ctx.memory);
}
std::function<TypedValue(const TypedValue *, const int64_t, const FunctionContext &)> UserFunction(
const mgp_func &func, const std::string &fully_qualified_name) {
return [func, fully_qualified_name](const TypedValue *args, int64_t nargs, const FunctionContext &ctx) -> TypedValue {
/// Find function is called to aquire the lock on Module pointer while user-defined function is executed
const auto &maybe_found =
procedure::FindFunction(procedure::gModuleRegistry, fully_qualified_name, utils::NewDeleteResource());
if (!maybe_found) {
throw QueryRuntimeException(
"Function '{}' has been unloaded. Please check query modules to confirm that function is loaded in Memgraph.",
fully_qualified_name);
}
/// Explicit extraction of module pointer, to clearly state that the lock is aquired.
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
const auto &module_ptr = (*maybe_found).first;
const auto &func_cb = func.cb;
mgp_memory memory{ctx.memory};
mgp_func_context functx{ctx.db_accessor, ctx.view};
auto graph = mgp_graph::NonWritableGraph(*ctx.db_accessor, ctx.view);
std::vector<TypedValue> args_list;
args_list.reserve(nargs);
for (std::size_t i = 0; i < nargs; ++i) {
args_list.emplace_back(args[i]);
}
auto function_argument_list = mgp_list(ctx.memory);
procedure::ConstructArguments(args_list, func, fully_qualified_name, function_argument_list, graph);
mgp_func_result maybe_res;
func_cb(&function_argument_list, &functx, &maybe_res, &memory);
if (maybe_res.error_msg) {
throw QueryRuntimeException(*maybe_res.error_msg);
}
if (!maybe_res.value) {
throw QueryRuntimeException(
"Function '{}' didn't set the result nor the error message. Please either set the result by using "
"mgp_func_result_set_value or the error by using mgp_func_result_set_error_msg.",
fully_qualified_name);
}
return {*(maybe_res.value), ctx.memory};
};
}
} // namespace
std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx)> NameToFunction(
@ -1259,6 +1309,14 @@ std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx
if (function_name == "LOCALDATETIME") return LocalDateTime;
if (function_name == "DURATION") return Duration;
const auto &maybe_found =
procedure::FindFunction(procedure::gModuleRegistry, function_name, utils::NewDeleteResource());
if (maybe_found) {
const auto *func = (*maybe_found).second;
return UserFunction(*func, function_name);
}
return nullptr;
}

View File

@ -3705,46 +3705,12 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
"containers aware of that");
// Build and type check procedure arguments.
mgp_list proc_args(memory);
proc_args.elems.reserve(args.size());
if (args.size() < proc.args.size() ||
// Rely on `||` short circuit so we can avoid potential overflow of
// proc.args.size() + proc.opt_args.size() by subtracting.
(args.size() - proc.args.size() > proc.opt_args.size())) {
if (proc.args.empty() && proc.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_procedure_name);
} else if (proc.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_procedure_name, proc.args.size(),
proc.args.size() == 1U ? "argument" : "arguments");
} else {
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_procedure_name,
proc.args.size(), proc.args.size() + proc.opt_args.size());
}
}
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args[i]->Accept(*evaluator);
std::string_view name;
const query::procedure::CypherType *type{nullptr};
if (proc.args.size() > i) {
name = proc.args[i].first;
type = proc.args[i].second;
} else {
MG_ASSERT(proc.opt_args.size() > i - proc.args.size());
name = std::get<0>(proc.opt_args[i - proc.args.size()]);
type = std::get<1>(proc.opt_args[i - proc.args.size()]);
}
if (!type->SatisfiesType(arg)) {
throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.",
fully_qualified_procedure_name, name, i, type->GetPresentableName());
}
proc_args.elems.emplace_back(std::move(arg), &graph);
}
// Fill missing optional arguments with their default values.
MG_ASSERT(args.size() >= proc.args.size());
size_t passed_in_opt_args = args.size() - proc.args.size();
MG_ASSERT(passed_in_opt_args <= proc.opt_args.size());
for (size_t i = passed_in_opt_args; i < proc.opt_args.size(); ++i) {
proc_args.elems.emplace_back(std::get<2>(proc.opt_args[i]), &graph);
std::vector<TypedValue> args_list;
args_list.reserve(args.size());
for (auto *expression : args) {
args_list.emplace_back(expression->Accept(*evaluator));
}
procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph);
if (memory_limit) {
SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name,
utils::GetReadableSize(*memory_limit));
@ -3832,7 +3798,7 @@ class CallProcedureCursor : public Cursor {
// generator like procedures which yield a new result on each invocation.
auto *memory = context.evaluation_context.memory;
auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_);
mgp_graph graph{context.db_accessor, graph_view, &context};
auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context);
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
&result_);

View File

@ -187,7 +187,10 @@ template <typename TFunc, typename... Args>
return MGP_ERROR_NO_ERROR;
}
bool MgpGraphIsMutable(const mgp_graph &graph) noexcept { return graph.view == memgraph::storage::View::NEW; }
// Graph mutations
bool MgpGraphIsMutable(const mgp_graph &graph) noexcept {
return graph.view == memgraph::storage::View::NEW && graph.ctx != nullptr;
}
bool MgpVertexIsMutable(const mgp_vertex &vertex) { return MgpGraphIsMutable(*vertex.graph); }
@ -289,6 +292,7 @@ mgp_value_type FromTypedValueType(memgraph::query::TypedValue::Type type) {
return MGP_VALUE_TYPE_DURATION;
}
}
} // namespace
memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory) {
switch (val.type) {
@ -345,8 +349,6 @@ memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::
}
}
} // namespace
mgp_value::mgp_value(memgraph::utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_NULL), memory(m) {}
mgp_value::mgp_value(bool val, memgraph::utils::MemoryResource *m) noexcept
@ -1451,6 +1453,14 @@ mgp_error mgp_result_record_insert(mgp_result_record *record, const char *field_
});
}
mgp_error mgp_func_result_set_error_msg(mgp_func_result *res, const char *msg, mgp_memory *memory) {
return WrapExceptions([=] { res->error_msg.emplace(msg, memory->impl); });
}
mgp_error mgp_func_result_set_value(mgp_func_result *res, mgp_value *value, mgp_memory *memory) {
return WrapExceptions([=] { res->value = ToTypedValue(*value, memory->impl); });
}
/// Graph Constructs
void mgp_properties_iterator_destroy(mgp_properties_iterator *it) { DeleteRawMgpObject(it); }
@ -2382,31 +2392,53 @@ mgp_error mgp_module_add_write_procedure(mgp_module *module, const char *name, m
return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result);
}
mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) {
return WrapExceptions([=] {
if (!IsValidIdentifierName(name)) {
throw std::invalid_argument{fmt::format("Invalid argument name for procedure '{}': {}", proc->name, name)};
namespace {
template <typename T>
concept IsCallable = memgraph::utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
template <IsCallable TCall>
mgp_error MgpAddArg(TCall &callable, const std::string &name, mgp_type &type) {
return WrapExceptions([&]() mutable {
static constexpr std::string_view type_name = std::invoke([]() constexpr {
if constexpr (std::is_same_v<TCall, mgp_proc>) {
return "procedure";
} else if constexpr (std::is_same_v<TCall, mgp_func>) {
return "function";
}
});
if (!IsValidIdentifierName(name.c_str())) {
throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)};
}
if (!proc->opt_args.empty()) {
throw std::logic_error{fmt::format(
"Cannot add required argument '{}' to procedure '{}' after adding any optional one", name, proc->name)};
if (!callable.opt_args.empty()) {
throw std::logic_error{fmt::format("Cannot add required argument '{}' to {} '{}' after adding any optional one",
name, type_name, callable.name)};
}
proc->args.emplace_back(name, type->impl.get());
callable.args.emplace_back(name, type.impl.get());
});
}
mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) {
return WrapExceptions([=] {
if (!IsValidIdentifierName(name)) {
throw std::invalid_argument{fmt::format("Invalid argument name for procedure '{}': {}", proc->name, name)};
template <IsCallable TCall>
mgp_error MgpAddOptArg(TCall &callable, const std::string name, mgp_type &type, mgp_value &default_value) {
return WrapExceptions([&]() mutable {
static constexpr std::string_view type_name = std::invoke([]() constexpr {
if constexpr (std::is_same_v<TCall, mgp_proc>) {
return "procedure";
} else if constexpr (std::is_same_v<TCall, mgp_func>) {
return "function";
}
});
if (!IsValidIdentifierName(name.c_str())) {
throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)};
}
switch (MgpValueGetType(*default_value)) {
switch (MgpValueGetType(default_value)) {
case MGP_VALUE_TYPE_VERTEX:
case MGP_VALUE_TYPE_EDGE:
case MGP_VALUE_TYPE_PATH:
// default_value must not be a graph element.
throw ValueConversionException{
"Default value of argument '{}' of procedure '{}' name must not be a graph element!", name, proc->name};
throw ValueConversionException{"Default value of argument '{}' of {} '{}' name must not be a graph element!",
name, type_name, callable.name};
case MGP_VALUE_TYPE_NULL:
case MGP_VALUE_TYPE_BOOL:
case MGP_VALUE_TYPE_INT:
@ -2421,16 +2453,32 @@ mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type,
break;
}
// Default value must be of required `type`.
if (!type->impl->SatisfiesType(*default_value)) {
throw std::logic_error{
fmt::format("The default value of argument '{}' for procedure '{}' doesn't satisfy type '{}'", name,
proc->name, type->impl->GetPresentableName())};
if (!type.impl->SatisfiesType(default_value)) {
throw std::logic_error{fmt::format("The default value of argument '{}' for {} '{}' doesn't satisfy type '{}'",
name, type_name, callable.name, type.impl->GetPresentableName())};
}
auto *memory = proc->opt_args.get_allocator().GetMemoryResource();
proc->opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type->impl.get(),
ToTypedValue(*default_value, memory));
auto *memory = callable.opt_args.get_allocator().GetMemoryResource();
callable.opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type.impl.get(),
ToTypedValue(default_value, memory));
});
}
} // namespace
mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) {
return MgpAddArg(*proc, std::string(name), *type);
}
mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) {
return MgpAddOptArg(*proc, std::string(name), *type, *default_value);
}
mgp_error mgp_func_add_arg(mgp_func *func, const char *name, mgp_type *type) {
return MgpAddArg(*func, std::string(name), *type);
}
mgp_error mgp_func_add_opt_arg(mgp_func *func, const char *name, mgp_type *type, mgp_value *default_value) {
return MgpAddOptArg(*func, std::string(name), *type, *default_value);
}
namespace {
@ -2545,6 +2593,22 @@ void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) {
(*stream) << ")";
}
void PrintFuncSignature(const mgp_func &func, std::ostream &stream) {
stream << func.name << "(";
utils::PrintIterable(stream, func.args, ", ", [](auto &stream, const auto &arg) {
stream << arg.first << " :: " << arg.second->GetPresentableName();
});
if (!func.args.empty() && !func.opt_args.empty()) {
stream << ", ";
}
utils::PrintIterable(stream, func.opt_args, ", ", [](auto &stream, const auto &arg) {
const auto &[name, type, default_val] = arg;
stream << name << " = ";
PrintValue(default_val, &stream) << " :: " << type->GetPresentableName();
});
stream << ")";
}
bool IsValidIdentifierName(const char *name) {
if (!name) return false;
std::regex regex("[_[:alpha:]][_[:alnum:]]*");
@ -2716,3 +2780,19 @@ mgp_error mgp_module_add_transformation(mgp_module *module, const char *name, mg
module->transformations.emplace(name, mgp_trans(name, cb, memory));
});
}
mgp_error mgp_module_add_function(mgp_module *module, const char *name, mgp_func_cb cb, mgp_func **result) {
return WrapExceptions(
[=] {
if (!IsValidIdentifierName(name)) {
throw std::invalid_argument{fmt::format("Invalid function name: {}", name)};
}
if (module->functions.find(name) != module->functions.end()) {
throw std::logic_error{fmt::format("Function with similar name already exists '{}'", name)};
};
auto *memory = module->functions.get_allocator().GetMemoryResource();
return &module->functions.emplace(name, mgp_func(name, cb, memory)).first->second;
},
result);
}

View File

@ -562,14 +562,36 @@ struct mgp_result {
std::optional<memgraph::utils::pmr::string> error_msg;
};
struct mgp_func_result {
mgp_func_result() {}
/// Return Magic function result. If user forgets it, the error is raised
std::optional<memgraph::query::TypedValue> value;
/// Return Magic function result with potential error
std::optional<memgraph::utils::pmr::string> error_msg;
};
struct mgp_graph {
memgraph::query::DbAccessor *impl;
memgraph::storage::View view;
// TODO: Merge `mgp_graph` and `mgp_memory` into a single `mgp_context`. The
// `ctx` field is out of place here.
memgraph::query::ExecutionContext *ctx;
static mgp_graph WritableGraph(memgraph::query::DbAccessor &acc, memgraph::storage::View view,
memgraph::query::ExecutionContext &ctx) {
return mgp_graph{&acc, view, &ctx};
}
static mgp_graph NonWritableGraph(memgraph::query::DbAccessor &acc, memgraph::storage::View view) {
return mgp_graph{&acc, view, nullptr};
}
};
// Prevents user to use ExecutionContext in writable callables
struct mgp_func_context {
memgraph::query::DbAccessor *impl;
memgraph::storage::View view;
};
struct mgp_properties_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_properties_iterator>;
@ -779,18 +801,69 @@ struct mgp_trans {
results;
};
struct mgp_func {
using allocator_type = memgraph::utils::Allocator<mgp_func>;
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const char *name, mgp_func_cb cb, memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const char *name, std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const mgp_func &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory), cb(other.cb), args(other.args, memory), opt_args(other.opt_args, memory) {}
mgp_func(mgp_func &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory),
cb(std::move(other.cb)),
args(std::move(other.args), memory),
opt_args(std::move(other.opt_args), memory) {}
mgp_func(const mgp_func &other) = default;
mgp_func(mgp_func &&other) = default;
mgp_func &operator=(const mgp_func &) = delete;
mgp_func &operator=(mgp_func &&) = delete;
~mgp_func() = default;
/// Name of the function.
memgraph::utils::pmr::string name;
/// Entry-point for the function.
std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb;
/// Required, positional arguments as a (name, type) pair.
memgraph::utils::pmr::vector<std::pair<memgraph::utils::pmr::string, const memgraph::query::procedure::CypherType *>>
args;
/// Optional positional arguments as a (name, type, default_value) tuple.
memgraph::utils::pmr::vector<std::tuple<memgraph::utils::pmr::string, const memgraph::query::procedure::CypherType *,
memgraph::query::TypedValue>>
opt_args;
};
mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept;
struct mgp_module {
using allocator_type = memgraph::utils::Allocator<mgp_module>;
explicit mgp_module(memgraph::utils::MemoryResource *memory) : procedures(memory), transformations(memory) {}
explicit mgp_module(memgraph::utils::MemoryResource *memory)
: procedures(memory), transformations(memory), functions(memory) {}
mgp_module(const mgp_module &other, memgraph::utils::MemoryResource *memory)
: procedures(other.procedures, memory), transformations(other.transformations, memory) {}
: procedures(other.procedures, memory),
transformations(other.transformations, memory),
functions(other.functions, memory) {}
mgp_module(mgp_module &&other, memgraph::utils::MemoryResource *memory)
: procedures(std::move(other.procedures), memory), transformations(std::move(other.transformations), memory) {}
: procedures(std::move(other.procedures), memory),
transformations(std::move(other.transformations), memory),
functions(std::move(other.functions), memory) {}
mgp_module(const mgp_module &) = default;
mgp_module(mgp_module &&) = default;
@ -802,6 +875,7 @@ struct mgp_module {
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_proc> procedures;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_trans> transformations;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_func> functions;
};
namespace memgraph::query::procedure {
@ -811,6 +885,11 @@ namespace memgraph::query::procedure {
/// @throw anything std::ostream::operator<< may throw.
void PrintProcSignature(const mgp_proc &, std::ostream *);
/// @throw std::bad_alloc
/// @throw std::length_error
/// @throw anything std::ostream::operator<< may throw.
void PrintFuncSignature(const mgp_func &, std::ostream &);
bool IsValidIdentifierName(const char *name);
} // namespace memgraph::query::procedure
@ -839,3 +918,5 @@ struct mgp_messages {
storage_type messages;
};
memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory);

View File

@ -52,6 +52,8 @@ class BuiltinModule final : public Module {
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
void AddProcedure(std::string_view name, mgp_proc proc);
void AddTransformation(std::string_view name, mgp_trans trans);
@ -62,6 +64,7 @@ class BuiltinModule final : public Module {
/// Registered procedures
std::map<std::string, mgp_proc, std::less<>> procedures_;
std::map<std::string, mgp_trans, std::less<>> transformations_;
std::map<std::string, mgp_func, std::less<>> functions_;
};
BuiltinModule::BuiltinModule() {}
@ -75,6 +78,7 @@ const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures()
const std::map<std::string, mgp_trans, std::less<>> *BuiltinModule::Transformations() const {
return &transformations_;
}
const std::map<std::string, mgp_func, std::less<>> *BuiltinModule::Functions() const { return &functions_; }
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) { procedures_.emplace(name, std::move(proc)); }
@ -300,6 +304,82 @@ void RegisterMgTransformations(const std::map<std::string, std::unique_ptr<Modul
module->AddProcedure("transformations", std::move(procedures));
}
void RegisterMgFunctions(
// We expect modules to be sorted by name.
const std::map<std::string, std::unique_ptr<Module>, std::less<>> *all_modules, BuiltinModule *module) {
auto functions_cb = [all_modules](mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result *result,
mgp_memory *memory) {
// Iterating over all_modules assumes that the standard mechanism of magic
// functions invocations takes the ModuleRegistry::lock_ with READ access.
for (const auto &[module_name, module] : *all_modules) {
// Return the results in sorted order by module and by function_name.
static_assert(std::is_same_v<decltype(module->Functions()), const std::map<std::string, mgp_func, std::less<>> *>,
"Expected module magic functions to be sorted by name");
const auto path = module->Path();
const auto path_string = GetPathString(path);
const auto is_editable = IsFileEditable(path);
for (const auto &[func_name, func] : *module->Functions()) {
mgp_result_record *record{nullptr};
if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) {
return;
}
const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result);
if (!path_value) {
return;
}
MgpUniquePtr<mgp_value> is_editable_value{nullptr, mgp_value_destroy};
if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); },
result)) {
return;
}
utils::pmr::string full_name(module_name, memory->impl);
full_name.append(1, '.');
full_name.append(func_name);
const auto name_value = GetStringValueOrSetError(full_name.c_str(), memory, result);
if (!name_value) {
return;
}
std::stringstream ss;
ss << module_name << ".";
PrintFuncSignature(func, ss);
const auto signature = ss.str();
const auto signature_value = GetStringValueOrSetError(signature.c_str(), memory, result);
if (!signature_value) {
return;
}
if (!InsertResultOrSetError(result, record, "name", name_value.get())) {
return;
}
if (!InsertResultOrSetError(result, record, "signature", signature_value.get())) {
return;
}
if (!InsertResultOrSetError(result, record, "path", path_value.get())) {
return;
}
if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) {
return;
}
}
}
};
mgp_proc functions("functions", functions_cb, utils::NewDeleteResource());
MG_ASSERT(mgp_proc_add_result(&functions, "name", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&functions, "signature", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&functions, "path", Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&functions, "is_editable", Call<mgp_type *>(mgp_type_bool)) == MGP_ERROR_NO_ERROR);
module->AddProcedure("functions", std::move(functions));
}
namespace {
bool IsAllowedExtension(const auto &extension) {
static constexpr std::array<std::string_view, 1> allowed_extensions{".py"};
@ -650,8 +730,8 @@ void RegisterMgDeleteModuleFile(ModuleRegistry *module_registry, utils::RWLock *
// `mgp_module::transformations into `proc_map`. The return value of WithModuleRegistration
// is the same as that of `fun`. Note, the return value need only be convertible to `bool`,
// it does not have to be `bool` itself.
template <class TProcMap, class TTransMap, class TFun>
auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, const TFun &fun) {
template <class TProcMap, class TTransMap, class TFuncMap, class TFun>
auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, TFuncMap *func_map, const TFun &fun) {
// We probably don't need more than 256KB for module initialization.
static constexpr size_t stack_bytes = 256UL * 1024UL;
unsigned char stack_memory[stack_bytes];
@ -664,6 +744,8 @@ auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, const TFun
for (const auto &proc : module_def.procedures) proc_map->emplace(proc);
// Copy transformations into resulting trans_map.
for (const auto &trans : module_def.transformations) trans_map->emplace(trans);
// Copy functions into resulting func_map.
for (const auto &func : module_def.functions) func_map->emplace(func);
}
return res;
}
@ -687,6 +769,8 @@ class SharedLibraryModule final : public Module {
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
std::optional<std::filesystem::path> Path() const override { return file_path_; }
private:
@ -702,6 +786,8 @@ class SharedLibraryModule final : public Module {
std::map<std::string, mgp_proc, std::less<>> procedures_;
/// Registered transformations
std::map<std::string, mgp_trans, std::less<>> transformations_;
/// Registered functions
std::map<std::string, mgp_func, std::less<>> functions_;
};
SharedLibraryModule::SharedLibraryModule() : handle_(nullptr) {}
@ -755,7 +841,7 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
}
return true;
};
if (!WithModuleRegistration(&procedures_, &transformations_, module_cb)) {
if (!WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb)) {
return false;
}
// Get optional mgp_shutdown_module
@ -801,6 +887,13 @@ const std::map<std::string, mgp_trans, std::less<>> *SharedLibraryModule::Transf
return &transformations_;
}
const std::map<std::string, mgp_func, std::less<>> *SharedLibraryModule::Functions() const {
MG_ASSERT(handle_,
"Attempting to access functions of a module that has not "
"been loaded...");
return &functions_;
}
class PythonModule final : public Module {
public:
PythonModule();
@ -816,6 +909,7 @@ class PythonModule final : public Module {
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override;
const std::map<std::string, mgp_func, std::less<>> *Functions() const override;
std::optional<std::filesystem::path> Path() const override { return file_path_; }
private:
@ -823,6 +917,7 @@ class PythonModule final : public Module {
py::Object py_module_;
std::map<std::string, mgp_proc, std::less<>> procedures_;
std::map<std::string, mgp_trans, std::less<>> transformations_;
std::map<std::string, mgp_func, std::less<>> functions_;
};
PythonModule::PythonModule() {}
@ -853,7 +948,7 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
};
return result;
};
py_module_ = WithModuleRegistration(&procedures_, &transformations_, module_cb);
py_module_ = WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb);
if (py_module_) {
spdlog::info("Loaded module {}", file_path);
@ -877,6 +972,7 @@ bool PythonModule::Close() {
auto gil = py::EnsureGIL();
procedures_.clear();
transformations_.clear();
functions_.clear();
// Delete the module from the `sys.modules` directory so that the module will
// be properly imported if imported again.
py::Object sys(PyImport_ImportModule("sys"));
@ -906,6 +1002,13 @@ const std::map<std::string, mgp_trans, std::less<>> *PythonModule::Transformatio
"not been loaded...");
return &transformations_;
}
const std::map<std::string, mgp_func, std::less<>> *PythonModule::Functions() const {
MG_ASSERT(py_module_,
"Attempting to access functions of a module that has "
"not been loaded...");
return &functions_;
}
namespace {
std::unique_ptr<Module> LoadModuleFromFile(const std::filesystem::path &path) {
@ -954,6 +1057,7 @@ ModuleRegistry::ModuleRegistry() {
auto module = std::make_unique<BuiltinModule>();
RegisterMgProcedures(&modules_, module.get());
RegisterMgTransformations(&modules_, module.get());
RegisterMgFunctions(&modules_, module.get());
RegisterMgLoad(this, &lock_, module.get());
RegisterMgGetModuleFiles(this, module.get());
RegisterMgGetModuleFile(this, module.get());
@ -1083,7 +1187,7 @@ std::optional<std::pair<std::string_view, std::string_view>> FindModuleNameAndPr
}
template <typename T>
concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans>;
concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans, mgp_func>;
template <ModuleProperties T>
std::optional<std::pair<ModulePtr, const T *>> MakePairIfPropFound(const ModuleRegistry &module_registry,
@ -1092,8 +1196,10 @@ std::optional<std::pair<ModulePtr, const T *>> MakePairIfPropFound(const ModuleR
auto prop_fun = [](auto &module) {
if constexpr (std::is_same_v<T, mgp_proc>) {
return module->Procedures();
} else {
} else if constexpr (std::is_same_v<T, mgp_trans>) {
return module->Transformations();
} else if constexpr (std::is_same_v<T, mgp_func>) {
return module->Functions();
}
};
auto result = FindModuleNameAndProp(module_registry, fully_qualified_name, memory);
@ -1121,4 +1227,10 @@ std::optional<std::pair<ModulePtr, const mgp_trans *>> FindTransformation(
return MakePairIfPropFound<mgp_trans>(module_registry, fully_qualified_transformation_name, memory);
}
std::optional<std::pair<ModulePtr, const mgp_func *>> FindFunction(const ModuleRegistry &module_registry,
std::string_view fully_qualified_function_name,
utils::MemoryResource *memory) {
return MakePairIfPropFound<mgp_func>(module_registry, fully_qualified_function_name, memory);
}
} // namespace memgraph::query::procedure

View File

@ -21,6 +21,7 @@
#include <string_view>
#include <unordered_map>
#include "query/procedure/cypher_types.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "utils/memory.hpp"
#include "utils/rw_lock.hpp"
@ -45,6 +46,8 @@ class Module {
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0;
/// Returns registered transformations of this module
virtual const std::map<std::string, mgp_trans, std::less<>> *Transformations() const = 0;
// /// Returns registered functions of this module
virtual const std::map<std::string, mgp_func, std::less<>> *Functions() const = 0;
virtual std::optional<std::filesystem::path> Path() const = 0;
};
@ -147,4 +150,62 @@ std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
std::optional<std::pair<procedure::ModulePtr, const mgp_trans *>> FindTransformation(
const ModuleRegistry &module_registry, const std::string_view fully_qualified_transformation_name,
utils::MemoryResource *memory);
/// Return the ModulePtr and `mgp_func *` of the found function after resolving
/// `fully_qualified_function_name` if found. If there is no such function
/// std::nullopt is returned. `memory` is used for temporary allocations
/// inside this function. ModulePtr must be kept alive to make sure it won't be unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_func *>> FindFunction(
const ModuleRegistry &module_registry, const std::string_view fully_qualified_function_name,
utils::MemoryResource *memory);
template <typename T>
concept IsCallable = utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
template <IsCallable TCall>
void ConstructArguments(const std::vector<TypedValue> &args, const TCall &callable,
const std::string_view fully_qualified_name, mgp_list &args_list, mgp_graph &graph) {
const auto n_args = args.size();
const auto c_args_sz = callable.args.size();
const auto c_opt_args_sz = callable.opt_args.size();
if (n_args < c_args_sz || (n_args - c_args_sz > c_opt_args_sz)) {
if (callable.args.empty() && callable.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_name);
}
if (callable.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_name, c_args_sz,
c_args_sz == 1U ? "argument" : "arguments");
}
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_name, c_args_sz,
c_args_sz + c_opt_args_sz);
}
args_list.elems.reserve(n_args);
auto is_not_optional_arg = [c_args_sz](int i) { return c_args_sz > i; };
for (size_t i = 0; i < n_args; ++i) {
auto arg = args[i];
std::string_view name;
const query::procedure::CypherType *type;
if (is_not_optional_arg(i)) {
name = callable.args[i].first;
type = callable.args[i].second;
} else {
name = std::get<0>(callable.opt_args[i - c_args_sz]);
type = std::get<1>(callable.opt_args[i - c_args_sz]);
}
if (!type->SatisfiesType(arg)) {
throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.", fully_qualified_name,
name, i, type->GetPresentableName());
}
args_list.elems.emplace_back(std::move(arg), &graph);
}
// Fill missing optional arguments with their default values.
const size_t passed_in_opt_args = n_args - c_args_sz;
for (size_t i = passed_in_opt_args; i < c_opt_args_sz; ++i) {
args_list.elems.emplace_back(std::get<2>(callable.opt_args[i]), &graph);
}
}
} // namespace memgraph::query::procedure

View File

@ -447,62 +447,94 @@ PyObject *MakePyCypherType(mgp_type *type) {
// clang-format off
struct PyQueryProc {
PyObject_HEAD
mgp_proc *proc;
mgp_proc *callable;
};
// clang-format on
PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) {
MG_ASSERT(self->proc);
// clang-format off
struct PyMagicFunc{
PyObject_HEAD
mgp_func *callable;
};
// clang-format on
template <typename T>
concept IsCallable = utils::SameAsAnyOf<T, PyQueryProc, PyMagicFunc>;
template <IsCallable TCall>
PyObject *PyCallableAddArg(TCall *self, PyObject *args) {
MG_ASSERT(self->callable);
const char *name = nullptr;
PyCypherType *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
auto *type = py_type->type;
if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->proc, name, type))) {
return nullptr;
if constexpr (std::is_same_v<TCall, PyQueryProc>) {
if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->callable, name, type))) {
return nullptr;
}
} else if constexpr (std::is_same_v<TCall, PyMagicFunc>) {
if (RaiseExceptionFromErrorCode(mgp_func_add_arg(self->callable, name, type))) {
return nullptr;
}
}
Py_RETURN_NONE;
}
PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) {
MG_ASSERT(self->proc);
template <IsCallable TCall>
PyObject *PyCallableAddOptArg(TCall *self, PyObject *args) {
MG_ASSERT(self->callable);
const char *name = nullptr;
PyCypherType *py_type = nullptr;
PyObject *py_value = nullptr;
if (!PyArg_ParseTuple(args, "sO!O", &name, &PyCypherTypeType, &py_type, &py_value)) return nullptr;
auto *type = py_type->type;
mgp_memory memory{self->proc->opt_args.get_allocator().GetMemoryResource()};
mgp_memory memory{self->callable->opt_args.get_allocator().GetMemoryResource()};
mgp_value *value = PyObjectToMgpValueWithPythonExceptions(py_value, &memory);
if (value == nullptr) {
return nullptr;
}
if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->proc, name, type, value))) {
mgp_value_destroy(value);
return nullptr;
if constexpr (std::is_same_v<TCall, PyQueryProc>) {
if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->callable, name, type, value))) {
mgp_value_destroy(value);
return nullptr;
}
} else if constexpr (std::is_same_v<TCall, PyMagicFunc>) {
if (RaiseExceptionFromErrorCode(mgp_func_add_opt_arg(self->callable, name, type, value))) {
mgp_value_destroy(value);
return nullptr;
}
}
mgp_value_destroy(value);
Py_RETURN_NONE;
}
PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) { return PyCallableAddArg(self, args); }
PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) { return PyCallableAddOptArg(self, args); }
PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) {
MG_ASSERT(self->proc);
MG_ASSERT(self->callable);
const char *name = nullptr;
PyCypherType *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->proc, name, type))) {
if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->callable, name, type))) {
return nullptr;
}
Py_RETURN_NONE;
}
PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) {
MG_ASSERT(self->proc);
MG_ASSERT(self->callable);
const char *name = nullptr;
PyCypherType *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr;
auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->proc, name, type))) {
if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->callable, name, type))) {
return nullptr;
}
Py_RETURN_NONE;
@ -532,6 +564,33 @@ static PyTypeObject PyQueryProcType = {
};
// clang-format on
PyObject *PyMagicFuncAddArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddArg(self, args); }
PyObject *PyMagicFuncAddOptArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddOptArg(self, args); }
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef PyMagicFuncMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"add_arg", reinterpret_cast<PyCFunction>(PyMagicFuncAddArg), METH_VARARGS,
"Add a required argument to a function."},
{"add_opt_arg", reinterpret_cast<PyCFunction>(PyMagicFuncAddOptArg), METH_VARARGS,
"Add an optional argument with a default value to a function."},
{nullptr},
};
// clang-format off
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static PyTypeObject PyMagicFuncType = {
PyVarObject_HEAD_INIT(nullptr, 0)
.tp_name = "_mgp.Func",
.tp_basicsize = sizeof(PyMagicFunc),
// NOLINTNEXTLINE(hicpp-signed-bitwise)
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_doc = "Wraps struct mgp_func.",
.tp_methods = PyMagicFuncMethods,
};
// clang-format on
// clang-format off
struct PyQueryModule {
PyObject_HEAD
@ -796,7 +855,6 @@ py::Object MgpListToPyTuple(mgp_list *list, PyObject *py_graph) {
}
namespace {
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result, py::Object py_record) {
py::Object py_mgp(PyImport_ImportModule("mgp"));
if (!py_mgp) return py::FetchError();
@ -870,6 +928,33 @@ std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(mgp_result *result
return std::nullopt;
}
std::function<void()> PyObjectCleanup(py::Object &py_object) {
return [py_object]() {
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If the
// user stored a reference to one of our `_mgp` instances then the
// internally used `mgp_*` structs will stay unfreed and a memory leak
// will be reported at the end of the query execution.
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
}
// After making sure all references from our side have been cleared,
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
// of our `_mgp` instances then this will prevent them from using those
// objects (whose internal `mgp_*` pointers are now invalid and would cause
// a crash).
if (!py_object.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
}
};
}
void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_result *result,
mgp_memory *memory) {
auto gil = py::EnsureGIL();
@ -895,31 +980,6 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
}
};
auto cleanup = [](py::Object py_graph) {
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If the
// user stored a reference to one of our `_mgp` instances then the
// internally used `mgp_*` structs will stay unfreed and a memory leak
// will be reported at the end of the query execution.
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
}
// After making sure all references from our side have been cleared,
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
// of our `_mgp` instances then this will prevent them from using those
// objects (whose internal `mgp_*` pointers are now invalid and would cause
// a crash).
if (!py_graph.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
}
};
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
// any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
// as not to introduce extra reference counts and prevent their deallocation.
@ -932,14 +992,9 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
std::optional<std::string> maybe_msg;
{
py::Object py_graph(MakePyGraph(graph, memory));
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
if (py_graph) {
try {
maybe_msg = error_to_msg(call(py_graph));
cleanup(py_graph);
} catch (...) {
cleanup(py_graph);
throw;
}
maybe_msg = error_to_msg(call(py_graph));
} else {
maybe_msg = error_to_msg(py::FetchError());
}
@ -972,32 +1027,58 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g
return AddRecordFromPython(result, py_res);
};
auto cleanup = [](py::Object py_graph, py::Object py_messages) {
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If the
// user stored a reference to one of our `_mgp` instances then the
// internally used `mgp_*` structs will stay unfreed and a memory leak
// will be reported at the end of the query execution.
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
// any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
// as not to introduce extra reference counts and prevent their deallocation.
// In particular, the `ExceptionInfo` object has a `traceback` field that
// contains references to the Python frames and their arguments, and therefore
// our `_mgp` instances as well. Within this code we ensure not to keep the
// `ExceptionInfo` object alive so that no extra reference counts are
// introduced. We only fetch the error message and immediately destroy the
// object.
std::optional<std::string> maybe_msg;
{
py::Object py_graph(MakePyGraph(graph, memory));
py::Object py_messages(MakePyMessages(msgs, memory));
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
}
utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph));
utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages));
// After making sure all references from our side have been cleared,
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
// of our `_mgp` instances then this will prevent them from using those
// objects (whose internal `mgp_*` pointers are now invalid and would cause
// a crash).
if (!py_graph.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
if (py_graph && py_messages) {
maybe_msg = error_to_msg(call(py_graph, py_messages));
} else {
maybe_msg = error_to_msg(py::FetchError());
}
if (!py_messages.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
}
if (maybe_msg) {
static_cast<void>(mgp_result_set_error_msg(result, maybe_msg->c_str()));
}
}
void CallPythonFunction(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_func_result *result,
mgp_memory *memory) {
auto gil = py::EnsureGIL();
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
if (!exc_info) return std::nullopt;
// Here we tell the traceback formatter to skip the first line of the
// traceback because that line will always be our wrapper function in our
// internal `mgp.py` file. With that line skipped, the user will always
// get only the relevant traceback that happened in his Python code.
return py::FormatException(*exc_info, /* skip_first_line = */ true);
};
auto call = [&](py::Object py_graph) -> utils::BasicResult<std::optional<py::ExceptionInfo>, mgp_value *> {
py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr()));
if (!py_args) return {py::FetchError()};
auto py_res = py_cb.Call(py_graph, py_args);
if (!py_res) return {py::FetchError()};
mgp_value *ret_val = PyObjectToMgpValueWithPythonExceptions(py_res.Ptr(), memory);
if (ret_val == nullptr) {
return {py::FetchError()};
}
return ret_val;
};
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
@ -1012,22 +1093,22 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g
std::optional<std::string> maybe_msg;
{
py::Object py_graph(MakePyGraph(graph, memory));
py::Object py_messages(MakePyMessages(msgs, memory));
if (py_graph && py_messages) {
try {
maybe_msg = error_to_msg(call(py_graph, py_messages));
cleanup(py_graph, py_messages);
} catch (...) {
cleanup(py_graph, py_messages);
throw;
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
if (py_graph) {
auto maybe_result = call(py_graph);
if (!maybe_result.HasError()) {
static_cast<void>(mgp_func_result_set_value(result, maybe_result.GetValue(), memory));
return;
}
maybe_msg = error_to_msg(maybe_result.GetError());
} else {
maybe_msg = error_to_msg(py::FetchError());
}
}
if (maybe_msg) {
static_cast<void>(mgp_result_set_error_msg(result, maybe_msg->c_str()));
static_cast<void>(
mgp_func_result_set_error_msg(result, maybe_msg->c_str(), memory)); // No error fetching if this fails
}
}
@ -1056,9 +1137,9 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
return nullptr;
}
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType);
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
if (!py_proc) return nullptr;
py_proc->proc = &proc_it->second;
py_proc->callable = &proc_it->second;
return reinterpret_cast<PyObject *>(py_proc);
}
} // namespace
@ -1100,6 +1181,39 @@ PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) {
Py_RETURN_NONE;
}
PyObject *PyQueryModuleAddFunction(PyQueryModule *self, PyObject *cb) {
MG_ASSERT(self->module);
if (!PyCallable_Check(cb)) {
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
return nullptr;
}
auto py_cb = py::Object::FromBorrow(cb);
py::Object py_name(py_cb.GetAttr("__name__"));
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
if (!name) return nullptr;
if (!IsValidIdentifierName(name)) {
PyErr_SetString(PyExc_ValueError, "Function name is not a valid identifier");
return nullptr;
}
auto *memory = self->module->functions.get_allocator().GetMemoryResource();
mgp_func func(
name,
[py_cb](mgp_list *args, mgp_func_context *func_ctx, mgp_func_result *result, mgp_memory *memory) {
auto graph = mgp_graph::NonWritableGraph(*(func_ctx->impl), func_ctx->view);
return CallPythonFunction(py_cb, args, &graph, result, memory);
},
memory);
const auto [func_it, did_insert] = self->module->functions.emplace(name, std::move(func));
if (!did_insert) {
PyErr_SetString(PyExc_ValueError, "Already registered a function with the same name.");
return nullptr;
}
auto *py_func = PyObject_New(PyMagicFunc, &PyMagicFuncType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
if (!py_func) return nullptr;
py_func->callable = &func_it->second;
return reinterpret_cast<PyObject *>(py_func);
}
static PyMethodDef PyQueryModuleMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
@ -1108,6 +1222,8 @@ static PyMethodDef PyQueryModuleMethods[] = {
"Register a writeable procedure with this module."},
{"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O,
"Register a transformation with this module."},
{"add_function", reinterpret_cast<PyCFunction>(PyQueryModuleAddFunction), METH_O,
"Register a function with this module."},
{nullptr},
};
@ -1980,6 +2096,7 @@ PyObject *PyInitMgpModule() {
if (!register_type(&PyGraphType, "Graph")) return nullptr;
if (!register_type(&PyEdgeType, "Edge")) return nullptr;
if (!register_type(&PyQueryProcType, "Proc")) return nullptr;
if (!register_type(&PyMagicFuncType, "Func")) return nullptr;
if (!register_type(&PyQueryModuleType, "Module")) return nullptr;
if (!register_type(&PyVertexType, "Vertex")) return nullptr;
if (!register_type(&PyPathType, "Path")) return nullptr;

View File

@ -6,6 +6,14 @@ add_custom_target(memgraph__e2e__${TARGET_PREFIX}__${FILE_NAME} ALL
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME})
endfunction()
function(copy_e2e_cpp_files TARGET_PREFIX FILE_NAME)
add_custom_target(memgraph__e2e__${TARGET_PREFIX}__${FILE_NAME} ALL
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME}
${CMAKE_CURRENT_BINARY_DIR}/${FILE_NAME}
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME})
endfunction()
add_subdirectory(replication)
add_subdirectory(memory)
add_subdirectory(triggers)
@ -13,6 +21,7 @@ add_subdirectory(isolation_levels)
add_subdirectory(streams)
add_subdirectory(temporal_types)
add_subdirectory(write_procedures)
add_subdirectory(magic_functions)
add_subdirectory(module_file_manager)
add_subdirectory(websocket)

View File

@ -0,0 +1,17 @@
# Set up C++ functions for e2e tests
function(add_query_module target_name src)
add_library(${target_name} SHARED ${src})
SET_TARGET_PROPERTIES(${target_name} PROPERTIES PREFIX "")
target_include_directories(${target_name} PRIVATE ${CMAKE_SOURCE_DIR}/include)
endfunction()
# Set up Python functions for e2e tests
function(copy_magic_functions_e2e_python_files FILE_NAME)
copy_e2e_python_files(functions ${FILE_NAME})
endfunction()
copy_magic_functions_e2e_python_files(common.py)
copy_magic_functions_e2e_python_files(conftest.py)
copy_magic_functions_e2e_python_files(function_example.py)
add_subdirectory(functions)

View File

@ -0,0 +1,35 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgclient
import typing
def execute_and_fetch_all(
cursor: mgclient.Cursor, query: str, params: dict = {}
) -> typing.List[tuple]:
cursor.execute(query, params)
return cursor.fetchall()
def connect(**kwargs) -> mgclient.Connection:
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
connection.autocommit = True
return connection
def has_n_result_row(cursor: mgclient.Cursor, query: str, n: int):
results = execute_and_fetch_all(cursor, query)
return len(results) == n
def has_one_result_row(cursor: mgclient.Cursor, query: str):
return has_n_result_row(cursor, query, 1)

View File

@ -0,0 +1,22 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import pytest
from common import execute_and_fetch_all, connect
@pytest.fixture(autouse=True)
def connection():
connection = connect()
yield connection
cursor = connection.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")

View File

@ -0,0 +1,122 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import typing
import mgclient
import sys
import pytest
from common import execute_and_fetch_all, has_n_result_row
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_argument(connection, function_type):
cursor = connection.cursor()
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
result = execute_and_fetch_all(
cursor,
f"MATCH (n) RETURN {function_type}_read.return_function_argument(n) AS argument;",
)
vertex = result[0][0]
assert isinstance(vertex, mgclient.Node)
assert vertex.labels == set(["Label"])
assert vertex.properties == {"id": 1}
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_optional_argument(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.return_optional_argument(42) AS argument;",
)
result = result[0][0]
assert isinstance(result, int)
assert result == 42
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_optional_argument_no_arg(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.return_optional_argument() AS argument;",
)
result = result[0][0]
assert isinstance(result, int)
assert result == 42
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_add_two_numbers(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.add_two_numbers(1, 5) AS total;",
)
result_sum = result[0][0]
assert isinstance(result_sum, (float, int))
assert result_sum == 6
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_return_null(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.return_null() AS null;",
)
result_null = result[0][0]
assert result_null is None
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_too_many_arguments(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
# Should raise too many arguments
with pytest.raises(mgclient.DatabaseError):
execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.return_null('parameter') AS null;",
)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_try_to_write(connection, function_type):
cursor = connection.cursor()
execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1)
# Should raise non mutable
with pytest.raises(mgclient.DatabaseError):
execute_and_fetch_all(
cursor,
f"MATCH (n) RETURN {function_type}_write.try_to_write(n, 'property', 1);",
)
@pytest.mark.parametrize("function_type", ["py", "c"])
def test_case_sensitivity(connection, function_type):
cursor = connection.cursor()
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
# Should raise function does not exist
with pytest.raises(mgclient.DatabaseError):
execute_and_fetch_all(
cursor,
f"RETURN {function_type}_read.ReTuRn_nUlL('parameter') AS null;",
)
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -0,0 +1,5 @@
copy_magic_functions_e2e_python_files(py_write.py)
copy_magic_functions_e2e_python_files(py_read.py)
add_query_module(c_read c_read.cpp)
add_query_module(c_write c_write.cpp)

View File

@ -0,0 +1,178 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <functional>
#include <stdexcept>
#include "mg_procedure.h"
#include "utils/on_scope_exit.hpp"
namespace {
static void ReturnFunctionArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
auto err_code = mgp_list_at(args, 0, &value);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
return;
}
}
static void ReturnOptionalArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
auto err_code = mgp_list_at(args, 0, &value);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
return;
}
}
double GetElementFromArg(struct mgp_list *args, int index) {
mgp_value *value{nullptr};
if (mgp_list_at(args, index, &value) != MGP_ERROR_NO_ERROR) {
throw std::runtime_error("Error while argument fetching.");
}
double result;
int is_int;
mgp_value_is_int(value, &is_int);
if (is_int) {
int64_t result_int;
mgp_value_get_int(value, &result_int);
result = static_cast<double>(result_int);
} else {
mgp_value_get_double(value, &result);
}
return result;
}
static void AddTwoNumbers(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
double first = 0;
double second = 0;
try {
first = GetElementFromArg(args, 0);
second = GetElementFromArg(args, 1);
} catch (...) {
mgp_func_result_set_error_msg(result, "Unable to fetch the result!", memory);
return;
}
mgp_value *value{nullptr};
auto summation = first + second;
mgp_value_make_double(summation, memory, &value);
memgraph::utils::OnScopeExit delete_summation_value([&value] { mgp_value_destroy(value); });
auto err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
}
}
static void ReturnNull(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
mgp_value_make_null(memory, &value);
memgraph::utils::OnScopeExit delete_null([&value] { mgp_value_destroy(value); });
auto err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory);
}
}
} // namespace
// Each module needs to define mgp_init_module function.
// Here you can register multiple functions/procedures your module supports.
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "return_function_argument", ReturnFunctionArgument, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_any{nullptr};
mgp_type_any(&type_any);
err_code = mgp_func_add_arg(func, "argument", type_any);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "return_optional_argument", ReturnOptionalArgument, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_value *default_value{nullptr};
mgp_value_make_int(42, memory, &default_value);
memgraph::utils::OnScopeExit delete_summation_value([&default_value] { mgp_value_destroy(default_value); });
mgp_type *type_int{nullptr};
mgp_type_int(&type_int);
err_code = mgp_func_add_opt_arg(func, "opt_argument", type_int, default_value);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "add_two_numbers", AddTwoNumbers, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_number{nullptr};
mgp_type_number(&type_number);
err_code = mgp_func_add_arg(func, "first", type_number);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
err_code = mgp_func_add_arg(func, "second", type_number);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "return_null", ReturnNull, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
return 0;
}
// This is an optional function if you need to release any resources before the
// module is unloaded. You will probably need this if you acquired some
// resources in mgp_init_module.
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,80 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "mg_procedure.h"
static void TryToWrite(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result,
struct mgp_memory *memory) {
mgp_value *value{nullptr};
mgp_vertex *vertex{nullptr};
mgp_list_at(args, 0, &value);
mgp_value_get_vertex(value, &vertex);
const char *name;
mgp_list_at(args, 1, &value);
mgp_value_get_string(value, &name);
mgp_list_at(args, 2, &value);
// Setting a property should set an error
auto err_code = mgp_vertex_set_property(vertex, name, value);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Cannot set property in the function!", memory);
return;
}
err_code = mgp_func_result_set_value(result, value, memory);
if (err_code != MGP_ERROR_NO_ERROR) {
mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory);
return;
}
}
// Each module needs to define mgp_init_module function.
// Here you can register multiple functions/procedures your module supports.
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
{
mgp_func *func{nullptr};
auto err_code = mgp_module_add_function(module, "try_to_write", TryToWrite, &func);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_vertex{nullptr};
mgp_type_node(&type_vertex);
err_code = mgp_func_add_arg(func, "argument", type_vertex);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_string{nullptr};
mgp_type_string(&type_string);
err_code = mgp_func_add_arg(func, "name", type_string);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *any_type{nullptr};
mgp_type_any(&any_type);
mgp_type *nullable_type{nullptr};
mgp_type_nullable(any_type, &nullable_type);
err_code = mgp_func_add_arg(func, "value", nullable_type);
if (err_code != MGP_ERROR_NO_ERROR) {
return 1;
}
}
return 0;
}
// This is an optional function if you need to release any resources before the
// module is unloaded. You will probably need this if you acquired some
// resources in mgp_init_module.
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,32 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgp
@mgp.function
def return_function_argument(ctx: mgp.FuncCtx, argument: mgp.Any):
return argument
@mgp.function
def return_optional_argument(ctx: mgp.FuncCtx, opt_argument: mgp.Number = 42):
return opt_argument
@mgp.function
def add_two_numbers(ctx: mgp.FuncCtx, first: mgp.Number, second: mgp.Number):
return first + second
@mgp.function
def return_null(ctx: mgp.FuncCtx):
return None

View File

@ -0,0 +1,17 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgp
@mgp.function
def try_to_write(ctx: mgp.FuncCtx, argument: mgp.Vertex, name: str, value: mgp.Nullable[mgp.Any]):
argument.properties.set(name, value)

View File

@ -0,0 +1,14 @@
template_cluster: &template_cluster
cluster:
main:
args: ["--bolt-port", "7687", "--log-level=TRACE"]
log_file: "magic-functions-e2e.log"
setup_queries: []
validation_queries: []
workloads:
- name: "Magic functions runner"
binary: "tests/e2e/pytest_runner.sh"
proc: "tests/e2e/magic_functions/functions/"
args: ["magic_functions/function_example.py"]
<<: *template_cluster

View File

@ -129,6 +129,12 @@ target_link_libraries(${test_prefix}query_serialization_property_value mg-query)
add_unit_test(query_streams.cpp)
target_link_libraries(${test_prefix}query_streams mg-query kafka-mock)
# Test query functions
add_unit_test(query_function_mgp_module.cpp)
target_link_libraries(${test_prefix}query_function_mgp_module mg-query)
target_include_directories(${test_prefix}query_function_mgp_module PRIVATE ${CMAKE_SOURCE_DIR}/include)
# Test query/procedure
add_unit_test(query_procedure_mgp_type.cpp)
target_link_libraries(${test_prefix}query_procedure_mgp_type mg-query)

View File

@ -211,13 +211,19 @@ class MockModule : public procedure::Module {
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override { return &procedures; }
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override { return &transformations; }
std::optional<std::filesystem::path> Path() const override { return std::nullopt; }
const std::map<std::string, mgp_func, std::less<>> *Functions() const override { return &functions; }
std::optional<std::filesystem::path> Path() const override { return std::nullopt; };
std::map<std::string, mgp_proc, std::less<>> procedures{};
std::map<std::string, mgp_trans, std::less<>> transformations{};
std::map<std::string, mgp_func, std::less<>> functions{};
};
void DummyProcCallback(mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, mgp_memory * /*memory*/){};
void DummyFuncCallback(mgp_list * /*args*/, mgp_func_context * /*func_ctx*/, mgp_func_result * /*result*/,
mgp_memory * /*memory*/){};
enum class ProcedureType { WRITE, READ };
@ -258,6 +264,15 @@ class CypherMainVisitorTest : public ::testing::TestWithParam<std::shared_ptr<Ba
module.procedures.emplace(name, std::move(proc));
}
static void AddFunc(MockModule &module, const char *name, const std::vector<std::string_view> &args) {
memgraph::utils::MemoryResource *memory = memgraph::utils::NewDeleteResource();
mgp_func func(name, DummyFuncCallback, memory);
for (const auto arg : args) {
func.args.emplace_back(memgraph::utils::pmr::string{arg, memory}, &any_type);
}
module.functions.emplace(name, std::move(func));
}
std::string CreateProcByType(const ProcedureType type, const std::vector<std::string_view> &args) {
const auto proc_name = std::string{"proc_"} + ToString(type);
SCOPED_TRACE(proc_name);
@ -858,6 +873,12 @@ TEST_P(CypherMainVisitorTest, UndefinedFunction) {
SemanticException);
}
TEST_P(CypherMainVisitorTest, MissingFunction) {
AddFunc(*mock_module, "get", {});
auto &ast_generator = *GetParam();
ASSERT_THROW(ast_generator.ParseQuery("RETURN missing_function.get()"), SemanticException);
}
TEST_P(CypherMainVisitorTest, Function) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN abs(n, 2)"));
@ -871,6 +892,20 @@ TEST_P(CypherMainVisitorTest, Function) {
ASSERT_TRUE(function->function_);
}
TEST_P(CypherMainVisitorTest, MagicFunction) {
AddFunc(*mock_module, "get", {});
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN mock_module.get()"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 1);
auto *function = dynamic_cast<Function *>(return_clause->body_.named_expressions[0]->expression_);
ASSERT_TRUE(function);
ASSERT_TRUE(function->function_);
}
TEST_P(CypherMainVisitorTest, StringLiteralDoubleQuotes) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN \"mi'rko\""));

View File

@ -0,0 +1,49 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gtest/gtest.h>
#include <functional>
#include <sstream>
#include <string_view>
#include "query/procedure/mg_procedure_impl.hpp"
#include "test_utils.hpp"
static void DummyCallback(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *){};
TEST(Module, InvalidFunctionRegistration) {
mgp_module module(memgraph::utils::NewDeleteResource());
mgp_func *func{nullptr};
// Other test cases are covered within the procedure API. This is only sanity check
EXPECT_EQ(mgp_module_add_function(&module, "dashes-not-supported", DummyCallback, &func), MGP_ERROR_INVALID_ARGUMENT);
}
TEST(Module, RegisterSameFunctionMultipleTimes) {
mgp_module module(memgraph::utils::NewDeleteResource());
mgp_func *func{nullptr};
EXPECT_EQ(module.functions.find("same_name"), module.functions.end());
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_NO_ERROR);
EXPECT_NE(module.functions.find("same_name"), module.functions.end());
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_LOGIC_ERROR);
EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_LOGIC_ERROR);
EXPECT_NE(module.functions.find("same_name"), module.functions.end());
}
TEST(Module, CaseSensitiveFunctionNames) {
mgp_module module(memgraph::utils::NewDeleteResource());
mgp_func *func{nullptr};
EXPECT_EQ(mgp_module_add_function(&module, "not_same", DummyCallback, &func), MGP_ERROR_NO_ERROR);
EXPECT_EQ(mgp_module_add_function(&module, "NoT_saME", DummyCallback, &func), MGP_ERROR_NO_ERROR);
EXPECT_EQ(mgp_module_add_function(&module, "NOT_SAME", DummyCallback, &func), MGP_ERROR_NO_ERROR);
EXPECT_EQ(module.functions.size(), 3U);
}