From 4abaf277653da0934e465670c937c4eee17f64f1 Mon Sep 17 00:00:00 2001 From: Josip Matak <31473001+jmatak@users.noreply.github.com> Date: Thu, 21 Apr 2022 15:45:31 +0200 Subject: [PATCH] Memgraph magic functions (#345) * Extend mgp_module with include adding functions * Add return type to the function API * Change Cypher grammar * Add Python support for functions * Implement error handling * E2e tests for functions * Write cpp e2e functions * Create mg.functions() procedure * Implement case insensitivity for user-defined Magic Functions. --- include/mg_procedure.h | 85 +++++- include/mgp.py | 247 ++++++++++------ src/query/frontend/ast/ast.lcp | 4 +- .../frontend/ast/cypher_main_visitor.cpp | 23 +- .../frontend/opencypher/grammar/Cypher.g4 | 2 +- .../interpret/awesome_memgraph_functions.cpp | 58 ++++ src/query/plan/operator.cpp | 46 +-- src/query/procedure/mg_procedure_impl.cpp | 130 ++++++-- src/query/procedure/mg_procedure_impl.hpp | 87 +++++- src/query/procedure/module.cpp | 124 +++++++- src/query/procedure/module.hpp | 61 ++++ src/query/procedure/py_module.cpp | 279 +++++++++++++----- tests/e2e/CMakeLists.txt | 9 + tests/e2e/magic_functions/CMakeLists.txt | 17 ++ tests/e2e/magic_functions/common.py | 35 +++ tests/e2e/magic_functions/conftest.py | 22 ++ tests/e2e/magic_functions/function_example.py | 122 ++++++++ .../magic_functions/functions/CMakeLists.txt | 5 + .../e2e/magic_functions/functions/c_read.cpp | 178 +++++++++++ .../e2e/magic_functions/functions/c_write.cpp | 80 +++++ .../e2e/magic_functions/functions/py_read.py | 32 ++ .../e2e/magic_functions/functions/py_write.py | 17 ++ tests/e2e/magic_functions/workloads.yaml | 14 + tests/unit/CMakeLists.txt | 6 + tests/unit/cypher_main_visitor.cpp | 37 ++- tests/unit/query_function_mgp_module.cpp | 49 +++ 26 files changed, 1523 insertions(+), 246 deletions(-) create mode 100644 tests/e2e/magic_functions/CMakeLists.txt create mode 100644 tests/e2e/magic_functions/common.py create mode 100644 tests/e2e/magic_functions/conftest.py create mode 100644 tests/e2e/magic_functions/function_example.py create mode 100644 tests/e2e/magic_functions/functions/CMakeLists.txt create mode 100644 tests/e2e/magic_functions/functions/c_read.cpp create mode 100644 tests/e2e/magic_functions/functions/c_write.cpp create mode 100644 tests/e2e/magic_functions/functions/py_read.py create mode 100644 tests/e2e/magic_functions/functions/py_write.py create mode 100644 tests/e2e/magic_functions/workloads.yaml create mode 100644 tests/unit/query_function_mgp_module.cpp diff --git a/include/mg_procedure.h b/include/mg_procedure.h index bf47d9477..8bf29afeb 100644 --- a/include/mg_procedure.h +++ b/include/mg_procedure.h @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -549,6 +549,8 @@ enum mgp_error mgp_path_equal(struct mgp_path *p1, struct mgp_path *p2, int *res struct mgp_result; /// Represents a record of resulting field values. struct mgp_result_record; +/// Represents a return type for magic functions +struct mgp_func_result; /// Set the error as the result of the procedure. /// Return MGP_ERROR_UNABLE_TO_ALLOCATE ff there's no memory for copying the error message. @@ -1290,6 +1292,9 @@ struct mgp_module; /// Describes a procedure of a query module. struct mgp_proc; +/// Describes a Memgraph magic function. +struct mgp_func; + /// Entry-point for a query module read procedure, invoked through openCypher. /// /// Passed in arguments will not live longer than the callback's execution. @@ -1502,6 +1507,84 @@ typedef void (*mgp_trans_cb)(struct mgp_messages *, struct mgp_graph *, struct m enum mgp_error mgp_module_add_transformation(struct mgp_module *module, const char *name, mgp_trans_cb cb); /// @} +/// @name Memgraph Magic Functions API +/// +/// API for creating the Memgraph magic functions. It is used to create external-source stateless methods which can +/// be called by using openCypher query language. These methods should not modify the original graph and should use only +/// the values provided as arguments to the method. +/// +///@{ + +/// Add a required argument to a function. +/// +/// The order of the added arguments corresponds to the signature of the openCypher function. +/// Note, that required arguments are followed by optional arguments. +/// +/// The `name` must be a valid identifier, following the same rules as the +/// function `name` in mgp_module_add_function. +/// +/// Passed in `type` describes what kind of values can be used as the argument. +/// +/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for an argument. +/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid argument name. +/// Return MGP_ERROR_LOGIC_ERROR if the function already has any optional argument. +enum mgp_error mgp_func_add_arg(struct mgp_func *func, const char *name, struct mgp_type *type); + +/// Add an optional argument with a default value to a function. +/// +/// The order of the added arguments corresponds to the signature of the openCypher function. +/// Note, that required arguments are followed by optional arguments. +/// +/// The `name` must be a valid identifier, following the same rules as the +/// function `name` in mgp_module_add_function. +/// +/// Passed in `type` describes what kind of values can be used as the argument. +/// +/// `default_value` is copied and set as the default value for the argument. +/// Don't forget to call mgp_value_destroy when you are done using +/// `default_value`. When the function is called, if this argument is not +/// provided, `default_value` will be used instead. `default_value` must not be +/// a graph element (node, relationship, path) and it must satisfy the given +/// `type`. +/// +/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for an argument. +/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid argument name. +/// Return MGP_ERROR_VALUE_CONVERSION if `default_value` is a graph element (vertex, edge or path). +/// Return MGP_ERROR_LOGIC_ERROR if `default_value` does not satisfy `type`. +enum mgp_error mgp_func_add_opt_arg(struct mgp_func *func, const char *name, struct mgp_type *type, + struct mgp_value *default_value); + +/// Entry-point for a custom Memgraph awesome function. +/// +/// Passed in arguments will not live longer than the callback's execution. +/// Therefore, you must not store them globally or use the passed in mgp_memory +/// to allocate global resources. +typedef void (*mgp_func_cb)(struct mgp_list *, struct mgp_func_context *, struct mgp_func_result *, + struct mgp_memory *); + +/// Register a Memgraph magic function +/// +/// The `name` must be a sequence of digits, underscores, lowercase and +/// uppercase Latin letters. The name must begin with a non-digit character. +/// Note that Unicode characters are not allowed. +/// +/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for mgp_func. +/// Return MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid function name. +/// RETURN MGP_ERROR_LOGIC_ERROR if a function with the same name was already registered. +enum mgp_error mgp_module_add_function(struct mgp_module *module, const char *name, mgp_func_cb cb, + struct mgp_func **result); + +/// Set an error message as an output to the Magic function +/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if there's no memory for copying the error message. +enum mgp_error mgp_func_result_set_error_msg(struct mgp_func_result *result, const char *error_msg, + struct mgp_memory *memory); + +/// Set an output value for the Magic function +/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory to copy the mgp_value to mgp_func_result. +enum mgp_error mgp_func_result_set_value(struct mgp_func_result *result, struct mgp_value *value, + struct mgp_memory *memory); +/// @} + #ifdef __cplusplus } // extern "C" #endif diff --git a/include/mgp.py b/include/mgp.py index ec2262644..7b9d2dcb9 100644 --- a/include/mgp.py +++ b/include/mgp.py @@ -40,6 +40,7 @@ class InvalidContextError(Exception): """ Signals using a graph element instance outside of the registered procedure. """ + pass @@ -47,6 +48,7 @@ class UnknownError(_mgp.UnknownError): """ Signals unspecified failure. """ + pass @@ -54,6 +56,7 @@ class UnableToAllocateError(_mgp.UnableToAllocateError): """ Signals failed memory allocation. """ + pass @@ -61,6 +64,7 @@ class InsufficientBufferError(_mgp.InsufficientBufferError): """ Signals that some buffer is not big enough. """ + pass @@ -69,6 +73,7 @@ class OutOfRangeError(_mgp.OutOfRangeError): Signals that an index-like parameter has a value that is outside its possible values. """ + pass @@ -77,6 +82,7 @@ class LogicErrorError(_mgp.LogicErrorError): Signals faulty logic within the program such as violating logical preconditions or class invariants and may be preventable. """ + pass @@ -84,6 +90,7 @@ class DeletedObjectError(_mgp.DeletedObjectError): """ Signals accessing an already deleted object. """ + pass @@ -91,6 +98,7 @@ class InvalidArgumentError(_mgp.InvalidArgumentError): """ Signals that some of the arguments have invalid values. """ + pass @@ -98,6 +106,7 @@ class KeyAlreadyExistsError(_mgp.KeyAlreadyExistsError): """ Signals that a key already exists in a container-like object. """ + pass @@ -105,6 +114,7 @@ class ImmutableObjectError(_mgp.ImmutableObjectError): """ Signals modification of an immutable object. """ + pass @@ -112,6 +122,7 @@ class ValueConversionError(_mgp.ValueConversionError): """ Signals that the conversion failed between python and cypher values. """ + pass @@ -120,12 +131,14 @@ class SerializationError(_mgp.SerializationError): Signals serialization error caused by concurrent modifications from different transactions. """ + pass class Label: """Label of a Vertex.""" - __slots__ = ('_name',) + + __slots__ = ("_name",) def __init__(self, name: str): self._name = name @@ -145,19 +158,22 @@ class Label: # Named property value of a Vertex or an Edge. # It would be better to use typing.NamedTuple with typed fields, but that is # not available in Python 3.5. -Property = namedtuple('Property', ('name', 'value')) +Property = namedtuple("Property", ("name", "value")) class Properties: """ A collection of properties either on a Vertex or an Edge. """ - __slots__ = ('_vertex_or_edge', '_len',) + + __slots__ = ( + "_vertex_or_edge", + "_len", + ) def __init__(self, vertex_or_edge): if not isinstance(vertex_or_edge, (_mgp.Vertex, _mgp.Edge)): - raise TypeError("Expected '_mgp.Vertex' or '_mgp.Edge', \ - got {}".format(type(vertex_or_edge))) + raise TypeError("Expected '_mgp.Vertex' or '_mgp.Edge', got {}".format(type(vertex_or_edge))) self._len = None self._vertex_or_edge = vertex_or_edge @@ -330,7 +346,8 @@ class Properties: class EdgeType: """Type of an Edge.""" - __slots__ = ('_name',) + + __slots__ = ("_name",) def __init__(self, name): self._name = name @@ -348,7 +365,7 @@ class EdgeType: if sys.version_info >= (3, 5, 2): - EdgeId = typing.NewType('EdgeId', int) + EdgeId = typing.NewType("EdgeId", int) else: EdgeId = int @@ -360,12 +377,12 @@ class Edge: a query. You should not globally store an instance of an Edge. Using an invalid Edge instance will raise InvalidContextError. """ - __slots__ = ('_edge',) + + __slots__ = ("_edge",) def __init__(self, edge): if not isinstance(edge, _mgp.Edge): - raise TypeError( - "Expected '_mgp.Edge', got '{}'".format(type(edge))) + raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge))) self._edge = edge def __deepcopy__(self, memo): @@ -408,7 +425,7 @@ class Edge: return EdgeType(self._edge.get_type_name()) @property - def from_vertex(self) -> 'Vertex': + def from_vertex(self) -> "Vertex": """ Get the source vertex. @@ -419,7 +436,7 @@ class Edge: return Vertex(self._edge.from_vertex()) @property - def to_vertex(self) -> 'Vertex': + def to_vertex(self) -> "Vertex": """ Get the destination vertex. @@ -453,7 +470,7 @@ class Edge: if sys.version_info >= (3, 5, 2): - VertexId = typing.NewType('VertexId', int) + VertexId = typing.NewType("VertexId", int) else: VertexId = int @@ -465,12 +482,12 @@ class Vertex: in a query. You should not globally store an instance of a Vertex. Using an invalid Vertex instance will raise InvalidContextError. """ - __slots__ = ('_vertex',) + + __slots__ = ("_vertex",) def __init__(self, vertex): if not isinstance(vertex, _mgp.Vertex): - raise TypeError( - "Expected '_mgp.Vertex', got '{}'".format(type(vertex))) + raise TypeError("Expected '_mgp.Vertex', got '{}'".format(type(vertex))) self._vertex = vertex def __deepcopy__(self, memo): @@ -513,8 +530,7 @@ class Vertex: """ if not self.is_valid(): raise InvalidContextError() - return tuple(Label(self._vertex.label_at(i)) - for i in range(self._vertex.labels_count())) + return tuple(Label(self._vertex.label_at(i)) for i in range(self._vertex.labels_count())) def add_label(self, label: str) -> None: """ @@ -615,7 +631,8 @@ class Vertex: class Path: """Path containing Vertex and Edge instances.""" - __slots__ = ('_path', '_vertices', '_edges') + + __slots__ = ("_path", "_vertices", "_edges") def __init__(self, starting_vertex_or_path: typing.Union[_mgp.Path, Vertex]): """Initialize with a starting Vertex. @@ -636,8 +653,7 @@ class Path: raise InvalidContextError() self._path = _mgp.Path.make_with_start(vertex) else: - raise TypeError("Expected '_mgp.Vertex' or '_mgp.Path', got '{}'" - .format(type(starting_vertex_or_path))) + raise TypeError("Expected '_mgp.Vertex' or '_mgp.Path', got '{}'".format(type(starting_vertex_or_path))) def __copy__(self): if not self.is_valid(): @@ -678,8 +694,7 @@ class Path: extension. """ if not isinstance(edge, Edge): - raise TypeError( - "Expected '_mgp.Edge', got '{}'".format(type(edge))) + raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge))) if not self.is_valid() or not edge.is_valid(): raise InvalidContextError() self._path.expand(edge._edge) @@ -698,8 +713,7 @@ class Path: raise InvalidContextError() if self._vertices is None: num_vertices = self._path.size() + 1 - self._vertices = tuple(Vertex(self._path.vertex_at(i)) - for i in range(num_vertices)) + self._vertices = tuple(Vertex(self._path.vertex_at(i)) for i in range(num_vertices)) return self._vertices @property @@ -713,14 +727,14 @@ class Path: raise InvalidContextError() if self._edges is None: num_edges = self._path.size() - self._edges = tuple(Edge(self._path.edge_at(i)) - for i in range(num_edges)) + self._edges = tuple(Edge(self._path.edge_at(i)) for i in range(num_edges)) return self._edges class Record: """Represents a record of resulting field values.""" - __slots__ = ('fields',) + + __slots__ = ("fields",) def __init__(self, **kwargs): """Initialize with name=value fields in kwargs.""" @@ -729,12 +743,12 @@ class Record: class Vertices: """Iterable over vertices in a graph.""" - __slots__ = ('_graph', '_len') + + __slots__ = ("_graph", "_len") def __init__(self, graph): if not isinstance(graph, _mgp.Graph): - raise TypeError( - "Expected '_mgp.Graph', got '{}'".format(type(graph))) + raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) self._graph = graph self._len = None @@ -791,12 +805,12 @@ class Vertices: class Graph: """State of the graph database in current ProcCtx.""" - __slots__ = ('_graph',) + + __slots__ = ("_graph",) def __init__(self, graph): if not isinstance(graph, _mgp.Graph): - raise TypeError( - "Expected '_mgp.Graph', got '{}'".format(type(graph))) + raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) self._graph = graph def __deepcopy__(self, memo): @@ -885,8 +899,7 @@ class Graph: raise InvalidContextError() self._graph.detach_delete_vertex(vertex._vertex) - def create_edge(self, from_vertex: Vertex, to_vertex: Vertex, - edge_type: EdgeType) -> None: + def create_edge(self, from_vertex: Vertex, to_vertex: Vertex, edge_type: EdgeType) -> None: """ Create an edge. @@ -899,8 +912,7 @@ class Graph: """ if not self.is_valid(): raise InvalidContextError() - return Edge(self._graph.create_edge(from_vertex._vertex, - to_vertex._vertex, edge_type.name)) + return Edge(self._graph.create_edge(from_vertex._vertex, to_vertex._vertex, edge_type.name)) def delete_edge(self, edge: Edge) -> None: """ @@ -918,6 +930,7 @@ class Graph: class AbortError(Exception): """Signals that the procedure was asked to abort its execution.""" + pass @@ -927,12 +940,12 @@ class ProcCtx: Access to a ProcCtx is only valid during a single execution of a procedure in a query. You should not globally store a ProcCtx instance. """ - __slots__ = ('_graph',) + + __slots__ = ("_graph",) def __init__(self, graph): if not isinstance(graph, _mgp.Graph): - raise TypeError( - "Expected '_mgp.Graph', got '{}'".format(type(graph))) + raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) self._graph = Graph(graph) def is_valid(self) -> bool: @@ -969,8 +982,7 @@ LocalDateTime = datetime.datetime Duration = datetime.timedelta -Any = typing.Union[bool, str, Number, Map, Path, - list, Date, LocalTime, LocalDateTime, Duration] +Any = typing.Union[bool, str, Number, Map, Path, list, Date, LocalTime, LocalDateTime, Duration] List = typing.List @@ -1003,7 +1015,7 @@ def _typing_to_cypher_type(type_): Date: _mgp.type_date(), LocalTime: _mgp.type_local_time(), LocalDateTime: _mgp.type_local_date_time(), - Duration: _mgp.type_duration() + Duration: _mgp.type_duration(), } try: return simple_types[type_] @@ -1021,14 +1033,14 @@ def _typing_to_cypher_type(type_): if type(None) in type_args: types = tuple(t for t in type_args if t is not type(None)) # noqa E721 if len(types) == 1: - type_arg, = types + (type_arg,) = types else: # We cannot do typing.Union[*types], so do the equivalent # with __getitem__ which does not even need arg unpacking. type_arg = typing.Union.__getitem__(types) return _mgp.type_nullable(_typing_to_cypher_type(type_arg)) elif complex_type == list: - type_arg, = type_args + (type_arg,) = type_args return _mgp.type_list(_typing_to_cypher_type(type_arg)) raise UnsupportedTypingError(type_) else: @@ -1038,13 +1050,17 @@ def _typing_to_cypher_type(type_): # printed the same way. `typing.List[type]` is printed as such, while # `typing.Optional[type]` is printed as 'typing.Union[type, NoneType]' def parse_type_args(type_as_str): - return tuple(map(str.strip, - type_as_str[type_as_str.index('[') + 1: -1].split(','))) + return tuple( + map( + str.strip, + type_as_str[type_as_str.index("[") + 1 : -1].split(","), + ) + ) def fully_qualified_name(cls): - if cls.__module__ is None or cls.__module__ == 'builtins': + if cls.__module__ is None or cls.__module__ == "builtins": return cls.__name__ - return cls.__module__ + '.' + cls.__name__ + return cls.__module__ + "." + cls.__name__ def get_simple_type(type_as_str): for simple_type, cypher_type in simple_types.items(): @@ -1060,28 +1076,26 @@ def _typing_to_cypher_type(type_): pass def parse_typing(type_as_str): - if type_as_str.startswith('typing.Union'): + if type_as_str.startswith("typing.Union"): type_args_as_str = parse_type_args(type_as_str) none_type_as_str = type(None).__name__ if none_type_as_str in type_args_as_str: - types = tuple( - t for t in type_args_as_str if t != none_type_as_str) + types = tuple(t for t in type_args_as_str if t != none_type_as_str) if len(types) == 1: - type_arg_as_str, = types + (type_arg_as_str,) = types else: - type_arg_as_str = 'typing.Union[' + \ - ', '.join(types) + ']' + type_arg_as_str = "typing.Union[" + ", ".join(types) + "]" simple_type = get_simple_type(type_arg_as_str) if simple_type is not None: return _mgp.type_nullable(simple_type) return _mgp.type_nullable(parse_typing(type_arg_as_str)) - elif type_as_str.startswith('typing.List'): + elif type_as_str.startswith("typing.List"): type_arg_as_str = parse_type_args(type_as_str) if len(type_arg_as_str) > 1: # Nested object could be a type consisting of a list of types (e.g. mgp.Map) # so we need to join the parts. - type_arg_as_str = ', '.join(type_arg_as_str) + type_arg_as_str = ", ".join(type_arg_as_str) else: type_arg_as_str = type_arg_as_str[0] @@ -1096,9 +1110,11 @@ def _typing_to_cypher_type(type_): # Procedure registration + class Deprecated: """Annotate a resulting Record's field as deprecated.""" - __slots__ = ('field_type',) + + __slots__ = ("field_type",) def __init__(self, type_): self.field_type = type_ @@ -1106,8 +1122,7 @@ class Deprecated: def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]): if not callable(func): - raise TypeError("Expected a callable object, got an instance of '{}'" - .format(type(func))) + raise TypeError("Expected a callable object, got an instance of '{}'".format(type(func))) if inspect.iscoroutinefunction(func): raise TypeError("Callable must not be 'async def' function") if sys.version_info >= (3, 6): @@ -1117,24 +1132,25 @@ def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]): raise NotImplementedError("Generator functions are not supported") -def _register_proc(func: typing.Callable[..., Record], - is_write: bool): +def _register_proc(func: typing.Callable[..., Record], is_write: bool): raise_if_does_not_meet_requirements(func) - register_func = ( - _mgp.Module.add_write_procedure if is_write - else _mgp.Module.add_read_procedure) + register_func = _mgp.Module.add_write_procedure if is_write else _mgp.Module.add_read_procedure sig = inspect.signature(func) params = tuple(sig.parameters.values()) if params and params[0].annotation is ProcCtx: + @wraps(func) def wrapper(graph, args): return func(ProcCtx(graph), *args) + params = params[1:] mgp_proc = register_func(_mgp._MODULE, wrapper) else: + @wraps(func) def wrapper(graph, args): return func(*args) + mgp_proc = register_func(_mgp._MODULE, wrapper) for param in params: name = param.name @@ -1149,8 +1165,7 @@ def _register_proc(func: typing.Callable[..., Record], if sig.return_annotation is not sig.empty: record = sig.return_annotation if not isinstance(record, Record): - raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'" - .format(func.__name__, type(record))) + raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'".format(func.__name__, type(record))) for name, type_ in record.fields.items(): if isinstance(type_, Deprecated): cypher_type = _typing_to_cypher_type(type_.field_type) @@ -1257,20 +1272,22 @@ class InvalidMessageError(Exception): """ Signals using a message instance outside of the registered transformation. """ + pass SOURCE_TYPE_KAFKA = _mgp.SOURCE_TYPE_KAFKA SOURCE_TYPE_PULSAR = _mgp.SOURCE_TYPE_PULSAR + class Message: """Represents a message from a stream.""" - __slots__ = ('_message',) + + __slots__ = ("_message",) def __init__(self, message): if not isinstance(message, _mgp.Message): - raise TypeError( - "Expected '_mgp.Message', got '{}'".format(type(message))) + raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message))) self._message = message def __deepcopy__(self, memo): @@ -1353,17 +1370,18 @@ class Message: class InvalidMessagesError(Exception): """Signals using a messages instance outside of the registered transformation.""" + pass class Messages: """Represents a list of messages from a stream.""" - __slots__ = ('_messages',) + + __slots__ = ("_messages",) def __init__(self, messages): if not isinstance(messages, _mgp.Messages): - raise TypeError( - "Expected '_mgp.Messages', got '{}'".format(type(messages))) + raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages))) self._messages = messages def __deepcopy__(self, memo): @@ -1395,12 +1413,12 @@ class TransCtx: Access to a TransCtx is only valid during a single execution of a transformation. You should not globally store a TransCtx instance. """ - __slots__ = ('_graph') + + __slots__ = "_graph" def __init__(self, graph): if not isinstance(graph, _mgp.Graph): - raise TypeError( - "Expected '_mgp.Graph', got '{}'".format(type(graph))) + raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) self._graph = Graph(graph) def is_valid(self) -> bool: @@ -1420,21 +1438,76 @@ def transformation(func: typing.Callable[..., Record]): params = tuple(sig.parameters.values()) if not params or not params[0].annotation is Messages: if not len(params) == 2 or not params[1].annotation is Messages: - raise NotImplementedError( - "Valid signatures for transformations are (TransCtx, Messages) or (Messages)") + raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)") if params[0].annotation is TransCtx: + @wraps(func) def wrapper(graph, messages): return func(TransCtx(graph), messages) + _mgp._MODULE.add_transformation(wrapper) else: + @wraps(func) def wrapper(graph, messages): return func(messages) + _mgp._MODULE.add_transformation(wrapper) return func +class FuncCtx: + """Context of a function being executed. + + Access to a FuncCtx is only valid during a single execution of a transformation. + You should not globally store a FuncCtx instance. + """ + + __slots__ = "_graph" + + def __init__(self, graph): + if not isinstance(graph, _mgp.Graph): + raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) + self._graph = Graph(graph) + + def is_valid(self) -> bool: + return self._graph.is_valid() + + +def function(func: typing.Callable): + raise_if_does_not_meet_requirements(func) + register_func = _mgp.Module.add_function + sig = inspect.signature(func) + params = tuple(sig.parameters.values()) + if params and params[0].annotation is FuncCtx: + + @wraps(func) + def wrapper(graph, args): + return func(FuncCtx(graph), *args) + + params = params[1:] + mgp_func = register_func(_mgp._MODULE, wrapper) + else: + + @wraps(func) + def wrapper(graph, args): + return func(*args) + + mgp_func = register_func(_mgp._MODULE, wrapper) + + for param in params: + name = param.name + type_ = param.annotation + if type_ is param.empty: + type_ = object + cypher_type = _typing_to_cypher_type(type_) + if param.default is param.empty: + mgp_func.add_arg(name, cypher_type) + else: + mgp_func.add_opt_arg(name, cypher_type, param.default) + return func + + def _wrap_exceptions(): def wrap_function(func): @wraps(func) @@ -1463,6 +1536,7 @@ def _wrap_exceptions(): raise ValueConversionError(e) except _mgp.SerializationError as e: raise SerializationError(e) + return wrapped_func def wrap_prop_func(func): @@ -1473,11 +1547,16 @@ def _wrap_exceptions(): if inspect.isfunction(obj): setattr(cls, name, wrap_function(obj)) elif isinstance(obj, property): - setattr(cls, name, property( - wrap_prop_func(obj.fget), - wrap_prop_func(obj.fset), - wrap_prop_func(obj.fdel), - obj.__doc__)) + setattr( + cls, + name, + property( + wrap_prop_func(obj.fget), + wrap_prop_func(obj.fset), + wrap_prop_func(obj.fdel), + obj.__doc__, + ), + ) def defined_in_this_module(obj: object): return getattr(obj, "__module__", "") == __name__ diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 6e959fc95..29b15a9c5 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -852,7 +852,9 @@ cpp<# : arguments_(arguments), function_name_(function_name), function_(NameToFunction(function_name_)) { - DMG_ASSERT(function_, "Unexpected missing function: {}", function_name_); + if (!function_) { + throw SemanticException("Function '{}' doesn't exist.", function_name); + } } cpp<#) (:private diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 2544ea663..3eac72d95 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -2109,13 +2109,30 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::Functio storage_->Create(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP)); } - auto function = NameToFunction(function_name); - if (!function) throw SemanticException("Function '{}' doesn't exist.", function_name); + auto is_user_defined_function = [](const std::string &function_name) { + // Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined + // functions. Builtin functions should be case insensitive. + return function_name.find('.') != std::string::npos; + }; + + // Don't cache queries which call user-defined functions. User-defined function's return + // types can vary depending on whether the module is reloaded, therefore the cache would + // be invalid. + if (is_user_defined_function(function_name)) { + query_info_.is_cacheable = false; + } + return static_cast(storage_->Create(function_name, expressions)); } antlrcpp::Any CypherMainVisitor::visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) { - return utils::ToUpperCase(ctx->getText()); + auto function_name = ctx->getText(); + // Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined + // functions. Builtin functions should be case insensitive. + if (function_name.find('.') != std::string::npos) { + return function_name; + } + return utils::ToUpperCase(function_name); } antlrcpp::Any CypherMainVisitor::visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) { diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4 index b2276389b..6ce84db1a 100644 --- a/src/query/frontend/opencypher/grammar/Cypher.g4 +++ b/src/query/frontend/opencypher/grammar/Cypher.g4 @@ -279,7 +279,7 @@ idInColl : variable IN expression ; functionInvocation : functionName '(' ( DISTINCT )? ( expression ( ',' expression )* )? ')' ; -functionName : symbolicName ; +functionName : symbolicName ( '.' symbolicName )* ; listComprehension : '[' filterExpression ( '|' expression )? ']' ; diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 8f09113c9..8d4dcd465 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -22,6 +22,9 @@ #include "query/db_accessor.hpp" #include "query/exceptions.hpp" +#include "query/procedure/cypher_types.hpp" +#include "query/procedure/mg_procedure_impl.hpp" +#include "query/procedure/module.hpp" #include "query/typed_value.hpp" #include "utils/string.hpp" #include "utils/temporal.hpp" @@ -1174,6 +1177,53 @@ TypedValue Duration(const TypedValue *args, int64_t nargs, const FunctionContext MapNumericParameters(parameter_mappings, args[0].ValueMap()); return TypedValue(utils::Duration(duration_parameters), ctx.memory); } + +std::function UserFunction( + const mgp_func &func, const std::string &fully_qualified_name) { + return [func, fully_qualified_name](const TypedValue *args, int64_t nargs, const FunctionContext &ctx) -> TypedValue { + /// Find function is called to aquire the lock on Module pointer while user-defined function is executed + const auto &maybe_found = + procedure::FindFunction(procedure::gModuleRegistry, fully_qualified_name, utils::NewDeleteResource()); + if (!maybe_found) { + throw QueryRuntimeException( + "Function '{}' has been unloaded. Please check query modules to confirm that function is loaded in Memgraph.", + fully_qualified_name); + } + /// Explicit extraction of module pointer, to clearly state that the lock is aquired. + // NOLINTNEXTLINE(clang-diagnostic-unused-variable) + const auto &module_ptr = (*maybe_found).first; + + const auto &func_cb = func.cb; + mgp_memory memory{ctx.memory}; + mgp_func_context functx{ctx.db_accessor, ctx.view}; + auto graph = mgp_graph::NonWritableGraph(*ctx.db_accessor, ctx.view); + + std::vector args_list; + args_list.reserve(nargs); + for (std::size_t i = 0; i < nargs; ++i) { + args_list.emplace_back(args[i]); + } + + auto function_argument_list = mgp_list(ctx.memory); + procedure::ConstructArguments(args_list, func, fully_qualified_name, function_argument_list, graph); + + mgp_func_result maybe_res; + func_cb(&function_argument_list, &functx, &maybe_res, &memory); + if (maybe_res.error_msg) { + throw QueryRuntimeException(*maybe_res.error_msg); + } + + if (!maybe_res.value) { + throw QueryRuntimeException( + "Function '{}' didn't set the result nor the error message. Please either set the result by using " + "mgp_func_result_set_value or the error by using mgp_func_result_set_error_msg.", + fully_qualified_name); + } + + return {*(maybe_res.value), ctx.memory}; + }; +} + } // namespace std::function NameToFunction( @@ -1259,6 +1309,14 @@ std::function proc.opt_args.size())) { - if (proc.args.empty() && proc.opt_args.empty()) { - throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_procedure_name); - } else if (proc.opt_args.empty()) { - throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_procedure_name, proc.args.size(), - proc.args.size() == 1U ? "argument" : "arguments"); - } else { - throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_procedure_name, - proc.args.size(), proc.args.size() + proc.opt_args.size()); - } - } - for (size_t i = 0; i < args.size(); ++i) { - auto arg = args[i]->Accept(*evaluator); - std::string_view name; - const query::procedure::CypherType *type{nullptr}; - if (proc.args.size() > i) { - name = proc.args[i].first; - type = proc.args[i].second; - } else { - MG_ASSERT(proc.opt_args.size() > i - proc.args.size()); - name = std::get<0>(proc.opt_args[i - proc.args.size()]); - type = std::get<1>(proc.opt_args[i - proc.args.size()]); - } - if (!type->SatisfiesType(arg)) { - throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.", - fully_qualified_procedure_name, name, i, type->GetPresentableName()); - } - proc_args.elems.emplace_back(std::move(arg), &graph); - } - // Fill missing optional arguments with their default values. - MG_ASSERT(args.size() >= proc.args.size()); - size_t passed_in_opt_args = args.size() - proc.args.size(); - MG_ASSERT(passed_in_opt_args <= proc.opt_args.size()); - for (size_t i = passed_in_opt_args; i < proc.opt_args.size(); ++i) { - proc_args.elems.emplace_back(std::get<2>(proc.opt_args[i]), &graph); + std::vector args_list; + args_list.reserve(args.size()); + for (auto *expression : args) { + args_list.emplace_back(expression->Accept(*evaluator)); } + procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph); if (memory_limit) { SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name, utils::GetReadableSize(*memory_limit)); @@ -3832,7 +3798,7 @@ class CallProcedureCursor : public Cursor { // generator like procedures which yield a new result on each invocation. auto *memory = context.evaluation_context.memory; auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_); - mgp_graph graph{context.db_accessor, graph_view, &context}; + auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context); CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit, &result_); diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 3474f7dbd..f38f4c6c6 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -187,7 +187,10 @@ template return MGP_ERROR_NO_ERROR; } -bool MgpGraphIsMutable(const mgp_graph &graph) noexcept { return graph.view == memgraph::storage::View::NEW; } +// Graph mutations +bool MgpGraphIsMutable(const mgp_graph &graph) noexcept { + return graph.view == memgraph::storage::View::NEW && graph.ctx != nullptr; +} bool MgpVertexIsMutable(const mgp_vertex &vertex) { return MgpGraphIsMutable(*vertex.graph); } @@ -289,6 +292,7 @@ mgp_value_type FromTypedValueType(memgraph::query::TypedValue::Type type) { return MGP_VALUE_TYPE_DURATION; } } +} // namespace memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory) { switch (val.type) { @@ -345,8 +349,6 @@ memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils:: } } -} // namespace - mgp_value::mgp_value(memgraph::utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_NULL), memory(m) {} mgp_value::mgp_value(bool val, memgraph::utils::MemoryResource *m) noexcept @@ -1451,6 +1453,14 @@ mgp_error mgp_result_record_insert(mgp_result_record *record, const char *field_ }); } +mgp_error mgp_func_result_set_error_msg(mgp_func_result *res, const char *msg, mgp_memory *memory) { + return WrapExceptions([=] { res->error_msg.emplace(msg, memory->impl); }); +} + +mgp_error mgp_func_result_set_value(mgp_func_result *res, mgp_value *value, mgp_memory *memory) { + return WrapExceptions([=] { res->value = ToTypedValue(*value, memory->impl); }); +} + /// Graph Constructs void mgp_properties_iterator_destroy(mgp_properties_iterator *it) { DeleteRawMgpObject(it); } @@ -2382,31 +2392,53 @@ mgp_error mgp_module_add_write_procedure(mgp_module *module, const char *name, m return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result); } -mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) { - return WrapExceptions([=] { - if (!IsValidIdentifierName(name)) { - throw std::invalid_argument{fmt::format("Invalid argument name for procedure '{}': {}", proc->name, name)}; +namespace { +template +concept IsCallable = memgraph::utils::SameAsAnyOf; + +template +mgp_error MgpAddArg(TCall &callable, const std::string &name, mgp_type &type) { + return WrapExceptions([&]() mutable { + static constexpr std::string_view type_name = std::invoke([]() constexpr { + if constexpr (std::is_same_v) { + return "procedure"; + } else if constexpr (std::is_same_v) { + return "function"; + } + }); + + if (!IsValidIdentifierName(name.c_str())) { + throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)}; } - if (!proc->opt_args.empty()) { - throw std::logic_error{fmt::format( - "Cannot add required argument '{}' to procedure '{}' after adding any optional one", name, proc->name)}; + if (!callable.opt_args.empty()) { + throw std::logic_error{fmt::format("Cannot add required argument '{}' to {} '{}' after adding any optional one", + name, type_name, callable.name)}; } - proc->args.emplace_back(name, type->impl.get()); + callable.args.emplace_back(name, type.impl.get()); }); } -mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) { - return WrapExceptions([=] { - if (!IsValidIdentifierName(name)) { - throw std::invalid_argument{fmt::format("Invalid argument name for procedure '{}': {}", proc->name, name)}; +template +mgp_error MgpAddOptArg(TCall &callable, const std::string name, mgp_type &type, mgp_value &default_value) { + return WrapExceptions([&]() mutable { + static constexpr std::string_view type_name = std::invoke([]() constexpr { + if constexpr (std::is_same_v) { + return "procedure"; + } else if constexpr (std::is_same_v) { + return "function"; + } + }); + + if (!IsValidIdentifierName(name.c_str())) { + throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)}; } - switch (MgpValueGetType(*default_value)) { + switch (MgpValueGetType(default_value)) { case MGP_VALUE_TYPE_VERTEX: case MGP_VALUE_TYPE_EDGE: case MGP_VALUE_TYPE_PATH: // default_value must not be a graph element. - throw ValueConversionException{ - "Default value of argument '{}' of procedure '{}' name must not be a graph element!", name, proc->name}; + throw ValueConversionException{"Default value of argument '{}' of {} '{}' name must not be a graph element!", + name, type_name, callable.name}; case MGP_VALUE_TYPE_NULL: case MGP_VALUE_TYPE_BOOL: case MGP_VALUE_TYPE_INT: @@ -2421,16 +2453,32 @@ mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, break; } // Default value must be of required `type`. - if (!type->impl->SatisfiesType(*default_value)) { - throw std::logic_error{ - fmt::format("The default value of argument '{}' for procedure '{}' doesn't satisfy type '{}'", name, - proc->name, type->impl->GetPresentableName())}; + if (!type.impl->SatisfiesType(default_value)) { + throw std::logic_error{fmt::format("The default value of argument '{}' for {} '{}' doesn't satisfy type '{}'", + name, type_name, callable.name, type.impl->GetPresentableName())}; } - auto *memory = proc->opt_args.get_allocator().GetMemoryResource(); - proc->opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type->impl.get(), - ToTypedValue(*default_value, memory)); + auto *memory = callable.opt_args.get_allocator().GetMemoryResource(); + callable.opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type.impl.get(), + ToTypedValue(default_value, memory)); }); } +} // namespace + +mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) { + return MgpAddArg(*proc, std::string(name), *type); +} + +mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) { + return MgpAddOptArg(*proc, std::string(name), *type, *default_value); +} + +mgp_error mgp_func_add_arg(mgp_func *func, const char *name, mgp_type *type) { + return MgpAddArg(*func, std::string(name), *type); +} + +mgp_error mgp_func_add_opt_arg(mgp_func *func, const char *name, mgp_type *type, mgp_value *default_value) { + return MgpAddOptArg(*func, std::string(name), *type, *default_value); +} namespace { @@ -2545,6 +2593,22 @@ void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) { (*stream) << ")"; } +void PrintFuncSignature(const mgp_func &func, std::ostream &stream) { + stream << func.name << "("; + utils::PrintIterable(stream, func.args, ", ", [](auto &stream, const auto &arg) { + stream << arg.first << " :: " << arg.second->GetPresentableName(); + }); + if (!func.args.empty() && !func.opt_args.empty()) { + stream << ", "; + } + utils::PrintIterable(stream, func.opt_args, ", ", [](auto &stream, const auto &arg) { + const auto &[name, type, default_val] = arg; + stream << name << " = "; + PrintValue(default_val, &stream) << " :: " << type->GetPresentableName(); + }); + stream << ")"; +} + bool IsValidIdentifierName(const char *name) { if (!name) return false; std::regex regex("[_[:alpha:]][_[:alnum:]]*"); @@ -2716,3 +2780,19 @@ mgp_error mgp_module_add_transformation(mgp_module *module, const char *name, mg module->transformations.emplace(name, mgp_trans(name, cb, memory)); }); } + +mgp_error mgp_module_add_function(mgp_module *module, const char *name, mgp_func_cb cb, mgp_func **result) { + return WrapExceptions( + [=] { + if (!IsValidIdentifierName(name)) { + throw std::invalid_argument{fmt::format("Invalid function name: {}", name)}; + } + if (module->functions.find(name) != module->functions.end()) { + throw std::logic_error{fmt::format("Function with similar name already exists '{}'", name)}; + }; + auto *memory = module->functions.get_allocator().GetMemoryResource(); + + return &module->functions.emplace(name, mgp_func(name, cb, memory)).first->second; + }, + result); +} diff --git a/src/query/procedure/mg_procedure_impl.hpp b/src/query/procedure/mg_procedure_impl.hpp index d28ca699d..fd0292fe2 100644 --- a/src/query/procedure/mg_procedure_impl.hpp +++ b/src/query/procedure/mg_procedure_impl.hpp @@ -562,14 +562,36 @@ struct mgp_result { std::optional error_msg; }; +struct mgp_func_result { + mgp_func_result() {} + /// Return Magic function result. If user forgets it, the error is raised + std::optional value; + /// Return Magic function result with potential error + std::optional error_msg; +}; + struct mgp_graph { memgraph::query::DbAccessor *impl; memgraph::storage::View view; // TODO: Merge `mgp_graph` and `mgp_memory` into a single `mgp_context`. The // `ctx` field is out of place here. memgraph::query::ExecutionContext *ctx; + + static mgp_graph WritableGraph(memgraph::query::DbAccessor &acc, memgraph::storage::View view, + memgraph::query::ExecutionContext &ctx) { + return mgp_graph{&acc, view, &ctx}; + } + + static mgp_graph NonWritableGraph(memgraph::query::DbAccessor &acc, memgraph::storage::View view) { + return mgp_graph{&acc, view, nullptr}; + } }; +// Prevents user to use ExecutionContext in writable callables +struct mgp_func_context { + memgraph::query::DbAccessor *impl; + memgraph::storage::View view; +}; struct mgp_properties_iterator { using allocator_type = memgraph::utils::Allocator; @@ -779,18 +801,69 @@ struct mgp_trans { results; }; +struct mgp_func { + using allocator_type = memgraph::utils::Allocator; + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_func(const char *name, mgp_func_cb cb, memgraph::utils::MemoryResource *memory) + : name(name, memory), cb(cb), args(memory), opt_args(memory) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_func(const char *name, std::function cb, + memgraph::utils::MemoryResource *memory) + : name(name, memory), cb(cb), args(memory), opt_args(memory) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_func(const mgp_func &other, memgraph::utils::MemoryResource *memory) + : name(other.name, memory), cb(other.cb), args(other.args, memory), opt_args(other.opt_args, memory) {} + + mgp_func(mgp_func &&other, memgraph::utils::MemoryResource *memory) + : name(std::move(other.name), memory), + cb(std::move(other.cb)), + args(std::move(other.args), memory), + opt_args(std::move(other.opt_args), memory) {} + + mgp_func(const mgp_func &other) = default; + mgp_func(mgp_func &&other) = default; + + mgp_func &operator=(const mgp_func &) = delete; + mgp_func &operator=(mgp_func &&) = delete; + + ~mgp_func() = default; + + /// Name of the function. + memgraph::utils::pmr::string name; + /// Entry-point for the function. + std::function cb; + /// Required, positional arguments as a (name, type) pair. + memgraph::utils::pmr::vector> + args; + /// Optional positional arguments as a (name, type, default_value) tuple. + memgraph::utils::pmr::vector> + opt_args; +}; + mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept; struct mgp_module { using allocator_type = memgraph::utils::Allocator; - explicit mgp_module(memgraph::utils::MemoryResource *memory) : procedures(memory), transformations(memory) {} + explicit mgp_module(memgraph::utils::MemoryResource *memory) + : procedures(memory), transformations(memory), functions(memory) {} mgp_module(const mgp_module &other, memgraph::utils::MemoryResource *memory) - : procedures(other.procedures, memory), transformations(other.transformations, memory) {} + : procedures(other.procedures, memory), + transformations(other.transformations, memory), + functions(other.functions, memory) {} mgp_module(mgp_module &&other, memgraph::utils::MemoryResource *memory) - : procedures(std::move(other.procedures), memory), transformations(std::move(other.transformations), memory) {} + : procedures(std::move(other.procedures), memory), + transformations(std::move(other.transformations), memory), + functions(std::move(other.functions), memory) {} mgp_module(const mgp_module &) = default; mgp_module(mgp_module &&) = default; @@ -802,6 +875,7 @@ struct mgp_module { memgraph::utils::pmr::map procedures; memgraph::utils::pmr::map transformations; + memgraph::utils::pmr::map functions; }; namespace memgraph::query::procedure { @@ -811,6 +885,11 @@ namespace memgraph::query::procedure { /// @throw anything std::ostream::operator<< may throw. void PrintProcSignature(const mgp_proc &, std::ostream *); +/// @throw std::bad_alloc +/// @throw std::length_error +/// @throw anything std::ostream::operator<< may throw. +void PrintFuncSignature(const mgp_func &, std::ostream &); + bool IsValidIdentifierName(const char *name); } // namespace memgraph::query::procedure @@ -839,3 +918,5 @@ struct mgp_messages { storage_type messages; }; + +memgraph::query::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory); diff --git a/src/query/procedure/module.cpp b/src/query/procedure/module.cpp index b4fff79cb..357dfba31 100644 --- a/src/query/procedure/module.cpp +++ b/src/query/procedure/module.cpp @@ -52,6 +52,8 @@ class BuiltinModule final : public Module { const std::map> *Transformations() const override; + const std::map> *Functions() const override; + void AddProcedure(std::string_view name, mgp_proc proc); void AddTransformation(std::string_view name, mgp_trans trans); @@ -62,6 +64,7 @@ class BuiltinModule final : public Module { /// Registered procedures std::map> procedures_; std::map> transformations_; + std::map> functions_; }; BuiltinModule::BuiltinModule() {} @@ -75,6 +78,7 @@ const std::map> *BuiltinModule::Procedures() const std::map> *BuiltinModule::Transformations() const { return &transformations_; } +const std::map> *BuiltinModule::Functions() const { return &functions_; } void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) { procedures_.emplace(name, std::move(proc)); } @@ -300,6 +304,82 @@ void RegisterMgTransformations(const std::mapAddProcedure("transformations", std::move(procedures)); } +void RegisterMgFunctions( + // We expect modules to be sorted by name. + const std::map, std::less<>> *all_modules, BuiltinModule *module) { + auto functions_cb = [all_modules](mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result *result, + mgp_memory *memory) { + // Iterating over all_modules assumes that the standard mechanism of magic + // functions invocations takes the ModuleRegistry::lock_ with READ access. + for (const auto &[module_name, module] : *all_modules) { + // Return the results in sorted order by module and by function_name. + static_assert(std::is_same_vFunctions()), const std::map> *>, + "Expected module magic functions to be sorted by name"); + + const auto path = module->Path(); + const auto path_string = GetPathString(path); + const auto is_editable = IsFileEditable(path); + + for (const auto &[func_name, func] : *module->Functions()) { + mgp_result_record *record{nullptr}; + + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result); + if (!path_value) { + return; + } + + MgpUniquePtr is_editable_value{nullptr, mgp_value_destroy}; + if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, + result)) { + return; + } + + utils::pmr::string full_name(module_name, memory->impl); + full_name.append(1, '.'); + full_name.append(func_name); + const auto name_value = GetStringValueOrSetError(full_name.c_str(), memory, result); + if (!name_value) { + return; + } + + std::stringstream ss; + ss << module_name << "."; + PrintFuncSignature(func, ss); + const auto signature = ss.str(); + const auto signature_value = GetStringValueOrSetError(signature.c_str(), memory, result); + if (!signature_value) { + return; + } + + if (!InsertResultOrSetError(result, record, "name", name_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "signature", signature_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "path", path_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) { + return; + } + } + } + }; + mgp_proc functions("functions", functions_cb, utils::NewDeleteResource()); + MG_ASSERT(mgp_proc_add_result(&functions, "name", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&functions, "signature", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&functions, "path", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&functions, "is_editable", Call(mgp_type_bool)) == MGP_ERROR_NO_ERROR); + module->AddProcedure("functions", std::move(functions)); +} namespace { bool IsAllowedExtension(const auto &extension) { static constexpr std::array allowed_extensions{".py"}; @@ -650,8 +730,8 @@ void RegisterMgDeleteModuleFile(ModuleRegistry *module_registry, utils::RWLock * // `mgp_module::transformations into `proc_map`. The return value of WithModuleRegistration // is the same as that of `fun`. Note, the return value need only be convertible to `bool`, // it does not have to be `bool` itself. -template -auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, const TFun &fun) { +template +auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, TFuncMap *func_map, const TFun &fun) { // We probably don't need more than 256KB for module initialization. static constexpr size_t stack_bytes = 256UL * 1024UL; unsigned char stack_memory[stack_bytes]; @@ -664,6 +744,8 @@ auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, const TFun for (const auto &proc : module_def.procedures) proc_map->emplace(proc); // Copy transformations into resulting trans_map. for (const auto &trans : module_def.transformations) trans_map->emplace(trans); + // Copy functions into resulting func_map. + for (const auto &func : module_def.functions) func_map->emplace(func); } return res; } @@ -687,6 +769,8 @@ class SharedLibraryModule final : public Module { const std::map> *Transformations() const override; + const std::map> *Functions() const override; + std::optional Path() const override { return file_path_; } private: @@ -702,6 +786,8 @@ class SharedLibraryModule final : public Module { std::map> procedures_; /// Registered transformations std::map> transformations_; + /// Registered functions + std::map> functions_; }; SharedLibraryModule::SharedLibraryModule() : handle_(nullptr) {} @@ -755,7 +841,7 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) { } return true; }; - if (!WithModuleRegistration(&procedures_, &transformations_, module_cb)) { + if (!WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb)) { return false; } // Get optional mgp_shutdown_module @@ -801,6 +887,13 @@ const std::map> *SharedLibraryModule::Transf return &transformations_; } +const std::map> *SharedLibraryModule::Functions() const { + MG_ASSERT(handle_, + "Attempting to access functions of a module that has not " + "been loaded..."); + return &functions_; +} + class PythonModule final : public Module { public: PythonModule(); @@ -816,6 +909,7 @@ class PythonModule final : public Module { const std::map> *Procedures() const override; const std::map> *Transformations() const override; + const std::map> *Functions() const override; std::optional Path() const override { return file_path_; } private: @@ -823,6 +917,7 @@ class PythonModule final : public Module { py::Object py_module_; std::map> procedures_; std::map> transformations_; + std::map> functions_; }; PythonModule::PythonModule() {} @@ -853,7 +948,7 @@ bool PythonModule::Load(const std::filesystem::path &file_path) { }; return result; }; - py_module_ = WithModuleRegistration(&procedures_, &transformations_, module_cb); + py_module_ = WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb); if (py_module_) { spdlog::info("Loaded module {}", file_path); @@ -877,6 +972,7 @@ bool PythonModule::Close() { auto gil = py::EnsureGIL(); procedures_.clear(); transformations_.clear(); + functions_.clear(); // Delete the module from the `sys.modules` directory so that the module will // be properly imported if imported again. py::Object sys(PyImport_ImportModule("sys")); @@ -906,6 +1002,13 @@ const std::map> *PythonModule::Transformatio "not been loaded..."); return &transformations_; } + +const std::map> *PythonModule::Functions() const { + MG_ASSERT(py_module_, + "Attempting to access functions of a module that has " + "not been loaded..."); + return &functions_; +} namespace { std::unique_ptr LoadModuleFromFile(const std::filesystem::path &path) { @@ -954,6 +1057,7 @@ ModuleRegistry::ModuleRegistry() { auto module = std::make_unique(); RegisterMgProcedures(&modules_, module.get()); RegisterMgTransformations(&modules_, module.get()); + RegisterMgFunctions(&modules_, module.get()); RegisterMgLoad(this, &lock_, module.get()); RegisterMgGetModuleFiles(this, module.get()); RegisterMgGetModuleFile(this, module.get()); @@ -1083,7 +1187,7 @@ std::optional> FindModuleNameAndPr } template -concept ModuleProperties = utils::SameAsAnyOf; +concept ModuleProperties = utils::SameAsAnyOf; template std::optional> MakePairIfPropFound(const ModuleRegistry &module_registry, @@ -1092,8 +1196,10 @@ std::optional> MakePairIfPropFound(const ModuleR auto prop_fun = [](auto &module) { if constexpr (std::is_same_v) { return module->Procedures(); - } else { + } else if constexpr (std::is_same_v) { return module->Transformations(); + } else if constexpr (std::is_same_v) { + return module->Functions(); } }; auto result = FindModuleNameAndProp(module_registry, fully_qualified_name, memory); @@ -1121,4 +1227,10 @@ std::optional> FindTransformation( return MakePairIfPropFound(module_registry, fully_qualified_transformation_name, memory); } +std::optional> FindFunction(const ModuleRegistry &module_registry, + std::string_view fully_qualified_function_name, + utils::MemoryResource *memory) { + return MakePairIfPropFound(module_registry, fully_qualified_function_name, memory); +} + } // namespace memgraph::query::procedure diff --git a/src/query/procedure/module.hpp b/src/query/procedure/module.hpp index 4d52f915d..c3403b636 100644 --- a/src/query/procedure/module.hpp +++ b/src/query/procedure/module.hpp @@ -21,6 +21,7 @@ #include #include +#include "query/procedure/cypher_types.hpp" #include "query/procedure/mg_procedure_impl.hpp" #include "utils/memory.hpp" #include "utils/rw_lock.hpp" @@ -45,6 +46,8 @@ class Module { virtual const std::map> *Procedures() const = 0; /// Returns registered transformations of this module virtual const std::map> *Transformations() const = 0; + // /// Returns registered functions of this module + virtual const std::map> *Functions() const = 0; virtual std::optional Path() const = 0; }; @@ -147,4 +150,62 @@ std::optional> FindProcedure( std::optional> FindTransformation( const ModuleRegistry &module_registry, const std::string_view fully_qualified_transformation_name, utils::MemoryResource *memory); + +/// Return the ModulePtr and `mgp_func *` of the found function after resolving +/// `fully_qualified_function_name` if found. If there is no such function +/// std::nullopt is returned. `memory` is used for temporary allocations +/// inside this function. ModulePtr must be kept alive to make sure it won't be unloaded. +std::optional> FindFunction( + const ModuleRegistry &module_registry, const std::string_view fully_qualified_function_name, + utils::MemoryResource *memory); + +template +concept IsCallable = utils::SameAsAnyOf; + +template +void ConstructArguments(const std::vector &args, const TCall &callable, + const std::string_view fully_qualified_name, mgp_list &args_list, mgp_graph &graph) { + const auto n_args = args.size(); + const auto c_args_sz = callable.args.size(); + const auto c_opt_args_sz = callable.opt_args.size(); + + if (n_args < c_args_sz || (n_args - c_args_sz > c_opt_args_sz)) { + if (callable.args.empty() && callable.opt_args.empty()) { + throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_name); + } + + if (callable.opt_args.empty()) { + throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_name, c_args_sz, + c_args_sz == 1U ? "argument" : "arguments"); + } + + throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_name, c_args_sz, + c_args_sz + c_opt_args_sz); + } + args_list.elems.reserve(n_args); + + auto is_not_optional_arg = [c_args_sz](int i) { return c_args_sz > i; }; + for (size_t i = 0; i < n_args; ++i) { + auto arg = args[i]; + std::string_view name; + const query::procedure::CypherType *type; + if (is_not_optional_arg(i)) { + name = callable.args[i].first; + type = callable.args[i].second; + } else { + name = std::get<0>(callable.opt_args[i - c_args_sz]); + type = std::get<1>(callable.opt_args[i - c_args_sz]); + } + if (!type->SatisfiesType(arg)) { + throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.", fully_qualified_name, + name, i, type->GetPresentableName()); + } + args_list.elems.emplace_back(std::move(arg), &graph); + } + // Fill missing optional arguments with their default values. + const size_t passed_in_opt_args = n_args - c_args_sz; + for (size_t i = passed_in_opt_args; i < c_opt_args_sz; ++i) { + args_list.elems.emplace_back(std::get<2>(callable.opt_args[i]), &graph); + } +} } // namespace memgraph::query::procedure diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index e00e097fc..6aa74d0ab 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -447,62 +447,94 @@ PyObject *MakePyCypherType(mgp_type *type) { // clang-format off struct PyQueryProc { PyObject_HEAD - mgp_proc *proc; + mgp_proc *callable; }; // clang-format on -PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) { - MG_ASSERT(self->proc); +// clang-format off +struct PyMagicFunc{ + PyObject_HEAD + mgp_func *callable; +}; +// clang-format on + +template +concept IsCallable = utils::SameAsAnyOf; + +template +PyObject *PyCallableAddArg(TCall *self, PyObject *args) { + MG_ASSERT(self->callable); const char *name = nullptr; PyCypherType *py_type = nullptr; if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; auto *type = py_type->type; - if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->proc, name, type))) { - return nullptr; + + if constexpr (std::is_same_v) { + if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->callable, name, type))) { + return nullptr; + } + } else if constexpr (std::is_same_v) { + if (RaiseExceptionFromErrorCode(mgp_func_add_arg(self->callable, name, type))) { + return nullptr; + } } + Py_RETURN_NONE; } -PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) { - MG_ASSERT(self->proc); +template +PyObject *PyCallableAddOptArg(TCall *self, PyObject *args) { + MG_ASSERT(self->callable); const char *name = nullptr; PyCypherType *py_type = nullptr; PyObject *py_value = nullptr; if (!PyArg_ParseTuple(args, "sO!O", &name, &PyCypherTypeType, &py_type, &py_value)) return nullptr; auto *type = py_type->type; - mgp_memory memory{self->proc->opt_args.get_allocator().GetMemoryResource()}; + mgp_memory memory{self->callable->opt_args.get_allocator().GetMemoryResource()}; mgp_value *value = PyObjectToMgpValueWithPythonExceptions(py_value, &memory); if (value == nullptr) { return nullptr; } - if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->proc, name, type, value))) { - mgp_value_destroy(value); - return nullptr; + if constexpr (std::is_same_v) { + if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->callable, name, type, value))) { + mgp_value_destroy(value); + return nullptr; + } + } else if constexpr (std::is_same_v) { + if (RaiseExceptionFromErrorCode(mgp_func_add_opt_arg(self->callable, name, type, value))) { + mgp_value_destroy(value); + return nullptr; + } } + mgp_value_destroy(value); Py_RETURN_NONE; } +PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) { return PyCallableAddArg(self, args); } + +PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) { return PyCallableAddOptArg(self, args); } + PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) { - MG_ASSERT(self->proc); + MG_ASSERT(self->callable); const char *name = nullptr; PyCypherType *py_type = nullptr; if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; auto *type = reinterpret_cast(py_type)->type; - if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->proc, name, type))) { + if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->callable, name, type))) { return nullptr; } Py_RETURN_NONE; } PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) { - MG_ASSERT(self->proc); + MG_ASSERT(self->callable); const char *name = nullptr; PyCypherType *py_type = nullptr; if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; auto *type = reinterpret_cast(py_type)->type; - if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->proc, name, type))) { + if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->callable, name, type))) { return nullptr; } Py_RETURN_NONE; @@ -532,6 +564,33 @@ static PyTypeObject PyQueryProcType = { }; // clang-format on +PyObject *PyMagicFuncAddArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddArg(self, args); } + +PyObject *PyMagicFuncAddOptArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddOptArg(self, args); } + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static PyMethodDef PyMagicFuncMethods[] = { + {"__reduce__", reinterpret_cast(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"add_arg", reinterpret_cast(PyMagicFuncAddArg), METH_VARARGS, + "Add a required argument to a function."}, + {"add_opt_arg", reinterpret_cast(PyMagicFuncAddOptArg), METH_VARARGS, + "Add an optional argument with a default value to a function."}, + {nullptr}, +}; + +// clang-format off +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static PyTypeObject PyMagicFuncType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Func", + .tp_basicsize = sizeof(PyMagicFunc), + // NOLINTNEXTLINE(hicpp-signed-bitwise) + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_func.", + .tp_methods = PyMagicFuncMethods, +}; +// clang-format on + // clang-format off struct PyQueryModule { PyObject_HEAD @@ -796,7 +855,6 @@ py::Object MgpListToPyTuple(mgp_list *list, PyObject *py_graph) { } namespace { - std::optional AddRecordFromPython(mgp_result *result, py::Object py_record) { py::Object py_mgp(PyImport_ImportModule("mgp")); if (!py_mgp) return py::FetchError(); @@ -870,6 +928,33 @@ std::optional AddMultipleRecordsFromPython(mgp_result *result return std::nullopt; } +std::function PyObjectCleanup(py::Object &py_object) { + return [py_object]() { + // Run `gc.collect` (reference cycle-detection) explicitly, so that we are + // sure the procedure cleaned up everything it held references to. If the + // user stored a reference to one of our `_mgp` instances then the + // internally used `mgp_*` structs will stay unfreed and a memory leak + // will be reported at the end of the query execution. + py::Object gc(PyImport_ImportModule("gc")); + if (!gc) { + LOG_FATAL(py::FetchError().value()); + } + + if (!gc.CallMethod("collect")) { + LOG_FATAL(py::FetchError().value()); + } + + // After making sure all references from our side have been cleared, + // invalidate the `_mgp.Graph` object. If the user kept a reference to one + // of our `_mgp` instances then this will prevent them from using those + // objects (whose internal `mgp_*` pointers are now invalid and would cause + // a crash). + if (!py_object.CallMethod("invalidate")) { + LOG_FATAL(py::FetchError().value()); + } + }; +} + void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory *memory) { auto gil = py::EnsureGIL(); @@ -895,31 +980,6 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra } }; - auto cleanup = [](py::Object py_graph) { - // Run `gc.collect` (reference cycle-detection) explicitly, so that we are - // sure the procedure cleaned up everything it held references to. If the - // user stored a reference to one of our `_mgp` instances then the - // internally used `mgp_*` structs will stay unfreed and a memory leak - // will be reported at the end of the query execution. - py::Object gc(PyImport_ImportModule("gc")); - if (!gc) { - LOG_FATAL(py::FetchError().value()); - } - - if (!gc.CallMethod("collect")) { - LOG_FATAL(py::FetchError().value()); - } - - // After making sure all references from our side have been cleared, - // invalidate the `_mgp.Graph` object. If the user kept a reference to one - // of our `_mgp` instances then this will prevent them from using those - // objects (whose internal `mgp_*` pointers are now invalid and would cause - // a crash). - if (!py_graph.CallMethod("invalidate")) { - LOG_FATAL(py::FetchError().value()); - } - }; - // It is *VERY IMPORTANT* to note that this code takes great care not to keep // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so // as not to introduce extra reference counts and prevent their deallocation. @@ -932,14 +992,9 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra std::optional maybe_msg; { py::Object py_graph(MakePyGraph(graph, memory)); + utils::OnScopeExit clean_up(PyObjectCleanup(py_graph)); if (py_graph) { - try { - maybe_msg = error_to_msg(call(py_graph)); - cleanup(py_graph); - } catch (...) { - cleanup(py_graph); - throw; - } + maybe_msg = error_to_msg(call(py_graph)); } else { maybe_msg = error_to_msg(py::FetchError()); } @@ -972,32 +1027,58 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g return AddRecordFromPython(result, py_res); }; - auto cleanup = [](py::Object py_graph, py::Object py_messages) { - // Run `gc.collect` (reference cycle-detection) explicitly, so that we are - // sure the procedure cleaned up everything it held references to. If the - // user stored a reference to one of our `_mgp` instances then the - // internally used `mgp_*` structs will stay unfreed and a memory leak - // will be reported at the end of the query execution. - py::Object gc(PyImport_ImportModule("gc")); - if (!gc) { - LOG_FATAL(py::FetchError().value()); - } + // It is *VERY IMPORTANT* to note that this code takes great care not to keep + // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so + // as not to introduce extra reference counts and prevent their deallocation. + // In particular, the `ExceptionInfo` object has a `traceback` field that + // contains references to the Python frames and their arguments, and therefore + // our `_mgp` instances as well. Within this code we ensure not to keep the + // `ExceptionInfo` object alive so that no extra reference counts are + // introduced. We only fetch the error message and immediately destroy the + // object. + std::optional maybe_msg; + { + py::Object py_graph(MakePyGraph(graph, memory)); + py::Object py_messages(MakePyMessages(msgs, memory)); - if (!gc.CallMethod("collect")) { - LOG_FATAL(py::FetchError().value()); - } + utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph)); + utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages)); - // After making sure all references from our side have been cleared, - // invalidate the `_mgp.Graph` object. If the user kept a reference to one - // of our `_mgp` instances then this will prevent them from using those - // objects (whose internal `mgp_*` pointers are now invalid and would cause - // a crash). - if (!py_graph.CallMethod("invalidate")) { - LOG_FATAL(py::FetchError().value()); + if (py_graph && py_messages) { + maybe_msg = error_to_msg(call(py_graph, py_messages)); + } else { + maybe_msg = error_to_msg(py::FetchError()); } - if (!py_messages.CallMethod("invalidate")) { - LOG_FATAL(py::FetchError().value()); + } + + if (maybe_msg) { + static_cast(mgp_result_set_error_msg(result, maybe_msg->c_str())); + } +} + +void CallPythonFunction(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_func_result *result, + mgp_memory *memory) { + auto gil = py::EnsureGIL(); + + auto error_to_msg = [](const std::optional &exc_info) -> std::optional { + if (!exc_info) return std::nullopt; + // Here we tell the traceback formatter to skip the first line of the + // traceback because that line will always be our wrapper function in our + // internal `mgp.py` file. With that line skipped, the user will always + // get only the relevant traceback that happened in his Python code. + return py::FormatException(*exc_info, /* skip_first_line = */ true); + }; + + auto call = [&](py::Object py_graph) -> utils::BasicResult, mgp_value *> { + py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr())); + if (!py_args) return {py::FetchError()}; + auto py_res = py_cb.Call(py_graph, py_args); + if (!py_res) return {py::FetchError()}; + mgp_value *ret_val = PyObjectToMgpValueWithPythonExceptions(py_res.Ptr(), memory); + if (ret_val == nullptr) { + return {py::FetchError()}; } + return ret_val; }; // It is *VERY IMPORTANT* to note that this code takes great care not to keep @@ -1012,22 +1093,22 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g std::optional maybe_msg; { py::Object py_graph(MakePyGraph(graph, memory)); - py::Object py_messages(MakePyMessages(msgs, memory)); - if (py_graph && py_messages) { - try { - maybe_msg = error_to_msg(call(py_graph, py_messages)); - cleanup(py_graph, py_messages); - } catch (...) { - cleanup(py_graph, py_messages); - throw; + utils::OnScopeExit clean_up(PyObjectCleanup(py_graph)); + if (py_graph) { + auto maybe_result = call(py_graph); + if (!maybe_result.HasError()) { + static_cast(mgp_func_result_set_value(result, maybe_result.GetValue(), memory)); + return; } + maybe_msg = error_to_msg(maybe_result.GetError()); } else { maybe_msg = error_to_msg(py::FetchError()); } } if (maybe_msg) { - static_cast(mgp_result_set_error_msg(result, maybe_msg->c_str())); + static_cast( + mgp_func_result_set_error_msg(result, maybe_msg->c_str(), memory)); // No error fetching if this fails } } @@ -1056,9 +1137,9 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name."); return nullptr; } - auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); + auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) if (!py_proc) return nullptr; - py_proc->proc = &proc_it->second; + py_proc->callable = &proc_it->second; return reinterpret_cast(py_proc); } } // namespace @@ -1100,6 +1181,39 @@ PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) { Py_RETURN_NONE; } +PyObject *PyQueryModuleAddFunction(PyQueryModule *self, PyObject *cb) { + MG_ASSERT(self->module); + if (!PyCallable_Check(cb)) { + PyErr_SetString(PyExc_TypeError, "Expected a callable object."); + return nullptr; + } + auto py_cb = py::Object::FromBorrow(cb); + py::Object py_name(py_cb.GetAttr("__name__")); + const auto *name = PyUnicode_AsUTF8(py_name.Ptr()); + if (!name) return nullptr; + if (!IsValidIdentifierName(name)) { + PyErr_SetString(PyExc_ValueError, "Function name is not a valid identifier"); + return nullptr; + } + auto *memory = self->module->functions.get_allocator().GetMemoryResource(); + mgp_func func( + name, + [py_cb](mgp_list *args, mgp_func_context *func_ctx, mgp_func_result *result, mgp_memory *memory) { + auto graph = mgp_graph::NonWritableGraph(*(func_ctx->impl), func_ctx->view); + return CallPythonFunction(py_cb, args, &graph, result, memory); + }, + memory); + const auto [func_it, did_insert] = self->module->functions.emplace(name, std::move(func)); + if (!did_insert) { + PyErr_SetString(PyExc_ValueError, "Already registered a function with the same name."); + return nullptr; + } + auto *py_func = PyObject_New(PyMagicFunc, &PyMagicFuncType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + if (!py_func) return nullptr; + py_func->callable = &func_it->second; + return reinterpret_cast(py_func); +} + static PyMethodDef PyQueryModuleMethods[] = { {"__reduce__", reinterpret_cast(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, {"add_read_procedure", reinterpret_cast(PyQueryModuleAddReadProcedure), METH_O, @@ -1108,6 +1222,8 @@ static PyMethodDef PyQueryModuleMethods[] = { "Register a writeable procedure with this module."}, {"add_transformation", reinterpret_cast(PyQueryModuleAddTransformation), METH_O, "Register a transformation with this module."}, + {"add_function", reinterpret_cast(PyQueryModuleAddFunction), METH_O, + "Register a function with this module."}, {nullptr}, }; @@ -1980,6 +2096,7 @@ PyObject *PyInitMgpModule() { if (!register_type(&PyGraphType, "Graph")) return nullptr; if (!register_type(&PyEdgeType, "Edge")) return nullptr; if (!register_type(&PyQueryProcType, "Proc")) return nullptr; + if (!register_type(&PyMagicFuncType, "Func")) return nullptr; if (!register_type(&PyQueryModuleType, "Module")) return nullptr; if (!register_type(&PyVertexType, "Vertex")) return nullptr; if (!register_type(&PyPathType, "Path")) return nullptr; diff --git a/tests/e2e/CMakeLists.txt b/tests/e2e/CMakeLists.txt index 12d9969e1..4a53a0836 100644 --- a/tests/e2e/CMakeLists.txt +++ b/tests/e2e/CMakeLists.txt @@ -6,6 +6,14 @@ add_custom_target(memgraph__e2e__${TARGET_PREFIX}__${FILE_NAME} ALL DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME}) endfunction() +function(copy_e2e_cpp_files TARGET_PREFIX FILE_NAME) +add_custom_target(memgraph__e2e__${TARGET_PREFIX}__${FILE_NAME} ALL + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME} + ${CMAKE_CURRENT_BINARY_DIR}/${FILE_NAME} + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME}) +endfunction() + add_subdirectory(replication) add_subdirectory(memory) add_subdirectory(triggers) @@ -13,6 +21,7 @@ add_subdirectory(isolation_levels) add_subdirectory(streams) add_subdirectory(temporal_types) add_subdirectory(write_procedures) +add_subdirectory(magic_functions) add_subdirectory(module_file_manager) add_subdirectory(websocket) diff --git a/tests/e2e/magic_functions/CMakeLists.txt b/tests/e2e/magic_functions/CMakeLists.txt new file mode 100644 index 000000000..fb4986724 --- /dev/null +++ b/tests/e2e/magic_functions/CMakeLists.txt @@ -0,0 +1,17 @@ +# Set up C++ functions for e2e tests +function(add_query_module target_name src) + add_library(${target_name} SHARED ${src}) + SET_TARGET_PROPERTIES(${target_name} PROPERTIES PREFIX "") + target_include_directories(${target_name} PRIVATE ${CMAKE_SOURCE_DIR}/include) +endfunction() + +# Set up Python functions for e2e tests +function(copy_magic_functions_e2e_python_files FILE_NAME) + copy_e2e_python_files(functions ${FILE_NAME}) +endfunction() + +copy_magic_functions_e2e_python_files(common.py) +copy_magic_functions_e2e_python_files(conftest.py) +copy_magic_functions_e2e_python_files(function_example.py) + +add_subdirectory(functions) diff --git a/tests/e2e/magic_functions/common.py b/tests/e2e/magic_functions/common.py new file mode 100644 index 000000000..9e2a8abc1 --- /dev/null +++ b/tests/e2e/magic_functions/common.py @@ -0,0 +1,35 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import mgclient +import typing + + +def execute_and_fetch_all( + cursor: mgclient.Cursor, query: str, params: dict = {} +) -> typing.List[tuple]: + cursor.execute(query, params) + return cursor.fetchall() + + +def connect(**kwargs) -> mgclient.Connection: + connection = mgclient.connect(host="localhost", port=7687, **kwargs) + connection.autocommit = True + return connection + + +def has_n_result_row(cursor: mgclient.Cursor, query: str, n: int): + results = execute_and_fetch_all(cursor, query) + return len(results) == n + + +def has_one_result_row(cursor: mgclient.Cursor, query: str): + return has_n_result_row(cursor, query, 1) diff --git a/tests/e2e/magic_functions/conftest.py b/tests/e2e/magic_functions/conftest.py new file mode 100644 index 000000000..762a093b0 --- /dev/null +++ b/tests/e2e/magic_functions/conftest.py @@ -0,0 +1,22 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import pytest + +from common import execute_and_fetch_all, connect + + +@pytest.fixture(autouse=True) +def connection(): + connection = connect() + yield connection + cursor = connection.cursor() + execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") diff --git a/tests/e2e/magic_functions/function_example.py b/tests/e2e/magic_functions/function_example.py new file mode 100644 index 000000000..0ca208b7d --- /dev/null +++ b/tests/e2e/magic_functions/function_example.py @@ -0,0 +1,122 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import typing +import mgclient +import sys +import pytest +from common import execute_and_fetch_all, has_n_result_row + + +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_return_argument(connection, function_type): + cursor = connection.cursor() + execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});") + assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1) + result = execute_and_fetch_all( + cursor, + f"MATCH (n) RETURN {function_type}_read.return_function_argument(n) AS argument;", + ) + vertex = result[0][0] + assert isinstance(vertex, mgclient.Node) + assert vertex.labels == set(["Label"]) + assert vertex.properties == {"id": 1} + + +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_return_optional_argument(connection, function_type): + cursor = connection.cursor() + assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) + result = execute_and_fetch_all( + cursor, + f"RETURN {function_type}_read.return_optional_argument(42) AS argument;", + ) + result = result[0][0] + assert isinstance(result, int) + assert result == 42 + + +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_return_optional_argument_no_arg(connection, function_type): + cursor = connection.cursor() + assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) + result = execute_and_fetch_all( + cursor, + f"RETURN {function_type}_read.return_optional_argument() AS argument;", + ) + result = result[0][0] + assert isinstance(result, int) + assert result == 42 + + +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_add_two_numbers(connection, function_type): + cursor = connection.cursor() + assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) + result = execute_and_fetch_all( + cursor, + f"RETURN {function_type}_read.add_two_numbers(1, 5) AS total;", + ) + result_sum = result[0][0] + assert isinstance(result_sum, (float, int)) + assert result_sum == 6 + + +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_return_null(connection, function_type): + cursor = connection.cursor() + assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) + result = execute_and_fetch_all( + cursor, + f"RETURN {function_type}_read.return_null() AS null;", + ) + result_null = result[0][0] + assert result_null is None + + +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_too_many_arguments(connection, function_type): + cursor = connection.cursor() + assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) + # Should raise too many arguments + with pytest.raises(mgclient.DatabaseError): + execute_and_fetch_all( + cursor, + f"RETURN {function_type}_read.return_null('parameter') AS null;", + ) + + +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_try_to_write(connection, function_type): + cursor = connection.cursor() + execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});") + assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1) + # Should raise non mutable + with pytest.raises(mgclient.DatabaseError): + execute_and_fetch_all( + cursor, + f"MATCH (n) RETURN {function_type}_write.try_to_write(n, 'property', 1);", + ) + +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_case_sensitivity(connection, function_type): + cursor = connection.cursor() + assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) + # Should raise function does not exist + with pytest.raises(mgclient.DatabaseError): + execute_and_fetch_all( + cursor, + f"RETURN {function_type}_read.ReTuRn_nUlL('parameter') AS null;", + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/magic_functions/functions/CMakeLists.txt b/tests/e2e/magic_functions/functions/CMakeLists.txt new file mode 100644 index 000000000..ef07acc14 --- /dev/null +++ b/tests/e2e/magic_functions/functions/CMakeLists.txt @@ -0,0 +1,5 @@ +copy_magic_functions_e2e_python_files(py_write.py) +copy_magic_functions_e2e_python_files(py_read.py) + +add_query_module(c_read c_read.cpp) +add_query_module(c_write c_write.cpp) diff --git a/tests/e2e/magic_functions/functions/c_read.cpp b/tests/e2e/magic_functions/functions/c_read.cpp new file mode 100644 index 000000000..87c42bc04 --- /dev/null +++ b/tests/e2e/magic_functions/functions/c_read.cpp @@ -0,0 +1,178 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include + +#include "mg_procedure.h" + +#include "utils/on_scope_exit.hpp" + +namespace { +static void ReturnFunctionArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result, + struct mgp_memory *memory) { + mgp_value *value{nullptr}; + auto err_code = mgp_list_at(args, 0, &value); + if (err_code != MGP_ERROR_NO_ERROR) { + mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory); + return; + } + + err_code = mgp_func_result_set_value(result, value, memory); + if (err_code != MGP_ERROR_NO_ERROR) { + mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory); + return; + } +} + +static void ReturnOptionalArgument(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result, + struct mgp_memory *memory) { + mgp_value *value{nullptr}; + auto err_code = mgp_list_at(args, 0, &value); + if (err_code != MGP_ERROR_NO_ERROR) { + mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory); + return; + } + + err_code = mgp_func_result_set_value(result, value, memory); + if (err_code != MGP_ERROR_NO_ERROR) { + mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory); + return; + } +} + +double GetElementFromArg(struct mgp_list *args, int index) { + mgp_value *value{nullptr}; + if (mgp_list_at(args, index, &value) != MGP_ERROR_NO_ERROR) { + throw std::runtime_error("Error while argument fetching."); + } + + double result; + int is_int; + mgp_value_is_int(value, &is_int); + + if (is_int) { + int64_t result_int; + mgp_value_get_int(value, &result_int); + result = static_cast(result_int); + } else { + mgp_value_get_double(value, &result); + } + return result; +} + +static void AddTwoNumbers(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result, + struct mgp_memory *memory) { + double first = 0; + double second = 0; + try { + first = GetElementFromArg(args, 0); + second = GetElementFromArg(args, 1); + } catch (...) { + mgp_func_result_set_error_msg(result, "Unable to fetch the result!", memory); + return; + } + + mgp_value *value{nullptr}; + auto summation = first + second; + mgp_value_make_double(summation, memory, &value); + memgraph::utils::OnScopeExit delete_summation_value([&value] { mgp_value_destroy(value); }); + + auto err_code = mgp_func_result_set_value(result, value, memory); + if (err_code != MGP_ERROR_NO_ERROR) { + mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory); + } +} + +static void ReturnNull(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result, + struct mgp_memory *memory) { + mgp_value *value{nullptr}; + mgp_value_make_null(memory, &value); + memgraph::utils::OnScopeExit delete_null([&value] { mgp_value_destroy(value); }); + + auto err_code = mgp_func_result_set_value(result, value, memory); + if (err_code != MGP_ERROR_NO_ERROR) { + mgp_func_result_set_error_msg(result, "Failed to fetch list!", memory); + } +} +} // namespace + +// Each module needs to define mgp_init_module function. +// Here you can register multiple functions/procedures your module supports. +extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) { + { + mgp_func *func{nullptr}; + auto err_code = mgp_module_add_function(module, "return_function_argument", ReturnFunctionArgument, &func); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + + mgp_type *type_any{nullptr}; + mgp_type_any(&type_any); + err_code = mgp_func_add_arg(func, "argument", type_any); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + } + + { + mgp_func *func{nullptr}; + auto err_code = mgp_module_add_function(module, "return_optional_argument", ReturnOptionalArgument, &func); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + + mgp_value *default_value{nullptr}; + mgp_value_make_int(42, memory, &default_value); + memgraph::utils::OnScopeExit delete_summation_value([&default_value] { mgp_value_destroy(default_value); }); + + mgp_type *type_int{nullptr}; + mgp_type_int(&type_int); + err_code = mgp_func_add_opt_arg(func, "opt_argument", type_int, default_value); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + } + + { + mgp_func *func{nullptr}; + auto err_code = mgp_module_add_function(module, "add_two_numbers", AddTwoNumbers, &func); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + + mgp_type *type_number{nullptr}; + mgp_type_number(&type_number); + err_code = mgp_func_add_arg(func, "first", type_number); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + err_code = mgp_func_add_arg(func, "second", type_number); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + } + + { + mgp_func *func{nullptr}; + auto err_code = mgp_module_add_function(module, "return_null", ReturnNull, &func); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + } + + return 0; +} + +// This is an optional function if you need to release any resources before the +// module is unloaded. You will probably need this if you acquired some +// resources in mgp_init_module. +extern "C" int mgp_shutdown_module() { return 0; } diff --git a/tests/e2e/magic_functions/functions/c_write.cpp b/tests/e2e/magic_functions/functions/c_write.cpp new file mode 100644 index 000000000..a95ab2192 --- /dev/null +++ b/tests/e2e/magic_functions/functions/c_write.cpp @@ -0,0 +1,80 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "mg_procedure.h" + +static void TryToWrite(struct mgp_list *args, mgp_func_context *ctx, mgp_func_result *result, + struct mgp_memory *memory) { + mgp_value *value{nullptr}; + mgp_vertex *vertex{nullptr}; + mgp_list_at(args, 0, &value); + mgp_value_get_vertex(value, &vertex); + + const char *name; + mgp_list_at(args, 1, &value); + mgp_value_get_string(value, &name); + + mgp_list_at(args, 2, &value); + + // Setting a property should set an error + auto err_code = mgp_vertex_set_property(vertex, name, value); + if (err_code != MGP_ERROR_NO_ERROR) { + mgp_func_result_set_error_msg(result, "Cannot set property in the function!", memory); + return; + } + + err_code = mgp_func_result_set_value(result, value, memory); + if (err_code != MGP_ERROR_NO_ERROR) { + mgp_func_result_set_error_msg(result, "Failed to construct return value!", memory); + return; + } +} + +// Each module needs to define mgp_init_module function. +// Here you can register multiple functions/procedures your module supports. +extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) { + { + mgp_func *func{nullptr}; + auto err_code = mgp_module_add_function(module, "try_to_write", TryToWrite, &func); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + + mgp_type *type_vertex{nullptr}; + mgp_type_node(&type_vertex); + err_code = mgp_func_add_arg(func, "argument", type_vertex); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + + mgp_type *type_string{nullptr}; + mgp_type_string(&type_string); + err_code = mgp_func_add_arg(func, "name", type_string); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + + mgp_type *any_type{nullptr}; + mgp_type_any(&any_type); + mgp_type *nullable_type{nullptr}; + mgp_type_nullable(any_type, &nullable_type); + err_code = mgp_func_add_arg(func, "value", nullable_type); + if (err_code != MGP_ERROR_NO_ERROR) { + return 1; + } + } + return 0; +} + +// This is an optional function if you need to release any resources before the +// module is unloaded. You will probably need this if you acquired some +// resources in mgp_init_module. +extern "C" int mgp_shutdown_module() { return 0; } diff --git a/tests/e2e/magic_functions/functions/py_read.py b/tests/e2e/magic_functions/functions/py_read.py new file mode 100644 index 000000000..80700f5bc --- /dev/null +++ b/tests/e2e/magic_functions/functions/py_read.py @@ -0,0 +1,32 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import mgp + + +@mgp.function +def return_function_argument(ctx: mgp.FuncCtx, argument: mgp.Any): + return argument + + +@mgp.function +def return_optional_argument(ctx: mgp.FuncCtx, opt_argument: mgp.Number = 42): + return opt_argument + + +@mgp.function +def add_two_numbers(ctx: mgp.FuncCtx, first: mgp.Number, second: mgp.Number): + return first + second + + +@mgp.function +def return_null(ctx: mgp.FuncCtx): + return None diff --git a/tests/e2e/magic_functions/functions/py_write.py b/tests/e2e/magic_functions/functions/py_write.py new file mode 100644 index 000000000..81c36444d --- /dev/null +++ b/tests/e2e/magic_functions/functions/py_write.py @@ -0,0 +1,17 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import mgp + + +@mgp.function +def try_to_write(ctx: mgp.FuncCtx, argument: mgp.Vertex, name: str, value: mgp.Nullable[mgp.Any]): + argument.properties.set(name, value) diff --git a/tests/e2e/magic_functions/workloads.yaml b/tests/e2e/magic_functions/workloads.yaml new file mode 100644 index 000000000..1f130099d --- /dev/null +++ b/tests/e2e/magic_functions/workloads.yaml @@ -0,0 +1,14 @@ +template_cluster: &template_cluster + cluster: + main: + args: ["--bolt-port", "7687", "--log-level=TRACE"] + log_file: "magic-functions-e2e.log" + setup_queries: [] + validation_queries: [] + +workloads: + - name: "Magic functions runner" + binary: "tests/e2e/pytest_runner.sh" + proc: "tests/e2e/magic_functions/functions/" + args: ["magic_functions/function_example.py"] + <<: *template_cluster diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index acfb535b8..1a6ae24a1 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -129,6 +129,12 @@ target_link_libraries(${test_prefix}query_serialization_property_value mg-query) add_unit_test(query_streams.cpp) target_link_libraries(${test_prefix}query_streams mg-query kafka-mock) +# Test query functions +add_unit_test(query_function_mgp_module.cpp) +target_link_libraries(${test_prefix}query_function_mgp_module mg-query) +target_include_directories(${test_prefix}query_function_mgp_module PRIVATE ${CMAKE_SOURCE_DIR}/include) + + # Test query/procedure add_unit_test(query_procedure_mgp_type.cpp) target_link_libraries(${test_prefix}query_procedure_mgp_type mg-query) diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 3800a9e01..814f3fa7c 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -211,13 +211,19 @@ class MockModule : public procedure::Module { const std::map> *Procedures() const override { return &procedures; } const std::map> *Transformations() const override { return &transformations; } - std::optional Path() const override { return std::nullopt; } + + const std::map> *Functions() const override { return &functions; } + + std::optional Path() const override { return std::nullopt; }; std::map> procedures{}; std::map> transformations{}; + std::map> functions{}; }; void DummyProcCallback(mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, mgp_memory * /*memory*/){}; +void DummyFuncCallback(mgp_list * /*args*/, mgp_func_context * /*func_ctx*/, mgp_func_result * /*result*/, + mgp_memory * /*memory*/){}; enum class ProcedureType { WRITE, READ }; @@ -258,6 +264,15 @@ class CypherMainVisitorTest : public ::testing::TestWithParam &args) { + memgraph::utils::MemoryResource *memory = memgraph::utils::NewDeleteResource(); + mgp_func func(name, DummyFuncCallback, memory); + for (const auto arg : args) { + func.args.emplace_back(memgraph::utils::pmr::string{arg, memory}, &any_type); + } + module.functions.emplace(name, std::move(func)); + } + std::string CreateProcByType(const ProcedureType type, const std::vector &args) { const auto proc_name = std::string{"proc_"} + ToString(type); SCOPED_TRACE(proc_name); @@ -858,6 +873,12 @@ TEST_P(CypherMainVisitorTest, UndefinedFunction) { SemanticException); } +TEST_P(CypherMainVisitorTest, MissingFunction) { + AddFunc(*mock_module, "get", {}); + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN missing_function.get()"), SemanticException); +} + TEST_P(CypherMainVisitorTest, Function) { auto &ast_generator = *GetParam(); auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN abs(n, 2)")); @@ -871,6 +892,20 @@ TEST_P(CypherMainVisitorTest, Function) { ASSERT_TRUE(function->function_); } +TEST_P(CypherMainVisitorTest, MagicFunction) { + AddFunc(*mock_module, "get", {}); + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN mock_module.get()")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1); + auto *function = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(function); + ASSERT_TRUE(function->function_); +} + TEST_P(CypherMainVisitorTest, StringLiteralDoubleQuotes) { auto &ast_generator = *GetParam(); auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN \"mi'rko\"")); diff --git a/tests/unit/query_function_mgp_module.cpp b/tests/unit/query_function_mgp_module.cpp new file mode 100644 index 000000000..457d6a236 --- /dev/null +++ b/tests/unit/query_function_mgp_module.cpp @@ -0,0 +1,49 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include + +#include +#include +#include + +#include "query/procedure/mg_procedure_impl.hpp" + +#include "test_utils.hpp" + +static void DummyCallback(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *){}; + +TEST(Module, InvalidFunctionRegistration) { + mgp_module module(memgraph::utils::NewDeleteResource()); + mgp_func *func{nullptr}; + // Other test cases are covered within the procedure API. This is only sanity check + EXPECT_EQ(mgp_module_add_function(&module, "dashes-not-supported", DummyCallback, &func), MGP_ERROR_INVALID_ARGUMENT); +} + +TEST(Module, RegisterSameFunctionMultipleTimes) { + mgp_module module(memgraph::utils::NewDeleteResource()); + mgp_func *func{nullptr}; + EXPECT_EQ(module.functions.find("same_name"), module.functions.end()); + EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_NO_ERROR); + EXPECT_NE(module.functions.find("same_name"), module.functions.end()); + EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_LOGIC_ERROR); + EXPECT_EQ(mgp_module_add_function(&module, "same_name", DummyCallback, &func), MGP_ERROR_LOGIC_ERROR); + EXPECT_NE(module.functions.find("same_name"), module.functions.end()); +} + +TEST(Module, CaseSensitiveFunctionNames) { + mgp_module module(memgraph::utils::NewDeleteResource()); + mgp_func *func{nullptr}; + EXPECT_EQ(mgp_module_add_function(&module, "not_same", DummyCallback, &func), MGP_ERROR_NO_ERROR); + EXPECT_EQ(mgp_module_add_function(&module, "NoT_saME", DummyCallback, &func), MGP_ERROR_NO_ERROR); + EXPECT_EQ(mgp_module_add_function(&module, "NOT_SAME", DummyCallback, &func), MGP_ERROR_NO_ERROR); + EXPECT_EQ(module.functions.size(), 3U); +}