Add procedure for setting a kafka stream offset ()

This commit is contained in:
Kostas Kyrimis 2021-11-11 12:07:58 +01:00 committed by GitHub
parent 636c551047
commit 47c0c629c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 486 additions and 140 deletions

View File

@ -1437,6 +1437,9 @@ enum mgp_error mgp_message_key_size(struct mgp_message *message, size_t *result)
/// Get the timestamp of mgp_message as a byte array
enum mgp_error mgp_message_timestamp(struct mgp_message *message, int64_t *result);
/// Get the message offset from a message.
enum mgp_error mgp_message_offset(struct mgp_message *message, int64_t *result);
/// Get the number of messages contained in the mgp_messages list
/// Current implementation always returns without errors.
enum mgp_error mgp_messages_size(struct mgp_messages *message, size_t *result);

View File

@ -1300,6 +1300,11 @@ class Message:
raise InvalidMessageError()
return self._message.timestamp()
def offset(self) -> int:
if not self.is_valid():
raise InvalidMessageError()
return self._message.offset()
class InvalidMessagesError(Exception):
"""Signals using a messages instance outside of the registered transformation."""

View File

@ -78,7 +78,7 @@ utils::BasicResult<std::string, std::vector<Message>> GetBatch(RdKafka::KafkaCon
start = now;
}
return {std::move(batch)};
return std::move(batch);
}
} // namespace
@ -109,8 +109,13 @@ int64_t Message::Timestamp() const {
return rd_kafka_message_timestamp(c_message, nullptr);
}
int64_t Message::Offset() const {
const auto *c_message = message_->c_ptr();
return c_message->offset;
}
Consumer::Consumer(const std::string &bootstrap_servers, ConsumerInfo info, ConsumerFunction consumer_function)
: info_{std::move(info)}, consumer_function_(std::move(consumer_function)) {
: info_{std::move(info)}, consumer_function_(std::move(consumer_function)), cb_(info_.consumer_name) {
MG_ASSERT(consumer_function_, "Empty consumer function for Kafka consumer");
// NOLINTNEXTLINE (modernize-use-nullptr)
if (info.batch_interval.value_or(kMinimumInterval) < kMinimumInterval) {
@ -131,6 +136,10 @@ Consumer::Consumer(const std::string &bootstrap_servers, ConsumerInfo info, Cons
throw ConsumerFailedToInitializeException(info_.consumer_name, error);
}
if (conf->set("rebalance_cb", &cb_, error) != RdKafka::Conf::CONF_OK) {
throw ConsumerFailedToInitializeException(info_.consumer_name, error);
}
if (conf->set("enable.partition.eof", "false", error) != RdKafka::Conf::CONF_OK) {
throw ConsumerFailedToInitializeException(info_.consumer_name, error);
}
@ -339,7 +348,18 @@ void Consumer::StartConsuming() {
try {
consumer_function_(batch);
if (const auto err = consumer_->commitSync(); err != RdKafka::ERR_NO_ERROR) {
std::vector<RdKafka::TopicPartition *> partitions;
utils::OnScopeExit clear_partitions([&]() { RdKafka::TopicPartition::destroy(partitions); });
if (const auto err = consumer_->assignment(partitions); err != RdKafka::ERR_NO_ERROR) {
throw ConsumerCheckFailedException(
info_.consumer_name, fmt::format("Couldn't get assignment to commit offsets: {}", RdKafka::err2str(err)));
}
if (const auto err = consumer_->position(partitions); err != RdKafka::ERR_NO_ERROR) {
throw ConsumerCheckFailedException(
info_.consumer_name, fmt::format("Couldn't get offsets from librdkafka {}", RdKafka::err2str(err)));
}
if (const auto err = consumer_->commitSync(partitions); err != RdKafka::ERR_NO_ERROR) {
spdlog::warn("Committing offset of consumer {} failed: {}", info_.consumer_name, RdKafka::err2str(err));
break;
}
@ -358,4 +378,51 @@ void Consumer::StopConsuming() {
if (thread_.joinable()) thread_.join();
}
utils::BasicResult<std::string> Consumer::SetConsumerOffsets(int64_t offset) {
if (is_running_) {
throw ConsumerRunningException(info_.consumer_name);
}
if (offset == -1) {
offset = RD_KAFKA_OFFSET_BEGINNING;
} else if (offset == -2) {
offset = RD_KAFKA_OFFSET_END;
}
cb_.set_offset(offset);
if (const auto err = consumer_->subscribe(info_.topics); err != RdKafka::ERR_NO_ERROR) {
return fmt::format("Could not set offset of consumer: {}. Error: {}", info_.consumer_name, RdKafka::err2str(err));
}
return {};
}
Consumer::ConsumerRebalanceCb::ConsumerRebalanceCb(std::string consumer_name)
: consumer_name_(std::move(consumer_name)) {}
void Consumer::ConsumerRebalanceCb::rebalance_cb(RdKafka::KafkaConsumer *consumer, RdKafka::ErrorCode err,
std::vector<RdKafka::TopicPartition *> &partitions) {
if (err == RdKafka::ERR__REVOKE_PARTITIONS) {
consumer->unassign();
return;
}
if (err != RdKafka::ERR__ASSIGN_PARTITIONS) {
spdlog::critical("Consumer {} received an unexpected error {}", consumer_name_, RdKafka::err2str(err));
return;
}
if (offset_) {
for (auto &partition : partitions) {
partition->set_offset(*offset_);
}
offset_.reset();
}
auto maybe_error = consumer->assign(partitions);
if (maybe_error != RdKafka::ErrorCode::ERR_NO_ERROR) {
spdlog::warn("Assigning offset of consumer {} failed: {}", consumer_name_, RdKafka::err2str(err));
}
maybe_error = consumer->commitSync(partitions);
if (maybe_error != RdKafka::ErrorCode::ERR_NO_ERROR) {
spdlog::warn("Commiting offsets of consumer {} failed: {}", consumer_name_, RdKafka::err2str(err));
}
}
void Consumer::ConsumerRebalanceCb::set_offset(int64_t offset) { offset_ = offset; }
} // namespace integrations::kafka

View File

@ -17,6 +17,7 @@
#include <memory>
#include <optional>
#include <span>
#include <string>
#include <thread>
#include <utility>
#include <vector>
@ -68,6 +69,9 @@ class Message final {
/// can be implemented knowing that.
int64_t Timestamp() const;
/// Returns the offset of the message
int64_t Offset() const;
private:
std::unique_ptr<RdKafka::Message> message_;
};
@ -137,6 +141,13 @@ class Consumer final : public RdKafka::EventCb {
/// Returns true if the consumer is actively consuming messages.
bool IsRunning() const;
/// Sets the consumer's offset.
///
/// This function returns the empty string on success or an error message otherwise.
///
/// @param offset: the offset to set.
[[nodiscard]] utils::BasicResult<std::string> SetConsumerOffsets(int64_t offset);
const ConsumerInfo &Info() const;
private:
@ -146,6 +157,20 @@ class Consumer final : public RdKafka::EventCb {
void StopConsuming();
class ConsumerRebalanceCb : public RdKafka::RebalanceCb {
public:
ConsumerRebalanceCb(std::string consumer_name);
void rebalance_cb(RdKafka::KafkaConsumer *consumer, RdKafka::ErrorCode err,
std::vector<RdKafka::TopicPartition *> &partitions) override final;
void set_offset(int64_t offset);
private:
std::optional<int64_t> offset_;
std::string consumer_name_;
};
ConsumerInfo info_;
ConsumerFunction consumer_function_;
mutable std::atomic<bool> is_running_{false};
@ -153,5 +178,6 @@ class Consumer final : public RdKafka::EventCb {
std::optional<int64_t> limit_batches_{std::nullopt};
std::unique_ptr<RdKafka::KafkaConsumer, std::function<void(RdKafka::KafkaConsumer *)>> consumer_;
std::thread thread_;
ConsumerRebalanceCb cb_;
};
} // namespace integrations::kafka

View File

@ -2516,6 +2516,10 @@ mgp_error mgp_message_timestamp(mgp_message *message, int64_t *result) {
return WrapExceptions([message] { return message->msg->Timestamp(); }, result);
}
mgp_error mgp_message_offset(struct mgp_message *message, int64_t *result) {
return WrapExceptions([message] { return message->msg->Offset(); }, result);
}
mgp_error mgp_messages_size(mgp_messages *messages, size_t *result) {
static_assert(noexcept(messages->messages.size()));
*result = messages->messages.size();

View File

@ -676,6 +676,17 @@ struct mgp_proc {
results(memory),
is_write_procedure(is_write_procedure) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const std::string_view name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
utils::MemoryResource *memory, bool is_write_procedure)
: name(name, memory),
cb(cb),
args(memory),
opt_args(memory),
results(memory),
is_write_procedure(is_write_procedure) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const mgp_proc &other, utils::MemoryResource *memory)

View File

@ -644,6 +644,16 @@ void ModuleRegistry::UnloadAllModules() {
utils::MemoryResource &ModuleRegistry::GetSharedMemoryResource() noexcept { return *shared_; }
bool ModuleRegistry::RegisterMgProcedure(const std::string_view name, mgp_proc proc) {
std::unique_lock<utils::RWLock> guard(lock_);
if (auto module = modules_.find("mg"); module != modules_.end()) {
auto *builtin_module = dynamic_cast<BuiltinModule *>(module->second.get());
builtin_module->AddProcedure(name, std::move(proc));
return true;
}
return false;
}
namespace {
/// This function returns a pair of either

View File

@ -117,6 +117,8 @@ class ModuleRegistry final {
/// Returns the shared memory allocator used by modules
utils::MemoryResource &GetSharedMemoryResource() noexcept;
bool RegisterMgProcedure(std::string_view name, mgp_proc proc);
private:
std::vector<std::filesystem::path> modules_dirs_;
};

View File

@ -622,6 +622,21 @@ PyObject *PyMessageGetTimestamp(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
return py_int;
}
PyObject *PyMessageGetOffset(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->message);
MG_ASSERT(self->memory);
int64_t offset{0};
if (RaiseExceptionFromErrorCode(mgp_message_offset(self->message, &offset))) {
return nullptr;
}
auto *py_int = PyLong_FromLongLong(offset);
if (!py_int) {
PyErr_SetString(PyExc_IndexError, "Unable to get offset");
return nullptr;
}
return py_int;
}
// NOLINTNEXTLINE
static PyMethodDef PyMessageMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
@ -631,6 +646,7 @@ static PyMethodDef PyMessageMethods[] = {
{"topic_name", reinterpret_cast<PyCFunction>(PyMessageGetTopicName), METH_NOARGS, "Get topic name."},
{"key", reinterpret_cast<PyCFunction>(PyMessageGetKey), METH_NOARGS, "Get message key."},
{"timestamp", reinterpret_cast<PyCFunction>(PyMessageGetTimestamp), METH_NOARGS, "Get message timestamp."},
{"offset", reinterpret_cast<PyCFunction>(PyMessageGetOffset), METH_NOARGS, "Get message offset."},
{nullptr},
};

View File

@ -20,6 +20,7 @@
#include "query/db_accessor.hpp"
#include "query/discard_value_stream.hpp"
#include "query/interpreter.hpp"
#include "query/procedure//mg_procedure_helpers.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/procedure/module.hpp"
#include "query/typed_value.hpp"
@ -185,7 +186,27 @@ Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_
std::filesystem::path directory)
: interpreter_context_(interpreter_context),
bootstrap_servers_(std::move(bootstrap_servers)),
storage_(std::move(directory)) {}
storage_(std::move(directory)) {
constexpr std::string_view proc_name = "kafka_set_stream_offset";
auto set_stream_offset = [ictx = interpreter_context, proc_name](mgp_list *args, mgp_graph * /*graph*/,
mgp_result *result, mgp_memory * /*memory*/) {
auto *arg_stream_name = procedure::Call<mgp_value *>(mgp_list_at, args, 0);
const auto *stream_name = procedure::Call<const char *>(mgp_value_get_string, arg_stream_name);
auto *arg_offset = procedure::Call<mgp_value *>(mgp_list_at, args, 1);
const auto offset = procedure::Call<int64_t>(mgp_value_get_int, arg_offset);
const auto error = ictx->streams.SetStreamOffset(stream_name, offset);
if (error.HasError()) {
MG_ASSERT(mgp_result_set_error_msg(result, error.GetError().c_str()) == MGP_ERROR_NO_ERROR,
"Unable to set procedure error message of procedure: {}", proc_name);
}
};
mgp_proc proc(proc_name, set_stream_offset, utils::NewDeleteResource(), false);
MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call<mgp_type *>(mgp_type_string)) == MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_arg(&proc, "offset", procedure::Call<mgp_type *>(mgp_type_int)) == MGP_ERROR_NO_ERROR);
procedure::gModuleRegistry.RegisterMgProcedure(proc_name, std::move(proc));
}
void Streams::RestoreStreams() {
spdlog::info("Loading streams...");
@ -434,4 +455,12 @@ void Streams::Persist(StreamStatus &&status) {
}
std::string_view Streams::BootstrapServers() const { return bootstrap_servers_; }
utils::BasicResult<std::string> Streams::SetStreamOffset(const std::string_view stream_name, int64_t offset) {
auto lock_ptr = streams_.Lock();
auto it = GetStream(*lock_ptr, std::string(stream_name));
auto consumer_lock_ptr = it->second.consumer->Lock();
return consumer_lock_ptr->SetConsumerOffsets(offset);
}
} // namespace query

View File

@ -143,6 +143,12 @@ class Streams final {
/// Return the configuration value passed to memgraph.
std::string_view BootstrapServers() const;
/// Sets the stream's consumer offset.
///
/// @param stream_name we want to set the offset.
/// @param offset to set.
[[nodiscard]] utils::BasicResult<std::string> SetStreamOffset(std::string_view stream_name, int64_t offset);
private:
using StreamsMap = std::unordered_map<std::string, StreamData>;
using SynchronizedStreamsMap = utils::Synchronized<StreamsMap, utils::WritePrioritizedRWLock>;

View File

@ -23,20 +23,21 @@ import common
QUERY = 0
PARAMS = 1
TRANSFORMATIONS_TO_CHECK = [
"transform.simple", "transform.with_parameters"]
TRANSFORMATIONS_TO_CHECK = ["transform.simple", "transform.with_parameters"]
SIMPLE_MSG = b'message'
SIMPLE_MSG = b"message"
@pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK)
def test_simple(producer, topics, connection, transformation):
assert len(topics) > 0
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {','.join(topics)} "
f"TRANSFORM {transformation}")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test "
f"TOPICS {','.join(topics)} "
f"TRANSFORM {transformation}",
)
common.start_stream(cursor, "test")
time.sleep(5)
@ -45,7 +46,8 @@ def test_simple(producer, topics, connection, transformation):
for topic in topics:
common.check_vertex_exists_with_topic_and_payload(
cursor, topic, SIMPLE_MSG)
cursor, topic, SIMPLE_MSG
)
@pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK)
@ -57,10 +59,12 @@ def test_separate_consumers(producer, topics, connection, transformation):
for topic in topics:
stream_name = "stream_" + topic
stream_names.append(stream_name)
common.execute_and_fetch_all(cursor,
f"CREATE STREAM {stream_name} "
f"TOPICS {topic} "
f"TRANSFORM {transformation}")
common.execute_and_fetch_all(
cursor,
f"CREATE STREAM {stream_name} "
f"TOPICS {topic} "
f"TRANSFORM {transformation}",
)
for stream_name in stream_names:
common.start_stream(cursor, stream_name)
@ -72,7 +76,8 @@ def test_separate_consumers(producer, topics, connection, transformation):
for topic in topics:
common.check_vertex_exists_with_topic_and_payload(
cursor, topic, SIMPLE_MSG)
cursor, topic, SIMPLE_MSG
)
def test_start_from_last_committed_offset(producer, topics, connection):
@ -84,17 +89,20 @@ def test_start_from_last_committed_offset(producer, topics, connection):
# restarting Memgraph during a single workload cannot be done currently.
assert len(topics) > 0
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
"TRANSFORM transform.simple")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
"TRANSFORM transform.simple",
)
common.start_stream(cursor, "test")
time.sleep(1)
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
common.check_vertex_exists_with_topic_and_payload(
cursor, topics[0], SIMPLE_MSG)
cursor, topics[0], SIMPLE_MSG
)
common.stop_stream(cursor, "test")
common.drop_stream(cursor, "test")
@ -108,30 +116,36 @@ def test_start_from_last_committed_offset(producer, topics, connection):
cursor,
"MATCH (n: MESSAGE {"
f"payload: '{message.decode('utf-8')}'"
"}) RETURN n")
"}) RETURN n",
)
assert len(vertices_with_msg) == 0
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
"TRANSFORM transform.simple")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
"TRANSFORM transform.simple",
)
common.start_stream(cursor, "test")
for message in messages:
common.check_vertex_exists_with_topic_and_payload(
cursor, topics[0], message)
cursor, topics[0], message
)
@pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK)
def test_check_stream(producer, topics, connection, transformation):
assert len(topics) > 0
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
f"TRANSFORM {transformation} "
"BATCH_SIZE 1")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
f"TRANSFORM {transformation} "
"BATCH_SIZE 1",
)
common.start_stream(cursor, "test")
time.sleep(1)
@ -143,23 +157,28 @@ def test_check_stream(producer, topics, connection, transformation):
producer.send(topics[0], message).get(timeout=60)
def check_check_stream(batch_limit):
assert transformation == "transform.simple" \
assert (
transformation == "transform.simple"
or transformation == "transform.with_parameters"
)
test_results = common.execute_and_fetch_all(
cursor, f"CHECK STREAM test BATCH_LIMIT {batch_limit}")
cursor, f"CHECK STREAM test BATCH_LIMIT {batch_limit}"
)
assert len(test_results) == batch_limit
for i in range(batch_limit):
message_as_str = messages[i].decode('utf-8')
message_as_str = messages[i].decode("utf-8")
if transformation == "transform.simple":
assert f"payload: '{message_as_str}'" in \
test_results[i][QUERY]
assert f"payload: '{message_as_str}'" in test_results[i][QUERY]
assert test_results[i][PARAMS] is None
else:
assert test_results[i][QUERY] == ("CREATE (n:MESSAGE "
"{timestamp: $timestamp, "
"payload: $payload, "
"topic: $topic})")
assert test_results[i][QUERY] == (
"CREATE (n:MESSAGE "
"{timestamp: $timestamp, "
"payload: $payload, "
"topic: $topic, "
"offset: $offset})"
)
parameters = test_results[i][PARAMS]
# this is not a very sofisticated test, but checks if
# timestamp has some kind of value
@ -174,45 +193,67 @@ def test_check_stream(producer, topics, connection, transformation):
for message in messages:
common.check_vertex_exists_with_topic_and_payload(
cursor, topics[0], message)
cursor, topics[0], message
)
def test_show_streams(producer, topics, connection):
assert len(topics) > 1
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM default_values "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple "
f"BOOTSTRAP_SERVERS \'localhost:9092\'")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM default_values "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple "
f"BOOTSTRAP_SERVERS 'localhost:9092'",
)
consumer_group = "my_special_consumer_group"
batch_interval = 42
batch_size = 3
common.execute_and_fetch_all(cursor,
"CREATE STREAM complex_values "
f"TOPICS {','.join(topics)} "
f"TRANSFORM transform.with_parameters "
f"CONSUMER_GROUP {consumer_group} "
f"BATCH_INTERVAL {batch_interval} "
f"BATCH_SIZE {batch_size} ")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM complex_values "
f"TOPICS {','.join(topics)} "
f"TRANSFORM transform.with_parameters "
f"CONSUMER_GROUP {consumer_group} "
f"BATCH_INTERVAL {batch_interval} "
f"BATCH_SIZE {batch_size} ",
)
assert len(common.execute_and_fetch_all(cursor, "SHOW STREAMS")) == 2
common.check_stream_info(cursor, "default_values", ("default_values", [
topics[0]], "mg_consumer", None, None,
"transform.simple", None, "localhost:9092", False))
common.check_stream_info(
cursor,
"default_values",
(
"default_values",
[topics[0]],
"mg_consumer",
None,
None,
"transform.simple",
None,
"localhost:9092",
False,
),
)
common.check_stream_info(cursor, "complex_values", (
common.check_stream_info(
cursor,
"complex_values",
topics,
consumer_group,
batch_interval,
batch_size,
"transform.with_parameters",
None,
"localhost:9092",
False))
(
"complex_values",
topics,
consumer_group,
batch_interval,
batch_size,
"transform.with_parameters",
None,
"localhost:9092",
False,
),
)
@pytest.mark.parametrize("operation", ["START", "STOP"])
@ -229,14 +270,16 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
assert len(topics) > 1
assert operation == "START" or operation == "STOP"
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple",
)
check_counter = Value('i', 0)
check_result_len = Value('i', 0)
operation_counter = Value('i', 0)
check_counter = Value("i", 0)
check_result_len = Value("i", 0)
operation_counter = Value("i", 0)
CHECK_BEFORE_EXECUTE = 1
CHECK_AFTER_FETCHALL = 2
@ -250,7 +293,8 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
cursor = connection.cursor()
counter.value = CHECK_BEFORE_EXECUTE
result = common.execute_and_fetch_all(
cursor, "CHECK STREAM test_stream")
cursor, "CHECK STREAM test_stream"
)
result_len.value = len(result)
counter.value = CHECK_AFTER_FETCHALL
if len(result) > 0 and "payload: 'message'" in result[0][QUERY]:
@ -272,7 +316,8 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
counter.value = OP_BEFORE_EXECUTE
try:
common.execute_and_fetch_all(
cursor, f"{operation} STREAM test_stream")
cursor, f"{operation} STREAM test_stream"
)
counter.value = OP_AFTER_FETCHALL
except mgclient.DatabaseError as e:
if "Kafka consumer test_stream is already stopped" in str(e):
@ -283,9 +328,11 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
counter.value = OP_UNEXPECTED_EXCEPTION
check_stream_proc = Process(
target=call_check, daemon=True, args=(check_counter, check_result_len))
operation_proc = Process(target=call_operation,
daemon=True, args=(operation_counter,))
target=call_check, daemon=True, args=(check_counter, check_result_len)
)
operation_proc = Process(
target=call_operation, daemon=True, args=(operation_counter,)
)
try:
check_stream_proc.start()
@ -293,18 +340,23 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
time.sleep(0.5)
assert common.timed_wait(
lambda: check_counter.value == CHECK_BEFORE_EXECUTE)
lambda: check_counter.value == CHECK_BEFORE_EXECUTE
)
assert common.timed_wait(
lambda: common.get_is_running(cursor, "test_stream"))
assert check_counter.value == CHECK_BEFORE_EXECUTE, "SHOW STREAMS " \
"was blocked until the end of CHECK STREAM"
lambda: common.get_is_running(cursor, "test_stream")
)
assert check_counter.value == CHECK_BEFORE_EXECUTE, (
"SHOW STREAMS " "was blocked until the end of CHECK STREAM"
)
operation_proc.start()
assert common.timed_wait(
lambda: operation_counter.value == OP_BEFORE_EXECUTE)
lambda: operation_counter.value == OP_BEFORE_EXECUTE
)
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
assert common.timed_wait(
lambda: check_counter.value > CHECK_AFTER_FETCHALL)
lambda: check_counter.value > CHECK_AFTER_FETCHALL
)
assert check_counter.value == CHECK_CORRECT_RESULT
assert check_result_len.value == 1
check_stream_proc.join()
@ -330,10 +382,12 @@ def test_check_already_started_stream(topics, connection):
assert len(topics) > 0
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM started_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM started_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple",
)
common.start_stream(cursor, "started_stream")
with pytest.raises(mgclient.DatabaseError):
@ -342,52 +396,61 @@ def test_check_already_started_stream(topics, connection):
def test_start_checked_stream_after_timeout(topics, connection):
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple",
)
timeout_ms = 2000
def call_check():
common.execute_and_fetch_all(
common.connect().cursor(),
f"CHECK STREAM test_stream TIMEOUT {timeout_ms}")
f"CHECK STREAM test_stream TIMEOUT {timeout_ms}",
)
check_stream_proc = Process(target=call_check, daemon=True)
start = time.time()
check_stream_proc.start()
assert common.timed_wait(
lambda: common.get_is_running(cursor, "test_stream"))
lambda: common.get_is_running(cursor, "test_stream")
)
common.start_stream(cursor, "test_stream")
end = time.time()
assert (end - start) < 1.3 * \
timeout_ms, "The START STREAM was blocked too long"
assert (
end - start
) < 1.3 * timeout_ms, "The START STREAM was blocked too long"
assert common.get_is_running(cursor, "test_stream")
common.stop_stream(cursor, "test_stream")
def test_restart_after_error(producer, topics, connection):
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.query")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.query",
)
common.start_stream(cursor, "test_stream")
time.sleep(1)
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
assert common.timed_wait(
lambda: not common.get_is_running(cursor, "test_stream"))
lambda: not common.get_is_running(cursor, "test_stream")
)
common.start_stream(cursor, "test_stream")
time.sleep(1)
producer.send(topics[0], b'CREATE (n:VERTEX { id : 42 })')
producer.send(topics[0], b"CREATE (n:VERTEX { id : 42 })")
assert common.check_one_result_row(
cursor, "MATCH (n:VERTEX { id : 42 }) RETURN n")
cursor, "MATCH (n:VERTEX { id : 42 }) RETURN n"
)
@pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK)
@ -395,11 +458,13 @@ def test_bootstrap_server(producer, topics, connection, transformation):
assert len(topics) > 0
cursor = connection.cursor()
local = "localhost:9092"
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {','.join(topics)} "
f"TRANSFORM {transformation} "
f"BOOTSTRAP_SERVERS \'{local}\'")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test "
f"TOPICS {','.join(topics)} "
f"TRANSFORM {transformation} "
f"BOOTSTRAP_SERVERS '{local}'",
)
common.start_stream(cursor, "test")
time.sleep(5)
@ -408,7 +473,8 @@ def test_bootstrap_server(producer, topics, connection, transformation):
for topic in topics:
common.check_vertex_exists_with_topic_and_payload(
cursor, topic, SIMPLE_MSG)
cursor, topic, SIMPLE_MSG
)
@pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK)
@ -416,11 +482,83 @@ def test_bootstrap_server_empty(producer, topics, connection, transformation):
assert len(topics) > 0
cursor = connection.cursor()
with pytest.raises(mgclient.DatabaseError):
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {','.join(topics)} "
f"TRANSFORM {transformation} "
"BOOTSTRAP_SERVERS ''")
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test "
f"TOPICS {','.join(topics)} "
f"TRANSFORM {transformation} "
"BOOTSTRAP_SERVERS ''",
)
@pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK)
def test_set_offset(producer, topics, connection, transformation):
assert len(topics) > 0
cursor = connection.cursor()
common.execute_and_fetch_all(
cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
f"TRANSFORM {transformation} "
"BATCH_SIZE 1",
)
messages = [f"{i} message" for i in range(1, 21)]
for message in messages:
producer.send(topics[0], message.encode()).get(timeout=60)
def consume(expected_msgs):
common.start_stream(cursor, "test")
if len(expected_msgs) == 0:
time.sleep(2)
else:
assert common.check_one_result_row(
cursor,
(
f"MATCH (n: MESSAGE {{payload: '{expected_msgs[-1]}'}})"
"RETURN n"
),
)
common.stop_stream(cursor, "test")
res = common.execute_and_fetch_all(
cursor, "MATCH (n) RETURN n.payload"
)
return res
def execute_set_offset_and_consume(id, expected_msgs):
common.execute_and_fetch_all(
cursor, f"CALL mg.kafka_set_stream_offset('test', {id})"
)
return consume(expected_msgs)
with pytest.raises(mgclient.DatabaseError):
res = common.execute_and_fetch_all(
cursor, "CALL mg.kafka_set_stream_offset('foo', 10)"
)
def comparison_check(a, b):
return a == str(b).strip("'(,)")
res = execute_set_offset_and_consume(10, messages[10:])
assert len(res) == 10
assert all([comparison_check(a, b) for a, b in zip(messages[10:], res)])
common.execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
res = execute_set_offset_and_consume(-1, messages)
assert len(res) == len(messages)
assert all([comparison_check(a, b) for a, b in zip(messages, res)])
res = common.execute_and_fetch_all(cursor, "MATCH (n) return n.offset")
assert all([comparison_check(str(i), res[i]) for i in range(1, 20)])
res = common.execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
res = execute_set_offset_and_consume(-2, [])
assert len(res) == 0
last_msg = "Final Message"
producer.send(topics[0], last_msg.encode()).get(timeout=60)
res = consume([last_msg])
assert len(res) == 1
assert comparison_check("Final Message", res[0])
common.execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
if __name__ == "__main__":

