Use transformations in streams and CHECK STREAM (#185)

* Use the correct transformation result type

* Execute the result queries in streams

* Change the result type of parameters to nullable map

* Serialize transformation name

* Fix order of transformation parameters

* Use actual transformation in Streams

* Clear the Python transformations under GIL

* Add CHECK STREAM query

* Handle missing record fields properly
This commit is contained in:
János Benjamin Antal 2021-07-05 11:42:40 +02:00 committed by Antonio Andelic
parent a37755ce43
commit ad32db5168
22 changed files with 289 additions and 115 deletions

View File

@ -190,7 +190,8 @@ class Edge:
def __init__(self, edge):
if not isinstance(edge, _mgp.Edge):
raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
raise TypeError(
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
self._edge = edge
def __deepcopy__(self, memo):
@ -268,7 +269,8 @@ class Vertex:
def __init__(self, vertex):
if not isinstance(vertex, _mgp.Vertex):
raise TypeError("Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
raise TypeError(
"Expected '_mgp.Vertex', got '{}'".format(type(vertex)))
self._vertex = vertex
def __deepcopy__(self, memo):
@ -404,7 +406,8 @@ class Path:
passed in edge is invalid.
'''
if not isinstance(edge, Edge):
raise TypeError("Expected '_mgp.Edge', got '{}'".format(type(edge)))
raise TypeError(
"Expected '_mgp.Edge', got '{}'".format(type(edge)))
if not self.is_valid() or not edge.is_valid():
raise InvalidContextError()
self._path.expand(edge._edge)
@ -454,7 +457,8 @@ class Vertices:
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
raise TypeError(
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = graph
self._len = None
@ -499,7 +503,8 @@ class Graph:
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
raise TypeError(
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = graph
def __deepcopy__(self, memo):
@ -557,7 +562,8 @@ class ProcCtx:
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
raise TypeError(
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph)
def is_valid(self) -> bool:
@ -676,11 +682,13 @@ def _typing_to_cypher_type(type_):
type_args_as_str = parse_type_args(type_as_str)
none_type_as_str = type(None).__name__
if none_type_as_str in type_args_as_str:
types = tuple(t for t in type_args_as_str if t != none_type_as_str)
types = tuple(
t for t in type_args_as_str if t != none_type_as_str)
if len(types) == 1:
type_arg_as_str, = types
else:
type_arg_as_str = 'typing.Union[' + ', '.join(types) + ']'
type_arg_as_str = 'typing.Union[' + \
', '.join(types) + ']'
simple_type = get_simple_type(type_arg_as_str)
if simple_type is not None:
return _mgp.type_nullable(simple_type)
@ -726,6 +734,7 @@ def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
if inspect.isgeneratorfunction(func):
raise NotImplementedError("Generator functions are not supported")
def read_proc(func: typing.Callable[..., Record]):
'''
Register `func` as a a read-only procedure of the current module.
@ -803,17 +812,20 @@ def read_proc(func: typing.Callable[..., Record]):
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
return func
class InvalidMessageError(Exception):
'''Signals using a message instance outside of the registered transformation.'''
pass
class Message:
'''Represents a message from a stream.'''
__slots__ = ('_message',)
def __init__(self, message):
if not isinstance(message, _mgp.Message):
raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message)))
raise TypeError(
"Expected '_mgp.Message', got '{}'".format(type(message)))
self._message = message
def __deepcopy__(self, memo):
@ -829,34 +841,37 @@ class Message:
def payload(self) -> bytes:
if not self.is_valid():
raise InvalidMessageError()
return self._messages._payload(_message)
return self._message.payload()
def topic_name(self) -> str:
if not self.is_valid():
raise InvalidMessageError()
return self._messages._topic_name(_message)
return self._message.topic_name()
def key() -> bytes:
def key(self) -> bytes:
if not self.is_valid():
raise InvalidMessageError()
return self._messages.key(_message)
return self._message.key()
def timestamp() -> int:
def timestamp(self) -> int:
if not self.is_valid():
raise InvalidMessageError()
return self._messages.timestamp(_message)
return self._message.timestamp()
class InvalidMessagesError(Exception):
'''Signals using a messages instance outside of the registered transformation.'''
pass
class Messages:
'''Represents a list of messages from a stream.'''
__slots__ = ('_messages',)
def __init__(self, messages):
if not isinstance(messages, _mgp.Messages):
raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages)))
raise TypeError(
"Expected '_mgp.Messages', got '{}'".format(type(messages)))
self._messages = messages
def __deepcopy__(self, memo):
@ -869,18 +884,19 @@ class Messages:
'''Return True if `self` is in valid context and may be used.'''
return self._messages.is_valid()
def message_at(self, id : int) -> Message:
def message_at(self, id: int) -> Message:
'''Raise InvalidMessagesError if context is invalid.'''
if not self.is_valid():
raise InvalidMessagesError()
return Message(self._messages.message_at(id))
def total_messages() -> int:
def total_messages(self) -> int:
'''Raise InvalidContextError if context is invalid.'''
if not self.is_valid():
raise InvalidMessagesError()
return self._messages.total_messages()
class TransCtx:
'''Context of a transformation being executed.
@ -891,7 +907,8 @@ class TransCtx:
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
raise TypeError(
"Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph)
def is_valid(self) -> bool:
@ -904,13 +921,15 @@ class TransCtx:
raise InvalidContextError()
return self._graph
def transformation(func: typing.Callable[..., Record]):
raise_if_does_not_meet_requirements(func)
sig = inspect.signature(func)
params = tuple(sig.parameters.values())
if not params or not params[0].annotation is Messages:
if not len(params) == 2 or not params[1].annotation is Messages:
raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
raise NotImplementedError(
"Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
if params[0].annotation is TransCtx:
@functools.wraps(func)
def wrapper(graph, messages):

View File

@ -44,6 +44,7 @@ utils::BasicResult<std::string, std::vector<Message>> GetBatch(RdKafka::KafkaCon
break;
default:
// TODO(antaljanosbenjamin): handle RD_KAFKA_RESP_ERR__MAX_POLL_EXCEEDED
auto error = msg->errstr();
spdlog::warn("Unexpected error while consuming message in consumer {}, error: {}!", info.consumer_name,
msg->errstr());
@ -307,6 +308,7 @@ void Consumer::StartConsuming() {
spdlog::warn("Error happened in consumer {} while processing a batch: {}!", info_.consumer_name, e.what());
break;
}
spdlog::info("Kafka consumer {} finished processing", info_.consumer_name);
}
is_running_.store(false);
});

View File

@ -23,6 +23,7 @@
#include "communication/bolt/v1/constants.hpp"
#include "helpers.hpp"
#include "py/py.hpp"
#include "query/discard_value_stream.hpp"
#include "query/exceptions.hpp"
#include "query/interpreter.hpp"
#include "query/plan/operator.hpp"
@ -433,7 +434,7 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
}
std::map<std::string, communication::bolt::Value> Discard(std::optional<int> n, std::optional<int> qid) override {
DiscardValueResultStream stream;
query::DiscardValueResultStream stream;
return PullResults(stream, n, qid);
}
@ -517,12 +518,6 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
const storage::Storage *db_;
};
struct DiscardValueResultStream {
void Result(const std::vector<query::TypedValue> &) {
// do nothing
}
};
// NOTE: Needed only for ToBoltValue conversions
const storage::Storage *db_;
query::Interpreter interpreter_;

View File

@ -29,8 +29,8 @@ class EnsureGIL final {
PyGILState_STATE gil_state_;
public:
EnsureGIL() : gil_state_(PyGILState_Ensure()) {}
~EnsureGIL() { PyGILState_Release(gil_state_); }
EnsureGIL() noexcept : gil_state_(PyGILState_Ensure()) {}
~EnsureGIL() noexcept { PyGILState_Release(gil_state_); }
EnsureGIL(const EnsureGIL &) = delete;
EnsureGIL(EnsureGIL &&) = delete;
EnsureGIL &operator=(const EnsureGIL &) = delete;

View File

@ -0,0 +1,13 @@
#pragma once
#include <vector>
#include "query/typed_value.hpp"
namespace query {
struct DiscardValueResultStream final {
void Result(const std::vector<query::TypedValue> & /*values*/) {
// do nothing
}
};
} // namespace query

View File

@ -2476,12 +2476,15 @@ cpp<#
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(batch_size "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(batch_limit "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression")))
(:public
(lcp:define-enum action
(create-stream drop-stream start-stream stop-stream start-all-streams stop-all-streams show-streams test-stream)
(create-stream drop-stream start-stream stop-stream start-all-streams stop-all-streams show-streams check-stream)
(:serialize))
#>cpp
StreamQuery() = default;

View File

@ -472,6 +472,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateStream(MemgraphCypher::CreateStreamC
auto *topic_names_ctx = ctx->topicNames();
MG_ASSERT(topic_names_ctx != nullptr);
// TODO(antaljanosbenjamin): Add dash
auto topic_names = topic_names_ctx->symbolicNameWithDots();
MG_ASSERT(!topic_names.empty());
stream_query->topic_names_.reserve(topic_names.size());
@ -540,6 +541,20 @@ antlrcpp::Any CypherMainVisitor::visitShowStreams(MemgraphCypher::ShowStreamsCon
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) {
auto *stream_query = storage_->Create<StreamQuery>();
stream_query->action_ = StreamQuery::Action::CHECK_STREAM;
stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>();
if (ctx->BATCH_LIMIT()) {
if (!ctx->batchLimit->numberLiteral() || !ctx->batchLimit->numberLiteral()->integerLiteral()) {
throw SemanticException("Batch limit should be an integer literal!");
}
stream_query->batch_limit_ = ctx->batchLimit->accept(this);
}
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) {
bool distinct = !ctx->ALL();
auto *cypher_union = storage_->Create<CypherUnion>(distinct);

View File

@ -288,6 +288,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitShowStreams(MemgraphCypher::ShowStreamsContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) override;
/**
* @return CypherUnion*
*/

View File

@ -12,10 +12,11 @@ memgraphCypherKeyword : cypherKeyword
| ASYNC
| AUTH
| BAD
| BATCHES
| BATCH_INTERVAL
| BATCH_LIMIT
| BATCH_SIZE
| BEFORE
| CHECK
| CLEAR
| COMMIT
| COMMITTED
@ -141,7 +142,8 @@ clause : cypherMatch
| loadCsv
;
streamQuery : createStream
streamQuery : checkStream
| createStream
| dropStream
| startStream
| startAllStreams
@ -289,3 +291,5 @@ stopStream : STOP STREAM streamName ;
stopAllStreams : STOP ALL STREAMS ;
showStreams : SHOW STREAMS ;
checkStream : CHECK STREAM streamName ( BATCH_LIMIT batchLimit=literal ) ? ;

View File

@ -17,10 +17,11 @@ ALTER : A L T E R ;
ASYNC : A S Y N C ;
AUTH : A U T H ;
BAD : B A D ;
BATCHES : B A T C H E S ;
BATCH_INTERVAL : B A T C H UNDERSCORE I N T E R V A L ;
BATCH_LIMIT : B A T C H UNDERSCORE L I M I T ;
BATCH_SIZE : B A T C H UNDERSCORE S I Z E ;
BEFORE : B E F O R E ;
CHECK : C H E C K ;
CLEAR : C L E A R ;
COMMIT : C O M M I T ;
COMMITTED : C O M M I T T E D ;

View File

@ -127,11 +127,11 @@ const trie::Trie kKeywords = {"union", "all",
"level", "next",
"read", "session",
"snapshot", "transaction",
"batches", "batch_interval",
"batch_limit", "batch_interval",
"batch_size", "consumer_group",
"start", "stream",
"streams", "transform",
"topics"};
"topics", "check"};
// Unicode codepoints that are allowed at the start of the unescaped name.
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(

View File

@ -490,7 +490,7 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &paramete
.consumer_group = std::move(consumer_group),
.batch_interval = batch_interval,
.batch_size = batch_size,
.transformation_name = "transform.trans"});
.transformation_name = transformation_name});
return std::vector<std::vector<TypedValue>>{};
};
return callback;
@ -575,8 +575,14 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &paramete
return results;
};
return callback;
case StreamQuery::Action::TEST_STREAM:
throw std::logic_error("not implemented");
}
case StreamQuery::Action::CHECK_STREAM: {
callback.header = {"query", "parameters"};
callback.fn = [interpreter_context, stream_name = stream_query->stream_name_,
batch_limit = GetOptionalValue<int64_t>(stream_query->batch_limit_, evaluator)]() mutable {
return interpreter_context->streams.Test(stream_name, batch_limit);
};
return callback;
}
}
}

View File

@ -3761,7 +3761,7 @@ class CallProcedureCursor : public Cursor {
std::string_view field_name(self_->result_fields_[i]);
auto result_it = values.find(field_name);
if (result_it == values.end()) {
throw QueryRuntimeException("Procedure '{}' does not yield a record with '{}' field.", self_->procedure_name_,
throw QueryRuntimeException("Procedure '{}' did not yield a record with '{}' field.", self_->procedure_name_,
field_name);
}
frame[self_->result_symbols_[i]] = result_it->second;

View File

@ -1372,7 +1372,7 @@ bool MgpTransAddFixedResult(mgp_trans *trans) {
if (int err = AddResultToProp(trans, "query", mgp_type_string(), false); err != 1) {
return err;
}
return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_list(mgp_type_any())), false);
return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_map()), false);
}
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) {

View File

@ -549,7 +549,7 @@ bool IsValidIdentifierName(const char *name);
} // namespace query::procedure
struct mgp_message {
integrations::kafka::Message *msg;
const integrations::kafka::Message *msg;
};
struct mgp_messages {

View File

@ -418,11 +418,11 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
bool PythonModule::Close() {
MG_ASSERT(py_module_, "Attempting to close a module that has not been loaded...");
spdlog::info("Closing module {}...", file_path_);
// The procedures are closures which hold references to the Python callbacks.
// Releasing these references might result in deallocations so we need to take
// the GIL.
// The procedures and transformations are closures which hold references to the Python callbacks.
// Releasing these references might result in deallocations so we need to take the GIL.
auto gil = py::EnsureGIL();
procedures_.clear();
transformations_.clear();
// Delete the module from the `sys.modules` directory so that the module will
// be properly imported if imported again.
py::Object sys(PyImport_ImportModule("sys"));

View File

@ -782,7 +782,7 @@ void CallPythonTransformation(const py::Object &py_cb, const mgp_messages *msgs,
};
auto call = [&](py::Object py_graph, py::Object py_messages) -> std::optional<py::ExceptionInfo> {
auto py_res = py_cb.Call(py_messages, py_graph);
auto py_res = py_cb.Call(py_graph, py_messages);
if (!py_res) return py::FetchError();
if (PySequence_Check(py_res.Ptr())) {
return AddMultipleRecordsFromPython(result, py_res);

View File

@ -6,23 +6,92 @@
#include <spdlog/spdlog.h>
#include <json/json.hpp>
#include "query/db_accessor.hpp"
#include "query/discard_value_stream.hpp"
#include "query/interpreter.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/procedure/module.hpp"
#include "query/typed_value.hpp"
#include "utils/memory.hpp"
#include "utils/on_scope_exit.hpp"
#include "utils/pmr/string.hpp"
namespace query {
using Consumer = integrations::kafka::Consumer;
using ConsumerInfo = integrations::kafka::ConsumerInfo;
using Message = integrations::kafka::Message;
namespace {
constexpr auto kExpectedTransformationResultSize = 2;
const utils::pmr::string query_param_name{"query", utils::NewDeleteResource()};
const utils::pmr::string params_param_name{"parameters", utils::NewDeleteResource()};
const std::map<std::string, storage::PropertyValue> empty_parameters{};
auto GetStream(auto &map, const std::string &stream_name) {
if (auto it = map.find(stream_name); it != map.end()) {
return it;
}
throw StreamsException("Couldn't find stream '{}'", stream_name);
}
} // namespace
using Consumer = integrations::kafka::Consumer;
using ConsumerInfo = integrations::kafka::ConsumerInfo;
using Message = integrations::kafka::Message;
void CallCustomTransformation(const std::string &transformation_name, const std::vector<Message> &messages,
mgp_result &result, storage::Storage::Accessor &storage_accessor,
utils::MemoryResource &memory_resource, const std::string &stream_name) {
DbAccessor db_accessor{&storage_accessor};
{
auto maybe_transformation =
procedure::FindTransformation(procedure::gModuleRegistry, transformation_name, utils::NewDeleteResource());
if (!maybe_transformation) {
throw StreamsException("Couldn't find transformation {} for stream '{}'", transformation_name, stream_name);
};
const auto &trans = *maybe_transformation->second;
mgp_messages mgp_messages{mgp_messages::storage_type{&memory_resource}};
std::transform(messages.begin(), messages.end(), std::back_inserter(mgp_messages.messages),
[](const integrations::kafka::Message &message) { return mgp_message{&message}; });
mgp_graph graph{&db_accessor, storage::View::OLD, nullptr};
mgp_memory memory{&memory_resource};
result.rows.clear();
result.error_msg.reset();
result.signature = &trans.results;
MG_ASSERT(result.signature->size() == kExpectedTransformationResultSize);
MG_ASSERT(result.signature->contains(query_param_name));
MG_ASSERT(result.signature->contains(params_param_name));
spdlog::trace("Calling transformation in stream '{}'", stream_name);
trans.cb(&mgp_messages, &graph, &result, &memory);
}
if (result.error_msg.has_value()) {
throw StreamsException(result.error_msg->c_str());
}
}
std::pair<TypedValue /*query*/, TypedValue /*parameters*/> ExtractTransformationResult(
utils::pmr::map<utils::pmr::string, TypedValue> &&values, const std::string_view transformation_name,
const std::string_view stream_name) {
if (values.size() != kExpectedTransformationResultSize) {
throw StreamsException(
"Transformation '{}' in stream '{}' did not yield all fields (query, parameters) as required.",
transformation_name, stream_name);
}
auto get_value = [&](const utils::pmr::string &field_name) mutable -> TypedValue & {
auto it = values.find(field_name);
if (it == values.end()) {
throw StreamsException{"Transformation '{}' in stream '{}' did not yield a record with '{}' field.",
transformation_name, stream_name, field_name};
};
return it->second;
};
auto &query_value = get_value(query_param_name);
MG_ASSERT(query_value.IsString());
auto &params_value = get_value(params_param_name);
MG_ASSERT(params_value.IsNull() || params_value.IsMap());
return {std::move(query_value), std::move(params_value)};
}
} // namespace
// nlohmann::json doesn't support string_view access yet
const std::string kStreamName{"name"};
@ -31,6 +100,7 @@ const std::string kConsumerGroupKey{"consumer_group"};
const std::string kBatchIntervalKey{"batch_interval"};
const std::string kBatchSizeKey{"batch_size"};
const std::string kIsRunningKey{"is_running"};
const std::string kTransformationName{"transformation_name"};
void to_json(nlohmann::json &data, StreamStatus &&status) {
auto &info = status.info;
@ -51,6 +121,7 @@ void to_json(nlohmann::json &data, StreamStatus &&status) {
}
data[kIsRunningKey] = status.is_running;
data[kTransformationName] = status.info.transformation_name;
}
void from_json(const nlohmann::json &data, StreamStatus &status) {
@ -75,6 +146,7 @@ void from_json(const nlohmann::json &data, StreamStatus &status) {
}
data.at(kIsRunningKey).get_to(status.is_running);
data.at(kTransformationName).get_to(status.info.transformation_name);
}
Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_servers,
@ -89,8 +161,8 @@ void Streams::RestoreStreams() {
MG_ASSERT(locked_streams_map->empty(), "Cannot restore streams when some streams already exist!");
for (const auto &[stream_name, stream_data] : storage_) {
const auto get_failed_message = [](const std::string_view stream_name, const std::string_view message,
const std::string_view nested_message) {
const auto get_failed_message = [&stream_name = stream_name](const std::string_view message,
const std::string_view nested_message) {
return fmt::format("Failed to load stream '{}', because: {} caused by {}", stream_name, message, nested_message);
};
@ -98,21 +170,22 @@ void Streams::RestoreStreams() {
try {
nlohmann::json::parse(stream_data).get_to(status);
} catch (const nlohmann::json::type_error &exception) {
spdlog::warn(get_failed_message(stream_name, "invalid type conversion", exception.what()));
spdlog::warn(get_failed_message("invalid type conversion", exception.what()));
continue;
} catch (const nlohmann::json::out_of_range &exception) {
spdlog::warn(get_failed_message(stream_name, "non existing field", exception.what()));
spdlog::warn(get_failed_message("non existing field", exception.what()));
continue;
}
MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", stream_name, status.name);
MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name);
try {
auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info));
if (status.is_running) {
it->second.consumer->Lock()->Start();
}
spdlog::info("Stream '{}' is loaded", stream_name);
} catch (const utils::BasicException &exception) {
spdlog::warn(get_failed_message(stream_name, "unexpected error", exception.what()));
spdlog::warn(get_failed_message("unexpected error", exception.what()));
}
}
}
@ -200,27 +273,38 @@ std::vector<StreamStatus> Streams::GetStreamInfo() const {
}
TransformationResult Streams::Test(const std::string &stream_name, std::optional<int64_t> batch_limit) const {
TransformationResult result;
auto consumer_function = [&result](const std::vector<Message> &messages) {
for (const auto &message : messages) {
// TODO(antaljanosbenjamin) Update the logic with using the transform from modules
const auto payload = message.Payload();
const std::string_view payload_as_string_view{payload.data(), payload.size()};
spdlog::info("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view);
result[fmt::format("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view)] = "replace with params";
// This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after that
auto [locked_consumer,
transformation_name] = [this, &stream_name]() -> std::pair<SynchronizedConsumer::ReadLockedPtr, std::string> {
auto locked_streams = streams_.ReadLock();
auto it = GetStream(*locked_streams, stream_name);
return {it->second.consumer->ReadLock(), it->second.transformation_name};
}();
auto *memory_resource = utils::NewDeleteResource();
mgp_result result{nullptr, memory_resource};
TransformationResult test_result;
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, &stream_name,
&transformation_name = transformation_name, &result,
&test_result](const std::vector<Message> &messages) mutable {
auto accessor = interpreter_context->db->Access();
CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name);
for (auto &row : result.rows) {
auto [query, parameters] = ExtractTransformationResult(std::move(row.values), transformation_name, stream_name);
std::vector<TypedValue> result_row;
result_row.reserve(kExpectedTransformationResultSize);
result_row.push_back(std::move(query));
result_row.push_back(std::move(parameters));
test_result.push_back(std::move(result_row));
}
};
// This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after that
auto locked_consumer = [this, &stream_name] {
auto locked_streams = streams_.ReadLock();
auto it = GetStream(*locked_streams, stream_name);
return it->second.consumer->ReadLock();
}();
locked_consumer->Test(batch_limit, consumer_function);
return result;
return test_result;
}
StreamStatus Streams::CreateStatus(const std::string &name, const std::string &transformation_name,
@ -243,24 +327,37 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
throw StreamsException{"Stream already exists with name '{}'", stream_name};
}
auto consumer_function = [interpreter_context =
interpreter_context_](const std::vector<integrations::kafka::Message> &messages) {
Interpreter interpreter = Interpreter{interpreter_context};
TransformationResult result;
auto *memory_resource = utils::NewDeleteResource();
for (const auto &message : messages) {
// TODO(antaljanosbenjamin) Update the logic with using the transform from modules
const auto payload = message.Payload();
const std::string_view payload_as_string_view{payload.data(), payload.size()};
result[fmt::format("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view)] = "replace with params";
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name,
transformation_name = stream_info.transformation_name,
interpreter = std::make_shared<Interpreter>(interpreter_context_),
result = mgp_result{nullptr, memory_resource}](
const std::vector<integrations::kafka::Message> &messages) mutable {
auto accessor = interpreter_context->db->Access();
CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name);
DiscardValueResultStream stream;
spdlog::trace("Start transaction in stream '{}'", stream_name);
interpreter->BeginTransaction();
for (auto &row : result.rows) {
spdlog::trace("Processing row in stream '{}'", stream_name);
auto [query_value, params_value] =
ExtractTransformationResult(std::move(row.values), transformation_name, stream_name);
storage::PropertyValue params_prop{params_value};
std::string query{query_value.ValueString()};
spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name);
auto prepare_result =
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap());
interpreter->PullAll(&stream);
}
for (const auto &[query, params] : result) {
// auto prepared_query = interpreter.Prepare(query, {});
spdlog::info("Executing query '{}'", query);
// TODO(antaljanosbenjamin) run the query in real life, try not to copy paste the whole execution code, but
// extract it to a function that can be called from multiple places (e.g: triggers)
}
spdlog::trace("Commit transaction in stream '{}'", stream_name);
interpreter->CommitTransaction();
result.rows.clear();
};
ConsumerInfo consumer_info{

View File

@ -8,6 +8,7 @@
#include "integrations/kafka/consumer.hpp"
#include "kvstore/kvstore.hpp"
#include "query/typed_value.hpp"
#include "utils/exceptions.hpp"
#include "utils/rw_lock.hpp"
#include "utils/synchronized.hpp"
@ -19,8 +20,7 @@ class StreamsException : public utils::BasicException {
using BasicException::BasicException;
};
// TODO(antaljanosbenjamin) Replace this with mgp_trans related thing
using TransformationResult = std::map<std::string, std::string>;
using TransformationResult = std::vector<std::vector<TypedValue>>;
using TransformFunction = std::function<TransformationResult(const std::vector<integrations::kafka::Message> &)>;
struct StreamInfo {
@ -28,7 +28,6 @@ struct StreamInfo {
std::string consumer_group;
std::optional<std::chrono::milliseconds> batch_interval;
std::optional<int64_t> batch_size;
// TODO(antaljanosbenjamin) How to reference the transformation in a better way?
std::string transformation_name;
};
@ -41,9 +40,7 @@ struct StreamStatus {
using SynchronizedConsumer = utils::Synchronized<integrations::kafka::Consumer, utils::WritePrioritizedRWLock>;
struct StreamData {
// TODO(antaljanosbenjamin) How to reference the transformation in a better way?
std::string transformation_name;
// TODO(antaljanosbenjamin) consider propagate_const
std::unique_ptr<SynchronizedConsumer> consumer;
};
@ -120,8 +117,8 @@ class Streams final {
/// @param stream_name name of the stream we want to test
/// @param batch_limit number of batches we want to test before stopping
///
/// TODO(antaljanosbenjamin) add type of parameters
/// @returns A vector of pairs consisting of the query (std::string) and its parameters ...
/// @returns A vector of vectors of TypedValue. Each subvector contains two elements, the query string and the
/// nullable parameters map.
///
/// @throws StreamsException if the stream doesn't exist
/// @throws ConsumerRunningException if the consumer is alredy running

View File

@ -244,12 +244,11 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
spdlog::info("Loading triggers...");
for (const auto &[trigger_name, trigger_data] : storage_) {
// structured binding cannot be captured in lambda
const auto get_failed_message = [](const std::string_view trigger_name, const std::string_view message) {
const auto get_failed_message = [&trigger_name = trigger_name](const std::string_view message) {
return fmt::format("Failed to load trigger '{}'. {}", trigger_name, message);
};
const auto invalid_state_message = get_failed_message(trigger_name, "Invalid state of the trigger data.");
const auto invalid_state_message = get_failed_message("Invalid state of the trigger data.");
spdlog::debug("Loading trigger '{}'", trigger_name);
auto json_trigger_data = nlohmann::json::parse(trigger_data);
@ -259,7 +258,7 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
continue;
}
if (json_trigger_data["version"] != kVersion) {
spdlog::warn(get_failed_message(trigger_name, "Invalid version of the trigger data."));
spdlog::warn(get_failed_message("Invalid version of the trigger data."));
continue;
}

View File

@ -3207,8 +3207,16 @@ TEST_P(CypherMainVisitorTest, CreateSnapshotQuery) {
ASSERT_TRUE(dynamic_cast<CreateSnapshotQuery *>(ast_generator.ParseQuery("CREATE SNAPSHOT")));
}
void CheckOptionalExpression(Base &ast_generator, Expression *expression, const std::optional<TypedValue> &expected) {
EXPECT_EQ(expression != nullptr, expected.has_value());
if (expected.has_value()) {
EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected));
}
};
void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &query_string,
const StreamQuery::Action action, const std::string_view stream_name) {
const StreamQuery::Action action, const std::string_view stream_name,
const std::optional<TypedValue> &batch_limit = std::nullopt) {
auto *parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string));
ASSERT_NE(parsed_query, nullptr);
EXPECT_EQ(parsed_query->action_, action);
@ -3218,6 +3226,7 @@ void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &quer
EXPECT_TRUE(parsed_query->consumer_group_.empty());
EXPECT_EQ(parsed_query->batch_interval_, nullptr);
EXPECT_EQ(parsed_query->batch_size_, nullptr);
EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_limit_, batch_limit));
}
TEST_P(CypherMainVisitorTest, DropStream) {
@ -3292,18 +3301,13 @@ void ValidateCreateStreamQuery(Base &ast_generator, const std::string &query_str
ASSERT_NO_THROW(parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string))) << query_string;
ASSERT_NE(parsed_query, nullptr);
EXPECT_EQ(parsed_query->stream_name_, stream_name);
auto check_expression = [&](Expression *expression, const std::optional<TypedValue> &expected) {
EXPECT_EQ(expression != nullptr, expected.has_value());
if (expected.has_value()) {
EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected));
}
};
EXPECT_EQ(parsed_query->topic_names_, topic_names);
EXPECT_EQ(parsed_query->transform_name_, transform_name);
EXPECT_EQ(parsed_query->consumer_group_, consumer_group);
EXPECT_NO_FATAL_FAILURE(check_expression(parsed_query->batch_interval_, batch_interval));
EXPECT_NO_FATAL_FAILURE(check_expression(parsed_query->batch_size_, batch_size));
EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_interval_, batch_interval));
EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_size_, batch_size));
EXPECT_EQ(parsed_query->batch_limit_, nullptr);
}
TEST_P(CypherMainVisitorTest, CreateStream) {
@ -3380,4 +3384,22 @@ TEST_P(CypherMainVisitorTest, CreateStream) {
EXPECT_NO_FATAL_FAILURE(check_topic_names({topic_name1, topic_name2}));
}
TEST_P(CypherMainVisitorTest, CheckStream) {
auto &ast_generator = *GetParam();
TestInvalidQuery("CHECK STREAM", ast_generator);
TestInvalidQuery("CHECK STREAMS", ast_generator);
TestInvalidQuery("CHECK STREAMS something", ast_generator);
TestInvalidQuery("CHECK STREAM something,something", ast_generator);
TestInvalidQuery("CHECK STREAM something BATCH LIMIT 1", ast_generator);
TestInvalidQuery("CHECK STREAM something BATCH_LIMIT", ast_generator);
TestInvalidQuery<SemanticException>("CHECK STREAM something BATCH_LIMIT 'it should be an integer'", ast_generator);
TestInvalidQuery<SemanticException>("CHECK STREAM something BATCH_LIMIT 2.5", ast_generator);
ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream", StreamQuery::Action::CHECK_STREAM,
"checkedStream");
ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream bAtCH_LIMIT 42",
StreamQuery::Action::CHECK_STREAM, "checkedStream", TypedValue(42));
}
} // namespace

View File

@ -32,8 +32,7 @@ StreamInfo CreateDefaultStreamInfo() {
.consumer_group = "ConsumerGroup " + GetDefaultStreamName(),
.batch_interval = std::nullopt,
.batch_size = std::nullopt,
// TODO(antaljanosbenjamin) Add proper reference once Streams supports that
.transformation_name = "not yet used",
.transformation_name = "not used in the tests",
};
}
@ -79,8 +78,7 @@ class StreamsTest : public ::testing::Test {
EXPECT_EQ(check_data.info.consumer_group, status.info.consumer_group);
EXPECT_EQ(check_data.info.batch_interval, status.info.batch_interval);
EXPECT_EQ(check_data.info.batch_size, status.info.batch_size);
// TODO(antaljanosbenjamin) Add proper reference once Streams supports that
// EXPECT_EQ(check_data.info.transformation_name, status.info.transformation_name);
EXPECT_EQ(check_data.info.transformation_name, status.info.transformation_name);
EXPECT_EQ(check_data.is_running, status.is_running);
}
@ -227,5 +225,3 @@ TEST_F(StreamsTest, RestoreStreams) {
check_restore_logic();
}
}
// TODO(antaljanosbenjamin) Add tests for Streams::Test method and transformation