Use transformations in streams and CHECK STREAM (#185)
* Use the correct transformation result type * Execute the result queries in streams * Change the result type of parameters to nullable map * Serialize transformation name * Fix order of transformation parameters * Use actual transformation in Streams * Clear the Python transformations under GIL * Add CHECK STREAM query * Handle missing record fields properly
This commit is contained in:
parent
a37755ce43
commit
ad32db5168
@ -190,7 +190,8 @@ class Edge:
|
||||
|
||||
def __init__(self, edge):
|
||||
if not isinstance(edge, _mgp.Edge):
|
||||
raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
|
||||
self._edge = edge
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -268,7 +269,8 @@ class Vertex:
|
||||
|
||||
def __init__(self, vertex):
|
||||
if not isinstance(vertex, _mgp.Vertex):
|
||||
raise TypeError("Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
|
||||
self._vertex = vertex
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -404,7 +406,8 @@ class Path:
|
||||
passed in edge is invalid.
|
||||
'''
|
||||
if not isinstance(edge, Edge):
|
||||
raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
|
||||
if not self.is_valid() or not edge.is_valid():
|
||||
raise InvalidContextError()
|
||||
self._path.expand(edge._edge)
|
||||
@ -454,7 +457,8 @@ class Vertices:
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = graph
|
||||
self._len = None
|
||||
|
||||
@ -499,7 +503,8 @@ class Graph:
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = graph
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -557,7 +562,8 @@ class ProcCtx:
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = Graph(graph)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
@ -676,11 +682,13 @@ def _typing_to_cypher_type(type_):
|
||||
type_args_as_str = parse_type_args(type_as_str)
|
||||
none_type_as_str = type(None).__name__
|
||||
if none_type_as_str in type_args_as_str:
|
||||
types = tuple(t for t in type_args_as_str if t != none_type_as_str)
|
||||
types = tuple(
|
||||
t for t in type_args_as_str if t != none_type_as_str)
|
||||
if len(types) == 1:
|
||||
type_arg_as_str, = types
|
||||
else:
|
||||
type_arg_as_str = 'typing.Union[' + ', '.join(types) + ']'
|
||||
type_arg_as_str = 'typing.Union[' + \
|
||||
', '.join(types) + ']'
|
||||
simple_type = get_simple_type(type_arg_as_str)
|
||||
if simple_type is not None:
|
||||
return _mgp.type_nullable(simple_type)
|
||||
@ -726,6 +734,7 @@ def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
|
||||
if inspect.isgeneratorfunction(func):
|
||||
raise NotImplementedError("Generator functions are not supported")
|
||||
|
||||
|
||||
def read_proc(func: typing.Callable[..., Record]):
|
||||
'''
|
||||
Register `func` as a a read-only procedure of the current module.
|
||||
@ -803,17 +812,20 @@ def read_proc(func: typing.Callable[..., Record]):
|
||||
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
|
||||
return func
|
||||
|
||||
|
||||
class InvalidMessageError(Exception):
|
||||
'''Signals using a message instance outside of the registered transformation.'''
|
||||
pass
|
||||
|
||||
|
||||
class Message:
|
||||
'''Represents a message from a stream.'''
|
||||
__slots__ = ('_message',)
|
||||
|
||||
def __init__(self, message):
|
||||
if not isinstance(message, _mgp.Message):
|
||||
raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Message', got '{}'".format(type(message)))
|
||||
self._message = message
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -829,34 +841,37 @@ class Message:
|
||||
def payload(self) -> bytes:
|
||||
if not self.is_valid():
|
||||
raise InvalidMessageError()
|
||||
return self._messages._payload(_message)
|
||||
return self._message.payload()
|
||||
|
||||
def topic_name(self) -> str:
|
||||
if not self.is_valid():
|
||||
raise InvalidMessageError()
|
||||
return self._messages._topic_name(_message)
|
||||
return self._message.topic_name()
|
||||
|
||||
def key() -> bytes:
|
||||
def key(self) -> bytes:
|
||||
if not self.is_valid():
|
||||
raise InvalidMessageError()
|
||||
return self._messages.key(_message)
|
||||
return self._message.key()
|
||||
|
||||
def timestamp() -> int:
|
||||
def timestamp(self) -> int:
|
||||
if not self.is_valid():
|
||||
raise InvalidMessageError()
|
||||
return self._messages.timestamp(_message)
|
||||
return self._message.timestamp()
|
||||
|
||||
|
||||
class InvalidMessagesError(Exception):
|
||||
'''Signals using a messages instance outside of the registered transformation.'''
|
||||
pass
|
||||
|
||||
|
||||
class Messages:
|
||||
'''Represents a list of messages from a stream.'''
|
||||
__slots__ = ('_messages',)
|
||||
|
||||
def __init__(self, messages):
|
||||
if not isinstance(messages, _mgp.Messages):
|
||||
raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Messages', got '{}'".format(type(messages)))
|
||||
self._messages = messages
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -869,18 +884,19 @@ class Messages:
|
||||
'''Return True if `self` is in valid context and may be used.'''
|
||||
return self._messages.is_valid()
|
||||
|
||||
def message_at(self, id : int) -> Message:
|
||||
def message_at(self, id: int) -> Message:
|
||||
'''Raise InvalidMessagesError if context is invalid.'''
|
||||
if not self.is_valid():
|
||||
raise InvalidMessagesError()
|
||||
return Message(self._messages.message_at(id))
|
||||
|
||||
def total_messages() -> int:
|
||||
def total_messages(self) -> int:
|
||||
'''Raise InvalidContextError if context is invalid.'''
|
||||
if not self.is_valid():
|
||||
raise InvalidMessagesError()
|
||||
return self._messages.total_messages()
|
||||
|
||||
|
||||
class TransCtx:
|
||||
'''Context of a transformation being executed.
|
||||
|
||||
@ -891,7 +907,8 @@ class TransCtx:
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
raise TypeError(
|
||||
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = Graph(graph)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
@ -904,13 +921,15 @@ class TransCtx:
|
||||
raise InvalidContextError()
|
||||
return self._graph
|
||||
|
||||
|
||||
def transformation(func: typing.Callable[..., Record]):
|
||||
raise_if_does_not_meet_requirements(func)
|
||||
sig = inspect.signature(func)
|
||||
params = tuple(sig.parameters.values())
|
||||
if not params or not params[0].annotation is Messages:
|
||||
if not len(params) == 2 or not params[1].annotation is Messages:
|
||||
raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
|
||||
raise NotImplementedError(
|
||||
"Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
|
||||
if params[0].annotation is TransCtx:
|
||||
@functools.wraps(func)
|
||||
def wrapper(graph, messages):
|
||||
|
@ -44,6 +44,7 @@ utils::BasicResult<std::string, std::vector<Message>> GetBatch(RdKafka::KafkaCon
|
||||
break;
|
||||
|
||||
default:
|
||||
// TODO(antaljanosbenjamin): handle RD_KAFKA_RESP_ERR__MAX_POLL_EXCEEDED
|
||||
auto error = msg->errstr();
|
||||
spdlog::warn("Unexpected error while consuming message in consumer {}, error: {}!", info.consumer_name,
|
||||
msg->errstr());
|
||||
@ -307,6 +308,7 @@ void Consumer::StartConsuming() {
|
||||
spdlog::warn("Error happened in consumer {} while processing a batch: {}!", info_.consumer_name, e.what());
|
||||
break;
|
||||
}
|
||||
spdlog::info("Kafka consumer {} finished processing", info_.consumer_name);
|
||||
}
|
||||
is_running_.store(false);
|
||||
});
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "communication/bolt/v1/constants.hpp"
|
||||
#include "helpers.hpp"
|
||||
#include "py/py.hpp"
|
||||
#include "query/discard_value_stream.hpp"
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
#include "query/plan/operator.hpp"
|
||||
@ -433,7 +434,7 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
|
||||
}
|
||||
|
||||
std::map<std::string, communication::bolt::Value> Discard(std::optional<int> n, std::optional<int> qid) override {
|
||||
DiscardValueResultStream stream;
|
||||
query::DiscardValueResultStream stream;
|
||||
return PullResults(stream, n, qid);
|
||||
}
|
||||
|
||||
@ -517,12 +518,6 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
|
||||
const storage::Storage *db_;
|
||||
};
|
||||
|
||||
struct DiscardValueResultStream {
|
||||
void Result(const std::vector<query::TypedValue> &) {
|
||||
// do nothing
|
||||
}
|
||||
};
|
||||
|
||||
// NOTE: Needed only for ToBoltValue conversions
|
||||
const storage::Storage *db_;
|
||||
query::Interpreter interpreter_;
|
||||
|
@ -29,8 +29,8 @@ class EnsureGIL final {
|
||||
PyGILState_STATE gil_state_;
|
||||
|
||||
public:
|
||||
EnsureGIL() : gil_state_(PyGILState_Ensure()) {}
|
||||
~EnsureGIL() { PyGILState_Release(gil_state_); }
|
||||
EnsureGIL() noexcept : gil_state_(PyGILState_Ensure()) {}
|
||||
~EnsureGIL() noexcept { PyGILState_Release(gil_state_); }
|
||||
EnsureGIL(const EnsureGIL &) = delete;
|
||||
EnsureGIL(EnsureGIL &&) = delete;
|
||||
EnsureGIL &operator=(const EnsureGIL &) = delete;
|
||||
|
13
src/query/discard_value_stream.hpp
Normal file
13
src/query/discard_value_stream.hpp
Normal file
@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "query/typed_value.hpp"
|
||||
|
||||
namespace query {
|
||||
struct DiscardValueResultStream final {
|
||||
void Result(const std::vector<query::TypedValue> & /*values*/) {
|
||||
// do nothing
|
||||
}
|
||||
};
|
||||
} // namespace query
|
@ -2476,12 +2476,15 @@ cpp<#
|
||||
:slk-save #'slk-save-ast-pointer
|
||||
:slk-load (slk-load-ast-pointer "Expression"))
|
||||
(batch_size "Expression *" :initval "nullptr" :scope :public
|
||||
:slk-save #'slk-save-ast-pointer
|
||||
:slk-load (slk-load-ast-pointer "Expression"))
|
||||
(batch_limit "Expression *" :initval "nullptr" :scope :public
|
||||
:slk-save #'slk-save-ast-pointer
|
||||
:slk-load (slk-load-ast-pointer "Expression")))
|
||||
|
||||
(:public
|
||||
(lcp:define-enum action
|
||||
(create-stream drop-stream start-stream stop-stream start-all-streams stop-all-streams show-streams test-stream)
|
||||
(create-stream drop-stream start-stream stop-stream start-all-streams stop-all-streams show-streams check-stream)
|
||||
(:serialize))
|
||||
#>cpp
|
||||
StreamQuery() = default;
|
||||
|
@ -472,6 +472,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateStream(MemgraphCypher::CreateStreamC
|
||||
|
||||
auto *topic_names_ctx = ctx->topicNames();
|
||||
MG_ASSERT(topic_names_ctx != nullptr);
|
||||
// TODO(antaljanosbenjamin): Add dash
|
||||
auto topic_names = topic_names_ctx->symbolicNameWithDots();
|
||||
MG_ASSERT(!topic_names.empty());
|
||||
stream_query->topic_names_.reserve(topic_names.size());
|
||||
@ -540,6 +541,20 @@ antlrcpp::Any CypherMainVisitor::visitShowStreams(MemgraphCypher::ShowStreamsCon
|
||||
return stream_query;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) {
|
||||
auto *stream_query = storage_->Create<StreamQuery>();
|
||||
stream_query->action_ = StreamQuery::Action::CHECK_STREAM;
|
||||
stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>();
|
||||
|
||||
if (ctx->BATCH_LIMIT()) {
|
||||
if (!ctx->batchLimit->numberLiteral() || !ctx->batchLimit->numberLiteral()->integerLiteral()) {
|
||||
throw SemanticException("Batch limit should be an integer literal!");
|
||||
}
|
||||
stream_query->batch_limit_ = ctx->batchLimit->accept(this);
|
||||
}
|
||||
return stream_query;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) {
|
||||
bool distinct = !ctx->ALL();
|
||||
auto *cypher_union = storage_->Create<CypherUnion>(distinct);
|
||||
|
@ -288,6 +288,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
*/
|
||||
antlrcpp::Any visitShowStreams(MemgraphCypher::ShowStreamsContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return StreamQuery*
|
||||
*/
|
||||
antlrcpp::Any visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return CypherUnion*
|
||||
*/
|
||||
|
@ -12,10 +12,11 @@ memgraphCypherKeyword : cypherKeyword
|
||||
| ASYNC
|
||||
| AUTH
|
||||
| BAD
|
||||
| BATCHES
|
||||
| BATCH_INTERVAL
|
||||
| BATCH_LIMIT
|
||||
| BATCH_SIZE
|
||||
| BEFORE
|
||||
| CHECK
|
||||
| CLEAR
|
||||
| COMMIT
|
||||
| COMMITTED
|
||||
@ -141,7 +142,8 @@ clause : cypherMatch
|
||||
| loadCsv
|
||||
;
|
||||
|
||||
streamQuery : createStream
|
||||
streamQuery : checkStream
|
||||
| createStream
|
||||
| dropStream
|
||||
| startStream
|
||||
| startAllStreams
|
||||
@ -289,3 +291,5 @@ stopStream : STOP STREAM streamName ;
|
||||
stopAllStreams : STOP ALL STREAMS ;
|
||||
|
||||
showStreams : SHOW STREAMS ;
|
||||
|
||||
checkStream : CHECK STREAM streamName ( BATCH_LIMIT batchLimit=literal ) ? ;
|
||||
|
@ -17,10 +17,11 @@ ALTER : A L T E R ;
|
||||
ASYNC : A S Y N C ;
|
||||
AUTH : A U T H ;
|
||||
BAD : B A D ;
|
||||
BATCHES : B A T C H E S ;
|
||||
BATCH_INTERVAL : B A T C H UNDERSCORE I N T E R V A L ;
|
||||
BATCH_LIMIT : B A T C H UNDERSCORE L I M I T ;
|
||||
BATCH_SIZE : B A T C H UNDERSCORE S I Z E ;
|
||||
BEFORE : B E F O R E ;
|
||||
CHECK : C H E C K ;
|
||||
CLEAR : C L E A R ;
|
||||
COMMIT : C O M M I T ;
|
||||
COMMITTED : C O M M I T T E D ;
|
||||
|
@ -127,11 +127,11 @@ const trie::Trie kKeywords = {"union", "all",
|
||||
"level", "next",
|
||||
"read", "session",
|
||||
"snapshot", "transaction",
|
||||
"batches", "batch_interval",
|
||||
"batch_limit", "batch_interval",
|
||||
"batch_size", "consumer_group",
|
||||
"start", "stream",
|
||||
"streams", "transform",
|
||||
"topics"};
|
||||
"topics", "check"};
|
||||
|
||||
// Unicode codepoints that are allowed at the start of the unescaped name.
|
||||
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(
|
||||
|
@ -490,7 +490,7 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete
|
||||
.consumer_group = std::move(consumer_group),
|
||||
.batch_interval = batch_interval,
|
||||
.batch_size = batch_size,
|
||||
.transformation_name = "transform.trans"});
|
||||
.transformation_name = transformation_name});
|
||||
return std::vector<std::vector<TypedValue>>{};
|
||||
};
|
||||
return callback;
|
||||
@ -575,8 +575,14 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete
|
||||
return results;
|
||||
};
|
||||
return callback;
|
||||
case StreamQuery::Action::TEST_STREAM:
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
case StreamQuery::Action::CHECK_STREAM: {
|
||||
callback.header = {"query", "parameters"};
|
||||
callback.fn = [interpreter_context, stream_name = stream_query->stream_name_,
|
||||
batch_limit = GetOptionalValue<int64_t>(stream_query->batch_limit_, evaluator)]() mutable {
|
||||
return interpreter_context->streams.Test(stream_name, batch_limit);
|
||||
};
|
||||
return callback;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3761,7 +3761,7 @@ class CallProcedureCursor : public Cursor {
|
||||
std::string_view field_name(self_->result_fields_[i]);
|
||||
auto result_it = values.find(field_name);
|
||||
if (result_it == values.end()) {
|
||||
throw QueryRuntimeException("Procedure '{}' does not yield a record with '{}' field.", self_->procedure_name_,
|
||||
throw QueryRuntimeException("Procedure '{}' did not yield a record with '{}' field.", self_->procedure_name_,
|
||||
field_name);
|
||||
}
|
||||
frame[self_->result_symbols_[i]] = result_it->second;
|
||||
|
@ -1372,7 +1372,7 @@ bool MgpTransAddFixedResult(mgp_trans *trans) {
|
||||
if (int err = AddResultToProp(trans, "query", mgp_type_string(), false); err != 1) {
|
||||
return err;
|
||||
}
|
||||
return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_list(mgp_type_any())), false);
|
||||
return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_map()), false);
|
||||
}
|
||||
|
||||
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
||||
|
@ -549,7 +549,7 @@ bool IsValidIdentifierName(const char *name);
|
||||
} // namespace query::procedure
|
||||
|
||||
struct mgp_message {
|
||||
integrations::kafka::Message *msg;
|
||||
const integrations::kafka::Message *msg;
|
||||
};
|
||||
|
||||
struct mgp_messages {
|
||||
|
@ -418,11 +418,11 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
|
||||
bool PythonModule::Close() {
|
||||
MG_ASSERT(py_module_, "Attempting to close a module that has not been loaded...");
|
||||
spdlog::info("Closing module {}...", file_path_);
|
||||
// The procedures are closures which hold references to the Python callbacks.
|
||||
// Releasing these references might result in deallocations so we need to take
|
||||
// the GIL.
|
||||
// The procedures and transformations are closures which hold references to the Python callbacks.
|
||||
// Releasing these references might result in deallocations so we need to take the GIL.
|
||||
auto gil = py::EnsureGIL();
|
||||
procedures_.clear();
|
||||
transformations_.clear();
|
||||
// Delete the module from the `sys.modules` directory so that the module will
|
||||
// be properly imported if imported again.
|
||||
py::Object sys(PyImport_ImportModule("sys"));
|
||||
|
@ -782,7 +782,7 @@ void CallPythonTransformation(const py::Object &py_cb, const mgp_messages *msgs,
|
||||
};
|
||||
|
||||
auto call = [&](py::Object py_graph, py::Object py_messages) -> std::optional<py::ExceptionInfo> {
|
||||
auto py_res = py_cb.Call(py_messages, py_graph);
|
||||
auto py_res = py_cb.Call(py_graph, py_messages);
|
||||
if (!py_res) return py::FetchError();
|
||||
if (PySequence_Check(py_res.Ptr())) {
|
||||
return AddMultipleRecordsFromPython(result, py_res);
|
||||
|
@ -6,23 +6,92 @@
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <json/json.hpp>
|
||||
#include "query/db_accessor.hpp"
|
||||
#include "query/discard_value_stream.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
#include "query/procedure/mg_procedure_impl.hpp"
|
||||
#include "query/procedure/module.hpp"
|
||||
#include "query/typed_value.hpp"
|
||||
#include "utils/memory.hpp"
|
||||
#include "utils/on_scope_exit.hpp"
|
||||
#include "utils/pmr/string.hpp"
|
||||
|
||||
namespace query {
|
||||
|
||||
using Consumer = integrations::kafka::Consumer;
|
||||
using ConsumerInfo = integrations::kafka::ConsumerInfo;
|
||||
using Message = integrations::kafka::Message;
|
||||
namespace {
|
||||
constexpr auto kExpectedTransformationResultSize = 2;
|
||||
const utils::pmr::string query_param_name{"query", utils::NewDeleteResource()};
|
||||
const utils::pmr::string params_param_name{"parameters", utils::NewDeleteResource()};
|
||||
const std::map<std::string, storage::PropertyValue> empty_parameters{};
|
||||
|
||||
auto GetStream(auto &map, const std::string &stream_name) {
|
||||
if (auto it = map.find(stream_name); it != map.end()) {
|
||||
return it;
|
||||
}
|
||||
throw StreamsException("Couldn't find stream '{}'", stream_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using Consumer = integrations::kafka::Consumer;
|
||||
using ConsumerInfo = integrations::kafka::ConsumerInfo;
|
||||
using Message = integrations::kafka::Message;
|
||||
void CallCustomTransformation(const std::string &transformation_name, const std::vector<Message> &messages,
|
||||
mgp_result &result, storage::Storage::Accessor &storage_accessor,
|
||||
utils::MemoryResource &memory_resource, const std::string &stream_name) {
|
||||
DbAccessor db_accessor{&storage_accessor};
|
||||
{
|
||||
auto maybe_transformation =
|
||||
procedure::FindTransformation(procedure::gModuleRegistry, transformation_name, utils::NewDeleteResource());
|
||||
|
||||
if (!maybe_transformation) {
|
||||
throw StreamsException("Couldn't find transformation {} for stream '{}'", transformation_name, stream_name);
|
||||
};
|
||||
const auto &trans = *maybe_transformation->second;
|
||||
mgp_messages mgp_messages{mgp_messages::storage_type{&memory_resource}};
|
||||
std::transform(messages.begin(), messages.end(), std::back_inserter(mgp_messages.messages),
|
||||
[](const integrations::kafka::Message &message) { return mgp_message{&message}; });
|
||||
mgp_graph graph{&db_accessor, storage::View::OLD, nullptr};
|
||||
mgp_memory memory{&memory_resource};
|
||||
result.rows.clear();
|
||||
result.error_msg.reset();
|
||||
result.signature = &trans.results;
|
||||
|
||||
MG_ASSERT(result.signature->size() == kExpectedTransformationResultSize);
|
||||
MG_ASSERT(result.signature->contains(query_param_name));
|
||||
MG_ASSERT(result.signature->contains(params_param_name));
|
||||
|
||||
spdlog::trace("Calling transformation in stream '{}'", stream_name);
|
||||
trans.cb(&mgp_messages, &graph, &result, &memory);
|
||||
}
|
||||
if (result.error_msg.has_value()) {
|
||||
throw StreamsException(result.error_msg->c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<TypedValue /*query*/, TypedValue /*parameters*/> ExtractTransformationResult(
|
||||
utils::pmr::map<utils::pmr::string, TypedValue> &&values, const std::string_view transformation_name,
|
||||
const std::string_view stream_name) {
|
||||
if (values.size() != kExpectedTransformationResultSize) {
|
||||
throw StreamsException(
|
||||
"Transformation '{}' in stream '{}' did not yield all fields (query, parameters) as required.",
|
||||
transformation_name, stream_name);
|
||||
}
|
||||
|
||||
auto get_value = [&](const utils::pmr::string &field_name) mutable -> TypedValue & {
|
||||
auto it = values.find(field_name);
|
||||
if (it == values.end()) {
|
||||
throw StreamsException{"Transformation '{}' in stream '{}' did not yield a record with '{}' field.",
|
||||
transformation_name, stream_name, field_name};
|
||||
};
|
||||
return it->second;
|
||||
};
|
||||
|
||||
auto &query_value = get_value(query_param_name);
|
||||
MG_ASSERT(query_value.IsString());
|
||||
auto ¶ms_value = get_value(params_param_name);
|
||||
MG_ASSERT(params_value.IsNull() || params_value.IsMap());
|
||||
return {std::move(query_value), std::move(params_value)};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// nlohmann::json doesn't support string_view access yet
|
||||
const std::string kStreamName{"name"};
|
||||
@ -31,6 +100,7 @@ const std::string kConsumerGroupKey{"consumer_group"};
|
||||
const std::string kBatchIntervalKey{"batch_interval"};
|
||||
const std::string kBatchSizeKey{"batch_size"};
|
||||
const std::string kIsRunningKey{"is_running"};
|
||||
const std::string kTransformationName{"transformation_name"};
|
||||
|
||||
void to_json(nlohmann::json &data, StreamStatus &&status) {
|
||||
auto &info = status.info;
|
||||
@ -51,6 +121,7 @@ void to_json(nlohmann::json &data, StreamStatus &&status) {
|
||||
}
|
||||
|
||||
data[kIsRunningKey] = status.is_running;
|
||||
data[kTransformationName] = status.info.transformation_name;
|
||||
}
|
||||
|
||||
void from_json(const nlohmann::json &data, StreamStatus &status) {
|
||||
@ -75,6 +146,7 @@ void from_json(const nlohmann::json &data, StreamStatus &status) {
|
||||
}
|
||||
|
||||
data.at(kIsRunningKey).get_to(status.is_running);
|
||||
data.at(kTransformationName).get_to(status.info.transformation_name);
|
||||
}
|
||||
|
||||
Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_servers,
|
||||
@ -89,8 +161,8 @@ void Streams::RestoreStreams() {
|
||||
MG_ASSERT(locked_streams_map->empty(), "Cannot restore streams when some streams already exist!");
|
||||
|
||||
for (const auto &[stream_name, stream_data] : storage_) {
|
||||
const auto get_failed_message = [](const std::string_view stream_name, const std::string_view message,
|
||||
const std::string_view nested_message) {
|
||||
const auto get_failed_message = [&stream_name = stream_name](const std::string_view message,
|
||||
const std::string_view nested_message) {
|
||||
return fmt::format("Failed to load stream '{}', because: {} caused by {}", stream_name, message, nested_message);
|
||||
};
|
||||
|
||||
@ -98,21 +170,22 @@ void Streams::RestoreStreams() {
|
||||
try {
|
||||
nlohmann::json::parse(stream_data).get_to(status);
|
||||
} catch (const nlohmann::json::type_error &exception) {
|
||||
spdlog::warn(get_failed_message(stream_name, "invalid type conversion", exception.what()));
|
||||
spdlog::warn(get_failed_message("invalid type conversion", exception.what()));
|
||||
continue;
|
||||
} catch (const nlohmann::json::out_of_range &exception) {
|
||||
spdlog::warn(get_failed_message(stream_name, "non existing field", exception.what()));
|
||||
spdlog::warn(get_failed_message("non existing field", exception.what()));
|
||||
continue;
|
||||
}
|
||||
MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", stream_name, status.name);
|
||||
MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name);
|
||||
|
||||
try {
|
||||
auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info));
|
||||
if (status.is_running) {
|
||||
it->second.consumer->Lock()->Start();
|
||||
}
|
||||
spdlog::info("Stream '{}' is loaded", stream_name);
|
||||
} catch (const utils::BasicException &exception) {
|
||||
spdlog::warn(get_failed_message(stream_name, "unexpected error", exception.what()));
|
||||
spdlog::warn(get_failed_message("unexpected error", exception.what()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -200,27 +273,38 @@ std::vector<StreamStatus> Streams::GetStreamInfo() const {
|
||||
}
|
||||
|
||||
TransformationResult Streams::Test(const std::string &stream_name, std::optional<int64_t> batch_limit) const {
|
||||
TransformationResult result;
|
||||
auto consumer_function = [&result](const std::vector<Message> &messages) {
|
||||
for (const auto &message : messages) {
|
||||
// TODO(antaljanosbenjamin) Update the logic with using the transform from modules
|
||||
const auto payload = message.Payload();
|
||||
const std::string_view payload_as_string_view{payload.data(), payload.size()};
|
||||
spdlog::info("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view);
|
||||
result[fmt::format("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view)] = "replace with params";
|
||||
// This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after that
|
||||
auto [locked_consumer,
|
||||
transformation_name] = [this, &stream_name]() -> std::pair<SynchronizedConsumer::ReadLockedPtr, std::string> {
|
||||
auto locked_streams = streams_.ReadLock();
|
||||
auto it = GetStream(*locked_streams, stream_name);
|
||||
return {it->second.consumer->ReadLock(), it->second.transformation_name};
|
||||
}();
|
||||
|
||||
auto *memory_resource = utils::NewDeleteResource();
|
||||
mgp_result result{nullptr, memory_resource};
|
||||
TransformationResult test_result;
|
||||
|
||||
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, &stream_name,
|
||||
&transformation_name = transformation_name, &result,
|
||||
&test_result](const std::vector<Message> &messages) mutable {
|
||||
auto accessor = interpreter_context->db->Access();
|
||||
CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name);
|
||||
|
||||
for (auto &row : result.rows) {
|
||||
auto [query, parameters] = ExtractTransformationResult(std::move(row.values), transformation_name, stream_name);
|
||||
std::vector<TypedValue> result_row;
|
||||
result_row.reserve(kExpectedTransformationResultSize);
|
||||
result_row.push_back(std::move(query));
|
||||
result_row.push_back(std::move(parameters));
|
||||
|
||||
test_result.push_back(std::move(result_row));
|
||||
}
|
||||
};
|
||||
|
||||
// This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after that
|
||||
auto locked_consumer = [this, &stream_name] {
|
||||
auto locked_streams = streams_.ReadLock();
|
||||
auto it = GetStream(*locked_streams, stream_name);
|
||||
return it->second.consumer->ReadLock();
|
||||
}();
|
||||
|
||||
locked_consumer->Test(batch_limit, consumer_function);
|
||||
|
||||
return result;
|
||||
return test_result;
|
||||
}
|
||||
|
||||
StreamStatus Streams::CreateStatus(const std::string &name, const std::string &transformation_name,
|
||||
@ -243,24 +327,37 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
|
||||
throw StreamsException{"Stream already exists with name '{}'", stream_name};
|
||||
}
|
||||
|
||||
auto consumer_function = [interpreter_context =
|
||||
interpreter_context_](const std::vector<integrations::kafka::Message> &messages) {
|
||||
Interpreter interpreter = Interpreter{interpreter_context};
|
||||
TransformationResult result;
|
||||
auto *memory_resource = utils::NewDeleteResource();
|
||||
|
||||
for (const auto &message : messages) {
|
||||
// TODO(antaljanosbenjamin) Update the logic with using the transform from modules
|
||||
const auto payload = message.Payload();
|
||||
const std::string_view payload_as_string_view{payload.data(), payload.size()};
|
||||
result[fmt::format("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view)] = "replace with params";
|
||||
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name,
|
||||
transformation_name = stream_info.transformation_name,
|
||||
interpreter = std::make_shared<Interpreter>(interpreter_context_),
|
||||
result = mgp_result{nullptr, memory_resource}](
|
||||
const std::vector<integrations::kafka::Message> &messages) mutable {
|
||||
auto accessor = interpreter_context->db->Access();
|
||||
CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name);
|
||||
|
||||
DiscardValueResultStream stream;
|
||||
|
||||
spdlog::trace("Start transaction in stream '{}'", stream_name);
|
||||
interpreter->BeginTransaction();
|
||||
|
||||
for (auto &row : result.rows) {
|
||||
spdlog::trace("Processing row in stream '{}'", stream_name);
|
||||
auto [query_value, params_value] =
|
||||
ExtractTransformationResult(std::move(row.values), transformation_name, stream_name);
|
||||
storage::PropertyValue params_prop{params_value};
|
||||
|
||||
std::string query{query_value.ValueString()};
|
||||
spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name);
|
||||
auto prepare_result =
|
||||
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap());
|
||||
interpreter->PullAll(&stream);
|
||||
}
|
||||
|
||||
for (const auto &[query, params] : result) {
|
||||
// auto prepared_query = interpreter.Prepare(query, {});
|
||||
spdlog::info("Executing query '{}'", query);
|
||||
// TODO(antaljanosbenjamin) run the query in real life, try not to copy paste the whole execution code, but
|
||||
// extract it to a function that can be called from multiple places (e.g: triggers)
|
||||
}
|
||||
spdlog::trace("Commit transaction in stream '{}'", stream_name);
|
||||
interpreter->CommitTransaction();
|
||||
result.rows.clear();
|
||||
};
|
||||
|
||||
ConsumerInfo consumer_info{
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "integrations/kafka/consumer.hpp"
|
||||
#include "kvstore/kvstore.hpp"
|
||||
#include "query/typed_value.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/rw_lock.hpp"
|
||||
#include "utils/synchronized.hpp"
|
||||
@ -19,8 +20,7 @@ class StreamsException : public utils::BasicException {
|
||||
using BasicException::BasicException;
|
||||
};
|
||||
|
||||
// TODO(antaljanosbenjamin) Replace this with mgp_trans related thing
|
||||
using TransformationResult = std::map<std::string, std::string>;
|
||||
using TransformationResult = std::vector<std::vector<TypedValue>>;
|
||||
using TransformFunction = std::function<TransformationResult(const std::vector<integrations::kafka::Message> &)>;
|
||||
|
||||
struct StreamInfo {
|
||||
@ -28,7 +28,6 @@ struct StreamInfo {
|
||||
std::string consumer_group;
|
||||
std::optional<std::chrono::milliseconds> batch_interval;
|
||||
std::optional<int64_t> batch_size;
|
||||
// TODO(antaljanosbenjamin) How to reference the transformation in a better way?
|
||||
std::string transformation_name;
|
||||
};
|
||||
|
||||
@ -41,9 +40,7 @@ struct StreamStatus {
|
||||
using SynchronizedConsumer = utils::Synchronized<integrations::kafka::Consumer, utils::WritePrioritizedRWLock>;
|
||||
|
||||
struct StreamData {
|
||||
// TODO(antaljanosbenjamin) How to reference the transformation in a better way?
|
||||
std::string transformation_name;
|
||||
// TODO(antaljanosbenjamin) consider propagate_const
|
||||
std::unique_ptr<SynchronizedConsumer> consumer;
|
||||
};
|
||||
|
||||
@ -120,8 +117,8 @@ class Streams final {
|
||||
/// @param stream_name name of the stream we want to test
|
||||
/// @param batch_limit number of batches we want to test before stopping
|
||||
///
|
||||
/// TODO(antaljanosbenjamin) add type of parameters
|
||||
/// @returns A vector of pairs consisting of the query (std::string) and its parameters ...
|
||||
/// @returns A vector of vectors of TypedValue. Each subvector contains two elements, the query string and the
|
||||
/// nullable parameters map.
|
||||
///
|
||||
/// @throws StreamsException if the stream doesn't exist
|
||||
/// @throws ConsumerRunningException if the consumer is alredy running
|
||||
|
@ -244,12 +244,11 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
|
||||
spdlog::info("Loading triggers...");
|
||||
|
||||
for (const auto &[trigger_name, trigger_data] : storage_) {
|
||||
// structured binding cannot be captured in lambda
|
||||
const auto get_failed_message = [](const std::string_view trigger_name, const std::string_view message) {
|
||||
const auto get_failed_message = [&trigger_name = trigger_name](const std::string_view message) {
|
||||
return fmt::format("Failed to load trigger '{}'. {}", trigger_name, message);
|
||||
};
|
||||
|
||||
const auto invalid_state_message = get_failed_message(trigger_name, "Invalid state of the trigger data.");
|
||||
const auto invalid_state_message = get_failed_message("Invalid state of the trigger data.");
|
||||
|
||||
spdlog::debug("Loading trigger '{}'", trigger_name);
|
||||
auto json_trigger_data = nlohmann::json::parse(trigger_data);
|
||||
@ -259,7 +258,7 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
|
||||
continue;
|
||||
}
|
||||
if (json_trigger_data["version"] != kVersion) {
|
||||
spdlog::warn(get_failed_message(trigger_name, "Invalid version of the trigger data."));
|
||||
spdlog::warn(get_failed_message("Invalid version of the trigger data."));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -3207,8 +3207,16 @@ TEST_P(CypherMainVisitorTest, CreateSnapshotQuery) {
|
||||
ASSERT_TRUE(dynamic_cast<CreateSnapshotQuery *>(ast_generator.ParseQuery("CREATE SNAPSHOT")));
|
||||
}
|
||||
|
||||
void CheckOptionalExpression(Base &ast_generator, Expression *expression, const std::optional<TypedValue> &expected) {
|
||||
EXPECT_EQ(expression != nullptr, expected.has_value());
|
||||
if (expected.has_value()) {
|
||||
EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected));
|
||||
}
|
||||
};
|
||||
|
||||
void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &query_string,
|
||||
const StreamQuery::Action action, const std::string_view stream_name) {
|
||||
const StreamQuery::Action action, const std::string_view stream_name,
|
||||
const std::optional<TypedValue> &batch_limit = std::nullopt) {
|
||||
auto *parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string));
|
||||
ASSERT_NE(parsed_query, nullptr);
|
||||
EXPECT_EQ(parsed_query->action_, action);
|
||||
@ -3218,6 +3226,7 @@ void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &quer
|
||||
EXPECT_TRUE(parsed_query->consumer_group_.empty());
|
||||
EXPECT_EQ(parsed_query->batch_interval_, nullptr);
|
||||
EXPECT_EQ(parsed_query->batch_size_, nullptr);
|
||||
EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_limit_, batch_limit));
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, DropStream) {
|
||||
@ -3292,18 +3301,13 @@ void ValidateCreateStreamQuery(Base &ast_generator, const std::string &query_str
|
||||
ASSERT_NO_THROW(parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string))) << query_string;
|
||||
ASSERT_NE(parsed_query, nullptr);
|
||||
EXPECT_EQ(parsed_query->stream_name_, stream_name);
|
||||
auto check_expression = [&](Expression *expression, const std::optional<TypedValue> &expected) {
|
||||
EXPECT_EQ(expression != nullptr, expected.has_value());
|
||||
if (expected.has_value()) {
|
||||
EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected));
|
||||
}
|
||||
};
|
||||
|
||||
EXPECT_EQ(parsed_query->topic_names_, topic_names);
|
||||
EXPECT_EQ(parsed_query->transform_name_, transform_name);
|
||||
EXPECT_EQ(parsed_query->consumer_group_, consumer_group);
|
||||
EXPECT_NO_FATAL_FAILURE(check_expression(parsed_query->batch_interval_, batch_interval));
|
||||
EXPECT_NO_FATAL_FAILURE(check_expression(parsed_query->batch_size_, batch_size));
|
||||
EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_interval_, batch_interval));
|
||||
EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_size_, batch_size));
|
||||
EXPECT_EQ(parsed_query->batch_limit_, nullptr);
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, CreateStream) {
|
||||
@ -3380,4 +3384,22 @@ TEST_P(CypherMainVisitorTest, CreateStream) {
|
||||
EXPECT_NO_FATAL_FAILURE(check_topic_names({topic_name1, topic_name2}));
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, CheckStream) {
|
||||
auto &ast_generator = *GetParam();
|
||||
|
||||
TestInvalidQuery("CHECK STREAM", ast_generator);
|
||||
TestInvalidQuery("CHECK STREAMS", ast_generator);
|
||||
TestInvalidQuery("CHECK STREAMS something", ast_generator);
|
||||
TestInvalidQuery("CHECK STREAM something,something", ast_generator);
|
||||
TestInvalidQuery("CHECK STREAM something BATCH LIMIT 1", ast_generator);
|
||||
TestInvalidQuery("CHECK STREAM something BATCH_LIMIT", ast_generator);
|
||||
TestInvalidQuery<SemanticException>("CHECK STREAM something BATCH_LIMIT 'it should be an integer'", ast_generator);
|
||||
TestInvalidQuery<SemanticException>("CHECK STREAM something BATCH_LIMIT 2.5", ast_generator);
|
||||
|
||||
ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream", StreamQuery::Action::CHECK_STREAM,
|
||||
"checkedStream");
|
||||
ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream bAtCH_LIMIT 42",
|
||||
StreamQuery::Action::CHECK_STREAM, "checkedStream", TypedValue(42));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -32,8 +32,7 @@ StreamInfo CreateDefaultStreamInfo() {
|
||||
.consumer_group = "ConsumerGroup " + GetDefaultStreamName(),
|
||||
.batch_interval = std::nullopt,
|
||||
.batch_size = std::nullopt,
|
||||
// TODO(antaljanosbenjamin) Add proper reference once Streams supports that
|
||||
.transformation_name = "not yet used",
|
||||
.transformation_name = "not used in the tests",
|
||||
};
|
||||
}
|
||||
|
||||
@ -79,8 +78,7 @@ class StreamsTest : public ::testing::Test {
|
||||
EXPECT_EQ(check_data.info.consumer_group, status.info.consumer_group);
|
||||
EXPECT_EQ(check_data.info.batch_interval, status.info.batch_interval);
|
||||
EXPECT_EQ(check_data.info.batch_size, status.info.batch_size);
|
||||
// TODO(antaljanosbenjamin) Add proper reference once Streams supports that
|
||||
// EXPECT_EQ(check_data.info.transformation_name, status.info.transformation_name);
|
||||
EXPECT_EQ(check_data.info.transformation_name, status.info.transformation_name);
|
||||
EXPECT_EQ(check_data.is_running, status.is_running);
|
||||
}
|
||||
|
||||
@ -227,5 +225,3 @@ TEST_F(StreamsTest, RestoreStreams) {
|
||||
check_restore_logic();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(antaljanosbenjamin) Add tests for Streams::Test method and transformation
|
||||
|
Loading…
Reference in New Issue
Block a user