From ad32db5168edc62d5a3262692a77718abb0751a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A1nos=20Benjamin=20Antal?= Date: Mon, 5 Jul 2021 11:42:40 +0200 Subject: [PATCH] Use transformations in streams and CHECK STREAM (#185) * Use the correct transformation result type * Execute the result queries in streams * Change the result type of parameters to nullable map * Serialize transformation name * Fix order of transformation parameters * Use actual transformation in Streams * Clear the Python transformations under GIL * Add CHECK STREAM query * Handle missing record fields properly --- include/mgp.py | 65 ++++--- src/integrations/kafka/consumer.cpp | 2 + src/memgraph.cpp | 9 +- src/py/py.hpp | 4 +- src/query/discard_value_stream.hpp | 13 ++ src/query/frontend/ast/ast.lcp | 5 +- .../frontend/ast/cypher_main_visitor.cpp | 15 ++ .../frontend/ast/cypher_main_visitor.hpp | 5 + .../opencypher/grammar/MemgraphCypher.g4 | 8 +- .../opencypher/grammar/MemgraphCypherLexer.g4 | 3 +- .../frontend/stripped_lexer_constants.hpp | 4 +- src/query/interpreter.cpp | 12 +- src/query/plan/operator.cpp | 2 +- src/query/procedure/mg_procedure_impl.cpp | 2 +- src/query/procedure/mg_procedure_impl.hpp | 2 +- src/query/procedure/module.cpp | 6 +- src/query/procedure/py_module.cpp | 2 +- src/query/streams.cpp | 179 ++++++++++++++---- src/query/streams.hpp | 11 +- src/query/trigger.cpp | 7 +- tests/unit/cypher_main_visitor.cpp | 40 +++- tests/unit/query_streams.cpp | 8 +- 22 files changed, 289 insertions(+), 115 deletions(-) create mode 100644 src/query/discard_value_stream.hpp diff --git a/include/mgp.py b/include/mgp.py index 025d9fd21..97a930949 100644 --- a/include/mgp.py +++ b/include/mgp.py @@ -190,7 +190,8 @@ class Edge: def __init__(self, edge): if not isinstance(edge, _mgp.Edge): - raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge))) + raise TypeError( + "Expected '_mgp.Edge', got '{}'".format(type(edge))) self._edge = edge def __deepcopy__(self, memo): @@ -268,7 +269,8 @@ class Vertex: def __init__(self, vertex): if not isinstance(vertex, _mgp.Vertex): - raise TypeError("Expected '_mgp.Vertex', got '{}'".format(type(vertex))) + raise TypeError( + "Expected '_mgp.Vertex', got '{}'".format(type(vertex))) self._vertex = vertex def __deepcopy__(self, memo): @@ -404,7 +406,8 @@ class Path: passed in edge is invalid. ''' if not isinstance(edge, Edge): - raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge))) + raise TypeError( + "Expected '_mgp.Edge', got '{}'".format(type(edge))) if not self.is_valid() or not edge.is_valid(): raise InvalidContextError() self._path.expand(edge._edge) @@ -454,7 +457,8 @@ class Vertices: def __init__(self, graph): if not isinstance(graph, _mgp.Graph): - raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) + raise TypeError( + "Expected '_mgp.Graph', got '{}'".format(type(graph))) self._graph = graph self._len = None @@ -499,7 +503,8 @@ class Graph: def __init__(self, graph): if not isinstance(graph, _mgp.Graph): - raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) + raise TypeError( + "Expected '_mgp.Graph', got '{}'".format(type(graph))) self._graph = graph def __deepcopy__(self, memo): @@ -557,7 +562,8 @@ class ProcCtx: def __init__(self, graph): if not isinstance(graph, _mgp.Graph): - raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) + raise TypeError( + "Expected '_mgp.Graph', got '{}'".format(type(graph))) self._graph = Graph(graph) def is_valid(self) -> bool: @@ -676,11 +682,13 @@ def _typing_to_cypher_type(type_): type_args_as_str = parse_type_args(type_as_str) none_type_as_str = type(None).__name__ if none_type_as_str in type_args_as_str: - types = tuple(t for t in type_args_as_str if t != none_type_as_str) + types = tuple( + t for t in type_args_as_str if t != none_type_as_str) if len(types) == 1: type_arg_as_str, = types else: - type_arg_as_str = 'typing.Union[' + ', '.join(types) + ']' + type_arg_as_str = 'typing.Union[' + \ + ', '.join(types) + ']' simple_type = get_simple_type(type_arg_as_str) if simple_type is not None: return _mgp.type_nullable(simple_type) @@ -726,6 +734,7 @@ def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]): if inspect.isgeneratorfunction(func): raise NotImplementedError("Generator functions are not supported") + def read_proc(func: typing.Callable[..., Record]): ''' Register `func` as a a read-only procedure of the current module. @@ -803,17 +812,20 @@ def read_proc(func: typing.Callable[..., Record]): mgp_proc.add_result(name, _typing_to_cypher_type(type_)) return func + class InvalidMessageError(Exception): '''Signals using a message instance outside of the registered transformation.''' pass -class Message: + +class Message: '''Represents a message from a stream.''' __slots__ = ('_message',) def __init__(self, message): if not isinstance(message, _mgp.Message): - raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message))) + raise TypeError( + "Expected '_mgp.Message', got '{}'".format(type(message))) self._message = message def __deepcopy__(self, memo): @@ -829,34 +841,37 @@ class Message: def payload(self) -> bytes: if not self.is_valid(): raise InvalidMessageError() - return self._messages._payload(_message) + return self._message.payload() def topic_name(self) -> str: if not self.is_valid(): raise InvalidMessageError() - return self._messages._topic_name(_message) + return self._message.topic_name() - def key() -> bytes: + def key(self) -> bytes: if not self.is_valid(): raise InvalidMessageError() - return self._messages.key(_message) - - def timestamp() -> int: + return self._message.key() + + def timestamp(self) -> int: if not self.is_valid(): raise InvalidMessageError() - return self._messages.timestamp(_message) + return self._message.timestamp() + class InvalidMessagesError(Exception): '''Signals using a messages instance outside of the registered transformation.''' pass + class Messages: '''Represents a list of messages from a stream.''' __slots__ = ('_messages',) def __init__(self, messages): if not isinstance(messages, _mgp.Messages): - raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages))) + raise TypeError( + "Expected '_mgp.Messages', got '{}'".format(type(messages))) self._messages = messages def __deepcopy__(self, memo): @@ -869,18 +884,19 @@ class Messages: '''Return True if `self` is in valid context and may be used.''' return self._messages.is_valid() - def message_at(self, id : int) -> Message: + def message_at(self, id: int) -> Message: '''Raise InvalidMessagesError if context is invalid.''' if not self.is_valid(): raise InvalidMessagesError() return Message(self._messages.message_at(id)) - def total_messages() -> int: + def total_messages(self) -> int: '''Raise InvalidContextError if context is invalid.''' if not self.is_valid(): raise InvalidMessagesError() return self._messages.total_messages() + class TransCtx: '''Context of a transformation being executed. @@ -891,8 +907,9 @@ class TransCtx: def __init__(self, graph): if not isinstance(graph, _mgp.Graph): - raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) - self._graph = Graph(graph) + raise TypeError( + "Expected '_mgp.Graph', got '{}'".format(type(graph))) + self._graph = Graph(graph) def is_valid(self) -> bool: return self._graph.is_valid() @@ -904,13 +921,15 @@ class TransCtx: raise InvalidContextError() return self._graph + def transformation(func: typing.Callable[..., Record]): raise_if_does_not_meet_requirements(func) sig = inspect.signature(func) params = tuple(sig.parameters.values()) if not params or not params[0].annotation is Messages: if not len(params) == 2 or not params[1].annotation is Messages: - raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)") + raise NotImplementedError( + "Valid signatures for transformations are (TransCtx, Messages) or (Messages)") if params[0].annotation is TransCtx: @functools.wraps(func) def wrapper(graph, messages): diff --git a/src/integrations/kafka/consumer.cpp b/src/integrations/kafka/consumer.cpp index 66f446633..929c153a6 100644 --- a/src/integrations/kafka/consumer.cpp +++ b/src/integrations/kafka/consumer.cpp @@ -44,6 +44,7 @@ utils::BasicResult> GetBatch(RdKafka::KafkaCon break; default: + // TODO(antaljanosbenjamin): handle RD_KAFKA_RESP_ERR__MAX_POLL_EXCEEDED auto error = msg->errstr(); spdlog::warn("Unexpected error while consuming message in consumer {}, error: {}!", info.consumer_name, msg->errstr()); @@ -307,6 +308,7 @@ void Consumer::StartConsuming() { spdlog::warn("Error happened in consumer {} while processing a batch: {}!", info_.consumer_name, e.what()); break; } + spdlog::info("Kafka consumer {} finished processing", info_.consumer_name); } is_running_.store(false); }); diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 05970eb05..8add7d27b 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -23,6 +23,7 @@ #include "communication/bolt/v1/constants.hpp" #include "helpers.hpp" #include "py/py.hpp" +#include "query/discard_value_stream.hpp" #include "query/exceptions.hpp" #include "query/interpreter.hpp" #include "query/plan/operator.hpp" @@ -433,7 +434,7 @@ class BoltSession final : public communication::bolt::Session Discard(std::optional n, std::optional qid) override { - DiscardValueResultStream stream; + query::DiscardValueResultStream stream; return PullResults(stream, n, qid); } @@ -517,12 +518,6 @@ class BoltSession final : public communication::bolt::Session &) { - // do nothing - } - }; - // NOTE: Needed only for ToBoltValue conversions const storage::Storage *db_; query::Interpreter interpreter_; diff --git a/src/py/py.hpp b/src/py/py.hpp index 366991630..cb37fe872 100644 --- a/src/py/py.hpp +++ b/src/py/py.hpp @@ -29,8 +29,8 @@ class EnsureGIL final { PyGILState_STATE gil_state_; public: - EnsureGIL() : gil_state_(PyGILState_Ensure()) {} - ~EnsureGIL() { PyGILState_Release(gil_state_); } + EnsureGIL() noexcept : gil_state_(PyGILState_Ensure()) {} + ~EnsureGIL() noexcept { PyGILState_Release(gil_state_); } EnsureGIL(const EnsureGIL &) = delete; EnsureGIL(EnsureGIL &&) = delete; EnsureGIL &operator=(const EnsureGIL &) = delete; diff --git a/src/query/discard_value_stream.hpp b/src/query/discard_value_stream.hpp new file mode 100644 index 000000000..170e33ed1 --- /dev/null +++ b/src/query/discard_value_stream.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include + +#include "query/typed_value.hpp" + +namespace query { +struct DiscardValueResultStream final { + void Result(const std::vector & /*values*/) { + // do nothing + } +}; +} // namespace query diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index ca75dc74d..fc7c80203 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2476,12 +2476,15 @@ cpp<# :slk-save #'slk-save-ast-pointer :slk-load (slk-load-ast-pointer "Expression")) (batch_size "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (batch_limit "Expression *" :initval "nullptr" :scope :public :slk-save #'slk-save-ast-pointer :slk-load (slk-load-ast-pointer "Expression"))) (:public (lcp:define-enum action - (create-stream drop-stream start-stream stop-stream start-all-streams stop-all-streams show-streams test-stream) + (create-stream drop-stream start-stream stop-stream start-all-streams stop-all-streams show-streams check-stream) (:serialize)) #>cpp StreamQuery() = default; diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 9d341aab0..cacad8148 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -472,6 +472,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateStream(MemgraphCypher::CreateStreamC auto *topic_names_ctx = ctx->topicNames(); MG_ASSERT(topic_names_ctx != nullptr); + // TODO(antaljanosbenjamin): Add dash auto topic_names = topic_names_ctx->symbolicNameWithDots(); MG_ASSERT(!topic_names.empty()); stream_query->topic_names_.reserve(topic_names.size()); @@ -540,6 +541,20 @@ antlrcpp::Any CypherMainVisitor::visitShowStreams(MemgraphCypher::ShowStreamsCon return stream_query; } +antlrcpp::Any CypherMainVisitor::visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) { + auto *stream_query = storage_->Create(); + stream_query->action_ = StreamQuery::Action::CHECK_STREAM; + stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as(); + + if (ctx->BATCH_LIMIT()) { + if (!ctx->batchLimit->numberLiteral() || !ctx->batchLimit->numberLiteral()->integerLiteral()) { + throw SemanticException("Batch limit should be an integer literal!"); + } + stream_query->batch_limit_ = ctx->batchLimit->accept(this); + } + return stream_query; +} + antlrcpp::Any CypherMainVisitor::visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) { bool distinct = !ctx->ALL(); auto *cypher_union = storage_->Create(distinct); diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index de108909a..f33a48bd4 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -288,6 +288,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitShowStreams(MemgraphCypher::ShowStreamsContext *ctx) override; + /** + * @return StreamQuery* + */ + antlrcpp::Any visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) override; + /** * @return CypherUnion* */ diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index d638df7e2..d0613d38a 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -12,10 +12,11 @@ memgraphCypherKeyword : cypherKeyword | ASYNC | AUTH | BAD - | BATCHES | BATCH_INTERVAL + | BATCH_LIMIT | BATCH_SIZE | BEFORE + | CHECK | CLEAR | COMMIT | COMMITTED @@ -141,7 +142,8 @@ clause : cypherMatch | loadCsv ; -streamQuery : createStream +streamQuery : checkStream + | createStream | dropStream | startStream | startAllStreams @@ -289,3 +291,5 @@ stopStream : STOP STREAM streamName ; stopAllStreams : STOP ALL STREAMS ; showStreams : SHOW STREAMS ; + +checkStream : CHECK STREAM streamName ( BATCH_LIMIT batchLimit=literal ) ? ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 7669a5eaa..9277efd2d 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -17,10 +17,11 @@ ALTER : A L T E R ; ASYNC : A S Y N C ; AUTH : A U T H ; BAD : B A D ; -BATCHES : B A T C H E S ; BATCH_INTERVAL : B A T C H UNDERSCORE I N T E R V A L ; +BATCH_LIMIT : B A T C H UNDERSCORE L I M I T ; BATCH_SIZE : B A T C H UNDERSCORE S I Z E ; BEFORE : B E F O R E ; +CHECK : C H E C K ; CLEAR : C L E A R ; COMMIT : C O M M I T ; COMMITTED : C O M M I T T E D ; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index ebd327d79..73bd83346 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -127,11 +127,11 @@ const trie::Trie kKeywords = {"union", "all", "level", "next", "read", "session", "snapshot", "transaction", - "batches", "batch_interval", + "batch_limit", "batch_interval", "batch_size", "consumer_group", "start", "stream", "streams", "transform", - "topics"}; + "topics", "check"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset kUnescapedNameAllowedStarts( diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index a56863cbc..3cbd8d7aa 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -490,7 +490,7 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete .consumer_group = std::move(consumer_group), .batch_interval = batch_interval, .batch_size = batch_size, - .transformation_name = "transform.trans"}); + .transformation_name = transformation_name}); return std::vector>{}; }; return callback; @@ -575,8 +575,14 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete return results; }; return callback; - case StreamQuery::Action::TEST_STREAM: - throw std::logic_error("not implemented"); + } + case StreamQuery::Action::CHECK_STREAM: { + callback.header = {"query", "parameters"}; + callback.fn = [interpreter_context, stream_name = stream_query->stream_name_, + batch_limit = GetOptionalValue(stream_query->batch_limit_, evaluator)]() mutable { + return interpreter_context->streams.Test(stream_name, batch_limit); + }; + return callback; } } } diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index fc9d3faa6..8dd413c27 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -3761,7 +3761,7 @@ class CallProcedureCursor : public Cursor { std::string_view field_name(self_->result_fields_[i]); auto result_it = values.find(field_name); if (result_it == values.end()) { - throw QueryRuntimeException("Procedure '{}' does not yield a record with '{}' field.", self_->procedure_name_, + throw QueryRuntimeException("Procedure '{}' did not yield a record with '{}' field.", self_->procedure_name_, field_name); } frame[self_->result_symbols_[i]] = result_it->second; diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 7f4d9e95a..cf345b371 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -1372,7 +1372,7 @@ bool MgpTransAddFixedResult(mgp_trans *trans) { if (int err = AddResultToProp(trans, "query", mgp_type_string(), false); err != 1) { return err; } - return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_list(mgp_type_any())), false); + return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_map()), false); } int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) { diff --git a/src/query/procedure/mg_procedure_impl.hpp b/src/query/procedure/mg_procedure_impl.hpp index b39651dc3..313d581e3 100644 --- a/src/query/procedure/mg_procedure_impl.hpp +++ b/src/query/procedure/mg_procedure_impl.hpp @@ -549,7 +549,7 @@ bool IsValidIdentifierName(const char *name); } // namespace query::procedure struct mgp_message { - integrations::kafka::Message *msg; + const integrations::kafka::Message *msg; }; struct mgp_messages { diff --git a/src/query/procedure/module.cpp b/src/query/procedure/module.cpp index 8050cba5f..4d4d291f0 100644 --- a/src/query/procedure/module.cpp +++ b/src/query/procedure/module.cpp @@ -418,11 +418,11 @@ bool PythonModule::Load(const std::filesystem::path &file_path) { bool PythonModule::Close() { MG_ASSERT(py_module_, "Attempting to close a module that has not been loaded..."); spdlog::info("Closing module {}...", file_path_); - // The procedures are closures which hold references to the Python callbacks. - // Releasing these references might result in deallocations so we need to take - // the GIL. + // The procedures and transformations are closures which hold references to the Python callbacks. + // Releasing these references might result in deallocations so we need to take the GIL. auto gil = py::EnsureGIL(); procedures_.clear(); + transformations_.clear(); // Delete the module from the `sys.modules` directory so that the module will // be properly imported if imported again. py::Object sys(PyImport_ImportModule("sys")); diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 157a47cb6..09867af73 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -782,7 +782,7 @@ void CallPythonTransformation(const py::Object &py_cb, const mgp_messages *msgs, }; auto call = [&](py::Object py_graph, py::Object py_messages) -> std::optional { - auto py_res = py_cb.Call(py_messages, py_graph); + auto py_res = py_cb.Call(py_graph, py_messages); if (!py_res) return py::FetchError(); if (PySequence_Check(py_res.Ptr())) { return AddMultipleRecordsFromPython(result, py_res); diff --git a/src/query/streams.cpp b/src/query/streams.cpp index 5e0fe6f22..6f2250b34 100644 --- a/src/query/streams.cpp +++ b/src/query/streams.cpp @@ -6,23 +6,92 @@ #include #include +#include "query/db_accessor.hpp" +#include "query/discard_value_stream.hpp" #include "query/interpreter.hpp" +#include "query/procedure/mg_procedure_impl.hpp" +#include "query/procedure/module.hpp" +#include "query/typed_value.hpp" +#include "utils/memory.hpp" #include "utils/on_scope_exit.hpp" +#include "utils/pmr/string.hpp" namespace query { +using Consumer = integrations::kafka::Consumer; +using ConsumerInfo = integrations::kafka::ConsumerInfo; +using Message = integrations::kafka::Message; namespace { +constexpr auto kExpectedTransformationResultSize = 2; +const utils::pmr::string query_param_name{"query", utils::NewDeleteResource()}; +const utils::pmr::string params_param_name{"parameters", utils::NewDeleteResource()}; +const std::map empty_parameters{}; + auto GetStream(auto &map, const std::string &stream_name) { if (auto it = map.find(stream_name); it != map.end()) { return it; } throw StreamsException("Couldn't find stream '{}'", stream_name); } -} // namespace -using Consumer = integrations::kafka::Consumer; -using ConsumerInfo = integrations::kafka::ConsumerInfo; -using Message = integrations::kafka::Message; +void CallCustomTransformation(const std::string &transformation_name, const std::vector &messages, + mgp_result &result, storage::Storage::Accessor &storage_accessor, + utils::MemoryResource &memory_resource, const std::string &stream_name) { + DbAccessor db_accessor{&storage_accessor}; + { + auto maybe_transformation = + procedure::FindTransformation(procedure::gModuleRegistry, transformation_name, utils::NewDeleteResource()); + + if (!maybe_transformation) { + throw StreamsException("Couldn't find transformation {} for stream '{}'", transformation_name, stream_name); + }; + const auto &trans = *maybe_transformation->second; + mgp_messages mgp_messages{mgp_messages::storage_type{&memory_resource}}; + std::transform(messages.begin(), messages.end(), std::back_inserter(mgp_messages.messages), + [](const integrations::kafka::Message &message) { return mgp_message{&message}; }); + mgp_graph graph{&db_accessor, storage::View::OLD, nullptr}; + mgp_memory memory{&memory_resource}; + result.rows.clear(); + result.error_msg.reset(); + result.signature = &trans.results; + + MG_ASSERT(result.signature->size() == kExpectedTransformationResultSize); + MG_ASSERT(result.signature->contains(query_param_name)); + MG_ASSERT(result.signature->contains(params_param_name)); + + spdlog::trace("Calling transformation in stream '{}'", stream_name); + trans.cb(&mgp_messages, &graph, &result, &memory); + } + if (result.error_msg.has_value()) { + throw StreamsException(result.error_msg->c_str()); + } +} + +std::pair ExtractTransformationResult( + utils::pmr::map &&values, const std::string_view transformation_name, + const std::string_view stream_name) { + if (values.size() != kExpectedTransformationResultSize) { + throw StreamsException( + "Transformation '{}' in stream '{}' did not yield all fields (query, parameters) as required.", + transformation_name, stream_name); + } + + auto get_value = [&](const utils::pmr::string &field_name) mutable -> TypedValue & { + auto it = values.find(field_name); + if (it == values.end()) { + throw StreamsException{"Transformation '{}' in stream '{}' did not yield a record with '{}' field.", + transformation_name, stream_name, field_name}; + }; + return it->second; + }; + + auto &query_value = get_value(query_param_name); + MG_ASSERT(query_value.IsString()); + auto ¶ms_value = get_value(params_param_name); + MG_ASSERT(params_value.IsNull() || params_value.IsMap()); + return {std::move(query_value), std::move(params_value)}; +} +} // namespace // nlohmann::json doesn't support string_view access yet const std::string kStreamName{"name"}; @@ -31,6 +100,7 @@ const std::string kConsumerGroupKey{"consumer_group"}; const std::string kBatchIntervalKey{"batch_interval"}; const std::string kBatchSizeKey{"batch_size"}; const std::string kIsRunningKey{"is_running"}; +const std::string kTransformationName{"transformation_name"}; void to_json(nlohmann::json &data, StreamStatus &&status) { auto &info = status.info; @@ -51,6 +121,7 @@ void to_json(nlohmann::json &data, StreamStatus &&status) { } data[kIsRunningKey] = status.is_running; + data[kTransformationName] = status.info.transformation_name; } void from_json(const nlohmann::json &data, StreamStatus &status) { @@ -75,6 +146,7 @@ void from_json(const nlohmann::json &data, StreamStatus &status) { } data.at(kIsRunningKey).get_to(status.is_running); + data.at(kTransformationName).get_to(status.info.transformation_name); } Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_servers, @@ -89,8 +161,8 @@ void Streams::RestoreStreams() { MG_ASSERT(locked_streams_map->empty(), "Cannot restore streams when some streams already exist!"); for (const auto &[stream_name, stream_data] : storage_) { - const auto get_failed_message = [](const std::string_view stream_name, const std::string_view message, - const std::string_view nested_message) { + const auto get_failed_message = [&stream_name = stream_name](const std::string_view message, + const std::string_view nested_message) { return fmt::format("Failed to load stream '{}', because: {} caused by {}", stream_name, message, nested_message); }; @@ -98,21 +170,22 @@ void Streams::RestoreStreams() { try { nlohmann::json::parse(stream_data).get_to(status); } catch (const nlohmann::json::type_error &exception) { - spdlog::warn(get_failed_message(stream_name, "invalid type conversion", exception.what())); + spdlog::warn(get_failed_message("invalid type conversion", exception.what())); continue; } catch (const nlohmann::json::out_of_range &exception) { - spdlog::warn(get_failed_message(stream_name, "non existing field", exception.what())); + spdlog::warn(get_failed_message("non existing field", exception.what())); continue; } - MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", stream_name, status.name); + MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name); try { auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info)); if (status.is_running) { it->second.consumer->Lock()->Start(); } + spdlog::info("Stream '{}' is loaded", stream_name); } catch (const utils::BasicException &exception) { - spdlog::warn(get_failed_message(stream_name, "unexpected error", exception.what())); + spdlog::warn(get_failed_message("unexpected error", exception.what())); } } } @@ -200,27 +273,38 @@ std::vector Streams::GetStreamInfo() const { } TransformationResult Streams::Test(const std::string &stream_name, std::optional batch_limit) const { - TransformationResult result; - auto consumer_function = [&result](const std::vector &messages) { - for (const auto &message : messages) { - // TODO(antaljanosbenjamin) Update the logic with using the transform from modules - const auto payload = message.Payload(); - const std::string_view payload_as_string_view{payload.data(), payload.size()}; - spdlog::info("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view); - result[fmt::format("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view)] = "replace with params"; + // This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after that + auto [locked_consumer, + transformation_name] = [this, &stream_name]() -> std::pair { + auto locked_streams = streams_.ReadLock(); + auto it = GetStream(*locked_streams, stream_name); + return {it->second.consumer->ReadLock(), it->second.transformation_name}; + }(); + + auto *memory_resource = utils::NewDeleteResource(); + mgp_result result{nullptr, memory_resource}; + TransformationResult test_result; + + auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, &stream_name, + &transformation_name = transformation_name, &result, + &test_result](const std::vector &messages) mutable { + auto accessor = interpreter_context->db->Access(); + CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name); + + for (auto &row : result.rows) { + auto [query, parameters] = ExtractTransformationResult(std::move(row.values), transformation_name, stream_name); + std::vector result_row; + result_row.reserve(kExpectedTransformationResultSize); + result_row.push_back(std::move(query)); + result_row.push_back(std::move(parameters)); + + test_result.push_back(std::move(result_row)); } }; - // This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after that - auto locked_consumer = [this, &stream_name] { - auto locked_streams = streams_.ReadLock(); - auto it = GetStream(*locked_streams, stream_name); - return it->second.consumer->ReadLock(); - }(); - locked_consumer->Test(batch_limit, consumer_function); - return result; + return test_result; } StreamStatus Streams::CreateStatus(const std::string &name, const std::string &transformation_name, @@ -243,24 +327,37 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std throw StreamsException{"Stream already exists with name '{}'", stream_name}; } - auto consumer_function = [interpreter_context = - interpreter_context_](const std::vector &messages) { - Interpreter interpreter = Interpreter{interpreter_context}; - TransformationResult result; + auto *memory_resource = utils::NewDeleteResource(); - for (const auto &message : messages) { - // TODO(antaljanosbenjamin) Update the logic with using the transform from modules - const auto payload = message.Payload(); - const std::string_view payload_as_string_view{payload.data(), payload.size()}; - result[fmt::format("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view)] = "replace with params"; + auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name, + transformation_name = stream_info.transformation_name, + interpreter = std::make_shared(interpreter_context_), + result = mgp_result{nullptr, memory_resource}]( + const std::vector &messages) mutable { + auto accessor = interpreter_context->db->Access(); + CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name); + + DiscardValueResultStream stream; + + spdlog::trace("Start transaction in stream '{}'", stream_name); + interpreter->BeginTransaction(); + + for (auto &row : result.rows) { + spdlog::trace("Processing row in stream '{}'", stream_name); + auto [query_value, params_value] = + ExtractTransformationResult(std::move(row.values), transformation_name, stream_name); + storage::PropertyValue params_prop{params_value}; + + std::string query{query_value.ValueString()}; + spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name); + auto prepare_result = + interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap()); + interpreter->PullAll(&stream); } - for (const auto &[query, params] : result) { - // auto prepared_query = interpreter.Prepare(query, {}); - spdlog::info("Executing query '{}'", query); - // TODO(antaljanosbenjamin) run the query in real life, try not to copy paste the whole execution code, but - // extract it to a function that can be called from multiple places (e.g: triggers) - } + spdlog::trace("Commit transaction in stream '{}'", stream_name); + interpreter->CommitTransaction(); + result.rows.clear(); }; ConsumerInfo consumer_info{ diff --git a/src/query/streams.hpp b/src/query/streams.hpp index 42117b015..68f559b57 100644 --- a/src/query/streams.hpp +++ b/src/query/streams.hpp @@ -8,6 +8,7 @@ #include "integrations/kafka/consumer.hpp" #include "kvstore/kvstore.hpp" +#include "query/typed_value.hpp" #include "utils/exceptions.hpp" #include "utils/rw_lock.hpp" #include "utils/synchronized.hpp" @@ -19,8 +20,7 @@ class StreamsException : public utils::BasicException { using BasicException::BasicException; }; -// TODO(antaljanosbenjamin) Replace this with mgp_trans related thing -using TransformationResult = std::map; +using TransformationResult = std::vector>; using TransformFunction = std::function &)>; struct StreamInfo { @@ -28,7 +28,6 @@ struct StreamInfo { std::string consumer_group; std::optional batch_interval; std::optional batch_size; - // TODO(antaljanosbenjamin) How to reference the transformation in a better way? std::string transformation_name; }; @@ -41,9 +40,7 @@ struct StreamStatus { using SynchronizedConsumer = utils::Synchronized; struct StreamData { - // TODO(antaljanosbenjamin) How to reference the transformation in a better way? std::string transformation_name; - // TODO(antaljanosbenjamin) consider propagate_const std::unique_ptr consumer; }; @@ -120,8 +117,8 @@ class Streams final { /// @param stream_name name of the stream we want to test /// @param batch_limit number of batches we want to test before stopping /// - /// TODO(antaljanosbenjamin) add type of parameters - /// @returns A vector of pairs consisting of the query (std::string) and its parameters ... + /// @returns A vector of vectors of TypedValue. Each subvector contains two elements, the query string and the + /// nullable parameters map. /// /// @throws StreamsException if the stream doesn't exist /// @throws ConsumerRunningException if the consumer is alredy running diff --git a/src/query/trigger.cpp b/src/query/trigger.cpp index bf58e0ea1..1c5bd39ff 100644 --- a/src/query/trigger.cpp +++ b/src/query/trigger.cpp @@ -244,12 +244,11 @@ void TriggerStore::RestoreTriggers(utils::SkipList *query_cache spdlog::info("Loading triggers..."); for (const auto &[trigger_name, trigger_data] : storage_) { - // structured binding cannot be captured in lambda - const auto get_failed_message = [](const std::string_view trigger_name, const std::string_view message) { + const auto get_failed_message = [&trigger_name = trigger_name](const std::string_view message) { return fmt::format("Failed to load trigger '{}'. {}", trigger_name, message); }; - const auto invalid_state_message = get_failed_message(trigger_name, "Invalid state of the trigger data."); + const auto invalid_state_message = get_failed_message("Invalid state of the trigger data."); spdlog::debug("Loading trigger '{}'", trigger_name); auto json_trigger_data = nlohmann::json::parse(trigger_data); @@ -259,7 +258,7 @@ void TriggerStore::RestoreTriggers(utils::SkipList *query_cache continue; } if (json_trigger_data["version"] != kVersion) { - spdlog::warn(get_failed_message(trigger_name, "Invalid version of the trigger data.")); + spdlog::warn(get_failed_message("Invalid version of the trigger data.")); continue; } diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 18b5465a9..b1b5cec1f 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -3207,8 +3207,16 @@ TEST_P(CypherMainVisitorTest, CreateSnapshotQuery) { ASSERT_TRUE(dynamic_cast(ast_generator.ParseQuery("CREATE SNAPSHOT"))); } +void CheckOptionalExpression(Base &ast_generator, Expression *expression, const std::optional &expected) { + EXPECT_EQ(expression != nullptr, expected.has_value()); + if (expected.has_value()) { + EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected)); + } +}; + void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &query_string, - const StreamQuery::Action action, const std::string_view stream_name) { + const StreamQuery::Action action, const std::string_view stream_name, + const std::optional &batch_limit = std::nullopt) { auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query_string)); ASSERT_NE(parsed_query, nullptr); EXPECT_EQ(parsed_query->action_, action); @@ -3218,6 +3226,7 @@ void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &quer EXPECT_TRUE(parsed_query->consumer_group_.empty()); EXPECT_EQ(parsed_query->batch_interval_, nullptr); EXPECT_EQ(parsed_query->batch_size_, nullptr); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_limit_, batch_limit)); } TEST_P(CypherMainVisitorTest, DropStream) { @@ -3292,18 +3301,13 @@ void ValidateCreateStreamQuery(Base &ast_generator, const std::string &query_str ASSERT_NO_THROW(parsed_query = dynamic_cast(ast_generator.ParseQuery(query_string))) << query_string; ASSERT_NE(parsed_query, nullptr); EXPECT_EQ(parsed_query->stream_name_, stream_name); - auto check_expression = [&](Expression *expression, const std::optional &expected) { - EXPECT_EQ(expression != nullptr, expected.has_value()); - if (expected.has_value()) { - EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected)); - } - }; EXPECT_EQ(parsed_query->topic_names_, topic_names); EXPECT_EQ(parsed_query->transform_name_, transform_name); EXPECT_EQ(parsed_query->consumer_group_, consumer_group); - EXPECT_NO_FATAL_FAILURE(check_expression(parsed_query->batch_interval_, batch_interval)); - EXPECT_NO_FATAL_FAILURE(check_expression(parsed_query->batch_size_, batch_size)); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_interval_, batch_interval)); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_size_, batch_size)); + EXPECT_EQ(parsed_query->batch_limit_, nullptr); } TEST_P(CypherMainVisitorTest, CreateStream) { @@ -3380,4 +3384,22 @@ TEST_P(CypherMainVisitorTest, CreateStream) { EXPECT_NO_FATAL_FAILURE(check_topic_names({topic_name1, topic_name2})); } +TEST_P(CypherMainVisitorTest, CheckStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("CHECK STREAM", ast_generator); + TestInvalidQuery("CHECK STREAMS", ast_generator); + TestInvalidQuery("CHECK STREAMS something", ast_generator); + TestInvalidQuery("CHECK STREAM something,something", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH LIMIT 1", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT 'it should be an integer'", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT 2.5", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream", StreamQuery::Action::CHECK_STREAM, + "checkedStream"); + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream bAtCH_LIMIT 42", + StreamQuery::Action::CHECK_STREAM, "checkedStream", TypedValue(42)); +} + } // namespace diff --git a/tests/unit/query_streams.cpp b/tests/unit/query_streams.cpp index e5cc20ad2..3d254c836 100644 --- a/tests/unit/query_streams.cpp +++ b/tests/unit/query_streams.cpp @@ -32,8 +32,7 @@ StreamInfo CreateDefaultStreamInfo() { .consumer_group = "ConsumerGroup " + GetDefaultStreamName(), .batch_interval = std::nullopt, .batch_size = std::nullopt, - // TODO(antaljanosbenjamin) Add proper reference once Streams supports that - .transformation_name = "not yet used", + .transformation_name = "not used in the tests", }; } @@ -79,8 +78,7 @@ class StreamsTest : public ::testing::Test { EXPECT_EQ(check_data.info.consumer_group, status.info.consumer_group); EXPECT_EQ(check_data.info.batch_interval, status.info.batch_interval); EXPECT_EQ(check_data.info.batch_size, status.info.batch_size); - // TODO(antaljanosbenjamin) Add proper reference once Streams supports that - // EXPECT_EQ(check_data.info.transformation_name, status.info.transformation_name); + EXPECT_EQ(check_data.info.transformation_name, status.info.transformation_name); EXPECT_EQ(check_data.is_running, status.is_running); } @@ -227,5 +225,3 @@ TEST_F(StreamsTest, RestoreStreams) { check_restore_logic(); } } - -// TODO(antaljanosbenjamin) Add tests for Streams::Test method and transformation