View File

@ -13,50 +13,74 @@ import mgp
@mgp.transformation
def simple(context: mgp.TransCtx,
messages: mgp.Messages
) -> mgp.Record(query=str, parameters=mgp.Map):
def simple(
context: mgp.TransCtx, messages: mgp.Messages
) -> mgp.Record(query=str, parameters=mgp.Map):
result_queries = []
for i in range(0, messages.total_messages()):
message = messages.message_at(i)
payload_as_str = message.payload().decode("utf-8")
result_queries.append(mgp.Record(
query=f"CREATE (n:MESSAGE {{timestamp: '{message.timestamp()}', payload: '{payload_as_str}', topic: '{message.topic_name()}'}})",
parameters=None))
offset = message.offset()
result_queries.append(
mgp.Record(
query=(
f"CREATE (n:MESSAGE {{timestamp: '{message.timestamp()}', "
f"payload: '{payload_as_str}', "
f"topic: '{message.topic_name()}', "
f"offset: {offset}}})"
),
parameters=None,
)
)
return result_queries
@mgp.transformation
def with_parameters(context: mgp.TransCtx,
messages: mgp.Messages
) -> mgp.Record(query=str, parameters=mgp.Map):
def with_parameters(
context: mgp.TransCtx, messages: mgp.Messages
) -> mgp.Record(query=str, parameters=mgp.Map):
result_queries = []
for i in range(0, messages.total_messages()):
message = messages.message_at(i)
payload_as_str = message.payload().decode("utf-8")
result_queries.append(mgp.Record(
query="CREATE (n:MESSAGE {timestamp: $timestamp, payload: $payload, topic: $topic})",
parameters={"timestamp": message.timestamp(),
"payload": payload_as_str,
"topic": message.topic_name()}))
offset = message.offset()
result_queries.append(
mgp.Record(
query=(
"CREATE (n:MESSAGE "
"{timestamp: $timestamp, "
"payload: $payload, "
"topic: $topic, "
"offset: $offset})"
),
parameters={
"timestamp": message.timestamp(),
"payload": payload_as_str,
"topic": message.topic_name(),
"offset": offset,
},
)
)
return result_queries
@mgp.transformation
def query(messages: mgp.Messages
) -> mgp.Record(query=str, parameters=mgp.Nullable[mgp.Map]):
def query(
messages: mgp.Messages,
) -> mgp.Record(query=str, parameters=mgp.Nullable[mgp.Map]):
result_queries = []
for i in range(0, messages.total_messages()):
message = messages.message_at(i)
payload_as_str = message.payload().decode("utf-8")
result_queries.append(mgp.Record(
query=payload_as_str, parameters=None))
result_queries.append(
mgp.Record(query=payload_as_str, parameters=None)
)
return result_queries

View File

@ -31,12 +31,12 @@
/// [[noreturn]] and throw an std::logic_error exception.
class MockedRdKafkaMessage : public RdKafka::Message {
public:
explicit MockedRdKafkaMessage(std::string key, std::string payload)
explicit MockedRdKafkaMessage(std::string key, std::string payload, int64_t offset)
: key_(std::move(key)), payload_(std::move(payload)) {
message_.err = rd_kafka_resp_err_t::RD_KAFKA_RESP_ERR__BEGIN;
message_.key = static_cast<void *>(key_.data());
message_.key_len = key_.size();
message_.offset = 0;
message_.offset = offset;
message_.payload = static_cast<void *>(payload_.data());
message_.len = payload_.size();
rd_kafka_ = rd_kafka_new(rd_kafka_type_t::RD_KAFKA_CONSUMER, nullptr, nullptr, 0);
@ -122,16 +122,19 @@ class MgpApiTest : public ::testing::Test {
const char key;
const char *topic_name;
const size_t payload_size;
const int64_t offset;
};
static constexpr std::array<ExpectedResult, 2> expected = {ExpectedResult{"payload1", '1', "Topic1", 8},
ExpectedResult{"payload2", '2', "Topic1", 8}};
static constexpr std::array<ExpectedResult, 2> expected = {ExpectedResult{"payload1", '1', "Topic1", 8, 0},
ExpectedResult{"payload2", '2', "Topic1", 8, 1}};
private:
utils::pmr::vector<mgp_message> CreateMockedBatch() {
std::transform(expected.begin(), expected.end(), std::back_inserter(msgs_storage_), [](const auto expected) {
return Message(std::make_unique<KafkaMessage>(std::string(1, expected.key), expected.payload));
});
std::transform(
expected.begin(), expected.end(), std::back_inserter(msgs_storage_),
[i = int64_t(0)](const auto expected) mutable {
return Message(std::make_unique<KafkaMessage>(std::string(1, expected.key), expected.payload, i++));
});
auto v = utils::pmr::vector<mgp_message>(utils::NewDeleteResource());
v.reserve(expected.size());
std::transform(msgs_storage_.begin(), msgs_storage_.end(), std::back_inserter(v),
@ -160,6 +163,8 @@ TEST_F(MgpApiTest, TestAllMgpKafkaCApi) {
// Test for topic name
EXPECT_FALSE(
std::strcmp(EXPECT_MGP_NO_ERROR(const char *, mgp_message_topic_name, message), expected[i].topic_name));
// Test for offset
EXPECT_EQ(EXPECT_MGP_NO_ERROR(int64_t, mgp_message_offset, message), expected[i].offset);
}
// Unfortunately, we can't test timestamp here because we can't mock (as explained above)